% Matlab Script for HW2 q2
clear all;
close all;
randn('seed', 0);
rand('seed', 0);

data = 'data1.mat';
% data = 'data2.mat';

load(data);

% Set hyper params
K = 3; % for both models.
lambda = 0.1;
theta0 = 2 * ones(K, 1);

% First Initialize theta, alpha
[thetaInit, alphaInit] = emInitPMM(XTrain, K, theta0, lambda);

% Now Train the model
tic,
[theta, alpha] = trainPMM(XTrain, K, theta0, lambda, thetaInit, alphaInit);
toc,

% Obtain predictions for Test data
preds = predictPMM(XTest, theta, alpha);

if strcmp(data, 'data1.mat')
  theta, alpha, preds,
else
  theta, alpha(:, 1:5), preds(1:5),
end

% Report the accuracy
acc = accuracy_score(yTest, preds);
fprintf('Accuracy: %0.4f\n', acc);

