% EMDEMO -- demonstration of the EM (expectation Maximization) algorithm.
%
% Set the variable EmDataSet to 1, 2, 3, or 4.

% Copyright (c) 1997 David S. Touretzky

if ~exist('EmDataSet'), EmDataSet=1, end
if ~exist('EmHeuristic'), EmHeuristic=0, end
[RecBumps,Patterns_nd] = emgenerate(EmDataSet);
[Npats,Nelts] = size(Patterns_nd);

if exist('HardBumps')
    HardBumps
    Nbumps = HardBumps;
  else
    Nbumps = RecBumps;
end

OnePats_n1 = ones(Npats,1);
OneElts_1d = ones(1,Nelts);
OneBumps_j1 = ones(Nbumps,1);

Mu_jd = Patterns_nd(1+floor(rand(Nbumps,1)*Npats),:);
clear SigmaSq_j1
for j=1:Nbumps, SigmaSq_j1(j,1) = emsigmasq(j,Mu_jd); end
minSigmaSq = 1e-5;
P_j1 = OneBumps_j1/Nbumps;
P_jn = P_j1 * OnePats_n1';

emshow(Mu_jd,SigmaSq_j1,Patterns_nd)

clear DistSq_jn new_Mu_jd new_SigmaSq_j1
for j = 1:Nbumps
  DistSq_jn(j,:) = sum(((OnePats_n1*Mu_jd(j,:)-Patterns_nd).^2)');
end

for epoch = 1:100
  fprintf('  %d',epoch);

    % Expectation step:
  P_nj = OnePats_n1 * (1 ./ (2*pi*SigmaSq_j1').^(Nelts/2)) .* ...
              exp(- DistSq_jn' ./ (OnePats_n1*2*SigmaSq_j1'));
  Pnj_Pj = P_nj .* (OnePats_n1*P_j1');
  P_n1 = sum(Pnj_Pj,2);
  P_jn = Pnj_Pj' ./ max(1e-10,OneBumps_j1*P_n1');
  sum_Pj1 = sum(P_jn,2);

    % Maximization step:  eqns (2.96) to (2.98) from Bishop p.67:
  for j = 1:Nbumps
    new_Mu_jd(j,:) = sum((P_jn(j,:)'*OneElts_1d).*Patterns_nd) / sum_Pj1(j);
    DistSq_jn(j,:) = sum(((OnePats_n1*new_Mu_jd(j,:)-Patterns_nd).^2)');
    new_SigmaSq_j1(j) = 1/Nelts * sum(P_jn(j,:).*DistSq_jn(j,:)) / sum_Pj1(j);
  end
  new_Pj1 = (1/Npats) * sum_Pj1;

  diff_P = sum((P_j1 - new_Pj1).^2);

  Mu_jd = new_Mu_jd;
  SigmaSq_j1 = max(minSigmaSq,new_SigmaSq_j1)';
  P_j1 = new_Pj1;

  emshow(Mu_jd,SigmaSq_j1,Patterns_nd)

  if (diff_P < 1e-6), break, end


  % Heuristic for restarting a bump at a new location if
  % it captures less than a "fair share" of the data.
  if EmHeuristic
    for j = 1:Nbumps
      if P_j1(j) < 1/(2*Nbumps)
	fprintf('r(%d)\n',j);
	fprintf('#%d was (%4.3f,%4.3f) by %4.3f\n',j,Mu_jd(j,:), ...
	    sqrt(SigmaSq_j1(j)));
	Mu_jd(j,:) = Patterns_nd(1+floor(rand(1)*Npats),:);
	SigmaSq_j1(j) = emsigmasq(j,Mu_jd);
	DistSq_jn(j,:) = sum(((OnePats_n1*Mu_jd(j,:)-Patterns_nd).^2)');
	fprintf('#%d now (%4.3f,%4.3f) by %4.3f\n',j,Mu_jd(j,:), ...
	    sqrt(SigmaSq_j1(j)));
	emshow(Mu_jd,SigmaSq_j1,Patterns_nd)
	P_j1(j) = 1/Nbumps;
      end
    end
  end

end

emprint
