package edu.cmu.cs.ark.compuframes;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import org.apache.commons.io.FilenameUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.util.FastMath;

import edu.cmu.cs.ark.compuframes.types.EpochData;

/**
 * @author Yanchuan Sim
 * @version 0.1
 * @since 0.1
 */
public class GibbsSampler
{
  // data related stuff
  private final int             epoch_count, term_count, ideo_count;
  private final String[]        terms_array;
  private final String[]        ideologies_array;
  private final String[]        epoches_array;
  private final String[][]      speeches_array;

  // terms_index[0] is __START_OF_SPEECH__
  // terms_index[-1] is last word of speech
  // terms_lag[0] is 0
  // terms_lag[-1] is lag at the end of speech
  // terms_lag[i] is lag between word i-1 and i
  private final int[][][]       terms_index;
  private final int[][][]       terms_lag;

  // array for holding parameters
  private final double[][]      emission_hyper;
  private final boolean[][]     emission_sage;
  private final double[][]      stop_probs;
  private final double[][]      continue_probs;
  private double                alpha;
  private double                gamma;
  private double                emission_prior_ideo;
  private double                emission_prior_other;

  // array for holding current iteration samples
  private final int[][][]       sample_x;
  private final boolean[][][]   sample_r;

  // array for holding counts
  private final int[][]         n_ideowords;
  private final int[][]         n_endstate;
  private final int[][][]       n_transitions;

  // array for caching normalization stuff
  private final double[]        Z_emission_hyper;
  private final double[][]      Z_transition_hyper;

  // array for holding saved counts
  private final int[][]         save_ideowords;
  private final int[][]         save_endstate;
  private final int[][][]       save_transitions;

  // first and last vertices are start and end points anyway
  private final int[][][]       tree_paths;
  private final int[][]         tree_neighbors;
  private static final int      IDEO_ROOT         = 0;

  private final double[]        restart_table;
  private final double[]        norestart_table;

  private final RandomGenerator R                 = new Well19937c();

  private final File            samples_dir;
  private int                   save_sample_count = 1;

  public GibbsSampler(AppGlobal app_data)
  {
    epoch_count = app_data.EPOCHES.size();
    term_count = app_data.TERMS.size();

    samples_dir = app_data.SETTINGS.SAMPLE_DIR;

    alpha = app_data.SETTINGS.ALPHA;
    gamma = app_data.SETTINGS.GAMMA;
    emission_prior_ideo = app_data.SETTINGS.EMISSION_IDEO;
    emission_prior_other = app_data.SETTINGS.EMISSION_OTHER;

    epoches_array = new String[epoch_count];
    for (int e = 0; e < epoch_count; e++)
      epoches_array[e] = app_data.EPOCHES.get(e).name;

    speeches_array = new String[epoch_count][];
    for (int e = 0; e < epoch_count; e++)
      speeches_array[e] = app_data.EPOCHES.get(e).speech_titles.toArray(new String[1]);

    terms_array = new String[term_count];
    terms_index = new int[epoch_count][][];
    terms_lag = new int[epoch_count][][];
    setupTermsAndLags(app_data);

    ideo_count = app_data.IDEOLOGIES.size() + 1;
    ideologies_array = new String[ideo_count];
    emission_hyper = new double[ideo_count][term_count];
    emission_sage = new boolean[ideo_count][term_count];
    setupIdeologiesAndEmissionsHyper(app_data);

    stop_probs = new double[epoch_count][ideo_count];
    continue_probs = new double[epoch_count][ideo_count];
    for (int e = 0; e < epoch_count; e++)
    {
      stop_probs[e][0] = 0.0; // always continue at root
      continue_probs[e][0] = 1.0;
      for (int i = 1; i < ideo_count; i++)
      {
        stop_probs[e][i] = 0.3;
        continue_probs[e][i] = 0.7;
      }
    }

    tree_paths = new int[ideo_count][ideo_count][];
    tree_neighbors = new int[ideo_count][];
    setupTreePaths(app_data);
    // for (int u = 0; u < ideo_count; u++)
    // for (int v = 0; v < ideo_count; v++)
    // System.out.format("%d %d %s\n", u, v,
    // ArrayUtils.toString(tree_paths[u][v]));

    sample_x = new int[epoch_count][][];
    sample_r = new boolean[epoch_count][][];
    n_ideowords = new int[ideo_count][term_count];
    n_endstate = new int[epoch_count][ideo_count];
    n_transitions = new int[epoch_count][ideo_count][ideo_count];
    Z_emission_hyper = new double[ideo_count];
    Z_transition_hyper = new double[epoch_count][ideo_count];
    setupSamplesAndCounts();

    save_ideowords = new int[ideo_count][term_count];
    save_endstate = new int[epoch_count][ideo_count];
    save_transitions = new int[epoch_count][ideo_count][ideo_count];

    restart_table = new double[1024];
    norestart_table = new double[1024];
    for (int i = 0; i < restart_table.length; i++)
    {
      double r = FastMath.pow(1.0 - gamma, (i + 1) / 50.0);
      restart_table[i] = FastMath.log(1.0 - r);
      norestart_table[i] = FastMath.log(r);

      if (Double.isNaN(restart_table[i]) || Double.isInfinite(restart_table[i]))
        restart_table[i] = -1e5;
      if (Double.isNaN(norestart_table[i]) || Double.isInfinite(norestart_table[i]))
        norestart_table[i] = -1e5;
    }
  }

  private void setupSamplesAndCounts()
  {
    for (int u = 0; u < ideo_count; u++)
      for (int w = 0; w < term_count; w++)
        Z_emission_hyper[u] += emission_hyper[u][w];

    for (int e = 0; e < epoch_count; e++)
    {
      for (int u = 0; u < ideo_count; u++)
        Z_transition_hyper[e][u] = alpha * tree_neighbors[u].length;

      sample_x[e] = new int[terms_index[e].length][];
      sample_r[e] = new boolean[terms_index[e].length][];

      for (int d = 0; d < terms_index[e].length; d++)
      {
        sample_x[e][d] = new int[terms_index[e][d].length];
        sample_r[e][d] = new boolean[terms_index[e][d].length];

        sample_x[e][d][0] = IDEO_ROOT;
        sample_r[e][d][0] = true;

        for (int i = 1; i < sample_x[e][d].length; i++)
        {
          sample_x[e][d][i] = R.nextInt(ideo_count - 1) + 1;
          sample_r[e][d][i] = R.nextBoolean();

          n_ideowords[sample_x[e][d][i]][terms_index[e][d][i]]++;
          Z_emission_hyper[sample_x[e][d][i]]++;
          n_endstate[e][sample_x[e][d][i]]++;

          if (sample_r[e][d][i])
          {
            updateTransitionCounts(n_transitions[e], tree_paths[IDEO_ROOT][sample_x[e][d][i]], +1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[IDEO_ROOT][sample_x[e][d][i]], +1);
          }
          else
          {
            updateTransitionCounts(n_transitions[e], tree_paths[sample_x[e][d][i - 1]][sample_x[e][d][i]], +1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[sample_x[e][d][i - 1]][sample_x[e][d][i]], +1);
          }
        }
      }
    }
  }

  private void setupTreePaths(AppGlobal app_data)
  {
    boolean[][] adj_matrix = new boolean[ideo_count][ideo_count];

    for (String[] edge : app_data.IDEOLOGY_PAIRS)
    {
      int i = ArrayUtils.indexOf(ideologies_array, edge[0]);
      int j = ArrayUtils.indexOf(ideologies_array, edge[1]);
      adj_matrix[i][j] = true;
      adj_matrix[j][i] = true;
    }

    for (int u = 0; u < ideo_count; u++)
    {
      int k = 0;
      for (int v = 0; v < ideo_count; v++)
        if (adj_matrix[u][v])
          k++;

      tree_neighbors[u] = new int[k];

      k = 0;
      for (int v = 0; v < ideo_count; v++)
        if (adj_matrix[u][v])
          tree_neighbors[u][k++] = v;
    }

    // assumes 0 is v0!
    for (int u = 0; u < ideo_count; u++)
    {
      LinkedList<Integer> q = new LinkedList<>();
      LinkedList<Integer> q_from = new LinkedList<>();
      boolean[] visited = new boolean[ideo_count];
      q.add(u);
      q_from.add(-1);
      tree_paths[u][u] = new int[1];
      tree_paths[u][u][0] = u; // can visit self

      while (!q.isEmpty())
      {
        int v = q.pop();
        int from = q_from.pop();

        if (v != u && !visited[v] && from != -1)
        {
          tree_paths[u][v] = new int[tree_paths[u][from].length + 1];
          for (int i = 0; i < tree_paths[u][from].length; i++)
            tree_paths[u][v][i] = tree_paths[u][from][i];
          tree_paths[u][v][tree_paths[u][from].length] = v;
        }

        visited[v] = true;

        for (int w = 0; w < ideo_count; w++)
          if (adj_matrix[v][w] && !visited[w])
          {
            q.add(w);
            q_from.add(v);
          }
      }
    }
    // int i = 2;
    // int j = 6;
    // System.out.println(ideologies_array[i]);
    // System.out.println(ideologies_array[j]);
    // System.out.println(ArrayUtils.toString(tree_paths[i][j]));
  }

  private void setupIdeologiesAndEmissionsHyper(AppGlobal app_data)
  {
    String[] term_weights_order = new String[] { "center", "left", "left-farl", "left-mainl", "left-prgrsv", "left-religl", "right", "right-farr", "right-lbtn", "right-nonrad", "right-poplstr", "right-religr" };
    // left left-farl left-mainl left-prgrsv left-religl right right-farr right-lbtn right-nonrad right-poplstr right-religr
    ideologies_array[0] = "v0";
    for (int i = 0; i < app_data.IDEOLOGIES.size(); i++)
      ideologies_array[i + 1] = app_data.IDEOLOGIES.get(i);

    for (int u = 1; u < ideo_count; u++)
    {
      int ideo_index = -1;
      for (int i = 0; i < term_weights_order.length; i++)
        if (term_weights_order[i].equals(ideologies_array[u]))
          ideo_index = i;

      assert ideo_index >= 0;

      for (int w = 0; w < term_count; w++)
      {
        emission_sage[u][w] = (app_data.TERMS.get(w).weights[ideo_index] > 0);
        emission_hyper[u][w] = (emission_sage[u][w] ? emission_prior_ideo : emission_prior_other);
      }
    }
  }

  private void setupTermsAndLags(AppGlobal app_data)
  {
    Map<String, Integer> terms_indexes = new HashMap<String, Integer>();
    for (int w = 0; w < term_count; w++)
    {
      terms_array[w] = app_data.TERMS.get(w).term;
      terms_indexes.put(terms_array[w], w);
    }

    for (int e = 0; e < epoch_count; e++)
    {
      EpochData epoch_data = app_data.EPOCHES.get(e);
      terms_index[e] = new int[epoch_data.speech_titles.size()][];
      terms_lag[e] = new int[epoch_data.speech_titles.size()][];

      for (int d = 0; d < epoch_data.speeches.size(); d++)
      {
        terms_index[e][d] = new int[epoch_data.speeches.get(d).size() - 2];
        terms_lag[e][d] = new int[epoch_data.speeches.get(d).size() - 1];
        int i = 0;
        for (Pair<String, Integer> wt : epoch_data.speeches.get(d))
        {
          if (wt.getLeft().equals("__SPEECH_LENGTH__"))
            continue;
          else if (wt.getLeft().equals("__END_OF_SPEECH__"))
            terms_lag[e][d][i++] = wt.getRight();
          else if (wt.getLeft().equals("__START_OF_SPEECH__"))
          {
            terms_index[e][d][i] = -1;
            terms_lag[e][d][i++] = 0;
          }
          else
          {
            terms_index[e][d][i] = terms_indexes.get(wt.getLeft());
            terms_lag[e][d][i++] = wt.getRight();
          }
        }
      }
    }
  }

  public boolean loadState(File serialized_file) throws FileNotFoundException, IOException, ClassNotFoundException
  {
    ObjectInput bis = new ObjectInputStream(new GZIPInputStream(new FileInputStream(serialized_file)));

    alpha = bis.readDouble(); // alpha
    gamma = bis.readDouble(); // gamma
    emission_prior_ideo = bis.readDouble(); // emission_prior_ideo
    emission_prior_other = bis.readDouble(); // emission_prior_other

    // bis.readObject(); // epoches_array
    // bis.readObject(); // speeches_array
    // bis.readObject(); // terms_array
    // bis.readObject(); // ideologies_array
    // bis.readObject(); // tree_paths

    // bis.readObject(); // terms_index
    // bis.readObject(); // terms_lag

    double[][] darr2 = (double[][]) bis.readObject();
    bis.readObject(); // continue_probs
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
      {
        stop_probs[e][u] = darr2[e][u];
        continue_probs[e][u] = 1.0 - darr2[e][u];
      }

    boolean[][][] barr = (boolean[][][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int d = 0; d < terms_index[e].length; d++)
        for (int i = 0; i < terms_index[e][d].length; i++)
          sample_r[e][d][i] = barr[e][d][i];

    int[][][] iarr3 = (int[][][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int d = 0; d < terms_index[e].length; d++)
        for (int i = 0; i < terms_index[e][d].length; i++)
          sample_x[e][d][i] = iarr3[e][d][i];

    int[][] iarr2 = (int[][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
        n_endstate[e][u] = iarr2[e][u];

    iarr2 = (int[][]) bis.readObject();
    for (int u = 0; u < ideo_count; u++)
      for (int w = 0; w < term_count; w++)
        n_ideowords[u][w] = iarr2[u][w];

    iarr3 = (int[][][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
        for (int v = 0; v < ideo_count; v++)
          n_transitions[e][u][v] = iarr3[e][u][v];

    darr2 = (double[][]) bis.readObject();
    for (int u = 0; u < ideo_count; u++)
      for (int w = 0; w < term_count; w++)
        emission_hyper[u][w] = darr2[u][w];

    double[] darr1 = (double[]) bis.readObject();
    for (int u = 0; u < ideo_count; u++)
      Z_emission_hyper[u] = darr1[u];

    darr2 = (double[][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
        Z_transition_hyper[e][u] = darr2[e][u];

    iarr2 = (int[][]) bis.readObject();
    for (int u = 0; u < ideo_count; u++)
      for (int w = 0; w < term_count; w++)
        save_ideowords[u][w] = iarr2[u][w];

    iarr2 = (int[][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
        save_endstate[e][u] = iarr2[e][u];

    iarr3 = (int[][][]) bis.readObject();
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
        for (int v = 0; v < ideo_count; v++)
          save_transitions[e][u][v] = iarr3[e][u][v];

    bis.close();

    return true;
  }

  public boolean saveState(AppState app_state) throws FileNotFoundException, IOException
  {
    ObjectOutput bos = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(app_state.getCurStateFile("state.serialized.gz"))));

    bos.writeDouble(alpha);
    bos.writeDouble(gamma);
    bos.writeDouble(emission_prior_ideo);
    bos.writeDouble(emission_prior_other);

    // bos.writeObject(epoches_array);
    // bos.writeObject(speeches_array);
    // bos.writeObject(terms_array);
    // bos.writeObject(ideologies_array);
    // bos.writeObject(tree_paths);

    // bos.writeObject(terms_index);
    // bos.writeObject(terms_lag);

    bos.writeObject(stop_probs);
    bos.writeObject(continue_probs);

    bos.writeObject(sample_r);
    bos.writeObject(sample_x);

    bos.writeObject(n_endstate);
    bos.writeObject(n_ideowords);
    bos.writeObject(n_transitions);

    bos.writeObject(emission_hyper);
    bos.writeObject(Z_emission_hyper);
    bos.writeObject(Z_transition_hyper);

    bos.writeObject(save_ideowords);
    bos.writeObject(save_endstate);
    bos.writeObject(save_transitions);

    bos.close();

    for (int u = 1; u < ideo_count; u++)
    {
      List<Pair<String, Integer>> word_count_pairs = new ArrayList<>(term_count);
      for (int w = 0; w < term_count; w++)
        word_count_pairs.add(new ImmutablePair<String, Integer>(terms_array[w], save_ideowords[u][w]));

      Collections.sort(word_count_pairs, new Comparator<Pair<String, Integer>>()
      {
        @Override
        public int compare(Pair<String, Integer> o1, Pair<String, Integer> o2)
        {
          return o2.getRight() - o1.getRight();
        }
      });

      OutputStreamWriter osw = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(app_state.getCurStateFile("ideowords-" + ideologies_array[u] + ".txt.gz"))));
      for (Pair<String, Integer> p : word_count_pairs)
        osw.write(p.getLeft() + '\t' + p.getRight().toString() + '\n');
      osw.close();
    }

    {
      OutputStreamWriter osw = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(app_state.getCurStateFile("endstates.txt.gz"))));
      for (int e = 0; e < epoch_count; e++)
        for (int u = 0; u < ideo_count; u++)
          osw.write(String.format("%s\t%s\t%d\n", epoches_array[e], ideologies_array[u], save_endstate[e][u]));
      osw.close();
    }

    {
      OutputStreamWriter osw = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(app_state.getCurStateFile("transitions.txt.gz"))));
      for (int e = 0; e < epoch_count; e++)
        for (int u = 0; u < ideo_count; u++)
          for (int v = 0; v < ideo_count; v++)
            osw.write(String.format("%s\t%s\t%s\t%d\n", epoches_array[e], ideologies_array[u], ideologies_array[v], save_transitions[e][u][v]));
      osw.close();
    }

    {
      OutputStreamWriter osw = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(app_state.getCurStateFile("stop_probs.txt.gz"))));
      for (int e = 0; e < epoch_count; e++)
        for (int u = 0; u < ideo_count; u++)
          osw.write(String.format("%s\t%s\t%f\n", epoches_array[e], ideologies_array[u], stop_probs[e][u]));
      osw.close();
    }

    {
      OutputStreamWriter osw = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(app_state.getCurStateFile("hyperparameters.txt.gz"))));
      osw.write(String.format("alpha\t%f\n", alpha));
      osw.write(String.format("gamma\t%f\n", gamma));
      osw.write(String.format("beta_sage\t%f\n", emission_prior_ideo));
      osw.write(String.format("beta_def\t%f\n", emission_prior_other));
      osw.close();
    }

    return true;
  }

  private void saveSamplesToFile() throws IOException
  {
    OutputStreamWriter osw = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(FilenameUtils.concat(samples_dir.getAbsolutePath(), String.format("%05d.samples.gz", save_sample_count)))));
    for (int e = 0; e < epoch_count; e++)
      for (int d = 0; d < terms_index[e].length; d++)
      {
        osw.write(epoches_array[e] + "\t" + speeches_array[e][d] + "\t");
        for (int i = 1; i < terms_index[e][d].length; i++)
          osw.write(ideologies_array[sample_x[e][d][i]] + "/" + (sample_r[e][d][i] ? 'R' : 'C') + " ");
        osw.write('\n');
      }
    osw.close();

    save_sample_count++;
  }

  private void updateTransitionCounts(int[][] matrix, int[] path, int delta)
  {
    for (int j = 1; j < path.length; j++)
      matrix[path[j - 1]][path[j]] += delta;
    if (path.length == 1)
      matrix[path[0]][path[0]] += delta;
  }

  private void updateTransitionCounts(double[] Z_matrix, int[] path, int delta)
  {
    for (int j = 1; j < path.length; j++)
      Z_matrix[path[j - 1]] += delta;
    if (path.length == 1)
      Z_matrix[path[0]] += delta;
  }

  public void resetSamples()
  {
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
      {
        save_endstate[e][u] = 0;
        for (int v = 0; v < ideo_count; v++)
          save_transitions[e][u][v] = 0;
      }

    for (int u = 0; u < ideo_count; u++)
      for (int w = 0; w < term_count; w++)
        save_ideowords[u][w] = 0;
  }

  public void samplingStep(boolean save_samples, boolean save_samples_to_file)
  {
    for (int e = 0; e < epoch_count; e++)
      for (int d = 0; d < terms_index[e].length; d++)
        for (int i = 1; i < terms_index[e][d].length; i++)
        {
          // remove stuff from n_counts first
          int old_x = sample_x[e][d][i];
          boolean old_r = sample_r[e][d][i];
          int w_di = terms_index[e][d][i];
          int left_x = sample_x[e][d][i - 1];
          int right_x = (i + 1 < sample_r[e][d].length && !sample_r[e][d][i + 1]) ? sample_x[e][d][i + 1] : IDEO_ROOT;

          n_endstate[e][old_x]--;
          n_ideowords[old_x][w_di]--;
          Z_emission_hyper[old_x]--;
          assert (n_endstate[e][old_x] >= 0);
          assert (n_ideowords[old_x][w_di] >= 0);
          assert (Z_emission_hyper[old_x] >= 0);

          if (old_r)
          {
            updateTransitionCounts(n_transitions[e], tree_paths[IDEO_ROOT][old_x], -1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[IDEO_ROOT][old_x], -1);
          }
          else
          {
            updateTransitionCounts(n_transitions[e], tree_paths[left_x][old_x], -1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[left_x][old_x], -1);
          }

          if (right_x != IDEO_ROOT)
          {
            updateTransitionCounts(n_transitions[e], tree_paths[old_x][right_x], -1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[old_x][right_x], -1);
          }
          // System.out.format("%d (%d,%d) %d\n", left_x, old_x, old_r ? 1 : 0,
          // right_x);
          // System.out.println(ArrayUtils.toString(n_transitions[e]));
          // int sum = 0;
          // for (int kk = 0; kk < ideo_count; kk++)
          // for (int k = 0; k < ideo_count; k++)
          // if (n_transitions[e][kk][k] < 0)
          // {
          // System.out.format("\n%d %d %d\n", kk, k, n_transitions[e][kk][k]);
          // System.out.format("%d %d %d\n", e, d, i);
          // System.out.format("%d (%d,%d) %d\n", left_x, old_x, old_r ? 1 : 0,
          // right_x);
          // System.out.format("%s\n", ArrayUtils.toString(sample_x[e][d]));
          // System.out.format("%s\n", ArrayUtils.toString(sample_r[e][d]));
          //
          // System.exit(-1);
          // }
          // for (int k = 0; k < ideo_count; k++)
          // if (Z_transition_hyper[e][k] <= 0)
          // {
          // System.out.format("\n%d %f\n", k, Z_transition_hyper[e][k]);
          // System.out.format("%d %d %d\n", e, d, i);
          // System.out.format("%d (%d,%d) %d\n", left_x, old_x, old_r ? 1 : 0,
          // right_x);
          //
          // System.exit(-1);
          // }

          double[] log_probs = new double[ideo_count + ideo_count - 2];

          // start from 1 coz cannot go to root
          for (int k = 1; k < ideo_count; k++)
          {
            double log_stop = FastMath.log(stop_probs[e][k]);
            double log_emission = FastMath.log(n_ideowords[k][w_di] + emission_hyper[k][w_di]) - FastMath.log(Z_emission_hyper[k]);
            double log_restart = restart_table[terms_lag[e][d][i]];
            double log_norestart = norestart_table[terms_lag[e][d][i]];

            assert (!Double.isNaN(log_stop) && !Double.isInfinite(log_stop));
            assert (!Double.isNaN(log_emission) && !Double.isInfinite(log_emission));
            assert (!Double.isNaN(log_restart) && !Double.isInfinite(log_restart));
            assert (!Double.isNaN(log_norestart) && !Double.isInfinite(log_norestart));

            double log_left_restart = 0.0;
            for (int j = 1; j < tree_paths[IDEO_ROOT][k].length; j++)
            {
              int u = tree_paths[IDEO_ROOT][k][j - 1];
              int v = tree_paths[IDEO_ROOT][k][j];

              log_left_restart += FastMath.log(continue_probs[e][u]);
              log_left_restart += FastMath.log(n_transitions[e][u][v] + alpha);
              log_left_restart -= FastMath.log(Z_transition_hyper[e][u]);
            }

            double log_left_norestart = 0.0;
            if (sample_x[e][d][i - 1] == IDEO_ROOT)
              log_left_norestart = log_left_restart;
            else
              for (int j = 1; j < tree_paths[left_x][k].length; j++)
              {
                int u = tree_paths[left_x][k][j - 1];
                int v = tree_paths[left_x][k][j];

                log_left_norestart += FastMath.log(continue_probs[e][u]);
                log_left_norestart += FastMath.log(n_transitions[e][u][v] + alpha);
                log_left_norestart -= FastMath.log(Z_transition_hyper[e][u]);
              }

            double log_right_norestart = 0.0;
            if (right_x != IDEO_ROOT)
              for (int j = 1; j < tree_paths[k][right_x].length; j++)
              {
                int u = tree_paths[k][right_x][j - 1];
                int v = tree_paths[k][right_x][j];

                log_right_norestart += FastMath.log(continue_probs[e][u]);
                log_right_norestart += FastMath.log(n_transitions[e][u][v] + alpha);
                log_right_norestart -= FastMath.log(Z_transition_hyper[e][u]);
              }

            assert (!Double.isNaN(log_left_restart) && !Double.isInfinite(log_left_restart));
            assert (!Double.isNaN(log_left_norestart) && !Double.isInfinite(log_left_norestart));
            assert (!Double.isNaN(log_right_norestart) && !Double.isInfinite(log_right_norestart));

            log_probs[k - 1] = log_stop + log_emission + log_norestart + log_left_norestart + log_right_norestart;
            log_probs[ideo_count + k - 2] = log_stop + log_emission + log_restart + log_left_restart + log_right_norestart;
          }

          int new_x = old_x;
          boolean new_r = old_r;

          // normalize probabilities
          double max_value = StatUtils.max(log_probs);
          double total_value = 0.0;
          for (int k = 0; k < log_probs.length; k++)
          {
            log_probs[k] = FastMath.exp(log_probs[k] - max_value);
            total_value += log_probs[k];
          }

          // for (int k = 0; k < log_probs.length; k++)
          // log_probs[k] /= total_value;
          // total_value = StatUtils.sum(log_probs);
          // System.out.println(ArrayUtils.toString(log_probs));

          total_value *= R.nextDouble();
          for (int k = 0; k < log_probs.length; k++)
          {
            total_value -= log_probs[k];

            if (total_value <= 0)
            {
              new_x = (k % (ideo_count - 1)) + 1;
              new_r = (k >= ideo_count - 1);
              break;
            }
          }

          assert (total_value <= 0);
          assert (!Double.isNaN(total_value));

          // sample from log_probs
          // System.out.println(ArrayUtils.toString(log_probs));

          sample_x[e][d][i] = new_x;
          sample_r[e][d][i] = new_r;

          n_endstate[e][new_x]++;
          n_ideowords[new_x][w_di]++;
          Z_emission_hyper[new_x]++;

          if (new_r)
          {
            updateTransitionCounts(n_transitions[e], tree_paths[IDEO_ROOT][new_x], +1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[IDEO_ROOT][new_x], +1);
          }
          else
          {
            updateTransitionCounts(n_transitions[e], tree_paths[left_x][new_x], +1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[left_x][new_x], +1);
          }

          if (right_x != IDEO_ROOT)
          {
            updateTransitionCounts(n_transitions[e], tree_paths[new_x][right_x], +1);
            updateTransitionCounts(Z_transition_hyper[e], tree_paths[new_x][right_x], +1);
          }
          // System.out.println(ArrayUtils.toString(n_transitions[e]));
          // System.out.format("%d (%d,%d) %d\n", left_x, new_x, new_r ? 1 : 0,
          // right_x);

          // int sum = 0;
          // for (int kk = 0; kk < ideo_count; kk++)
          // for (int k = 0; k < ideo_count; k++)
          // if (n_transitions[e][kk][k] < 0)
          // {
          // System.out.format("%d %d %d %d\n", e, kk, k,
          // n_transitions[e][kk][k]);
          // System.out.format("%d %d %d\n", e, d, i);
          // System.out.format("%d %d %d %d\n", old_x, old_r ? 1 : 0, new_x,
          // new_r ? 1 : 0);
          // System.out.format("%d %d\n", left_x, right_x);
          // // System.exit(-1);
          // }
          // System.out.println(sum);
        }

    if (save_samples)
    {
      for (int e = 0; e < epoch_count; e++)
        for (int d = 0; d < terms_index[e].length; d++)
          for (int i = 1; i < terms_index[e][d].length; i++)
          {
            int x = sample_x[e][d][i];
            save_endstate[e][x]++;
            save_ideowords[x][terms_index[e][d][i]]++;
            if (sample_r[e][d][i])
              updateTransitionCounts(save_transitions[e], tree_paths[IDEO_ROOT][x], +1);
            else
              updateTransitionCounts(save_transitions[e], tree_paths[sample_x[e][d][i - 1]][x], +1);
          }
    }
    try
    {
      if (save_samples && save_samples_to_file)
        saveSamplesToFile();
    }
    catch (IOException e)
    {
    }
  }

  public void MStep()
  {
    for (int e = 0; e < epoch_count; e++)
    {
      stop_probs[e][0] = 0.0; // always continue at root
      continue_probs[e][0] = 1.0;
      for (int v = 1; v < ideo_count; v++)
      {
        int sum = 0;
        for (int u = 0; u < ideo_count; u++)
          sum += save_transitions[e][v][u]; // no of times we actually come out of v

        stop_probs[e][v] = (double) save_endstate[e][v] / (save_endstate[e][v] + sum);
        if (stop_probs[e][v] == 1.0)
          stop_probs[e][v] = 0.999;
        if (stop_probs[e][v] == 0.0)
          stop_probs[e][v] = 0.001;
        if (sum == 0 && save_endstate[e][v] == 0)
          stop_probs[e][v] = 0.5;

        continue_probs[e][v] = 1.0 - stop_probs[e][v];
      }

      // System.out.println(Arrays.toString(stop_probs[e]));
    }

    // m-step for alpha
    MultivariateFunction ll_alpha = new MultivariateFunction()
    {
      @Override
      public double value(double[] point)
      {
        double a = FastMath.exp(point[0]);
        double ll = 0.0;

        for (int e = 0; e < epoch_count; e++)
          for (int u = 0; u < ideo_count; u++)
          {
            double sum = 0.0;
            ll += Gamma.logGamma(a * tree_neighbors[u].length);
            for (int v : tree_neighbors[u])
            {
              ll += Gamma.logGamma(n_transitions[e][u][v] + a);
              sum += n_transitions[e][u][v] + a;
              ll -= Gamma.logGamma(a);
            }
            ll -= Gamma.logGamma(sum);
          }

        // System.out.format("a: %f, ll: %f\n", a, ll);

        return ll;
      }
    };

    MultivariateFunction ll_gamma = new MultivariateFunction()
    {
      @Override
      public double value(double[] point)
      {
        double g = FastMath.exp(point[0]);

        double[] temp_restart_table = new double[1024];
        double[] temp_norestart_table = new double[1024];
        for (int i = 0; i < temp_restart_table.length; i++)
        {
          double r = FastMath.pow(1.0 - g, (i + 1) / 50.0);
          temp_restart_table[i] = FastMath.log(1.0 - r);
          temp_norestart_table[i] = FastMath.log(r);

          if (Double.isNaN(temp_restart_table[i]) || Double.isInfinite(temp_restart_table[i]))
            temp_restart_table[i] = -1e5;
          if (Double.isNaN(temp_norestart_table[i]) || Double.isInfinite(temp_norestart_table[i]))
            temp_norestart_table[i] = -1e5;
        }

        double ll = 0.0;
        for (int e = 0; e < epoch_count; e++)
          for (int d = 0; d < terms_index[e].length; d++)
            for (int i = 1; i < terms_index[e][d].length; i++)
              if (sample_r[e][d][i])
                ll += temp_restart_table[terms_lag[e][d][i]];
              else
                ll += temp_norestart_table[terms_lag[e][d][i]];

        return ll;
      }
    };

    MultivariateFunction ll_betas = new MultivariateFunction()
    {
      @Override
      public double value(double[] point)
      {
        double beta_sage = FastMath.exp(point[0]);
        double beta_def = FastMath.exp(point[1]);
        double ll = 0.0;

        for (int e = 0; e < epoch_count; e++)
          for (int u = 1; u < ideo_count; u++)
          {
            double b_sum = 0.0;
            double f_sum = 0.0;
            for (int w = 0; w < term_count; w++)
            {
              double b = emission_sage[u][w] ? beta_sage : beta_def;

              b_sum += b;
              f_sum += n_ideowords[u][w];

              ll += Gamma.logGamma(n_ideowords[u][w] + b);
              ll -= Gamma.logGamma(b);
            }
            ll += Gamma.logGamma(b_sum);
            ll -= Gamma.logGamma(f_sum + b_sum);
          }

        return ll;
      }
    };

    List<double[]> history = SliceSampler.slice_sample(ll_alpha, new double[] { FastMath.log(alpha) }, new double[] { 0.1 }, 30);
    double new_alpha = FastMath.exp(history.get(history.size() - 1)[0]);
    assert new_alpha > 0;
    if (new_alpha > 10 || new_alpha < 0.1)
    {
      System.err.format("Degenerate alpha %f. ", new_alpha);
      new_alpha = alpha;
      for (double[] state : history)
        if (FastMath.exp(state[0]) < 10 && FastMath.exp(state[0]) > 0.1)
        {
          new_alpha = FastMath.exp(state[0]);
          break;
        }
      System.err.format("Using %f instead.\n", new_alpha);
    }
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
        Z_transition_hyper[e][u] += (new_alpha - alpha) * (tree_neighbors[u].length - 1);
    alpha = new_alpha;

    // System.out.format("New alpha: %f\n", alpha);
    // for (double[] state : history)
    // System.out.format("%f, ", state[0]);
    // System.out.println();
    history.clear();

    if (gamma > 0.0 && gamma < 1.0)
    {
      history = SliceSampler.slice_sample(ll_gamma, new double[] { FastMath.log(gamma) }, new double[] { 1 }, 30);
      gamma = FastMath.exp(history.get(history.size() - 1)[0]);
      assert (gamma > 0);
      for (int i = 0; i < restart_table.length; i++)
      {
        double r = FastMath.pow(1.0 - gamma, (i + 1) / 50.0);
        restart_table[i] = FastMath.log(1.0 - r);
        norestart_table[i] = FastMath.log(r);

        if (Double.isNaN(restart_table[i]) || Double.isInfinite(restart_table[i]))
          restart_table[i] = -1e5;
        if (Double.isNaN(norestart_table[i]) || Double.isInfinite(norestart_table[i]))
          norestart_table[i] = -1e5;
      }
      // System.out.format("New gamma: %f\n", gamma);
      history.clear();
    }

    history = SliceSampler.slice_sample(ll_betas, new double[] { FastMath.log(emission_prior_ideo), FastMath.log(emission_prior_other) }, new double[] { 1, 1 }, 30);
    double new_beta_sage = FastMath.exp(history.get(history.size() - 1)[0]);
    double new_beta_def = FastMath.exp(history.get(history.size() - 1)[1]);
    if (new_beta_def < 1e-7)
    {
      System.err.format("Degenerate beta_def %f. ", new_beta_def);
      new_beta_def = emission_prior_other;
      for (double[] state : history)
        if (FastMath.exp(state[1]) > 1e-7)
        {
          new_beta_def = FastMath.exp(state[1]);
          break;
        }
      System.err.format("Using %f instead.\n", new_beta_def);
    }
    assert (new_beta_sage > 0);
    assert (new_beta_def > 1e-7);
    for (int u = 0; u < ideo_count; u++)
      for (int w = 0; w < term_count; w++)
      {
        double new_beta = (emission_sage[u][w] ? new_beta_sage : new_beta_def);
        Z_emission_hyper[u] += new_beta - emission_hyper[u][w];
        emission_hyper[u][w] = new_beta;
      }
    emission_prior_ideo = new_beta_sage;
    emission_prior_other = new_beta_def;
    history.clear();
    // System.out.format("New beta_sage: %f\n", emission_prior_ideo);
    // System.out.format("New beta_def: %f\n", emission_prior_other);
  }

  public Map<String, Double> computeModelLogLikelihood()
  {
    Map<String, Double> model_ll = new HashMap<>();

    double[][][] transition_probs = new double[epoch_count][ideo_count][ideo_count];
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
      {
        double denom = 0;
        for (int v = 0; v < ideo_count; v++)
        {
          transition_probs[e][u][v] = save_transitions[e][u][v] + alpha;
          denom += transition_probs[e][u][v];
        }
        for (int v = 0; v < ideo_count; v++)
          transition_probs[e][u][v] /= denom;
      }

    double[][] emission_probs = new double[ideo_count][term_count];
    for (int u = 0; u < ideo_count; u++)
    {
      double denom = 0;
      for (int w = 0; w < term_count; w++)
      {
        emission_probs[u][w] = save_ideowords[u][w] + emission_hyper[u][w];
        denom += emission_probs[u][w];
      }
      for (int w = 0; w < term_count; w++)
        emission_probs[u][w] /= denom;
    }

    double ll = 0.0;
    for (int e = 0; e < epoch_count; e++)
    {
      double epoch_ll = 0.0;
      for (int d = 0; d < terms_index[e].length; d++)
        for (int i = 1; i < terms_index[e][d].length; i++)
          if (sample_r[e][d][i])
            epoch_ll += restart_table[terms_lag[e][d][i]] + computeLogPathProb(e, transition_probs[e], tree_paths[IDEO_ROOT][sample_x[e][d][i]]) + emission_probs[sample_x[e][d][i]][terms_index[e][d][i]];
          else
            epoch_ll += norestart_table[terms_lag[e][d][i]] + computeLogPathProb(e, transition_probs[e], tree_paths[sample_x[e][d][i - 1]][sample_x[e][d][i]]) + emission_probs[sample_x[e][d][i]][terms_index[e][d][i]];

      ll += epoch_ll;
      model_ll.put(epoches_array[e], epoch_ll);
    }

    model_ll.put("", ll);

    return model_ll;
  }

  private double computeLogPathProb(int epoch, double[][] theta, int[] path)
  {
    double log_prob = FastMath.log(stop_probs[epoch][path[path.length - 1]]);

    for (int j = 1; j < path.length; j++)
      log_prob += FastMath.log(continue_probs[epoch][path[j - 1]]) + theta[path[j - 1]][path[j]];

    return log_prob;
  }
}
