package edu.cmu.cs.lti.avenue.corpus;

import info.jonclark.lang.Pair;
import info.jonclark.util.ArrayUtils;
import info.jonclark.util.StringUtils;

import java.text.ParseException;
import java.util.ArrayList;

public class Retokenizer {

	// ' is intentionally left out for possessive ambiguity
	// colon left out for ambiguity within times
	private final char[] punctuationChars = ",.!?\"۔؟،()[] {}<>;".toCharArray();
	private final String[] contractions = { "'s", "n't", "'ll", "'" };

	/**
	 * Splits off punctuation from the specified string while updating the
	 * alignment
	 * 
	 * @param sourceTokens
	 * @param alignment
	 * @param flip
	 *            If false, retokenizes the source side. If true, retokenizes
	 *            the target side.
	 * @return
	 * @throws ParseException
	 */
	public Pair<String[], String> retokenize(String[] sourceTokens, String alignment, boolean flip)
			throws ParseException {

		// TODO: Make this class use the newer RawAlignment class and transpose() methods
		
		ArrayList<String> tokens = ArrayUtils.toArrayList(sourceTokens);
		ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> align =
				Alignment.parseAlignments(alignment);
		

		if (flip)
			flip(align);

		// phase 1: retokenize punctuation
		retokenize(tokens, align);

		// phase 2: resegment
		resegment(tokens, align);

		// correct english spelling mistakes
		correct(tokens);

		if (flip)
			flip(align);

		String[] t = tokens.toArray(new String[tokens.size()]);
		String a = alignment.equals("") ? "" : Alignment.serialize(align);
		return new Pair<String[], String>(t, a);
	}

	private void resegment(ArrayList<String> tokens,
			ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> align) {
		for (int i = 0; i < tokens.size(); i++) {

			String token = tokens.get(i);

			for (final String contraction : contractions) {
				if (token.equals(contraction)) {
					continue;
				} else {
					int n = token.indexOf(contraction);
					if (n == -1) {
						continue;
					} else if (n == token.length() - contraction.length()) {
						// get ready to replace
						tokens.remove(i);
						// ends with
						tokens.add(i, token.substring(0, n));
						tokens.add(i + 1, contraction);
						// update alignments
						appendStartingAt(align, i);

						// start over processing this token
						i--;
						break;

					}
				}
			}
		}
	}

	private void retokenize(ArrayList<String> tokens,
			ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> align) {
		for (int i = 0; i < tokens.size(); i++) {

			String token = tokens.get(i);

			if (token.length() == 1) {
				continue;
			} else {
				for (final char c : punctuationChars) {
					int n = token.indexOf(c);
					if (n == -1) {
						continue;
					} else {
						// get ready to replace
						tokens.remove(i);
						if (n == 0) {
							// starts with
							tokens.add(i, c + "");
							tokens.add(i + 1, token.substring(1));
							// update alignments
							incrementStartingAt(align, i);
						} else if (n < token.length() - 1) {
							// contains
							tokens.add(i, token.substring(0, n));
							tokens.add(i + 1, c + "");
							tokens.add(i + 2, token.substring(n + 1, token.length()));
							// update alignments
							incrementStartingAt(align, i + 1);
							appendStartingAt(align, i + 1);
						} else if (n == token.length() - 1) {
							// ends with
							tokens.add(i, token.substring(0, n));
							tokens.add(i + 1, c + "");
							// update alignments
							incrementStartingAt(align, i + 1);
						} else {
							assert false : "Unreachable case.";
						}
						// start over processing this token
						i--;
						break;
					}
				}
			}
		}
	}

	private static void flip(ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> align) {
		for (final Pair<ArrayList<Integer>, ArrayList<Integer>> p : align) {
			ArrayList<Integer> temp = p.first;
			p.first = p.second;
			p.second = temp;
		}
	}

	/**
	 * n is the ZERO-BASED index to update while alignments is ONE-BASED
	 * 
	 * @param alignments
	 * @param n
	 */
	private static void incrementStartingAt(
			ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> alignments, int n) {

		n++;

		for (final Pair<ArrayList<Integer>, ArrayList<Integer>> p : alignments) {
			ArrayList<Integer> sourceIndices = p.first;
			for (int i = 0; i < sourceIndices.size(); i++) {
				if (sourceIndices.get(i) >= n) {
					sourceIndices.set(i, sourceIndices.get(i) + 1);
				}
			}
		}
	}

	/**
	 * n is the ZERO-BASED index to update while alignments is ONE-BASED
	 * 
	 * @param alignments
	 * @param n
	 */
	private static void appendStartingAt(
			ArrayList<Pair<ArrayList<Integer>, ArrayList<Integer>>> alignments, int n) {

		n++;

		for (final Pair<ArrayList<Integer>, ArrayList<Integer>> p : alignments) {
			ArrayList<Integer> sourceIndices = p.first;

			// first, increment elements to the at or to the right of the
			// insertion
			for (int i = 0; i < sourceIndices.size(); i++) {
				if (sourceIndices.get(i) > n) {
					sourceIndices.set(i, sourceIndices.get(i) + 1);
				}
			}

			// make the insertion on a second pass
			for (int i = 0; i < sourceIndices.size(); i++) {
				if (sourceIndices.get(i) == n) {
					sourceIndices.add(i + 1, n + 1);
				}
			}
		}
	}

	private void correct(ArrayList<String> sourceTokens) {
		for (int i = 0; i < sourceTokens.size(); i++) {
			if (sourceTokens.get(i).equals("wo")) {
				sourceTokens.set(i, "will");
			}
		}
	}

	public static void main(String[] args) throws Exception {
		String[] tokens = StringUtils.tokenize("It's face isn't, \"it is.\"");
		// String[] tokens = StringUtils.tokenize("won't mary have sung for
		// john?");
		String alignments = "((1 2 3, 1),(3,2 5),(4,3),(5,4))";
//		String alignments = "";

		System.out.println(StringUtils.untokenize(tokens));
		System.out.println(alignments);

		Pair<String[], String> p = new Retokenizer().retokenize(tokens, alignments, false);
		System.out.println(StringUtils.untokenize(p.first));
		System.out.println(p.second);
	}
}
