package edu.cmu.cs.ark.compuframes.types;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.zip.GZIPInputStream;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.util.FastMath;

public abstract class SamplerOutputLoader
{
  protected int epoch_count, term_count, ideo_count;
  protected String[] terms_array;
  protected String[] ideologies_array;
  protected String[] epoches_array;
  protected String[][] speeches_array;

  // terms_index[0] is __START_OF_SPEECH__
  // terms_index[-1] is last word of speech
  // terms_lag[0] is 0
  // terms_lag[-1] is lag at the end of speech
  // terms_lag[i] is lag between word i-1 and i
  protected int[][][] terms_index;
  protected int[][][] terms_lag;

  // array for holding parameters
  protected double[][] emission_hyper;
  protected double[][] stop_probs;
  protected double[][] continue_probs;
  protected double alpha, gamma;

  // array for holding current iteration samples
  protected int[][][] sample_x;
  protected boolean[][][] sample_r;

  // array for holding counts
  protected int[][] n_ideowords;
  protected int[][] n_endstate;
  protected int[][][] n_transitions;

  // array for caching normalization stuff
  protected double[] Z_emission_hyper;
  protected double[][] Z_transition_hyper;

  // array for holding saved counts
  protected int[][] save_ideowords;
  protected int[][] save_endstate;
  protected int[][][] save_transitions;

  // first and last vertices are start and end points anyway
  protected int[][][] tree_paths;
  protected static final int IDEO_ROOT = 0;

  protected double[] restart_table;
  protected double[] norestart_table;

  protected double[][][] path_probs;
  protected double[][] emission_probs;

  public SamplerOutputLoader()
  {
  }

  public void loadGzipSerializedData(File serialized_file) throws ClassNotFoundException, IOException
  {
    ObjectInput bis = new ObjectInputStream(new GZIPInputStream(new FileInputStream(serialized_file)));

    alpha = bis.readDouble();
    gamma = bis.readDouble();

    restart_table = new double[1024];
    norestart_table = new double[1024];
    for (int i = 0; i < restart_table.length; i++)
    {
      double r = FastMath.pow(1.0 - gamma, (i + 1) / 50.0);
      restart_table[i] = FastMath.log(1.0 - r);
      norestart_table[i] = FastMath.log(r);

      if (Double.isNaN(restart_table[i]) || Double.isInfinite(restart_table[i]))
        restart_table[i] = -1e8;
      if (Double.isNaN(norestart_table[i]) || Double.isInfinite(norestart_table[i]))
        norestart_table[i] = -1e8;
    }

    epoches_array = (String[]) bis.readObject(); // epoches_array
    epoch_count = epoches_array.length;
    speeches_array = (String[][]) bis.readObject(); // speeches_array

    terms_array = (String[]) bis.readObject(); // terms_array
    term_count = terms_array.length;

    ideologies_array = (String[]) bis.readObject(); // ideologies_array
    ideo_count = ideologies_array.length;

    tree_paths = (int[][][]) bis.readObject();

    terms_index = (int[][][]) bis.readObject(); // terms_index
    terms_lag = (int[][][]) bis.readObject(); // terms_lag

    stop_probs = (double[][]) bis.readObject();
    continue_probs = (double[][]) bis.readObject(); // continue_probs

    sample_r = (boolean[][][]) bis.readObject();
    sample_x = (int[][][]) bis.readObject();

    n_endstate = (int[][]) bis.readObject();
    n_ideowords = (int[][]) bis.readObject();
    n_transitions = (int[][][]) bis.readObject();

    emission_hyper = (double[][]) bis.readObject();
    Z_emission_hyper = (double[]) bis.readObject();
    Z_transition_hyper = (double[][]) bis.readObject();

    save_ideowords = (int[][]) bis.readObject();
    save_endstate = (int[][]) bis.readObject();
    save_transitions = (int[][][]) bis.readObject();

    for (int e = 0; e < epoch_count; e++)
    {
      stop_probs[e][IDEO_ROOT] = 0.0;
      continue_probs[e][IDEO_ROOT] = 1.0;
    }

    bis.close();

    path_probs = computePathProbabilities();
    emission_probs = computeEmissionProbabilities();
  }

  private double[][] computeEmissionProbabilities()
  {
    double[][] probs = new double[ideo_count][term_count];
    for (int u = 0; u < ideo_count; u++)
    {
      double denom = 0;
      for (int w = 0; w < term_count; w++)
      {
        probs[u][w] = save_ideowords[u][w] + emission_hyper[u][w];
        denom += probs[u][w];
      }
      for (int w = 0; w < term_count; w++)
        probs[u][w] = FastMath.log(probs[u][w] / denom);
    }

    return probs;
  }

  @SuppressWarnings("unused")
  private void setupTreePaths(File ideology_file) throws IOException
  {
    List<String[]> ideology_pairs = new ArrayList<String[]>();

    BufferedReader br = new BufferedReader(new FileReader(ideology_file));

    br.readLine();

    for (int i = 0; i < ideo_count; i++)
      br.readLine(); // skip the ideologies name

    String line;
    while ((line = br.readLine()) != null)
    {
      if (line.isEmpty() || line.startsWith("#"))
        continue;
      ideology_pairs.add(line.split(" ", 2));
    }

    br.close();

    boolean[][] adj_matrix = new boolean[ideo_count][ideo_count];

    for (String[] edge : ideology_pairs)
    {
      int i = ArrayUtils.indexOf(ideologies_array, edge[0]);
      int j = ArrayUtils.indexOf(ideologies_array, edge[1]);
      adj_matrix[i][j] = true;
      adj_matrix[j][i] = true;
    }

    for (int u = 0; u < ideo_count; u++)
    {
      LinkedList<Integer> q = new LinkedList<>();
      LinkedList<Integer> q_from = new LinkedList<>();
      boolean[] visited = new boolean[ideo_count];
      q.add(u);
      q_from.add(-1);
      tree_paths[u][u] = new int[1];
      tree_paths[u][u][0] = u;

      while (!q.isEmpty())
      {
        int v = q.pop();
        int from = q_from.pop();

        if (v != u && !visited[v] && from != -1)
        {
          tree_paths[u][v] = new int[tree_paths[u][from].length + 1];
          for (int i = 0; i < tree_paths[u][from].length; i++)
            tree_paths[u][v][i] = tree_paths[u][from][i];
          tree_paths[u][v][tree_paths[u][from].length] = v;
        }

        visited[v] = true;

        for (int w = 0; w < ideo_count; w++)
          if (adj_matrix[v][w] && !visited[w])
          {
            q.add(w);
            q_from.add(v);
          }
      }
    }
    // int i = 9;
    // int j = 4;
    // System.out.println(ideologies_array[i]);
    // System.out.println(ideologies_array[j]);
    // System.out.println(ArrayUtils.toString(tree_paths[i][j]));
  }

  /**
   * Compiles the edge transitions into a speaker state transitions matrix. <br />
   * Normalized the speaker state transitions (mass disappearing due to illegal paths).<br />
   * Not very useful coz nearby states get much higher probabilities.
   */
  private double[][][] computePathProbabilities()
  {
    double[][][] transition_probs = new double[epoch_count][ideo_count][ideo_count];
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
      {
        double denom = 0;
        for (int v = 0; v < ideo_count; v++)
        {
          transition_probs[e][u][v] = save_transitions[e][u][v] + alpha;
          denom += transition_probs[e][u][v];
        }
        for (int v = 0; v < ideo_count; v++)
          transition_probs[e][u][v] /= denom;
      }

    // System.out.println(ArrayUtils.toString(stop_probs));
    // System.out.println(ArrayUtils.toString(continue_probs));

    double[][][] state_probs = new double[epoch_count][ideo_count][ideo_count];
    for (int e = 0; e < epoch_count; e++)
      for (int u = 0; u < ideo_count; u++)
      {
        for (int v = 0; v < ideo_count; v++)
        {
          int[] path = tree_paths[u][v];

          double prob = stop_probs[e][v];
          for (int j = 1; j < path.length; j++)
            prob *= continue_probs[e][path[j - 1]] * transition_probs[e][path[j - 1]][path[j]];

          state_probs[e][u][v] = prob;
        }
        double denom = StatUtils.sum(state_probs[e][u]);
        for (int v = 0; v < ideo_count; v++)
          state_probs[e][u][v] = FastMath.log(state_probs[e][u][v] / denom);
      }

    return state_probs;

    // for (int e = 0; e < epoch_count; e++)
    // {
    // List<Pair<Pair<Integer, Integer>, Double>> epoch_transitions = new ArrayList<>();
    // for (int u = 0; u < ideo_count; u++)
    // for (int v = 0; v < ideo_count; v++)
    // if (u != v)
    // epoch_transitions.add(new ImmutablePair<Pair<Integer, Integer>, Double>(new ImmutablePair<Integer, Integer>(u, v), state_probs[e][u][v]));
    //
    // Collections.sort(epoch_transitions, new Comparator<Pair<Pair<Integer, Integer>, Double>>()
    // {
    // @Override
    // public int compare(Pair<Pair<Integer, Integer>, Double> o1, Pair<Pair<Integer, Integer>, Double> o2)
    // {
    // return (o2.getRight() > o1.getRight()) ? 1 : -1;
    // }
    // });

    // for (int u = 0; u < ideo_count; u++)
    // {
    // System.out.println(ArrayUtils.toString(state_probs[e][u]));
    // }
    // for (int i = 0; i < 5; i++)
    // System.out.println(epoch_transitions.get(i));
    // break;
    // }
  }

  protected List<Pair<Integer, Boolean>> computeViterbiPath(int e, int d)
  {
    int N = ideo_count + ideo_count;
    int doc_length = terms_index[e][d].length;
    double[][] viterbi_matrix = new double[doc_length][N];
    int[][] back_matrix = new int[doc_length][N];

    for (int u = 0; u < N; u++)
      viterbi_matrix[0][u] = -1e8;
    for (int i = 1; i < doc_length; i++)
    {
      viterbi_matrix[i][0] = -1e8;
      viterbi_matrix[i][ideo_count] = -1e8;
    }

    viterbi_matrix[0][IDEO_ROOT] = 0.0;

    for (int i = 1; i < doc_length; i++)
    {
      int w_di = terms_index[e][d][i];

      for (int u = 1; u < ideo_count; u++)
      {
        double max_p_norestart = -1e8;
        int max_v_norestart = -1;
        double max_p_restart = -1e8;
        int max_v_restart = -1;

        for (int v = 0; v < N; v++)
          if (v != ideo_count)
          {
            double p_norestart = viterbi_matrix[i - 1][v] + path_probs[e][v % ideo_count][u] + emission_probs[u][w_di];
            if (p_norestart > max_p_norestart)
            {
              max_p_norestart = p_norestart;
              max_v_norestart = v;
            }

            double p_restart = viterbi_matrix[i - 1][v] + path_probs[e][IDEO_ROOT][u] + emission_probs[u][w_di];
            if (p_restart > max_p_restart)
            {
              max_p_restart = p_restart;
              max_v_restart = v;
            }
          }

        viterbi_matrix[i][u] = max_p_norestart + norestart_table[terms_lag[e][d][i]];
        back_matrix[i][u] = max_v_norestart;

        viterbi_matrix[i][u + ideo_count] = max_p_restart + restart_table[terms_lag[e][d][i]];
        back_matrix[i][u + ideo_count] = max_v_restart;
        // System.out.format("%d %d %f %d\n", i, u, max_p_norestart, max_v_norestart);

        assert (!Double.isNaN(max_p_norestart));
        assert (!Double.isNaN(max_p_restart));
        assert (max_v_restart >= 0);
        assert (max_v_norestart >= 0);
      }
    }

    // for (int i = 0; i < doc_length; i++)
    // {
    // for (int u = 0; u < N; u++)
    // System.out.format("%d:%.2f\t", back_matrix[i][u], viterbi_matrix[i][u]);
    // System.out.println();
    // }
    // System.out.println();
    // System.out.println(path_probs[e][6][6]);
    // System.out.println(path_probs[e][0][6]);

    int[] path = new int[doc_length];
    for (int u = 1; u < N; u++)
      if (viterbi_matrix[doc_length - 1][u] > viterbi_matrix[doc_length - 1][path[doc_length - 1]])
        path[doc_length - 1] = u;

    for (int i = doc_length - 2; i >= 0; i--)
      path[i] = back_matrix[i + 1][path[i + 1]];

    List<Pair<Integer, Boolean>> speaker_states = new ArrayList<>();
    for (int i = 0; i < doc_length; i++)
      speaker_states.add(new ImmutablePair<Integer, Boolean>(path[i] % ideo_count, path[i] >= ideo_count));

    return speaker_states;

    // restart[i] = back_matrix[i + 1][path[i + 1] + (restart[i + 1] ? ideo_count : 0)] >= ideo_count;
    // boolean[] restart = new boolean[doc_length];
    // path = new int[] { 0, 2, 2, 2, 6, 6, 6, 6 };
    // restart = new boolean[] { true, false, false, false, true, false, false, false };
    // double p = 0.0;
    // for (int i = 1; i < doc_length; i++)
    // {
    // int u = path[i - 1] + (restart[i - 1] ? ideo_count : 0);
    // int v = path[i] + (restart[i] ? ideo_count : 0);
    // int w_di = terms_index[e][d][i];
    //
    // p += path_probs[e][restart[i] ? IDEO_ROOT : path[i - 1]][path[i]] + emission_probs[path[i]][w_di];
    //
    // System.out.format("%d\t%s\t%f\t%f\t%f\t=\t%f\t%f\n", path[i], restart[i], viterbi_matrix[i - 1][u], path_probs[e][u >= ideo_count ? IDEO_ROOT : u][path[i]], emission_probs[path[i]][w_di], viterbi_matrix[i][v], p);
    // }
  }

}
