package edu.cmu.cs.lti.letras.corpus;

import info.jonclark.util.ArrayUtils;
import info.jonclark.util.StringUtils;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;

import edu.cmu.cs.lti.letras.trees.TreeNode;

/**
 * An immutable Alignment representation.
 */
public class Alignment {

	private final int[][] sourceToTarget;
	private final int[][] targetToSource;
	private final String serialized;

	public Alignment(String serialized, int nSourceTokens, int nTargetTokens) throws ParseException {

		this.serialized = serialized;

		ArrayList<Integer>[] s2t = createArray(nSourceTokens);
		ArrayList<Integer>[] t2s = createArray(nTargetTokens);

		String strList = StringUtils.substringBetweenMatching(serialized, '(', ')');
		String[] nodes = StringUtils.tokenize(strList, ",");

		for (final String node : nodes) {
			String strPair = StringUtils.substringBetweenMatching(node, '(', ')');

			String strSourceIndices = StringUtils.substringBefore(strPair, ",");
			String strTargetIndices = StringUtils.substringAfter(strPair, ",");
			int[] sourceIndices = StringUtils.toIntArray(StringUtils.tokenize(strSourceIndices));
			int[] targetIndices = StringUtils.toIntArray(StringUtils.tokenize(strTargetIndices));

			for(int sourceIndex : sourceIndices) {
				for(int targetIndex : targetIndices) {
					s2t[sourceIndex].add(targetIndex);
					t2s[targetIndex].add(sourceIndex);
				}
			}
		}

		// TODO: This is painfully inefficient...
		
		this.sourceToTarget = new int[s2t.length][];
		for (int i = 0; i < s2t.length; i++) {
			Collections.sort(s2t[i]);
			ArrayUtils.pruneNonUniqueFromSortedVector(s2t[i]);
			this.sourceToTarget[i] = ArrayUtils.toArray(s2t[i]);
		}
		
		this.targetToSource = new int[t2s.length][];
		for (int i = 0; i < t2s.length; i++) {
			Collections.sort(t2s[i]);
			ArrayUtils.pruneNonUniqueFromSortedVector(t2s[i]);
			this.targetToSource[i] = ArrayUtils.toArray(t2s[i]);
		}
	}

	@SuppressWarnings("unchecked")
	private static ArrayList<Integer>[] createArray(int size) {
		ArrayList<Integer>[] arr = (ArrayList<Integer>[]) new ArrayList[size];
		for (int i = 0; i < size; i++)
			arr[i] = new ArrayList<Integer>();
		return arr;
	}

	public int[] getTargetIndices(int nSourceLexicalIndex) {
		return sourceToTarget[nSourceLexicalIndex];
	}

	public int[] getSourceIndices(int nTargetLexicalIndex) {
		return targetToSource[nTargetLexicalIndex];
	}

	public ArrayList<String> getTargetLexicons(SentencePair pair, TreeNode sourceConstituentNode) {

		// first resolve from a sourceNode constituent index
		// to a lexical source index

		// we need an ArrayList here in case we ask for the source lexicons of
		// a phrase (e.g. NP node)
		ArrayList<Integer> sourceLexicalIndices = sourceConstituentNode.getTerminalIndices();
		ArrayList<String> lexicons = new ArrayList<String>();
		String[] fSentence = pair.getTargetSentence();

		for (int sourceLexicalIndex : sourceLexicalIndices) {
			int[] targetIndices = getTargetIndices(sourceLexicalIndex);

			for (int i = 0; i < targetIndices.length; i++) {
				lexicons.add(fSentence[targetIndices[i]]);
			}
		}

		return lexicons;
	}

	public ArrayList<String> getSourceLexicons(SentencePair pair, TreeNode targetConstituentNode) {

		// first resolve from a sourceNode constituent index
		// to a lexical source index

		// we need an ArrayList here in case we ask for the source lexicons of
		// a phrase (e.g. NP node)
		ArrayList<Integer> targetLexicalIndices = targetConstituentNode.getTerminalIndices();
		ArrayList<String> lexicons = new ArrayList<String>();
		String[] eSentence = pair.getSourceSentence();

		for (int sourceLexicalIndex : targetLexicalIndices) {
			int[] sourceIndices = getSourceIndices(sourceLexicalIndex);

			for (int i = 0; i < sourceIndices.length; i++) {
				lexicons.add(eSentence[sourceIndices[i]]);
			}
		}

		return lexicons;
	}

	public String toString() {
		return serialized;
	}
}
