%% Example code for EM to fit the prior probabilities and means of two % univariate Gaussians. % % This script will generate some sample data and then run the EM % algorithm for a mixture of Gaussians (with known variance) to estimate % the class centers. After each iteration, t, of the algorithm, it will plot % % 1) P(Z_i = 0|X_i, \theta_{t}) as red bars above each point X_i, with % length equal to P(Z_i = 0|X_i, \theta_{t}) % % 2) P(Z_i = 1|X_i, \theta_{t}) as blue bars above each point X_i, with % length equal to P(Z_i = 1|X_i, \theta_{t}) % % 3) The estimate for the mean for class 0 after the M step as a large red dot % % 4) The estimate for the mean for class 1 after the M step as a large blue dot % % It will also give the esimates for class priors in the title after each % iteration. % % Author: wbishop AT cs.cmu.edu % %% ======================================================================== % PARAMETERS GO HERE % ======================================================================== N = 100; % Number of samples P_0 = .7; % True probability of class 0 P_1 = 1 - P_0; % True probability of class 1 MU_0 = 5; % True mean of class 0 MU_1 = 10; % True mean of class 1 VAR = 3; % Variance (which is known) and common to class 1 and 2 %% ======================================================================== % GENERATE SIMULATED DATA % ======================================================================== % First, we draw class labels - remember, we don't get to see these in % practice Z = (rand(1,N) > P_0); % Now we draw X values; we do get to see these X = nan(size(Z)); for i = 1:length(Z) if Z(i) == 0 X(i) = randn*sqrt(VAR) + MU_0; else X(i) = randn*sqrt(VAR) + MU_1; end end %% ======================================================================== % Make initial plot % ======================================================================== % Make initial plot figure(); markerHdls = nan(1,N); probZEquals0LineHdls = nan(1,N); probZEquals1LineHdls = nan(1,N); for i = 1:N markerHdls(i) = plot(X(i), 0, 'ko', 'MarkerSize', 15, 'MarkerEdgeColor', [0 0 0], 'MarkerFaceColor', [1 1 1]); hold on; probZEquals0LineHdls(i) = plot([X(i), X(i)], [0,0], 'r-', 'lineWidth', 2); probZEquals1LineHdls(i) = plot([X(i), X(i)], [0,0], 'b-', 'lineWidth', 2); mu0MarkerHdl = plot(0, -1, 'o', 'MarkerSize', 20, 'MarkerEdgeColor', [1 0 0], 'MarkerFaceColor', [1 0 0]); mu1MarkerHdl = plot(0, -1, 'o', 'MarkerSize', 20, 'MarkerEdgeColor', [0 0 1], 'MarkerFaceColor', [0 0 1]); end hold off; xlabel('X', 'FontSize', 30); set(gca, 'FontSize', 30, 'YLim', [0, 2]); t = title('EM Demo', 'FontSize', 30); %% ======================================================================== % RUN THE EM ALGORITHM HERE % ======================================================================== % Make initial guesses for parameters MU_0t = 1; MU_1t = 7; P_0t = .5; P_1t = .5; for it = 1:1000 % Normally, you would apply a stopping criteria based on convergence of parameters % Perform the E Step likelihoodOfXGivenZEquals0 = (1/sqrt(2*pi*VAR))*exp(-(X - MU_0t).^2/(2*VAR)); likelihoodOfXGivenZEquals1 = (1/sqrt(2*pi*VAR))*exp(-(X - MU_1t).^2/(2*VAR)); probZEquals1GivenX = likelihoodOfXGivenZEquals1*P_1t./(likelihoodOfXGivenZEquals0*P_0t + likelihoodOfXGivenZEquals1*P_1t); probZEquals0GivenX = 1 - probZEquals1GivenX; % Perform the M Step P_0tPlus1 = sum(probZEquals0GivenX)/N; P_1tPlus1 = sum(probZEquals1GivenX)/N; % Perfrom the E Step MU_0tPlus1 = (X*probZEquals0GivenX')/sum(probZEquals0GivenX); MU_1tPlus1 = (X*probZEquals1GivenX')/sum(probZEquals1GivenX); % Update our plots for i = 1:N set(probZEquals0LineHdls(i),'YData', [0, probZEquals0GivenX(i)]); set(probZEquals1LineHdls(i),'YData', [probZEquals0GivenX(i), probZEquals0GivenX(i)+probZEquals1GivenX(i)]); end set(mu0MarkerHdl, 'XData', MU_0tPlus1, 'YData', 0); set(mu1MarkerHdl, 'XData', MU_1tPlus1, 'YData', 0); set(t, 'String', ['Em Demo after iteration ', num2str(it) '. \pi_{0, t+1} = ', num2str(P_0tPlus1), ' \pi_{1, t+1} = ', num2str(P_1tPlus1)]); % Update parameters for next E-Step MU_0t = MU_0tPlus1; MU_1t = MU_1tPlus1; P_0t = P_0tPlus1; P_1t = P_1tPlus1; input('Press enter to advance. Ctr C to stop.') end