package edu.cmu.cs.lti.avenue.navigation.search.generation1;

import info.jonclark.stat.SecondTimer;
import info.jonclark.util.ArrayUtils;
import info.jonclark.util.FormatUtils;
import info.jonclark.util.HashUtils;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.TreeMap;
import java.util.Map.Entry;

import edu.cmu.cs.lti.avenue.corpus.Corpus;
import edu.cmu.cs.lti.avenue.corpus.CorpusException;
import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.corpus.Serializer;
import edu.cmu.cs.lti.avenue.navigation.search.generation1.AlignmentScorer.DiscontinuityType;
import edu.cmu.cs.lti.avenue.navigation.search.generation1.DiversityScorer.SmoothingMode;

public class HeuristicTester {
	public static void main(String[] args) throws Exception {
		if (args.length != 3 && args.length != 4) {
			System.err.println("Usage: program <elicitated_data_file> <ranking_function> <n-best> [--verbose]");
			System.exit(1);
		}
		
		String encoding = "UTF-8";

		File elicitedFile = new File(args[0]);
		String rankingFunction = args[1];
		int nBest = Integer.parseInt(args[2]);

		final boolean verbose = ArrayUtils.unsortedArrayContains(args, "--verbose");

		SecondTimer sentencesTime = new SecondTimer(true, true);
		Corpus corpus = Serializer.loadSentencePairs(elicitedFile, encoding);
		System.out.println("Loaded " + corpus.getSentences().size() + " sentences in "
				+ sentencesTime.getSecondsFormatted() + " seconds");

		double averageWeight = .5;
		double sumWeight = 0;
		double maxWeight = .5;
		double productWeight = 0;

		AlignmentScorer alignmentScorer = new AlignmentScorer();
		DiversityScorer diversityScorer =
				new DiversityScorer(corpus.getSentences(), new ArithmeticBlender(averageWeight, maxWeight,
						productWeight, sumWeight), SmoothingMode.ADD1);

		SecondTimer compTime = new SecondTimer(true, true);

		TreeMap<Integer, ArrayList<SentencePair>> crossingRanking =
				new TreeMap<Integer, ArrayList<SentencePair>>();
		TreeMap<String, Integer> allCounts = new TreeMap<String, Integer>();
		analyzeCrossings(corpus.getSentences(), alignmentScorer, crossingRanking, allCounts);

		TreeMap<Integer, ArrayList<SentencePair>> reorderingRanking =
				analyzeReorderings(corpus.getSentences(), alignmentScorer);

		TreeMap<Integer, ArrayList<SentencePair>> discontinuityRanking =
				analyzeDiscontinuities(corpus.getSentences(), alignmentScorer,
						DiscontinuityType.EXTERNAL_INTERRUPTION);

		TreeMap<Double, ArrayList<SentencePair>> diversityRanking =
				analyzeDiversity(corpus.getSentences(), diversityScorer);

		// TODO: Rank based on frequency of discontinuous words
		// TODO: Find lost sheep of Urdu data
		// TODO: Rank based on structure vs other rules as vamshi does
		// TODO: Does is affect Urdu TO ENGLISH translation?
		// TODO: Fix crossings metric

		if (rankingFunction.equals("--crossings")) {
			int i = corpus.getSentences().size();
			for (final Entry<Integer, ArrayList<SentencePair>> entry : crossingRanking.entrySet()) {
				for (final SentencePair pair : entry.getValue()) {
					i--;
					if (i < nBest) {
						System.out.println("Rank " + i);
						System.out.println("CROSSINGS: " + entry.getKey());
						System.out.println(pair.serialize());
					}
				}
			}

		} else if (rankingFunction.equals("--diversity")) {
			int i = corpus.getSentences().size();
			for (final Entry<Double, ArrayList<SentencePair>> pairEntry : diversityRanking.entrySet()) {
				for (final SentencePair pair : pairEntry.getValue()) {
					i--;
					if (i < nBest) {
						System.out.println("Rank " + i);
						System.out.println(pair.serialize());
						System.out.println("DIVERSITY: "
								+ FormatUtils.formatDoubleExp(pairEntry.getKey()));

						TreeMap<Double, String> types =
								diversityScorer.getIndividualConstituentRanking(pair);
						for (final Entry<Double, String> typeEntry : types.entrySet()) {
							System.out.println(typeEntry.getValue() + ":\t"
									+ FormatUtils.formatDoubleExp(typeEntry.getKey()));
						}
						System.out.println("\n");
					}
				}
			}
		} else if (rankingFunction.equals("--reordering")) {
			int i = corpus.getSentences().size();
			for (final Entry<Integer, ArrayList<SentencePair>> pairEntry : reorderingRanking.entrySet()) {
				for (final SentencePair pair : pairEntry.getValue()) {
					i--;
					if (i < nBest) {
						System.out.println("Rank " + i);
						System.out.println(pair.serialize());
						System.out.println("REORDERINGS: " + pairEntry.getKey());

						HashMap<String, Integer> reorderings =
								alignmentScorer.getReorderingsWithinConstituents(pair);
						for (final Entry<String, Integer> reordering : reorderings.entrySet()) {
							System.out.println(reordering.getKey() + ": " + reordering.getValue());
						}
						System.out.println("\n");
					}
				}
			}
		} else if (rankingFunction.equals("--discontinuity")) {
			int i = corpus.getSentences().size();
			for (final Entry<Integer, ArrayList<SentencePair>> pairEntry : discontinuityRanking.entrySet()) {
				for (final SentencePair pair : pairEntry.getValue()) {
					i--;
					if (i < nBest) {
						System.out.println("Rank " + i);
						System.out.println(pair.serialize());
						System.out.println("DISCONTINUITIES: " + pairEntry.getKey());

						// for (final DiscontinuityType d :
						// DiscontinuityType.values()) {
						// HashMap<TreeNode, ArrayList<Integer>>
						// discontinuityList =
						// alignmentScorer.getDiscontinuitiesWithinConstituentL(pair,
						// d);
						//
						// for (final Entry<TreeNode, ArrayList<Integer>>
						// discontinuity :
						// discontinuityList.entrySet()) {
						// System.out.println("DISCONTINUITY: "
						// + discontinuity.getKey().toCfgRule() + " = "
						// + discontinuity.getValue().toString());
						//
						// ArrayList<String> targetLex =
						// pair.getAlignment().getTargetLexicons(pair,
						// discontinuity.getKey());
						// System.out.print("\t" + targetLex.toString()
						// + " was interrupted by [");
						// for (final int targetIndex :
						// discontinuity.getValue()) {
						// System.out.print(pair.getNormalizedTargetSentence()[targetIndex
						// - 1]
						// + ", ");
						// }
						// System.out.println("]");
						// }

						// for (final Entry<TreeNode, ArrayList<Integer>>
						// discontinuity : discontinuityList.entrySet()) {
						// System.out.print("DISCONTINUITY: " +
						// discontinuity.getKey().toCfgRule() + " = ");
						// for(final int j : discontinuity.getValue()) {
						// System.out.print(pair.getNormalizedTargetSentence()[j]
						// + ", ");
						// }
						// System.out.println();
						// }
						// }
						System.out.println("\n");
					}
				}
			}
		} else {
			System.out.println("Unknown ranking function: " + rankingFunction);
		}

		System.out.println("Evaluated heuristics in " + compTime.getSecondsFormatted() + " seconds");
	}

	private static TreeMap<Double, ArrayList<SentencePair>> analyzeDiversity(
			ArrayList<SentencePair> sentences, DiversityScorer diversityScorer) {

		TreeMap<Double, ArrayList<SentencePair>> diversityRanking =
				new TreeMap<Double, ArrayList<SentencePair>>();
		for (final SentencePair pair : sentences) {
			double diversityScore = diversityScorer.getDiversityScore(pair);
			HashUtils.append(diversityRanking, diversityScore, pair);
		}
		return diversityRanking;
	}

	private static TreeMap<Integer, ArrayList<SentencePair>> analyzeDiscontinuities(
			ArrayList<SentencePair> sentences, AlignmentScorer alignmentScorer, DiscontinuityType d)
			throws CorpusException {

		TreeMap<Integer, ArrayList<SentencePair>> discontinuityRanking =
				new TreeMap<Integer, ArrayList<SentencePair>>();
		for (final SentencePair pair : sentences) {
			HashMap<String, Integer> discontinuities =
					alignmentScorer.getDiscontinuitiesWithinConstituentC(pair, d);

			int total = 0;
			for (final Entry<String, Integer> discontinuity : discontinuities.entrySet()) {
				total += discontinuity.getValue();
			}
			HashUtils.append(discontinuityRanking, total, pair);
		}
		return discontinuityRanking;
	}

	private static TreeMap<Integer, ArrayList<SentencePair>> analyzeReorderings(
			ArrayList<SentencePair> sentences, AlignmentScorer alignmentScorer)
			throws CorpusException {

		TreeMap<Integer, ArrayList<SentencePair>> reorderingRanking =
				new TreeMap<Integer, ArrayList<SentencePair>>();
		for (final SentencePair pair : sentences) {
			HashMap<String, Integer> reorderings =
					alignmentScorer.getReorderingsWithinConstituents(pair);

			int total = 0;
			for (final Entry<String, Integer> reordering : reorderings.entrySet()) {
				total += reordering.getValue();
			}
			HashUtils.append(reorderingRanking, total, pair);
		}

		return reorderingRanking;
	}

	private static void analyzeCrossings(ArrayList<SentencePair> sentences,
			AlignmentScorer alignmentScorer,
			TreeMap<Integer, ArrayList<SentencePair>> crossingRanking,
			TreeMap<String, Integer> allCounts) throws CorpusException {

		for (final SentencePair pair : sentences) {

			int nCrossings = alignmentScorer.getCrossings(pair);
			HashUtils.append(crossingRanking, nCrossings, pair);

			HashMap<String, Integer> counts = alignmentScorer.getConstituentCounts(pair);
			for (final Entry<String, Integer> count : counts.entrySet()) {
				// System.out.println("CONSTITUENTS: " + count.getKey() + " = "
				// + count.getValue());
				HashUtils.add(allCounts, count.getKey(), count.getValue());
			}
		}
	}
}
