package edu.cmu.cs.lti.avenue.navigation.featuredetection.inductive.evidence;

import info.jonclark.util.LatexUtils;
import info.jonclark.util.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

import edu.cmu.cs.lti.avenue.corpus.CorpusException;
import edu.cmu.cs.lti.avenue.corpus.PhiPlusMapping;
import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.morphology.Paradigm;
import edu.cmu.cs.lti.avenue.morphology.Segmenter;
import edu.cmu.cs.lti.avenue.morphology.SegmenterException;
import edu.cmu.cs.lti.avenue.navigation.featuredetection.inductive.FeatureValueInteraction;
import edu.cmu.cs.lti.avenue.navigation.featuredetection.inductive.MinimalPairMapping;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;

/**
 * Determines in what way two sentences are different and on what constituent a
 * given feature is marked. This class is also responsible for calculating
 * things such as added word and morpheme sets. It is critical that this class
 * be implemented very efficiently with regard to memory since thousands to
 * millions of instances will be created simultaneously.
 * 
 * @author jon
 */
public class FeatureMarking extends ArcEvidence {

	private final SentencePair pairA;
	private final SentencePair pairB;
	private final FeatureValueInteraction featureValueA;
	private final FeatureValueInteraction featureValueB;
	private final MinimalPairMapping minPairA;
	private final MinimalPairMapping minPairB;

	// XXX: Hack to reduce memory consumption
	// static initializers are a bad idea for software engineering
	private static boolean allowMarkingsOnMe;
	private static boolean allowMarkingsOnMyDependends;
	private static boolean allowMarkingsOnMyGovernor;
	private static boolean allowMarkingsOnOthers;
	private static boolean initDone = false;

	public static void init(boolean allowMarkingsOnMe, boolean allowMarkingsOnMyDependends,
			boolean allowMarkingsOnMyGovernor, boolean allowMarkingsOnOthers) {

		FeatureMarking.allowMarkingsOnMe = allowMarkingsOnMe;
		FeatureMarking.allowMarkingsOnMyDependends = allowMarkingsOnMyDependends;
		FeatureMarking.allowMarkingsOnMyGovernor = allowMarkingsOnMyGovernor;
		FeatureMarking.allowMarkingsOnOthers = allowMarkingsOnOthers;
		FeatureMarking.initDone = true;
	}

	private static final List<Paradigm> EMPTY_LIST = new ArrayList<Paradigm>(0);

	int hashCode = -1;

	private static class Cache {
		private boolean[] additionsA;
		private boolean[] additionsB;
		private boolean[] reorderingsA;
		private boolean[] reorderingsB;
		private boolean[] meA;
		private boolean[] meB;
		private boolean[] depA;
		private boolean[] depB;
		private boolean[] govA;
		private boolean[] govB;

		private TreeSet<String> addedWordsA;
		private TreeSet<String> addedWordsB;
		private TreeSet<ObservedMorpheme> addedMorphemesA;
		private TreeSet<ObservedMorpheme> addedMorphemesB;
		private TreeSet<String> reorderedWords;
		// private String strReorderedWords = null;
		// private String strAddedWordsA = null;
		// private String strAddedWordsB = null;
	}
	private Cache cache;

	public FeatureMarking(SentencePair pairA, SentencePair pairB,
			FeatureValueInteraction featureValueA, FeatureValueInteraction featureValueB,
			MinimalPairMapping minPairA, MinimalPairMapping minPairB) {

		if (initDone == false) {
			throw new RuntimeException(
					"You must first call the static init() method before using this class.");
		}

		this.pairA = pairA;
		this.pairB = pairB;
		this.featureValueA = featureValueA;
		this.featureValueB = featureValueB;
		this.minPairA = minPairA;
		this.minPairB = minPairB;

		analyze();
	}

	private void analyze() {
		// determine if the marking is on the head, dependent, etc.
	}

	public SentencePair getPairA() {
		return pairA;
	}

	public SentencePair getPairB() {
		return pairB;
	}

	public FeatureValueInteraction getFeatureValueA() {
		return featureValueA;
	}

	public FeatureValueInteraction getFeatureValueB() {
		return featureValueB;
	}

	/**
	 * @param sent
	 * @param sourceIndex
	 *            0-based
	 * @param targetArray
	 * @throws CorpusException
	 */
	private static void mapSourceToTargetIndices(SentencePair sent, int sourceIndex,
			boolean[] targetArray) throws CorpusException {

		// convert 0-based to 1-based
		int[] targetIndices = sent.getNormalizedAlignment().getTargetIndices(sourceIndex + 1);
		for (int targetIndex : targetIndices) {

			targetIndex--;

			// convert 1-based to 0-based
			targetArray[targetIndex] = true;

			String sourceWord = sent.getNormalizedSourceTokens()[sourceIndex];
			String targetWord = sent.getNormalizedTargetTokens()[targetIndex];
			System.out.println("FEATURE MARKING PROJECTION: sent #" + sent.getId() + " "
					+ sourceWord + " --> " + targetWord);
		}
	}

	private static void findMarkedOn(SentencePair sent, MinimalPairMapping mapping,
			boolean[] additions, boolean[] me, boolean[] dep, boolean[] gov) throws CorpusException {

		// 1) find the head ("me") for the features being examined in this
		// marking

		for (int wildcardNodeAbsoluteTreeIndex : mapping.wildcardFeatureStructure.getWildcardNodes()) {

			// convert absolute tree labeling index to terminal index
			TreeNode wildcardNode =
					sent.getFeatureStructure().getByAbsoluteTreeIndex(wildcardNodeAbsoluteTreeIndex);

			// only context nodes such as root/actor/undergoer/clause are
			// included in phi mapping
			TreeNode contextNode = wildcardNode.getParentNode();
			int wildcardTerminalNodeIndex = contextNode.getTreeLabelingIndex();

			assert wildcardTerminalNodeIndex != -1 : "Not a labeled node";

			// 2) get which source indices correspond to "me" for this feature
			int[] meSourceIndices =
					sent.getPhiPlusMapping().getPhiInverse(wildcardTerminalNodeIndex);

			for (int meSourceIndex : meSourceIndices) {

				assert meSourceIndex != -1 : "phi inverse returned -1";

				if (meSourceIndex + 1 != PhiPlusMapping.NO_CONSTITUENT) {

					// 3) project source dependency tree fragment for this
					// feature onto the target side, preserving ambiguity for
					// each of "me" "my dependent" and "my governor"

					System.out.println("FEATURE MARKING ME: "
							+ Arrays.toString(mapping.featureNames) + ":");

					mapSourceToTargetIndices(sent, meSourceIndex, me);

					// 4) determine which words represent "my dependent" and "my
					// governor" by following head mapping

					System.out.println("FEATURE MARKING GOV: "
							+ Arrays.toString(mapping.featureNames) + ":");
					int govSourceIndex =
							sent.getPhiPlusMapping().getImmediateHead(meSourceIndex + 1) - 1;
					mapSourceToTargetIndices(sent, govSourceIndex, gov);

					System.out.println("FEATURE MARKING DEP: "
							+ Arrays.toString(mapping.featureNames) + ":");
					int[] depSourceIndices =
							sent.getPhiPlusMapping().getImmediateDependents(meSourceIndex);
					for (int depSourceIndex : depSourceIndices) {
						mapSourceToTargetIndices(sent, depSourceIndex, dep);
					}
				}
			}
		}

		// will we ever have overlap b/t me, my dependent, and my governor?
		// we could.

		// 5) for each added word, examine how the markings
		// align to the projected target dependency tree
		for (int i = 0; i < additions.length; i++) {
			if (additions[i]) {

				// 6) disallow certain additions based on these projections
				// TODO: Use these decisions in the feature value
				// clustering stage
				if (me[i] && allowMarkingsOnMe) {
					;
				} else if (dep[i] && allowMarkingsOnMyDependends) {
					;
				} else if (gov[i] && allowMarkingsOnMyGovernor) {
					;
				} else if (allowMarkingsOnOthers) {
					;
				} else {
					additions[i] = false;
				}
			}
		}
	}

	private void calculate() throws CorpusException {
		if (cache == null) {
			cache = new Cache();

			String[] arrTargetA = pairA.getNormalizedTargetTokens();
			String[] arrTargetB = pairB.getNormalizedTargetTokens();

			cache.additionsA = new boolean[arrTargetA.length];
			cache.additionsB = new boolean[arrTargetB.length];
			cache.reorderingsA = new boolean[arrTargetA.length];
			cache.reorderingsB = new boolean[arrTargetB.length];
			cache.meA = new boolean[arrTargetA.length];
			cache.meB = new boolean[arrTargetB.length];
			cache.depA = new boolean[arrTargetA.length];
			cache.depB = new boolean[arrTargetB.length];
			cache.govA = new boolean[arrTargetA.length];
			cache.govB = new boolean[arrTargetB.length];

			TreeSet<String> hsTargetA = new TreeSet<String>();
			TreeSet<String> hsTargetB = new TreeSet<String>();
			for (final String str : arrTargetA)
				hsTargetA.add(str);
			for (final String str : arrTargetB)
				hsTargetB.add(str);

			for (int i = 0; i < arrTargetA.length; i++) {
				if (!hsTargetB.contains(arrTargetA[i])) {
					// word added in pairA at i
					cache.additionsA[i] = true;
				}
			}

			for (int i = 0; i < arrTargetB.length; i++) {
				if (!hsTargetA.contains(arrTargetB[i])) {
					// word added in pairB at i
					cache.additionsB[i] = true;
				}
			}

			// TODO: Figure out if it's marked on "me" "dependent" or "governor"
			// ...or "other"
			// by following word alignments back to the source side
			findMarkedOn(pairA, minPairA, cache.additionsA, cache.meA, cache.depA, cache.govA);
			findMarkedOn(pairB, minPairB, cache.additionsB, cache.meB, cache.depB, cache.govB);

			if (arrTargetA.length != arrTargetB.length) {

			} else {

				int i = 0;
				int j = 0;
				while (i < arrTargetA.length && j < arrTargetB.length) {
					if (cache.additionsA[i]) {
						i++;
						continue;
					}
					if (cache.additionsB[j]) {
						j++;
						continue;
					}

					if (!arrTargetA[i].equals(arrTargetB[j])) {
						cache.reorderingsA[i] = true;
						cache.reorderingsB[j] = true;
					}

					i++;
					j++;
				}
			}
		}
	}

	public TreeSet<String> getReorderedWords() throws CorpusException {
		calculate();

		if (cache.reorderedWords == null) {
			cache.reorderedWords = new TreeSet<String>();
			String[] arrTargetA = pairA.getNormalizedTargetTokens();
			String[] arrTargetB = pairB.getNormalizedTargetTokens();
			for (int i = 0; i < cache.reorderingsA.length; i++) {
				if (cache.reorderingsA[i]) {
					cache.reorderedWords.add(arrTargetA[i]);
				}
			}
			for (int i = 0; i < cache.reorderingsB.length; i++) {
				if (cache.reorderingsB[i]) {
					cache.reorderedWords.add(arrTargetB[i]);
				}
			}
		}
		return cache.reorderedWords;
	}

	private void initAddedWords() {
		if (cache.addedWordsA == null) {
			cache.addedWordsA = new TreeSet<String>();
			cache.addedWordsB = new TreeSet<String>();
			String[] arrTargetA = pairA.getNormalizedTargetTokens();
			String[] arrTargetB = pairB.getNormalizedTargetTokens();
			for (int i = 0; i < cache.reorderingsA.length; i++) {
				if (cache.additionsA[i]) {
					cache.addedWordsA.add(arrTargetA[i]);
				}
			}
			for (int i = 0; i < cache.reorderingsB.length; i++) {
				if (cache.additionsB[i]) {
					cache.addedWordsB.add(arrTargetB[i]);
				}
			}
		}
	}

	private static void generateMorphemeListWithoutSegmenting(Set<ObservedMorpheme> morphemes,
			SentencePair pair, boolean[] additions, boolean[] me, boolean[] dep, boolean[] gov) {

		for (int i = 0; i < additions.length; i++) {
			if (additions[i]) {

				String word = pair.getNormalizedTargetTokens()[i];
				morphemes.add(new ObservedMorpheme(EMPTY_LIST, word, new String[] { word }, me[i],
						dep[i], gov[i]));
			}
		}
	}

	private static void generateSegmentedMorphemes(Segmenter segmenter,
			Set<ObservedMorpheme> morphemes, SentencePair pair, boolean[] additions, boolean[] me,
			boolean[] dep, boolean[] gov) throws SegmenterException {

		for (int i = 0; i < additions.length; i++) {
			if (additions[i]) {

				String word = pair.getNormalizedTargetTokens()[i];
				List<Paradigm> paradigms = segmenter.getParadigms(word);
				String[] segmentedWord = segmenter.getCombinedSegmentation(word);

				for (String strMorpheme : segmentedWord) {
					ObservedMorpheme addedMorpheme =
							new ObservedMorpheme(paradigms, strMorpheme, segmentedWord, me[i],
									dep[i], gov[i]);
					morphemes.add(addedMorpheme);
				}
			}
		}
	}

	private void initAddedMorphemes(Segmenter segmenter) throws SegmenterException {

		if (cache.addedMorphemesA == null) {

			cache.addedMorphemesA = new TreeSet<ObservedMorpheme>();
			cache.addedMorphemesB = new TreeSet<ObservedMorpheme>();

			if (segmenter == null) {

				generateMorphemeListWithoutSegmenting(cache.addedMorphemesA, pairA,
						cache.additionsA, cache.meA, cache.depA, cache.govA);
				generateMorphemeListWithoutSegmenting(cache.addedMorphemesB, pairB,
						cache.additionsB, cache.meB, cache.depB, cache.govB);

			} else {

				generateSegmentedMorphemes(segmenter, cache.addedMorphemesA, pairA,
						cache.additionsA, cache.meA, cache.depA, cache.govA);
				generateSegmentedMorphemes(segmenter, cache.addedMorphemesB, pairB,
						cache.additionsB, cache.meB, cache.depB, cache.govB);

				// TODO: optimize this process
				ArrayList<ObservedMorpheme> intersection = new ArrayList<ObservedMorpheme>();

				// make sure sets have zero intersection
				for (ObservedMorpheme morph : cache.addedMorphemesA) {
					if (cache.addedMorphemesB.contains(morph)) {
						intersection.add(morph);
					}
				}

				cache.addedMorphemesA.removeAll(intersection);
				cache.addedMorphemesB.removeAll(intersection);

				String[] arrTargetA = pairA.getNormalizedTargetTokens();
				String[] arrTargetB = pairB.getNormalizedTargetTokens();
				for (int i = 0; i < cache.reorderingsA.length; i++) {
					if (cache.additionsA[i]) {
						cache.addedWordsA.add(arrTargetA[i]);
					}
				}
				for (int i = 0; i < cache.reorderingsB.length; i++) {
					if (cache.additionsB[i]) {
						cache.addedWordsB.add(arrTargetB[i]);
					}
				}
			}
		}
	}

	public TreeSet<String> getAddedWordsA() throws CorpusException {
		calculate();
		initAddedWords();
		return cache.addedWordsA;
	}

	public TreeSet<String> getAddedWordsB() throws CorpusException {
		calculate();
		initAddedWords();
		return cache.addedWordsB;
	}

	@Override
	public TreeSet<ObservedMorpheme> getAddedMorphemesA(Segmenter segmenter)
			throws SegmenterException, CorpusException {
		calculate();
		initAddedWords();
		initAddedMorphemes(segmenter);
		return cache.addedMorphemesA;
	}

	@Override
	public TreeSet<ObservedMorpheme> getAddedMorphemesB(Segmenter segmenter)
			throws SegmenterException, CorpusException {
		calculate();
		initAddedWords();
		initAddedMorphemes(segmenter);
		return cache.addedMorphemesB;
	}

	// public String getReorderedWordsString() {
	// calculate();
	// if (cache.strReorderedWords == null) {
	// ArrayList<String> rw = new ArrayList<String>(getReorderedWords());
	// Collections.sort(rw);
	// cache.strReorderedWords = StringUtils.untokenize(rw, "_");
	// }
	// return cache.strReorderedWords;
	// }
	//
	// public String getAddedWordsStringA() {
	// calculate();
	// if (cache.strAddedWordsA == null) {
	// ArrayList<String> aw = new ArrayList<String>(getAddedWordsA());
	// Collections.sort(aw);
	// cache.strAddedWordsA = StringUtils.untokenize(aw, "_");
	// }
	// return cache.strAddedWordsA;
	// }
	//
	// public String getAddedWordsStringB() {
	// calculate();
	// if (cache.strAddedWordsB == null) {
	// ArrayList<String> aw = new ArrayList<String>(getAddedWordsB());
	// Collections.sort(aw);
	// cache.strAddedWordsB = StringUtils.untokenize(aw, "_");
	// }
	// return cache.strAddedWordsB;
	// }

	public int hashCode() {
		if (hashCode == -1) {
			String str =
					pairA.getMyLine() + pairB.getMyLine() + featureValueA.getName()
							+ featureValueB.getName();
			hashCode = str.hashCode();
		}
		return hashCode;
	}

	public boolean equals(Object obj) {
		if (obj instanceof ArcEvidence) {
			ArcEvidence other = (ArcEvidence) obj;
			return this.pairA == other.getPairA() && this.pairB == other.getPairB()
					&& this.featureValueA == other.getFeatureValueA()
					&& this.featureValueB == other.getFeatureValueB();
		} else {
			return false;
		}
	}

	public String toLatexString() throws CorpusException {
		calculate();

		String[] arrTargetA = pairA.getNormalizedTargetTokens();
		String[] arrTargetB = pairB.getNormalizedTargetTokens();

		int nColumns = Math.max(arrTargetA.length, arrTargetB.length) + 2;
		StringBuilder builder =
				new StringBuilder("\\begin{tabular}{"
						+ StringUtils.duplicateCharacter('c', nColumns) + "}\n");

		builder.append(featureValueA.getName() + ": & ");
		for (int i = 0; i < arrTargetA.length; i++) {
			if (cache.additionsA[i]) {
				builder.append("\\textbf{" + clean(arrTargetA[i]) + "}");
			} else if (cache.reorderingsA[i]) {
				builder.append("\\emph{" + clean(arrTargetA[i]) + "}");
			} else {
				builder.append(clean(arrTargetA[i]));
			}
			builder.append(" & ");
		}

		// make sure translations line up
		for (int i = arrTargetA.length; i < nColumns - 2; i++) {
			builder.append(" & ");
		}
		builder.append("(" + pairA.getDisplaySourceSentence() + ") sentnum: " + pairA.getId()
				+ " \\\\ \n");

		builder.append(featureValueB.getName() + ": & ");
		for (int i = 0; i < arrTargetB.length; i++) {
			if (cache.additionsB[i]) {
				builder.append("\\textbf{" + clean(arrTargetB[i]) + "}");
			} else if (cache.reorderingsB[i]) {
				builder.append("\\emph{" + clean(arrTargetB[i]) + "}");
			} else {
				builder.append(clean(arrTargetB[i]));
			}
			builder.append(" & ");
		}

		// make sure translations line up
		for (int i = arrTargetB.length; i < nColumns - 2; i++) {
			builder.append(" & ");
		}
		builder.append("(" + pairB.getDisplaySourceSentence() + ") sentnum: " + pairB.getId()
				+ " \\\\ \n");
		builder.append("\\hline");
		builder.append("\\end{tabular}\n");
		return builder.toString();
	}

	private static String clean(String s) {
		return LatexUtils.replaceLatexKillers(s);
	}

	public void clearCache() {
		this.cache = null;
	}

	public String toString() {
		return featureValueA.getName() + ": " + pairA.toString() + " is different than "
				+ featureValueB.getName() + ": " + pairB.toString();
	}
}
