/**
 * 
 */
package edu.cmu.cs.ark.compuframes;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.io.filefilter.FalseFileFilter;
import org.apache.commons.io.filefilter.TrueFileFilter;
import org.apache.commons.io.output.TeeOutputStream;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.UnflaggedOption;
import com.martiansoftware.jsap.stringparsers.FileStringParser;

import edu.cmu.cs.ark.compuframes.types.EpochData;
import edu.cmu.cs.ark.compuframes.types.TermWeights;
import edu.cmu.cs.ark.yc.config.AppConfig;

/**
 * @author Yanchuan Sim
 * @version 0.1
 * @since 0.1
 */
public class EtchASketchSamplerApp
{
  private final AppGlobal app_data;
  private final AppState app_state;

  public EtchASketchSamplerApp(String[] args) throws IOException
  {
    AppConfig app_config = createAppConfig();
    AppSettings app_settings = new AppSettings(app_config, args);
    if (!app_settings.isValid())
    {
      app_config.displayHelp(app_settings.getJSAPResults());
      System.exit(-1);
    }

    app_data = new AppGlobal();
    app_data.SETTINGS = app_settings;

    app_state = new AppState(app_settings.OUTPUT_DIR);
    app_data.LOGGING = new PrintStream(new TeeOutputStream(System.err, new FileOutputStream(app_state.getBaseDirFile("logging.txt"))), true);
    if (!app_settings.isResume())
    {
      createOutputDirectory(app_config);
      app_data.LOGGING.format("[Init] Initialized output directory <%s>.\n", app_settings.OUTPUT_DIR.getPath());
    }

    app_data.SETTINGS.displaySettings(app_data.LOGGING);

    loadTermsData(app_settings.TERMS_FILE);
    app_data.LOGGING.format("[Init] Loaded %d terms from <%s>.\n", app_data.TERMS.size(), app_settings.TERMS_FILE.getPath());

    loadIdeologyData(app_settings.IDEOLOGY_FILE);
    app_data.LOGGING.format("[Init] Loaded %d ideologies and %d ideology edges from <%s>.\n", app_data.IDEOLOGIES.size(), app_data.IDEOLOGY_PAIRS.size(), app_settings.IDEOLOGY_FILE.getPath());

    int speeches_count = loadSpeechData(app_settings.DATA_DIR);
    app_data.LOGGING.format("[Init] Loaded %d epoches (total %d speeches) and from <%s>.\n", app_data.EPOCHES.size(), speeches_count, app_settings.DATA_DIR.getPath());
  }

  public void run() throws IOException, NumberFormatException, ClassNotFoundException
  {
    GibbsSampler gs = new GibbsSampler(app_data);
    double run_time_start = System.nanoTime();

    int start_from = 1;

    if (app_data.SETTINGS.isResume() && app_state.gotoLatestState())
    {
      if (gs.loadState(app_state.getCurStateFile("state.serialized.gz")))
      {
        start_from = Integer.parseInt(app_state.getCurStateDirectory().getName());
        app_data.LOGGING.format("[Init] Resuming from serialized information in iteration %d...\n", start_from);
        start_from++;
      }
      else
      {
        app_data.LOGGING.format("[Init] ERROR while resuming from serialized information in iteration %d...\n", Integer.parseInt(app_state.getCurStateDirectory().getName()));
        return;
      }
    }
    else
    {
      app_data.LOGGING.println("[Save] Saving system state...");
      app_state.createState(0);
      gs.saveState(app_state);
    }

    // if (true)
    // return;
    app_data.LOGGING.format("[Time] running=%f seconds\n", (System.nanoTime() - run_time_start) / 1e9);
    app_data.LOGGING.println();

    for (int iter = start_from; iter <= app_data.SETTINGS.ITERATIONS; iter++)
    {
      double iter_time_start = System.nanoTime();

      app_data.LOGGING.format("[Iter] Iteration %d%s\n", iter, iter < app_data.SETTINGS.BURNIN_ITERS ? " (burn-in iter)" : "");
      app_data.LOGGING.format("\t[E-step] Gibbs sampling...");

      if (iter == app_data.SETTINGS.BURNIN_ITERS)
        gs.resetSamples();

      for (int sample_iter = 0; sample_iter < app_data.SETTINGS.SAMPLE_COUNT; sample_iter++)
      {
        if (sample_iter >= app_data.SETTINGS.BURNIN_SAMPLES && sample_iter % app_data.SETTINGS.SAMPLE_INTERVALS == 0)
          gs.samplingStep(true, iter >= app_data.SETTINGS.BURNIN_ITERS);
        else
          gs.samplingStep(false, false);

        if ((sample_iter + 1) % (app_data.SETTINGS.SAMPLE_COUNT / 10) == 0)
          app_data.LOGGING.format("%d..", sample_iter + 1);
      }

      if (iter % app_data.SETTINGS.MSTEP_ITERS == 0 && app_data.SETTINGS.MSTEP_ITERS > 0)
      {
        app_data.LOGGING.print("m-step..");
        gs.MStep();
      }

      app_data.LOGGING.println("done!");

      Map<String, Double> model_ll = gs.computeModelLogLikelihood();
      app_data.LOGGING.format("\t[Stat] Model LL=%s\n", model_ll);
      app_data.LOGGING.format("\t[Time] iteration=%f seconds, running=%f seconds\n", (System.nanoTime() - iter_time_start) / 1e9, (System.nanoTime() - run_time_start) / 1e9);
      app_data.LOGGING.println();

      BufferedWriter bw = new BufferedWriter(new FileWriter(app_state.getBaseDirFile("likelihood.txt"), true));
      // bw.write("# iteration " + iter + "\n");
      // bw.write(model_ll.get("") + "\n");
      bw.write(String.format("%d\t%f\n", iter, model_ll.get("")));
      bw.close();

      if (iter % app_data.SETTINGS.SAVE_INTERVAL == 0)
      {
        app_data.LOGGING.println("[Save] Saving system state...");
        app_state.createState(iter);
        gs.saveState(app_state);
        app_data.LOGGING.println();
      }
    }
  }

  private int loadSpeechData(File data_dir) throws IOException
  {
    app_data.EPOCHES = new ArrayList<>();

    for (File epoch_dir : FileUtils.listFilesAndDirs(data_dir, FalseFileFilter.INSTANCE, TrueFileFilter.INSTANCE))
    {
      if (epoch_dir.equals(data_dir))
        continue;

      app_data.EPOCHES.add(new EpochData(epoch_dir));
    }

    int speeches_count = 0;
    for (EpochData epoch : app_data.EPOCHES)
      speeches_count += epoch.speech_titles.size();

    return speeches_count;
  }

  private void loadIdeologyData(File ideology_file) throws IOException
  {
    app_data.IDEOLOGIES = new ArrayList<String>();
    app_data.IDEOLOGY_PAIRS = new ArrayList<String[]>();

    BufferedReader br = new BufferedReader(new FileReader(ideology_file));

    int ideo_count = Integer.parseInt(br.readLine());

    for (int i = 0; i < ideo_count; i++)
      app_data.IDEOLOGIES.add(br.readLine());

    String line;
    while ((line = br.readLine()) != null)
    {
      if (line.isEmpty() || line.startsWith("#"))
        continue;
      app_data.IDEOLOGY_PAIRS.add(line.split(" ", 2));
    }

    br.close();

    // if (app_data.IDEOLOGIES.size() != app_data.IDEOLOGY_PAIRS.size())
    // {
    // System.err.println("Ideology tree structure is invalid!");
    // System.exit(-1);
    // }
  }

  private void loadTermsData(File terms_file) throws IOException
  {
    app_data.TERMS = new ArrayList<>();

    LineIterator it = FileUtils.lineIterator(terms_file);
    while (it.hasNext())
    {
      String line = it.nextLine();
      if (line.startsWith("#") || line.isEmpty())
        continue;

      app_data.TERMS.add(new TermWeights(line));
    }
    it.close();
  }

  private AppConfig createAppConfig()
  {
    AppConfig app_config = new AppConfig("EtchASketchSamplerApp", "Etch A Sketch sampler.", true, true);

    try
    {
      app_config.registerParameter(new UnflaggedOption("name", JSAP.STRING_PARSER, false, "Name of this sampling run. (default: output directory name)"));

      app_config.registerParameter(new FlaggedOption("data-dir", FileStringParser.getParser().setMustBeDirectory(true).setMustExist(true), JSAP.NO_DEFAULT, false, 'D', "data-dir", "Directory containing our speech-lag data."));
      app_config.registerParameter(new FlaggedOption("output-dir", FileStringParser.getParser().setMustBeDirectory(true).setMustExist(true), JSAP.NO_DEFAULT, false, 'o', "output-dir", "Directory to save progress of Gibbs sampler."));
      app_config.registerParameter(new FlaggedOption("resume-dir", FileStringParser.getParser().setMustBeDirectory(true).setMustExist(true), JSAP.NO_DEFAULT, false, 'r', "resume-dir", "Resume Gibbs sampling with data in this directory."));
      // app_config.registerParameter(new FlaggedOption("sample-dir", FileStringParser.getParser().setMustBeDirectory(true).setMustExist(true), JSAP.NO_DEFAULT, false, 's', "sample-dir", "Directory to save samples in."));

      app_config.registerParameter(new FlaggedOption("terms-file", FileStringParser.getParser().setMustBeFile(true).setMustExist(true), JSAP.NO_DEFAULT, false, JSAP.NO_SHORTFLAG, "terms-file", "File containing terms used in speeches (and their prior weights)."));
      app_config.registerParameter(new FlaggedOption("ideology-file", FileStringParser.getParser().setMustBeFile(true).setMustExist(true), JSAP.NO_DEFAULT, false, JSAP.NO_SHORTFLAG, "ideology-file", "File containing ideologies used in speeches (and their structure)."));

      // Hyperparameters
      app_config.registerParameter(new FlaggedOption("alpha", AppConfig.POSITIVE_DOUBLE_PARSER, "1", false, JSAP.NO_SHORTFLAG, "alpha", "Hyperparameter for transition probabilities."));
      app_config.registerParameter(new FlaggedOption("gamma", AppConfig.DOUBLE_PARSER, "0.5", false, JSAP.NO_SHORTFLAG, "gamma", "Hyperparameter for restart probabilities."));
      app_config.registerParameter(new FlaggedOption("emission-ideo", AppConfig.DOUBLE_PARSER, "1", false, JSAP.NO_SHORTFLAG, "emission-ideo", "Emission prior for terms that are ideologically loaded."));
      app_config.registerParameter(new FlaggedOption("emission-other", AppConfig.DOUBLE_PARSER, "0.001", false, JSAP.NO_SHORTFLAG, "emission-other", "Emission prior for terms."));

      // Gibbs sampler stuff
      app_config.registerParameter(new FlaggedOption("save-interval", AppConfig.POSITIVE_INTEGER_PARSER, "5", false, JSAP.NO_SHORTFLAG, "save-interval", "How often to save state of sampler."));
      app_config.registerParameter(new FlaggedOption("iterations", AppConfig.POSITIVE_INTEGER_PARSER, "100", false, JSAP.NO_SHORTFLAG, "iterations", "Number of iterations to run for."));
      app_config.registerParameter(new FlaggedOption("burnin-iters", AppConfig.INTEGER_PARSER, "0", false, JSAP.NO_SHORTFLAG, "burnin-iters", "Number of iterations to ignore for burn-in."));
      app_config.registerParameter(new FlaggedOption("burnin-samples", AppConfig.INTEGER_PARSER, "200", false, JSAP.NO_SHORTFLAG, "burnin-samples", "Number of samples to discard for burn-in."));
      app_config.registerParameter(new FlaggedOption("mstep-iters", AppConfig.POSITIVE_INTEGER_PARSER, "1", false, JSAP.NO_SHORTFLAG, "mstep-iters", "Conduct M-step whenever iterations is a multiple of this."));
      app_config.registerParameter(new FlaggedOption("sample-count", AppConfig.POSITIVE_INTEGER_PARSER, "500", false, JSAP.NO_SHORTFLAG, "sample-count", "Number of Gibbs samples to generate at each iteration."));
      app_config.registerParameter(new FlaggedOption("sample-interval", AppConfig.POSITIVE_INTEGER_PARSER, "3", false, JSAP.NO_SHORTFLAG, "sample-interval", "Collect samples every N-th interval."));
      // app_config.registerParameter(new Switch("burnin-everyiter", JSAP.NO_SHORTFLAG, "burnin-everyiter", "Discard burn-in samples at every iteration."));
    }
    catch (JSAPException e)
    {
      System.err.println("[Init] Error registering argument options!");
      e.printStackTrace();
      System.exit(-1);
    }

    return app_config;
  }

  private void createOutputDirectory(AppConfig app_config) throws IOException
  {
    BufferedWriter bw = new BufferedWriter(new FileWriter(app_state.getBaseDirFile("settings.txt")));

    bw.write(String.format("# %s\n%s\n\n", app_config.getByID("name").getHelp(), app_data.SETTINGS.NAME));

    bw.write("# All file paths here are relative to resume directory.\n");
    bw.write(String.format("# %s\n--%s %s\n\n", app_config.getByID("data-dir").getHelp(), app_config.getByID("data-dir").getUsageName(), app_data.SETTINGS.DATA_DIR));
    bw.write(String.format("# %s\n--%s %s\n\n", app_config.getByID("output-dir").getHelp(), app_config.getByID("output-dir").getUsageName(), "."));
    // bw.write(String.format("# %s\n--%s %s\n\n", app_config.getByID("sample-dir").getHelp(), app_config.getByID("sample-dir").getUsageName(), "."));
    bw.write(String.format("# %s\n--%s %s\n\n", app_config.getByID("terms-file").getHelp(), app_config.getByID("terms-file").getUsageName(), "terms.txt"));
    bw.write(String.format("# %s\n--%s %s\n\n", app_config.getByID("ideology-file").getHelp(), app_config.getByID("ideology-file").getUsageName(), "ideology.txt"));

    bw.write(String.format("# %s\n--%s %f\n\n", app_config.getByID("alpha").getHelp(), app_config.getByID("alpha").getUsageName(), app_data.SETTINGS.ALPHA));
    bw.write(String.format("# %s\n--%s %f\n\n", app_config.getByID("gamma").getHelp(), app_config.getByID("gamma").getUsageName(), app_data.SETTINGS.GAMMA));
    bw.write(String.format("# %s\n--%s %f\n\n", app_config.getByID("emission-ideo").getHelp(), app_config.getByID("emission-ideo").getUsageName(), app_data.SETTINGS.EMISSION_IDEO));
    bw.write(String.format("# %s\n--%s %f\n\n", app_config.getByID("emission-other").getHelp(), app_config.getByID("emission-other").getUsageName(), app_data.SETTINGS.EMISSION_OTHER));

    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("save-interval").getHelp(), app_config.getByID("save-interval").getUsageName(), app_data.SETTINGS.SAVE_INTERVAL));
    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("iterations").getHelp(), app_config.getByID("iterations").getUsageName(), app_data.SETTINGS.ITERATIONS));
    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("burnin-iters").getHelp(), app_config.getByID("burnin-iters").getUsageName(), app_data.SETTINGS.BURNIN_ITERS));
    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("mstep-iters").getHelp(), app_config.getByID("mstep-iters").getUsageName(), app_data.SETTINGS.MSTEP_ITERS));
    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("sample-count").getHelp(), app_config.getByID("sample-count").getUsageName(), app_data.SETTINGS.SAMPLE_COUNT));
    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("sample-interval").getHelp(), app_config.getByID("sample-interval").getUsageName(), app_data.SETTINGS.SAMPLE_INTERVALS));
    bw.write(String.format("# %s\n--%s %d\n\n", app_config.getByID("burnin-samples").getHelp(), app_config.getByID("burnin-samples").getUsageName(), app_data.SETTINGS.BURNIN_SAMPLES));

    bw.close();

    FileUtils.copyFile(app_data.SETTINGS.TERMS_FILE, new File(FilenameUtils.concat(app_data.SETTINGS.OUTPUT_DIR.getPath(), "terms.txt")));
    FileUtils.copyFile(app_data.SETTINGS.IDEOLOGY_FILE, new File(FilenameUtils.concat(app_data.SETTINGS.OUTPUT_DIR.getPath(), "ideology.txt")));

    app_data.SETTINGS.SAMPLE_DIR.mkdir();
  }

  public static void main(String[] args) throws IOException, NumberFormatException, ClassNotFoundException
  {
    EtchASketchSamplerApp app = new EtchASketchSamplerApp(args);
    app.run();
  }
}
