package edu.cmu.cs.lti.avenue.projection;

import java.util.ArrayList;
import java.util.Arrays;

import info.jonclark.util.ArrayUtils;
import info.jonclark.util.MathUtils;
import edu.cmu.cs.lti.avenue.corpus.Alignment;
import edu.cmu.cs.lti.avenue.corpus.CorpusException;
import edu.cmu.cs.lti.avenue.corpus.RawAlignment;
import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;

/**
 * See also AlignmentScorer, the prototype of this class.
 * 
 * @author jon
 */
public class ProjectionFeatureExtractor {

	public enum AlignmentRelation {
		BEFORE, AFTER, INSIDE, BEFORE_AND_INSIDE, AFTER_AND_INSIDE, INSIDE_AT_END, INSIDE_AT_START, AROUND, EQUAL
	};

	/**
	 * Gets an empty upper triangular matrix without a main diagonal to store
	 * which alignment interactions have already been reported.
	 * 
	 * @param pair
	 * @return
	 */
	public static boolean[][] getEmptyReorderingCache(SentencePair pair) {

		int n = pair.getNormalizedAlignment().getRawAlignments().length;

		if (n < 2) {
			return new boolean[0][0];
		}

		boolean[][] reorderingCache = new boolean[n - 1][];
		for (int i = 0; i < n - 1; i++) {
			reorderingCache[i] = new boolean[i + 1];
		}

		return reorderingCache;
	}

	/**
	 * Counts the reorderings for the specified range of source terminals.
	 * <p>
	 * This method should be run after projection has already been performed
	 * since it uses projection metadata.
	 * 
	 * @param pair
	 * @param reorderingCache
	 *            stores which alignment interactions have already been
	 *            reported. null indicates that all interactions should be
	 *            reported.
	 * @param nFirstSourceWord
	 *            ZERO-BASED
	 * @param nLastSourceWord
	 *            ZERO-BASED
	 * @return
	 */
	public static int countReorderings(SentencePair pair, boolean[][] reorderingCache,
			int nFirstSourceWord, int nLastSourceWord) {

		// convert to ONE-BASED
		nFirstSourceWord++;
		nLastSourceWord++;

		RawAlignment[] alignments = pair.getNormalizedAlignment().getRawAlignments();

		return countReorderings(reorderingCache, nFirstSourceWord, nLastSourceWord, alignments);
	}

	public static int countReorderings(RawAlignment[] alignments, int nFirstSourceWord,
			int nLastSourceWord) {
		return countReorderings(null, nFirstSourceWord, nLastSourceWord, alignments);
	}

	private static int countReorderings(boolean[][] reorderingCache, int nFirstSourceWord,
			int nLastSourceWord, RawAlignment[] alignments) {

		// traverse the alignments as if they were a 2D upper triangular matrix
		// with out the diagonal
		int count = 0;
		for (int i = 0; i < alignments.length - 1; i++) {
			for (int j = 0; j <= i; j++) {

				RawAlignment alignmentA = alignments[i + 1];
				RawAlignment alignmentB = alignments[j];

				if ((alignmentA.getSourceMin() >= nFirstSourceWord && alignmentA.getSourceMax() <= nLastSourceWord)
						&& (alignmentB.getSourceMin() >= nFirstSourceWord && alignmentB.getSourceMax() <= nLastSourceWord))

					if (reorderingCache == null || reorderingCache[i][j] == false) {

						if (reorderingCache != null) {
							reorderingCache[i][j] = true;
						}

						// check whether the other pair falls before, inside, or
						// outside this alignment
						AlignmentRelation sourceRelation =
								getRelation(alignmentA.sourceTerminals, alignmentB.sourceTerminals);
						AlignmentRelation targetRelation =
								getRelation(alignmentA.targetTerminals, alignmentB.targetTerminals);
						if (sourceRelation != targetRelation) {
							count++;
						}
					}
			}
		}
		return count;
	}

	/**
	 * Gives the position of sourceTerminalsA relative to sourceTerminalsB
	 * 
	 * @param sourceTerminalsA
	 *            source or target terminals from one alignment group
	 * @param sourceTerminalsB
	 *            source or target terminals from another alignment group
	 * @return
	 */
	public static AlignmentRelation getRelation(int[] sourceTerminalsA, int[] sourceTerminalsB) {

		int minA = MathUtils.min(sourceTerminalsA);
		int maxA = MathUtils.max(sourceTerminalsA);
		int minB = MathUtils.min(sourceTerminalsB);
		int maxB = MathUtils.max(sourceTerminalsB);

		if (maxA < minB) {
			return AlignmentRelation.BEFORE;
		} else if (maxB < minA) {
			return AlignmentRelation.AFTER;
		} else if (minA == minB && maxA == maxB) {
			return AlignmentRelation.EQUAL;
		} else if (minA > minB && maxA < maxB) {
			return AlignmentRelation.INSIDE;
		} else if (minA < minB && maxA <= maxB) {
			return AlignmentRelation.BEFORE_AND_INSIDE;
		} else if (minA >= minB && maxA > maxB) {
			return AlignmentRelation.AFTER_AND_INSIDE;
		} else if (minA < minB && maxA > maxB) {
			return AlignmentRelation.AROUND;
		} else if (minA > minB && maxA == maxB) {
			return AlignmentRelation.INSIDE_AT_END;
		} else if (minA == minB && maxA < maxB) {
			return AlignmentRelation.INSIDE_AT_START;
		} else {
			throw new Error("Unknown relation: " + minA + "-" + maxA + "," + minB + "-" + maxB);
		}
	}

	protected static int[] getMirroredSourceTerminals(int[] targetTerminals, Alignment alignment)
			throws CorpusException {

		ArrayList<Integer> mirroredSourceTerminals = new ArrayList<Integer>();

		for (final int nTargetTerminal : targetTerminals) {
			// 1-based
			for (final int nSourceTerminal : alignment.getSourceIndices(nTargetTerminal + 1)) {
				mirroredSourceTerminals.add(nSourceTerminal);
			}
		}

		return ArrayUtils.toArray(mirroredSourceTerminals);
	}

	protected static ProjectionFeatures extractProjectionFeatures(SentencePair pair,
			TreeNode sourceNode, TreeNode[] targetNodeCache, boolean[] targetTerminalCoverage,
			boolean[][] reorderingCache) throws CorpusException {

		Alignment alignment = pair.getNormalizedAlignment();

		int[] insideSourceTerminals = ArrayUtils.toArray(sourceNode.getTerminalIndices());
		int[] targetTerminals = ArrayUtils.toArray(alignment.getTargetIndices(sourceNode));
		int[] mirroredSourceTerminals = getMirroredSourceTerminals(targetTerminals, alignment);

		int nFirstWord = MathUtils.min(insideSourceTerminals) - 1;
		int nLastWord = MathUtils.max(insideSourceTerminals) - 1;

		ProjectionFeatures features = new ProjectionFeatures();
		features.r = countReorderings(pair, reorderingCache, nFirstWord, nLastWord);
		features.dd = countDualAlignments(pair, insideSourceTerminals, mirroredSourceTerminals);
		features.di = countDiscontinuitiesByInterruption(targetTerminals, alignment);
		features.dus = countDiscontinuitiesByUnalignedSourceWord(insideSourceTerminals);
		findDiscontinuitiesByUnalignedTargetWord(targetTerminals, pair, sourceNode,
				targetNodeCache, targetTerminalCoverage, features);

		return features;
	}

	/**
	 * Do any target words referenced by this constituent have alignments going
	 * to target words outside of the specified source node.
	 * 
	 * @param targetNode
	 * @return
	 * @throws CorpusException
	 */
	protected static int countDualAlignments(SentencePair pair, int[] insideSourceTerminals,
			int[] mirroredSourceTerminals) throws CorpusException {

		assert ArrayUtils.isSorted(insideSourceTerminals) : "insideSourceTerminals not sorted";

		// only count each terminal once
		boolean[] coverage = new boolean[pair.getNormalizedSourceTokens().length];

		// we want to see if mirroredSourceTerminals is a subset of
		// insideSourceTerminals
		int count = 0;
		for (final int mirroredSourceTerminal : mirroredSourceTerminals) {
			if (Arrays.binarySearch(insideSourceTerminals, mirroredSourceTerminal) < 0) {
				if (coverage[mirroredSourceTerminal - 1] == false) {
					coverage[mirroredSourceTerminal - 1] = true;
					count++;
				}
			}
		}

		return count;
	}

	/**
	 * Do any target words <b>between</b> the target words referenced by this
	 * constituent have alignments going to target words outside of the
	 * specified source node.
	 * 
	 * @param sourceNode
	 * @param alignment
	 * @return
	 * @throws CorpusException
	 */
	protected static int countDiscontinuitiesByInterruption(int[] targetTerminals,
			Alignment alignment) throws CorpusException {

		if (targetTerminals.length == 0) {
			return 0;
		} else {

			assert ArrayUtils.isSorted(targetTerminals) : "targetTerminals is not sorted";

			int min = targetTerminals[0];
			int max = targetTerminals[targetTerminals.length - 1];

			int count = 0;
			for (int i = min + 1; i < max; i++) {
				if (Arrays.binarySearch(targetTerminals, i) < 0) {
					// if this is a target constituent between the specified
					// nodes
					// 1-based
					int[] mirroredSourceIndices = alignment.getSourceIndices(i + 1);

					// since these can NEVER go to anything inside of the source
					// node
					// because of the transitive property of alignments
					if (mirroredSourceIndices.length > 0) {
						count++;
					}
				}
			}

			return count;
		}
	}

	protected static int countDiscontinuitiesByUnalignedSourceWord(int[] insideSourceTerminals) {

		if (insideSourceTerminals.length == 0) {
			return 0;
		} else {
			assert ArrayUtils.isSorted(insideSourceTerminals) : "insideSourceTerminals is not sorted";

			int min = insideSourceTerminals[0];
			int max = insideSourceTerminals[insideSourceTerminals.length - 1];

			int count = 0;
			for (int i = min + 1; i < max; i++) {
				if (Arrays.binarySearch(insideSourceTerminals, i) < 0) {
					count++;
				}
			}

			return count;
		}
	}

	protected static void findDiscontinuitiesByUnalignedTargetWord(int[] targetTerminals,
			SentencePair pair, TreeNode sourceNode, TreeNode[] targetNodeCache,
			boolean[] targetTerminalCoverage, ProjectionFeatures features) throws CorpusException {

		Alignment alignment = pair.getNormalizedAlignment();
		String[] fSentence = pair.getNormalizedTargetTokens();

		int count = 0;
		ArrayList<TreeNode> unalignedTargets = new ArrayList<TreeNode>();

		if (targetTerminals.length == 0) {
			count = 0;
		} else {
			assert ArrayUtils.isSorted(targetTerminals) : "targetTerminals is not sorted";

			int first = targetTerminals[0];
			int last = targetTerminals[targetTerminals.length - 1];

			for (int i = first + 1; i < last; i++) {
				if (Arrays.binarySearch(targetTerminals, i) < 0) {
					// if this is a target constituent between the specified
					// nodes (1-based)
					int[] mirroredSourceIndices = alignment.getSourceIndices(i + 1);

					// since these can NEVER go to anything inside of the source
					// node because of the transitive property of alignments
					if (mirroredSourceIndices.length == 0 && targetTerminalCoverage[i] == false) {

						// we should also check to see if any constituent has
						// already "consumed" this target terminal so that we
						// don't double count it
						count++;

						TreeNode targetNode =
								TreeNode.createOrphanNode(i, SmartTree.TARGET_C_STRUCT_LABEL);
						targetNode.addValue(ConstituentStructureProjector.LEX);
						targetNode.addValue(fSentence[i]); // 0-based
						unalignedTargets.add(targetNode);
					}
				}
			}

			ConstituentStructureProjector.ensureUnique(unalignedTargets, targetNodeCache);
		}

		features.unalignedTargets = unalignedTargets;
		features.dut = count;
	}

	/**
	 * Get the distance of an integer array compared to a sorted version of
	 * itself Based on code from Chas Emerick at
	 * http://www.merriampark.com/ldjava.htm
	 * 
	 * @param s
	 * @param t
	 * @return
	 */
	public static int getLevenshteinEditDistance(int[] s) {

		int n = s.length; // length of s

		if (n < 2) {
			return 0;
		}

		int p[] = new int[n + 1]; // 'previous' cost array, horizontally
		int d[] = new int[n + 1]; // cost array, horizontally
		int temp[]; // placeholder to assist in swapping p and d

		// indexes into strings s and t
		int i; // iterates through s
		int j; // iterates through t
		int t_j; // jth element of t

		int cost; // cost

		for (i = 0; i <= n; i++) {
			p[i] = i;
		}

		for (j = 1; j <= s.length; j++) {
			t_j = j;
			d[0] = j;

			for (i = 1; i <= n; i++) {
				cost = s[i - 1] == t_j ? 0 : 1;
				// minimum of cell to the left+1, to the top+1, diagonally left
				// and up +cost
				d[i] = Math.min(Math.min(d[i - 1] + 1, p[i] + 1), p[i - 1] + cost);
			}

			// copy current distance counts to 'previous row' distance counts
			temp = p;
			p = d;
			d = temp;
		}

		// our last action in the above loop was to switch d and p, so p now
		// actually has the most recent cost counts
		return p[n];
	}
}
