package edu.cmu.cs.ark.compuframes;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FalseFileFilter;
import org.apache.commons.io.filefilter.TrueFileFilter;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.StatUtils;

import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Switch;
import com.martiansoftware.jsap.UnflaggedOption;
import com.martiansoftware.jsap.stringparsers.FileStringParser;

import edu.cmu.cs.ark.compuframes.types.SamplerOutputLoader;
import edu.cmu.cs.ark.yc.config.AppConfig;

public class EtchASketchAnalysisApp extends SamplerOutputLoader
{
  private final JSAPResult jsap_results;

  public EtchASketchAnalysisApp(String[] args)
  {
    AppConfig app_config = new AppConfig("EtchASketchAnalysisApp", "App to analyze output of Etch A Sketch sampler.", false, true);

    try
    {
      app_config.registerParameter(new UnflaggedOption("output-folder", FileStringParser.getParser().setMustBeDirectory(true).setMustExist(true), true, "Location of sampler output directory."));
      app_config.registerParameter(new Switch("ideology-lags", JSAP.NO_SHORTFLAG, "ideology-lags", "Compute amount of lag time spent in each ideology."));
      app_config.registerParameter(new Switch("viterbi-states", JSAP.NO_SHORTFLAG, "viterbi-states", "Output viterbi states for all speeches."));
      app_config.registerParameter(new Switch("restart-correlations", JSAP.NO_SHORTFLAG, "restart-correlations", "Compute correlations between states and restarting."));
      app_config.registerParameter(new Switch("transition-stats", JSAP.NO_SHORTFLAG, "transition-stats", "Compute correlations between states and restarting."));
    }
    catch (JSAPException e)
    {
      System.err.println("Error registering argument options!");
      e.printStackTrace();
      System.exit(-1);
    }

    jsap_results = app_config.parse(args);

    if (!jsap_results.success())
    {
      @SuppressWarnings({ "rawtypes" })
      Iterator iter = jsap_results.getErrorMessageIterator();
      while (iter.hasNext())
        System.err.println(iter.next());
      System.exit(-1);
    }

    File last_iter_dir = null;
    int largest_iter = -1;
    for (File iter_dir : FileUtils.listFilesAndDirs(jsap_results.getFile("output-folder"), FalseFileFilter.INSTANCE, TrueFileFilter.INSTANCE))
      try
      {
        int iter = Integer.parseInt(iter_dir.getName());
        if (iter > largest_iter)
        {
          largest_iter = iter;
          last_iter_dir = iter_dir;
        }
      }
      catch (NumberFormatException nfe)
      {
        continue;
      }

    if (last_iter_dir == null || largest_iter < 0)
    {
      System.err.println("Error finding iteration directories!");
      System.exit(-1);
    }

    try
    {
      System.err.println("Loading from " + last_iter_dir + "...");
      loadGzipSerializedData(new File(last_iter_dir, "state.serialized.gz"));
    }
    catch (Exception e)
    {
      e.printStackTrace();
      System.exit(-1);
    }
  }

  private void computeIdeologyLags()
  {
    for (int e = 0; e < epoch_count; e++)
    {
      double[] ideology_lags = new double[ideo_count];
      for (int d = 0; d < terms_index[e].length; d++)
      {
        List<Pair<Integer, Boolean>> states = computeViterbiPath(e, d);

        for (int i = 1; i < terms_index[e][d].length; i++)
        {
          int x = states.get(i).getLeft();
          // boolean r = states.get(i).getRight();
          ideology_lags[x] += 0.5 * (terms_lag[e][d][i] + terms_lag[e][d][i + 1]);
        }
      }

      double total_lags = StatUtils.sum(ideology_lags);
      for (int u = 1; u < ideo_count; u++)
        System.out.format("%s\t%s\t%.1f\t%f\n", epoches_array[e], ideologies_array[u], ideology_lags[u], ideology_lags[u] / total_lags);
      System.out.println();
    }
  }

  private void displayViterbiStates()
  {
    for (int e = 0; e < epoch_count; e++)
    {
      for (int d = 0; d < terms_index[e].length; d++)
      {
        List<Pair<Integer, Boolean>> states = computeViterbiPath(e, d);

        System.out.format("%s\t%s\t", epoches_array[e], speeches_array[e][d]);
        System.out.format("%s:%d:%s:%d\t", "__START_OF_SPEECH__", 0, ideologies_array[IDEO_ROOT], 0);
        for (int i = 1; i < terms_index[e][d].length; i++)
        {
          int x = states.get(i).getLeft();
          boolean r = states.get(i).getRight();
          System.out.format("%s:%d:%s:%d\t", terms_array[terms_index[e][d][i]], terms_lag[e][d][i], ideologies_array[x], r ? 1 : 0);
        }
        System.out.format("%s:%d:%s:%d\n", "__END_OF_SPEECH__", terms_lag[e][d][terms_index[e][d].length], "", 0);
      }
      System.out.println();
    }
  }

  private void computeRestartCorrelation()
  {
    for (int e = 0; e < epoch_count; e++)
    {
      int[][] ideology_restarts = new int[ideo_count][2];
      for (int d = 0; d < terms_index[e].length; d++)
      {
        List<Pair<Integer, Boolean>> states = computeViterbiPath(e, d);

        for (int i = 1; i < terms_index[e][d].length; i++)
        {
          int x = states.get(i).getLeft();
          boolean r = states.get(i).getRight();

          ideology_restarts[x][r ? 1 : 0]++;
        }
      }

      double[] restart_probs = new double[ideo_count];
      for (int u = 1; u < ideo_count; u++)
      {
        double total = ideology_restarts[u][0] + ideology_restarts[u][1];
        restart_probs[u] = total > 0 ? ideology_restarts[u][0] / total : 0.0;
      }
      double total_restart_probs = StatUtils.sum(restart_probs);
      for (int u = 1; u < ideo_count; u++)
        System.out.format("%s\t%s\t%d\t%d\t%f\t%f\n", epoches_array[e], ideologies_array[u], ideology_restarts[u][0], ideology_restarts[u][1], restart_probs[u], restart_probs[u] / total_restart_probs);
      System.out.println();
    }
  }

  private void computeTransitionsStats()
  {
    for (int e = 0; e < epoch_count; e++)
    {
      int[][][] transition_counts = new int[ideo_count][2][ideo_count];
      int total = 0;

      for (int d = 0; d < terms_index[e].length; d++)
      {
        List<Pair<Integer, Boolean>> states = computeViterbiPath(e, d);

        for (int i = 2; i < terms_index[e][d].length; i++)
        {
          int x_prev = states.get(i - 1).getLeft();
          int x = states.get(i).getLeft();
          boolean r = states.get(i).getRight();

          transition_counts[x_prev][r ? 1 : 0][x]++;
          total++;
        }
      }

      List<Pair<String, Integer>> transition_list = new ArrayList<>();
      for (int u = 0; u < ideo_count; u++)
        for (int v = 0; v < ideo_count; v++)
        {
          transition_list.add(new ImmutablePair<String, Integer>(ideologies_array[u] + "->" + ideologies_array[v], transition_counts[u][0][v]));
          transition_list.add(new ImmutablePair<String, Integer>(ideologies_array[u] + "##" + ideologies_array[v], transition_counts[u][1][v]));
        }
      Collections.sort(transition_list, new Comparator<Pair<String, Integer>>()
      {
        @Override
        public int compare(Pair<String, Integer> o1, Pair<String, Integer> o2)
        {
          return o2.getRight() - o1.getRight();
        }
      });

      for (Pair<String, Integer> trans : transition_list)
        if (trans.getRight() > 0)
          System.out.format("%s\t%s\t%d\t%f\n", epoches_array[e], trans.getLeft(), trans.getRight(), (double) trans.getRight() / total);
      System.out.println();
    }
  }

  public static void main(String[] args)
  {
    EtchASketchAnalysisApp app = new EtchASketchAnalysisApp(args);

    if (app.jsap_results.getBoolean("ideology-lags"))
      app.computeIdeologyLags();

    if (app.jsap_results.getBoolean("viterbi-states"))
      app.displayViterbiStates();

    if (app.jsap_results.getBoolean("restart-correlations"))
      app.computeRestartCorrelation();

    if (app.jsap_results.getBoolean("transition-stats"))
      app.computeTransitionsStats();
  }
}
