/**
 * The AVENUE Project
 * Language Technologies Institute
 * School of Computer Science
 * (c) 2007 Carnegie Mellon University
 * 
 * Corpus Navigator
 * Written by Jonathan Clark
 */
package edu.cmu.cs.lti.avenue.corpus;

import info.jonclark.lang.Pair;
import info.jonclark.util.ArrayUtils;
import info.jonclark.util.StringUtils;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;

import edu.cmu.cs.lti.avenue.trees.smart.SmartTree;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;

/**
 * An immutable Alignment representation.
 */
public class Alignment {

	private final int[][] sourceToTarget;
	private final int[][] targetToSource;
	private RawAlignment[] rawAlignments;
	private ArrayList<RawAlignment>[] rawAlignmentsBySource;
	private ArrayList<RawAlignment>[] rawAlignmentsByTarget;
	private final String serialized;

	public Alignment(String serialized, int nSourceTokens, int nTargetTokens) throws ParseException {

		this.serialized = serialized.trim();

		if (serialized.equals("") || serialized.equals("()")) {
			sourceToTarget = new int[0][];
			targetToSource = new int[0][];
		} else {

			ArrayList<Integer>[] s2t = createArray(nSourceTokens);
			ArrayList<Integer>[] t2s = createArray(nTargetTokens);

			try {
				String strList = StringUtils.substringBetweenMatching(serialized, '(', ')');

				// tokenize the list on the commas OUTSIDE the parens
				ArrayList<String> nodes =
						StringUtils.tokenizeQuotedValues(strList, ",", "(", ")", false, false,
								Integer.MAX_VALUE);

				for (final String strPair : nodes) {

					// strPair has the form 1,2 (note that parentheses have
					// already been stripped)

					String strSourceIndices = StringUtils.substringBefore(strPair, ",");
					String strTargetIndices = StringUtils.substringAfter(strPair, ",");
					int[] sourceIndices =
							StringUtils.toIntArray(StringUtils.tokenize(strSourceIndices));
					int[] targetIndices =
							StringUtils.toIntArray(StringUtils.tokenize(strTargetIndices));

					for (int sourceIndex : sourceIndices) {
						for (int targetIndex : targetIndices) {
							s2t[sourceIndex - 1].add(targetIndex);
							t2s[targetIndex - 1].add(sourceIndex);
						}
					}
				}

				// prune non-unique entries from array and sort them
				// TODO: This is painfully inefficient...

				this.sourceToTarget = new int[s2t.length][];
				for (int i = 0; i < s2t.length; i++) {
					Collections.sort(s2t[i]);
					ArrayUtils.pruneNonUniqueFromSortedVector(s2t[i]);
					this.sourceToTarget[i] = ArrayUtils.toArray(s2t[i]);
				}

				this.targetToSource = new int[t2s.length][];
				for (int i = 0; i < t2s.length; i++) {
					Collections.sort(t2s[i]);
					ArrayUtils.pruneNonUniqueFromSortedVector(t2s[i]);
					this.targetToSource[i] = ArrayUtils.toArray(t2s[i]);
				}

			} catch (ParseException e) {
				throw new ParseException("ParseException for alignment: " + serialized + "\n"
						+ StringUtils.getStackTrace(e), -1);
			}

		} // if serialized == ""
	}

	@SuppressWarnings("unchecked")
	private static ArrayList<Integer>[] createArray(int size) {
		ArrayList<Integer>[] arr = (ArrayList<Integer>[]) new ArrayList[size];
		for (int i = 0; i < size; i++)
			arr[i] = new ArrayList<Integer>();
		return arr;
	}

	/**
	 * @param nSourceLexicalIndex
	 *            1-based
	 * @return
	 * @throws CorpusException
	 */
	public int[] getTargetIndices(int nSourceLexicalIndex) throws CorpusException {
		if (nSourceLexicalIndex <= 0)
			throw new CorpusException("Alignment Index out of bounds: " + nSourceLexicalIndex);

		if (nSourceLexicalIndex > sourceToTarget.length) {
			return new int[0];
		} else {
			return sourceToTarget[nSourceLexicalIndex - 1];
		}
	}

	/**
	 * @param nTargetLexicalIndex
	 *            1-based
	 * @return
	 * @throws CorpusException
	 */
	public int[] getSourceIndices(int nTargetLexicalIndex) throws CorpusException {
		if (nTargetLexicalIndex <= 0)
			throw new CorpusException("Alignment Index out of bounds: " + nTargetLexicalIndex);

		if (nTargetLexicalIndex > targetToSource.length) {
			return new int[0];
		} else {
			return targetToSource[nTargetLexicalIndex - 1];
		}
	}

	/**
	 * @param sourceConstituentNode
	 * @return A list of zero-based target indices
	 * @throws CorpusException
	 */
	public ArrayList<Integer> getTargetIndices(TreeNode sourceConstituentNode)
			throws CorpusException {

		// first resolve from a sourceNode constituent index
		// to a lexical source index

		// we need an ArrayList here in case we ask for the source lexicons of
		// a phrase (e.g. NP node)
		ArrayList<Integer> sourceLexicalIndices = sourceConstituentNode.getTerminalIndices();

		// create a bitmap of corresponding indices
		boolean[] targetBitmap = new boolean[targetToSource.length];

		for (int sourceLexicalIndex : sourceLexicalIndices) {

			assert sourceLexicalIndex != -1 : "sourceLexicalIndex == -1";

			int[] localTargetIndices = getTargetIndices(sourceLexicalIndex);
			for (int i = 0; i < localTargetIndices.length; i++) {
				targetBitmap[localTargetIndices[i] - 1] = true; // 1-based
			}
		}

		ArrayList<Integer> indexList = new ArrayList<Integer>();
		for (int i = 0; i < targetBitmap.length; i++) {
			if (targetBitmap[i]) {
				indexList.add(i);
			}
		}

		return indexList;
	}

	public ArrayList<TreeNode> getTargetNodes(String[] fSentence, TreeNode sourceConstituentNode)
			throws CorpusException {

		ArrayList<Integer> targetIndices = getTargetIndices(sourceConstituentNode);

		// give all corresponding target nodes that same POS as this node
		// TODO: This is a dummy POS value only
		String sourceValue0 = sourceConstituentNode.getValues().get(0);

		ArrayList<TreeNode> targetNodes = new ArrayList<TreeNode>();

		for (final int targetIndex : targetIndices) {
			TreeNode targetNode =
					TreeNode.createOrphanNode(targetIndex, SmartTree.TARGET_C_STRUCT_LABEL);
			targetNode.addValue(sourceValue0);
			targetNode.addValue(fSentence[targetIndex]); // 0-based
			targetNodes.add(targetNode);
		}

		return targetNodes;
	}

	public ArrayList<String> getTargetLexicons(SentencePair pair, TreeNode sourceConstituentNode)
			throws CorpusException {

		// first resolve from a sourceNode constituent index
		// to a lexical source index

		// we need an ArrayList here in case we ask for the source lexicons of
		// a phrase (e.g. NP node)
		ArrayList<Integer> sourceLexicalIndices = sourceConstituentNode.getTerminalIndices();
		ArrayList<String> lexicons = new ArrayList<String>();
		String[] fSentence = pair.getNormalizedTargetTokens();

		try {
			boolean[] sourceIndices = new boolean[sourceToTarget.length];

			for (int sourceLexicalIndex : sourceLexicalIndices) {
				int[] targetIndices = getTargetIndices(sourceLexicalIndex);

				for (int i = 0; i < targetIndices.length; i++) {
					sourceIndices[targetIndices[i]] = true;
				}
			}

			for (int i = 0; i < sourceIndices.length; i++) {
				if (sourceIndices[i]) {
					lexicons.add(fSentence[i]);
				}
			}
		} catch (CorpusException e) {
			throw new CorpusException(
					"Could not resolve target lexicons for sourceConstituentNode "
							+ sourceConstituentNode.getNodeLabel() + " given sentence "
							+ pair.getSourceConstituentStructure() + " and alignment "
							+ pair.getDisplayAlignment(), e);
		}

		return lexicons;
	}

	public ArrayList<String> getSourceLexicons(SentencePair pair, TreeNode targetConstituentNode)
			throws CorpusException {

		// first resolve from a sourceNode constituent index
		// to a lexical source index

		// we need an ArrayList here in case we ask for the source lexicons of
		// a phrase (e.g. NP node)
		ArrayList<Integer> targetLexicalIndices = targetConstituentNode.getTerminalIndices();
		ArrayList<String> lexicons = new ArrayList<String>();
		String[] eSentence = pair.getNormalizedSourceTokens();

		try {
			for (int sourceLexicalIndex : targetLexicalIndices) {
				int[] sourceIndices = getSourceIndices(sourceLexicalIndex);

				for (int i = 0; i < sourceIndices.length; i++) {
					lexicons.add(eSentence[sourceIndices[i]]);
				}
			}
		} catch (CorpusException e) {
			throw new CorpusException(
					"Could not resolve target lexicons for targetConstituentNode "
							+ targetConstituentNode.getNodeLabel() + " given SOURCE sentence "
							+ pair.getSourceConstituentStructure() + " and alignment "
							+ pair.getDisplayAlignment(), e);
		}

		return lexicons;
	}

	public String toString() {
		return serialized;
	}

	@SuppressWarnings("unchecked")
	public ArrayList<RawAlignment> getRawAlignmentsForSource(TreeNode sourceNode) {

		// lazy instantiation / caching
		if (rawAlignmentsBySource == null) {
			rawAlignmentsBySource = (ArrayList<RawAlignment>[]) new ArrayList[sourceToTarget.length];
		}

		ArrayList<Integer> terminalIndices = sourceNode.getTerminalIndices();
		ArrayList<RawAlignment> alignments = new ArrayList<RawAlignment>();

		// terminalIndices are 1-based
		for (final int terminalIndex : terminalIndices) {

			// lazy instantiation / caching
			if (rawAlignmentsBySource[terminalIndex - 1] == null) {

				rawAlignmentsBySource[terminalIndex - 1] = new ArrayList<RawAlignment>();

				for (final RawAlignment alignmentGroup : getRawAlignments()) {
					if (ArrayUtils.unsortedArrayContains(alignmentGroup.sourceTerminals,
							terminalIndex)) {
						rawAlignmentsBySource[terminalIndex - 1].add(alignmentGroup);
					}
				}
			}

			// gathering alignments for the specified source node
			alignments.addAll(rawAlignmentsBySource[terminalIndex - 1]);
		}

		return alignments;
	}

	@SuppressWarnings("unchecked")
	public ArrayList<RawAlignment> getRawAlignmentsForTarget(TreeNode targetNode) {

		// lazy instantiation / caching
		if (rawAlignmentsByTarget == null) {
			rawAlignmentsByTarget =
					(ArrayList<RawAlignment>[]) new ArrayList[targetToSource.length];
		}

		ArrayList<Integer> terminalIndices = targetNode.getTerminalIndices();
		ArrayList<RawAlignment> alignments = new ArrayList<RawAlignment>();

		// terminalIndices are 1-based
		for (final int terminalIndex : terminalIndices) {

			// lazy instantiation / caching
			if (rawAlignmentsByTarget[terminalIndex - 1] == null) {

				rawAlignmentsByTarget[terminalIndex - 1] = new ArrayList<RawAlignment>();

				for (final RawAlignment alignmentGroup : getRawAlignments()) {
					if (ArrayUtils.unsortedArrayContains(alignmentGroup.targetTerminals, terminalIndex)) {
						rawAlignmentsByTarget[terminalIndex - 1].add(alignmentGroup);
					}
				}
			}

			// gathering alignments for the specified source node
			alignments.addAll(rawAlignmentsByTarget[terminalIndex - 1]);
		}

		return alignments;
	}

	@SuppressWarnings("unchecked")
	public RawAlignment[] getRawAlignments() {
		if (rawAlignments == null) {
			try {
				ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> list =
						parseAlignments(serialized);
				Pair<ArrayList<Integer>, ArrayList<Integer>>[] arr1 =
						(Pair<ArrayList<Integer>, ArrayList<Integer>>[]) list.toArray(new Pair[list.size()]);

				this.rawAlignments = new RawAlignment[arr1.length];
				for (int i = 0; i < arr1.length; i++) {

					RawAlignment group = new RawAlignment();
					group.sourceTerminals = ArrayUtils.toArray(arr1[i].first);
					group.targetTerminals = ArrayUtils.toArray(arr1[i].second);
					this.rawAlignments[i] = group;
				}
			} catch (ParseException e) {
				throw new RuntimeException("Alignments should have already been parsed.", e);
			}
		}
		return rawAlignments;
	}

	public static ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> parseAlignments(
			String alignment) throws ParseException {

		ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> result =
				new ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>>();

		if (alignment.equals("") || alignment.equals("()")) {
			return result;
		}

		String strList = StringUtils.substringBetweenMatching(alignment, '(', ')');

		// tokenize the list on the commas OUTSIDE the parens
		ArrayList<String> nodes =
				StringUtils.tokenizeQuotedValues(strList, ",", "(", ")", false, false,
						Integer.MAX_VALUE);

		for (final String strPair : nodes) {

			// strPair has the form 1,2 (note that parentheses have
			// already been stripped)

			String strSourceIndices = StringUtils.substringBefore(strPair, ",");
			String strTargetIndices = StringUtils.substringAfter(strPair, ",");
			ArrayList<Integer> sourceIndices =
					ArrayUtils.toArrayList(StringUtils.toIntArray(StringUtils.tokenize(strSourceIndices)));
			ArrayList<Integer> targetIndices =
					ArrayUtils.toArrayList(StringUtils.toIntArray(StringUtils.tokenize(strTargetIndices)));
			Pair<ArrayList<Integer>, ArrayList<Integer>> p =
					new Pair<ArrayList<Integer>, ArrayList<Integer>>(sourceIndices, targetIndices);
			result.add(p);
		}

		return result;
	}

	@SuppressWarnings("unchecked")
	public Alignment transpose() {

		try {
			ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> rawAlignments =
					parseAlignments(serialized);
			ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> invertedAlignments =
					new ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>>(
							rawAlignments.size());

			for (int i = 0; i < rawAlignments.size(); i++) {
				invertedAlignments.add(new Pair<ArrayList<Integer>, ArrayList<Integer>>(
						rawAlignments.get(i).second, rawAlignments.get(i).first));
			}

			Alignment inverted =
					new Alignment(serialize(invertedAlignments), this.targetToSource.length,
							this.sourceToTarget.length);
			return inverted;
		} catch (ParseException e) {
			throw new RuntimeException("Parse Exception should have been caught earlier.", e);
		}
	}

	public static String serialize(ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> alignment) {

		if (alignment.size() == 0) {
			return "";
		}

		StringBuilder builder = new StringBuilder("(");
		for (final Pair<ArrayList<Integer>, ArrayList<Integer>> a : alignment) {
			builder.append("(" + StringUtils.untokenize(a.first) + ","
					+ StringUtils.untokenize(a.second) + "),");
		}
		builder.deleteCharAt(builder.length() - 1);
		builder.append(")");
		return builder.toString();
	}
}
