function [theta, alpha] = trainPDA(X, y, theta0, lambda)
% X is an nxV matrix, y is an nx1 vector
% This function returns
% theta: a Kx1 vector indicating the class probabilities
% alpha: a matlab cell array with K elements. Each element is a Vx1 vector

  % Prelims
  K = numel(unique(y));
  V = size(X, 2);
  numData = size(X, 1);

  table = tabulate(y);
  adjustedFreqs = table(:,2) + theta0 - 1;
  theta = adjustedFreqs/sum(adjustedFreqs); % MAP for the class probs

  alpha = zeros(K, V);
  for k = 1:K
  % Iterate through each class and obtain the alpha_k's
    Xk = X( y==k, :);
    alpha(k, :) = newtonRaphsonPDA(Xk, lambda);
  end

end

% This function implements Newton's Method. At once it only takes in data
% belonging to one class.
function [alpha_k] = newtonRaphsonPDA(Xk, lambda)

  % Prelims
  numNRIters = 10; % Just use 5 iterations of NR
  numInits = 1;
  V = size(Xk, 2); % size of vocabulary
  nk = size(Xk, 1); % number of training data in this class
  m = sum(Xk, 2); % number of words in each documents

  % Set up initializations
  initPts = rand(numInits, V);
  initPts(1,:) = sum(Xk) + 0.1 * sum(sum(Xk)) * rand(1, V);
  initPts = bsxfun(@rdivide, initPts, sum(initPts, 2));
  bestVal = -inf; % Initialize the best value to infinity

  for initPtIter = 1:numInits

    currAlphak = initPts(initPtIter, :); % alphak in the current iteration
    for nrIter = 1:numNRIters
      % Compute the following
      Ak = sum(currAlphak);
      XplusAlpha = bsxfun(@plus, Xk, currAlphak);
      % The gradient
      g = nk * psi(Ak) - sum(psi(m + Ak)) + sum( psi(XplusAlpha) ) ...
          - nk * psi(currAlphak) - 2 * lambda * currAlphak;
      % The value z ( see solutions)
      z = nk * psi(1, Ak) - sum(psi(1, m + Ak));
      % The diagonal of the Hessian
      D = sum(psi(1, XplusAlpha)) - nk * psi(1, currAlphak) - 2*lambda;
      % Newton's step update
      Hinvg = g./D - (1./D) * sum(g./D) / (1/z + sum(1./D));
      currAlphak = currAlphak - 1*Hinvg;
    end

    currVal = classLogJointProb(Xk, currAlphak, lambda);
    if currVal > bestVal
      bestVal = currVal;
      alpha_k = currAlphak;
    end
  end
end

