package edu.cmu.cs.lti.avenue.navigation.search.generation1;

import info.jonclark.util.HashUtils;

import java.util.ArrayList;
import java.util.HashMap;

import edu.cmu.cs.lti.avenue.corpus.Alignment;
import edu.cmu.cs.lti.avenue.corpus.CorpusException;
import edu.cmu.cs.lti.avenue.corpus.RawAlignment;
import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;

/**
 * A heuristic for scoring the badness of alignment divergences (we measure
 * crossings).
 * <p>
 * See also ProjectionFeatureExtractor, a newer version of this class.
 * 
 * @author jon
 */
public class AlignmentScorer {

	/**
	 * EXTERNAL_INTERRUPTION - Part of another constituent is aligned to this
	 * node<br>
	 * ANY_BREAK - The constituent's target indices are not continuous.<br>
	 */
	public enum DiscontinuityType {
		EXTERNAL_INTERRUPTION, ANY_BREAK
	};

	public double getScore(SentencePair pair) {
		Alignment alignment = pair.getDisplayAlignment();
		return 0.0;
	}

	/**
	 * crossings(S[a] -> T[b]) = count(S[i] -> T[j]), i>a, j<b
	 * 
	 * @param pair
	 * @return
	 * @throws CorpusException
	 */
	public int getCrossings(SentencePair pair) throws CorpusException {
		RawAlignment[] rawAlignments = pair.getNormalizedAlignment().getRawAlignments();

		int crossings = 0;
		for (int i = 0; i < rawAlignments.length; i++) {

			RawAlignment alignment = rawAlignments[i];
			int[] targetIndices = alignment.targetTerminals;
			int minTargetIndex = targetIndices[0];
			int maxTargetIndex = targetIndices[targetIndices.length - 1];

			// for the alignments before this one
			for (int j = 0; j < i; j++) {
				// check for crossings (max ending after/on our min target
				// index)
				RawAlignment prevAlignment = rawAlignments[j];
				int[] prevTarget = prevAlignment.targetTerminals;
				int prevMaxTarget = prevTarget[prevTarget.length - 1];
				if (prevMaxTarget >= minTargetIndex) {
					crossings++;
				}
			}

			// for the alignments after this one
			for (int j = i; j < rawAlignments.length; j++) {
				// check for crossings (min ending before/on our max target
				// index)
				RawAlignment prevAlignment = rawAlignments[j];
				int[] prevTarget = prevAlignment.targetTerminals;
				int prevMinTarget = prevTarget[0];
				if (prevMinTarget >= maxTargetIndex) {
					crossings++;
				}
			}
		}

		return crossings;

		// int crossings = 0;
		// for (int curSourceIndex = 1; curSourceIndex <= nSourceLen;
		// curSourceIndex++) {
		// for (final int curTargetIndex :
		// alignment.getTargetIndices(curSourceIndex)) {
		// for (int followingSourceIndex = curSourceIndex + 1;
		// followingSourceIndex < nSourceLen; followingSourceIndex++) {
		// for (final int followingTargetIndex :
		// alignment.getTargetIndices(followingSourceIndex)) {
		//						
		// assert followingSourceIndex > curSourceIndex : "Expected i>a";
		// if (followingTargetIndex < curTargetIndex) {
		// crossings++;
		// } else {
		// // this only works because the j's must be ascending
		// break;
		// }
		// }
		// }
		// }
		// }
	}

	public HashMap<String, Integer> getConstituentCounts(SentencePair pair) throws CorpusException {

		assert pair.getSourceConstituentStructure() != null : "Null c-structure for: "
				+ pair.getDisplaySourceSentence();
		ArrayList<TreeNode> nodes = pair.getSourceConstituentStructure().getLabeledNodes();
		HashMap<String, Integer> counts = new HashMap<String, Integer>();

		for (final TreeNode node : nodes) {
			String constituentType = node.getValues().get(0);
			HashUtils.increment(counts, constituentType);
		}

		return counts;
	}

	public HashMap<String, Integer> getReorderingsWithinConstituents(SentencePair pair)
			throws CorpusException {

		Alignment alignment = pair.getDisplayAlignment();
		int nSourceLen = pair.getNormalizedSourceTokens().length;

		ArrayList<TreeNode> nodes = pair.getSourceConstituentStructure().getLabeledNodes();
		HashMap<String, Integer> crossings = new HashMap<String, Integer>();

		// for each constituent in the whole tree, get its terminals...
		for (final TreeNode node : nodes) {
			for (final TreeNode terminal : node.getLexicalTerminals()) {

				// ...and see what they align to...
				for (final int a : terminal.getTerminalIndices()) {
					for (final int b : alignment.getTargetIndices(a)) {

						// ...and then see if there are reorderings
						for (int i = a + 1; i < nSourceLen; i++) {
							for (final int j : alignment.getTargetIndices(i)) {
								assert i > a : "Expected i>a";
								if (j < b) {
									String constituentType = node.getValues().get(0);
									HashUtils.increment(crossings, constituentType);
								} else {
									// this only works because the j's must be
									// ascending
									break;
								}
							}
						}
					}
				}
			}
		}

		return crossings;
	}

	public HashMap<String, Integer> getDiscontinuitiesWithinConstituentC(SentencePair pair,
			DiscontinuityType d) throws CorpusException {

		RawAlignment[] rawAlignments = pair.getNormalizedAlignment().getRawAlignments();
		HashMap<String, Integer> m = new HashMap<String, Integer>();

		for (int i = 0; i < rawAlignments.length; i++) {

			RawAlignment alignment = rawAlignments[i];
			int[] targetIndices = alignment.targetTerminals;

			int prev = targetIndices[0];
			for (int j = 1; j < targetIndices.length; j++) {
				if (targetIndices[j] != prev + 1) {
					HashUtils.increment(m, "ALIGNMENT");
				}
				prev = targetIndices[j];
			}
		}

		return m;
	}

	public HashMap<String, Integer> getDiscontinuitiesWithinConstituentC_Old(SentencePair pair,
			DiscontinuityType d) throws CorpusException {

		Alignment alignment = pair.getDisplayAlignment();

		ArrayList<TreeNode> nodes = pair.getSourceConstituentStructure().getLabeledNodes();
		HashMap<String, Integer> discontinuities = new HashMap<String, Integer>();

		// for each constituent in the whole tree, get its terminals...
		for (final TreeNode node : nodes) {

			int maxTargetIndex = Integer.MIN_VALUE;
			int minTargetIndex = Integer.MAX_VALUE;

			// first find the min and max indices on the target side
			for (final TreeNode terminal : node.getLexicalTerminals()) {
				for (final int sourceIndex : terminal.getTerminalIndices()) {
					for (final int targetIndex : alignment.getTargetIndices(sourceIndex)) {
						maxTargetIndex = Math.max(maxTargetIndex, targetIndex);
						minTargetIndex = Math.min(minTargetIndex, targetIndex);
					}
				}
			}

			boolean contiguous = true;
			ArrayList<Integer> sourceLeaves = node.getTerminalIndices();

			// now check if any lexical items between the min and max
			// on the target side are aligned to a different source
			// item (or just aren't aligned for the case of ANY_BREAK)
			for (int t = minTargetIndex; t <= maxTargetIndex; t++) {
				int[] sourceIndices = alignment.getSourceIndices(t);

				if (d == DiscontinuityType.ANY_BREAK) {
					if (sourceIndices.length == 0) {
						contiguous = false;
					}
				}

				if (d == DiscontinuityType.ANY_BREAK
						|| d == DiscontinuityType.EXTERNAL_INTERRUPTION) {

					for (final int sourceIndex : sourceIndices) {
						// see if there are any outside words
						if (!sourceLeaves.contains(sourceIndex)) {
							contiguous = false;
						}
						if (!contiguous)
							break;
					}
				} else {
					throw new RuntimeException("Unknown discontinuity type: " + d);
				}
				if (!contiguous)
					break;
			}

			if (!contiguous) {
				String constituentType = node.getValues().get(0);
				HashUtils.increment(discontinuities, constituentType);
			}
		}

		return discontinuities;
	}

	/**
	 * @param pair
	 * @param d
	 * @return A map with each key corresponding to a source constituent
	 *         (TreeNode) that is discontinuous on the target side and each
	 *         value corresponding to the list of target side indices which
	 *         cause the constituent to be discontinuous.
	 * @throws CorpusException
	 */
	public HashMap<TreeNode, ArrayList<Integer>> getDiscontinuitiesWithinConstituentL(
			SentencePair pair, DiscontinuityType d) throws CorpusException {

		Alignment alignment = pair.getDisplayAlignment();

		ArrayList<TreeNode> nodes = pair.getSourceConstituentStructure().getLabeledNodes();
		HashMap<TreeNode, ArrayList<Integer>> discontinuities =
				new HashMap<TreeNode, ArrayList<Integer>>();

		// for each constituent in the whole tree, get its terminals...
		for (final TreeNode node : nodes) {

			int maxTargetIndex = Integer.MIN_VALUE;
			int minTargetIndex = Integer.MAX_VALUE;

			// first find the min and max indices on the target side
			for (final TreeNode terminal : node.getLexicalTerminals()) {
				for (final int sourceIndex : terminal.getTerminalIndices()) {
					for (final int targetIndex : alignment.getTargetIndices(sourceIndex)) {
						maxTargetIndex = Math.max(maxTargetIndex, targetIndex);
						minTargetIndex = Math.min(minTargetIndex, targetIndex);
					}
				}
			}

			ArrayList<Integer> sourceChildren = node.getTerminalIndices();

			// now check if any lexical items between the min and max
			// on the target side are aligned to a different source
			// item (or just aren't aligned for the case of ANY_BREAK)
			for (int t = minTargetIndex; t <= maxTargetIndex; t++) {
				int[] sourceIndices = alignment.getSourceIndices(t);

				if (d == DiscontinuityType.ANY_BREAK) {
					if (sourceIndices.length == 0) {
						HashUtils.append(discontinuities, node, t);
					}
				}

				if (d == DiscontinuityType.ANY_BREAK
						|| d == DiscontinuityType.EXTERNAL_INTERRUPTION) {
					for (final int sourceIndex : sourceIndices) {
						if (!sourceChildren.contains(sourceIndex)) {
							HashUtils.append(discontinuities, node, t);
							break;
						}
					}
				} else {
					throw new RuntimeException("Unknown discontinuity type: " + d);
				}
			}
		}

		return discontinuities;
	}
}
