package edu.cmu.cs.lti.avenue.navigation.tools;

import info.jonclark.lang.Pair;
import info.jonclark.util.ArrayUtils;
import info.jonclark.util.FileUtils;
import info.jonclark.util.StringUtils;

import java.io.File;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;

import edu.cmu.cs.lti.avenue.atavi.AtaviWrapper;
import edu.cmu.cs.lti.avenue.corpus.RawAlignment;
import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.corpus.SentencePairFactory;
import edu.cmu.cs.lti.avenue.projection.ProjectionFeatureExtractor;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree.LabelMode;

public class SortMeteor {

	private static class Sentence {
		public String seg;
		public String ref;
		public String sourceSentence;

		public SmartTree refParsed;

		public String[] actualTokens;
		public String[] refTokens;

		public String meteorScore;
		public String segId;

		public ArrayList<String> oovTokens;
		public ArrayList<String> lSentencesAddedLex;
		public ArrayList<String> lSentencesMissingLex;
		public ArrayList<String> lSentencesAddedFunc;
		public ArrayList<String> lSentencesMissingFunc;
		public ArrayList<String> lSentencesAddedPunc;
		public ArrayList<String> lSentencesMissingPunc;
		public RawAlignment[] alignment;
		public int nReorderings;

		public Sentence(String ref, String seg, String actualSentence, SmartTree refParsed,
				String sourceSentence, String meteorScore, String segId,
				HashSet<String> refContentWords, HashSet<String> refFunctionWords,
				HashSet<String> refPunctuation, HashSet<String> globalFunctionWords,
				HashSet<String> globalPunctuation, HashSet<String> noAlign, boolean group) {

			this.seg = seg;
			this.ref = ref;
			this.sourceSentence = sourceSentence;

			this.meteorScore = meteorScore;
			this.segId = segId;

			this.refParsed = refParsed;

			actualTokens = StringUtils.tokenize(actualSentence);
			refTokens = StringUtils.tokenize(ref);

			HashSet<String> actualContent = new HashSet<String>();
			HashSet<String> actualFunc = new HashSet<String>();
			HashSet<String> actualPunc = new HashSet<String>();
			classifyWords(globalFunctionWords, globalPunctuation, actualSentence, actualContent,
					actualFunc, actualPunc);

			oovTokens = UtfUtils.getUtfTokens(StringUtils.tokenize(actualSentence));
			lSentencesAddedLex = getAdded(actualContent, refContentWords);
			lSentencesMissingLex = getMissing(actualContent, refContentWords);
			lSentencesAddedFunc = getAdded(actualFunc, refFunctionWords);
			lSentencesMissingFunc = getMissing(actualFunc, refFunctionWords);
			lSentencesAddedPunc = getAdded(actualPunc, refPunctuation);
			lSentencesMissingPunc = getMissing(actualPunc, refPunctuation);
			alignment = align(actualTokens, refTokens, noAlign, group);
			nReorderings =
					ProjectionFeatureExtractor.countReorderings(alignment, 0,
							actualTokens.length - 1);
		}

		public String toString() {
			return toString(true);
		}

		public String toString(boolean includeSentences) {
			StringBuilder builder = new StringBuilder();
			if (includeSentences) {
				builder.append(seg + "\n");
				builder.append(ref + "\n");
			}
			builder.append(sourceSentence + "\n\n");
			builder.append("METEOR SCORE: " + meteorScore + "\n");
			builder.append("SEGMENT ID: " + segId + "\n");
			builder.append("OOV: " + "(" + StringUtils.untokenize(oovTokens, ", ") + ")" + "\n");
			builder.append("ADDED LEX: " + "(" + StringUtils.untokenize(lSentencesAddedLex, ", ")
					+ ")" + "\n");
			builder.append("MISSING LEX: " + "("
					+ StringUtils.untokenize(lSentencesMissingLex, ", ") + ")" + "\n");
			builder.append("ADDED FUNC: " + "(" + StringUtils.untokenize(lSentencesAddedFunc, ", ")
					+ ")" + "\n");
			builder.append("MISSING FUNC: " + "("
					+ StringUtils.untokenize(lSentencesMissingFunc, ", ") + ")" + "\n");
			builder.append("ADDED PUNC: " + "(" + StringUtils.untokenize(lSentencesAddedPunc, ", ")
					+ ")" + "\n");
			builder.append("MISSING PUNC: " + "("
					+ StringUtils.untokenize(lSentencesMissingPunc, ", ") + ")" + "\n");
			builder.append("REORDERINGS: " + nReorderings + "\n");
			builder.append("ALIGNMENTS: " + RawAlignment.toString(alignment) + "\n");
			builder.append("\n");
			return builder.toString();
		}
	}

	public static void main(String[] args) throws Exception {
		if (args.length < 3) {
			System.err.println("Usage: program <func_words_file> <ref_file> <1_best_or_meteor_file> [<xfer_output_or_source_file>]  [<parsed_ref_file>]");
			System.exit(1);
		}

		File fileFuncWords = new File(args[0]);
		File fileRef = new File(args[1]);
		File fileMeteor = new File(args[2]);
		File fileXferOutput = args.length > 3 ? new File(args[3]) : null;
		File fileParsedRef = args.length > 4 ? new File(args[4]) : null;

		// TODO: Expose this as a commandline argument
		final boolean group = true;
		final boolean includesMarkup = false;

		int nOOV = 0;
		int nAddedLex = 0;
		int nMissingLex = 0;
		int nAddedFunc = 0;
		int nMissingFunc = 0;
		int nAddedPunc = 0;
		int nMissingPunc = 0;
		int nBadOrdering = 0;

		boolean useParses = (fileParsedRef != null);
		ArrayList<SmartTree> refParses = new ArrayList<SmartTree>();
		if (useParses) {
			for (final String line : StringUtils.tokenize(FileUtils.getFileAsString(fileParsedRef),
					"\n")) {
				if (!line.equals("")) {
					SmartTree parsed =
							SmartTree.parse(line, SmartTree.SOURCE_C_STRUCT_LABEL,
									LabelMode.LABEL_ALL_NODES);
					// PennTreeBankAnalyzer.removeExtraJunk(parsed, true);
					refParses.add(parsed);
				}
			}
		}

		HashSet<String> globalFunctionWords =
				new HashSet<String>(Arrays.asList(StringUtils.tokenize(FileUtils.getFileAsString(
						fileFuncWords).toLowerCase())));
		globalFunctionWords.add("'s");

		HashSet<String> globalPunctuation = new HashSet<String>();
		globalPunctuation.add(".");
		globalPunctuation.add(",");
		globalPunctuation.add(";");
		globalPunctuation.add("!");
		globalPunctuation.add("?");
		globalPunctuation.add("'");
		globalPunctuation.add("\"");
		globalPunctuation.add("-");

		HashSet<String> noAlign = new HashSet<String>(globalPunctuation);
		noAlign.add("the");
		noAlign.add("a");
		noAlign.add("an");
		noAlign.add("that");
		noAlign.add("of");
		noAlign.add("in");
		noAlign.add("and");
		noAlign.add("to");
		noAlign.add("on");

		ArrayList<String> sourceSentences = new ArrayList<String>();
		if (fileXferOutput != null) {
			for (final String line : StringUtils.tokenize(
					FileUtils.getFileAsString(fileXferOutput), "\n")) {
				if (includesMarkup) {
					if (line.startsWith("SrcSent")) {
						String[] tokens = StringUtils.tokenize(line, " ", 3);
						String sourceSentence = tokens[2];
						sourceSentences.add(sourceSentence);
					}
				} else {
					sourceSentences.add(line);
				}
			}
		}

		ArrayList<String> refSegments;
		if (includesMarkup) {
			refSegments =
					StringUtils.allSubstringsBetween(FileUtils.getFileAsString(fileRef,
							Charset.forName("UTF-8")), "<seg ", "</seg>", false);
		} else {
			refSegments =
					ArrayUtils.toArrayList(StringUtils.tokenize(FileUtils.getFileAsString(fileRef,
							Charset.forName("UTF-8")), "\n"));
		}
		ArrayList<HashSet<String>> contentWords = new ArrayList<HashSet<String>>();
		ArrayList<HashSet<String>> functionWords = new ArrayList<HashSet<String>>();
		ArrayList<HashSet<String>> punctuation = new ArrayList<HashSet<String>>();
		for (int i = 0; i < refSegments.size(); i++) {
			String sentence = refSegments.get(i);
			sentence = StringUtils.substringAfter(sentence, ">");
			sentence = StringUtils.replaceFast(sentence, "“", "\" ");
			sentence = StringUtils.replaceFast(sentence, "\"\"", "\"");
			sentence = StringUtils.replaceFast(sentence, "’", " '");
			refSegments.set(i, sentence);

			HashSet<String> sentContent = new HashSet<String>();
			HashSet<String> sentFunc = new HashSet<String>();
			HashSet<String> sentPunc = new HashSet<String>();
			classifyWords(globalFunctionWords, globalPunctuation, sentence, sentContent, sentFunc,
					sentPunc);
			contentWords.add(sentContent);
			functionWords.add(sentFunc);
			punctuation.add(sentPunc);
		}

		ArrayList<String> scoredSegments;
		if (includesMarkup) {
			scoredSegments =
					StringUtils.allSubstringsBetween(FileUtils.getFileAsString(fileMeteor,
							Charset.forName("UTF-8")), "<seg ", "</seg>", true);
		} else {
			scoredSegments =
					ArrayUtils.toArrayList(StringUtils.tokenize(FileUtils.getFileAsString(
							fileMeteor, Charset.forName("UTF-8")), "\n"));
		}

		System.err.println("Loaded " + scoredSegments.size() + " segments.");

		ArrayList<Pair<Double, Sentence>> sorted = new ArrayList<Pair<Double, Sentence>>();
		for (int i = 0; i < scoredSegments.size(); i++) {
			String scoredSeg = scoredSegments.get(i);
			String strRef = refSegments.get(i);
			String sourceSentence = fileXferOutput == null ? null : sourceSentences.get(i);
			SmartTree refParsed = useParses ? refParsed = refParses.get(i) : null;

			String strScore =
					includesMarkup ? StringUtils.substringBetween(scoredSeg, "score=", " ref=")
							: "-1";
			String segId =
					includesMarkup ? StringUtils.substringBetween(scoredSeg, "id=\"", "\" ") : "-1";

			String strActual =
					includesMarkup
							? StringUtils.substringBetween(scoredSeg, ">", "</seg>").toLowerCase()
							: scoredSeg;
			Sentence sentence =
					new Sentence(strRef, scoredSeg, strActual, refParsed, sourceSentence, strScore,
							segId, contentWords.get(i), functionWords.get(i), punctuation.get(i),
							globalFunctionWords, globalPunctuation, noAlign, group);

			nOOV += sentence.oovTokens.size();
			nAddedLex += sentence.lSentencesAddedLex.size();
			nMissingLex += sentence.lSentencesMissingLex.size();
			nAddedFunc += sentence.lSentencesAddedFunc.size();
			nMissingFunc += sentence.lSentencesMissingFunc.size();
			nAddedPunc += sentence.lSentencesAddedPunc.size();
			nMissingPunc += sentence.lSentencesMissingPunc.size();
			nBadOrdering += sentence.nReorderings;

			// RANK BY REORDERINGS
			double score = Double.parseDouble(strScore);
			sorted.add(new Pair<Double, Sentence>((double) sentence.nReorderings, sentence));
		}

		Collections.sort(sorted, new Comparator<Pair<Double, Sentence>>() {
			public int compare(Pair<Double, Sentence> o1, Pair<Double, Sentence> o2) {
				return Double.compare(o2.first, o1.first);
			}
		});

		int i = 0;
		for (final Pair<Double, Sentence> p : sorted) {
			if (p.first > 0) {
				Sentence s = p.second;

				boolean ec = true;
				if (ec) {
					System.out.println("newpair");
					System.out.println("context: reorderings = " + s.nReorderings);
					System.out.println("srcsent: " + s.sourceSentence);
					System.out.println("tgtsent: " + s.ref);
					System.out.println();
				} else {
					System.out.println(s);
				}

				// TODO: Incorporate partial parses from XFER system and add
				// dummy struct on top of it?

				// TODO: Determine which are needed rules

				String strAlignment = RawAlignment.toString(s.alignment);
				SentencePair pair =
						SentencePairFactory.getInstance(i, s.refTokens, s.actualTokens,
								new boolean[s.refTokens.length],
								new boolean[s.actualTokens.length], strAlignment, "",
								s.toString(false), null, s.refParsed, null, null, "", "", "");
				if (!useParses) {
					AtaviWrapper.generateDummyCStruct(pair, false);
					AtaviWrapper.generateDummyCStruct(pair, true);
				} else {
					AtaviWrapper.project(pair, false, false);
				}
				i++;
			}
		}

		System.err.println(i + " sentences selected.");

		System.out.println("nOOV = " + nOOV);
		System.out.println("nAddedLex = " + nAddedLex);
		System.out.println("nMissingLex = " + nMissingLex);
		System.out.println("nAddedFunc = " + nAddedFunc);
		System.out.println("nMissingFunc = " + nMissingFunc);
		System.out.println("nAddedPunc = " + nAddedPunc);
		System.out.println("nMissingPunc = " + nMissingPunc);
		System.out.println("nBadOrdering = " + nBadOrdering);

		System.out.println();
		System.out.println("TODO: Reorderings over sentence length.");
		System.out.println("TODO: Max sentence length parameter.");
	}

	private static void classifyWords(HashSet<String> globalFunctionWords,
			HashSet<String> globalPunctuation, String sentence, HashSet<String> sentContent,
			HashSet<String> sentFunc, HashSet<String> sentPunc) {
		for (final String token : StringUtils.tokenize(sentence)) {
			if (globalPunctuation.contains(token)) {
				sentPunc.add(token);
			} else if (globalFunctionWords.contains(token)) {
				sentFunc.add(token);
			} else {
				sentContent.add(token);
			}
		}
	}

	public static ArrayList<String> getMissing(HashSet<String> actual, HashSet<String> expected) {
		ArrayList<String> list = new ArrayList<String>();
		for (final String str : expected) {
			if (!actual.contains(str)) {
				list.add(str);
			}
		}
		return list;
	}

	public static ArrayList<String> getAdded(HashSet<String> actual, HashSet<String> expected) {
		ArrayList<String> list = new ArrayList<String>();
		for (final String str : actual) {
			if (!expected.contains(str)) {
				list.add(str);
			}
		}
		return list;
	}

	public static RawAlignment[] align(String[] actualTokens, String[] refTokens,
			HashSet<String> noAlign, boolean group) {

		ArrayList<RawAlignment> rawAlignments = new ArrayList<RawAlignment>();

		ArrayList<Integer> actualIndices = null;
		ArrayList<Integer> refIndices = null;
		for (int i = 0; i < actualTokens.length; i++) {
			int nPosInRef = -1;
			if (!noAlign.contains(actualTokens[i])) {
				nPosInRef = findRefIndex(i, actualTokens, refTokens);
			}

			if (nPosInRef != -1) {
				if (actualIndices == null) {
					actualIndices = new ArrayList<Integer>();
					refIndices = new ArrayList<Integer>();
				}

				// AVENUE Alignments are 1-based
				// only cluster groups that have the same order
				// and are continuous for source and target
				if ((actualIndices.size() > 0 && refIndices.size() > 0)
						&& (!group || (actualIndices.get(actualIndices.size() - 1) != i || refIndices.get(refIndices.size() - 1) != nPosInRef))) {
					RawAlignment raw = new RawAlignment();
					raw.targetTerminals = ArrayUtils.toArray(actualIndices);
					raw.sourceTerminals = ArrayUtils.toArray(refIndices);
					rawAlignments.add(raw);

					actualIndices = new ArrayList<Integer>();
					refIndices = new ArrayList<Integer>();
				}
				actualIndices.add(i + 1);
				refIndices.add(nPosInRef + 1);

			}
		}

		if (actualIndices != null) {
			RawAlignment raw = new RawAlignment();
			raw.targetTerminals = ArrayUtils.toArray(actualIndices);
			raw.sourceTerminals = ArrayUtils.toArray(refIndices);
			rawAlignments.add(raw);

			actualIndices = null;
			refIndices = null;
		}

		return rawAlignments.toArray(new RawAlignment[rawAlignments.size()]);
	}

	public static int findRefIndex(int nActualPos, String[] actualTokens, String[] refTokens) {

		// this is the nth instance of this token in actualTokens?
		int n = 0;
		for (int i = 0; i <= nActualPos; i++) {
			if (actualTokens[i].equals(actualTokens[nActualPos])) {
				n++;
			}
		}

		int nFound = 0;
		int nRefPos = -1;
		for (int i = 0; i < refTokens.length; i++) {
			if (refTokens[i].equals(actualTokens[nActualPos])) {
				// if (nRefPos != -1) {
				// // non-unique
				// return -1;

				nFound++;
				if (nFound == n) {
					nRefPos = i;
					break;
				}
			}
		}
		return nRefPos;
	}
}
