/**
 * The AVENUE Project
 * Language Technologies Institute
 * School of Computer Science
 * (c) 2007 Carnegie Mellon University
 * 
 * Corpus Navigator
 * Written by Jonathan Clark
 */
package edu.cmu.cs.lti.avenue.navigation.featuredetection.deductive;

import static edu.cmu.cs.lti.avenue.navigation.featuredetection.deductive.RuleConstants.COUNT_DIFF;
import info.jonclark.util.StringUtils;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import edu.cmu.cs.lti.avenue.corpus.CorpusException;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;

public class CountFunctionEvaluator {

	// initialize the threads for doing counting
	// private static final BlockingQueue<CountJob> countQueue;
	// private static final AtomicInteger nInProgress = new AtomicInteger(0);
	//
	// private static final int nThreads =
	// Runtime.getRuntime().availableProcessors();
	//
	// static {
	// countQueue = new ArrayBlockingQueue<CountJob>(nThreads);
	// for (int i = 0; i < nThreads; i++) {
	// Thread t = new Thread() {
	// public void run() throws RuntimeException {
	// CountJob job;
	// try {
	// while ((job = countQueue.take()) != null) {
	// nInProgress.incrementAndGet();
	// System.out.println("Starting job.");
	// evaluateCountDiffRecursively(job, job.startArg);
	// System.out.println("Finishing job.");
	// nInProgress.decrementAndGet();
	// synchronized (nInProgress) {
	// nInProgress.notifyAll();
	// }
	// }
	// } catch (InterruptedException e) {
	// throw new RuntimeException(e);
	// } catch (ParseException e) {
	// throw new RuntimeException(e);
	// } catch (CorpusException e) {
	// throw new RuntimeException(e);
	// }
	// }
	// };
	// t.start();
	// }
	// }

	private static class CountJob {
		public final Rule rule;
		public final TreeNode condition;
		public final ArrayList<TreeNode> args;
		public final ArrayList<LexicalResult>[] resultsForArgs;
		public final List<NumericResult> diffResults;
		public final LexicalResult[] proposedResult;
		public final int startArg;

		public CountJob(Rule rule, TreeNode condition, ArrayList<TreeNode> args,
				ArrayList<LexicalResult>[] resultsForArgs, List<NumericResult> diffResults,
				LexicalResult[] proposedResult, int startArg) {
			this.rule = rule;
			this.condition = condition;
			this.args = args;
			this.resultsForArgs = resultsForArgs;
			this.diffResults = diffResults;
			this.proposedResult = proposedResult;
			this.startArg = startArg;
		}
	}

	/**
	 * Evaluates functions that count things such as count-diff
	 */
	protected static ArrayList<NumericResult> evaluateCount(TreeNode condition, Rule rule)
			throws ParseException, CorpusException {

		assert condition.getValues().size() > 0 : "No function name for condition: " + condition;
		String functionName = condition.getValues().get(0);

		if (condition.getChildren().size() < 1) {
			throw new ParseException("function " + functionName + " takes at least argument: "
					+ condition, -1);
		}

		if (functionName == COUNT_DIFF) {

			// 1) evaluate each argument
			ArrayList<TreeNode> args = condition.getChildren();
			ArrayList<LexicalResult>[] resultsForArgs = createArray(args.size());
			for (int k = 0; k < resultsForArgs.length; k++) {
				resultsForArgs[k] =
						ConstituentFunctionEvaluator.evaluateConstituent(args.get(k), rule);
				// System.out.println(k + " has " + resultsForArgs[k].size() + "
				// results");
			}

//			pruneDuplicates(resultsForArgs);
//			pruneByIndexCoverage(resultsForArgs);

			ArrayList<NumericResult> results = new ArrayList<NumericResult>();
			// OverlapEvaluator.pruneResultLattice(rule, resultsForArgs);
			evaluateCountDiffSynchronously(rule, condition, args, resultsForArgs, results);
			return results;

		} else {
			// TODO: Give trace for file and line number
			throw new ParseException("Unrecognized count function name: " + functionName, -1);
		}
	}

	private static void pruneByIndexCoverage(ArrayList<LexicalResult>[] resultsForArgs) throws CorpusException {

		long nSearchStates = 1;
		
		// 3) Make a reverse index of words to which args they occur in
		HashMap<String, ArrayList<Integer>> argToString =
				new HashMap<String, ArrayList<Integer>>();
		for (int k = 0; k < resultsForArgs.length; k++) {
			for (int i = 0; i < resultsForArgs[k].size(); i++) {

				String key =
						StringUtils.untokenize(resultsForArgs[k].get(i).getCurrentResultLexicons());
				ArrayList<Integer> indices = argToString.get(key);
				if (indices == null) {
					indices = new ArrayList<Integer>();
					argToString.put(key, indices);
				}

				indices.add(k);
			}
		}

		// 4) Prune words that occur in exactly the same args
		HashSet<String> argSequencesAlreadyCovered = new HashSet<String>();
		HashSet<String> wordsToKeep = new HashSet<String>();
		for (int k = 0; k < resultsForArgs.length; k++) {

			int nPruned = 0;
			final int nOrigSize = resultsForArgs[k].size();
			for (int i = resultsForArgs[k].size() - 1; i >= 0; i--) {

				String key =
						StringUtils.untokenize(resultsForArgs[k].get(i).getCurrentResultLexicons());
				ArrayList<Integer> indices = argToString.get(key);
				String argSequence = StringUtils.untokenize(indices);

				if (argSequencesAlreadyCovered.contains(argSequence)
						&& !wordsToKeep.contains(key)) {

					resultsForArgs[k].remove(i);
					nPruned++;
				} else {
					argSequencesAlreadyCovered.add(argSequence);
					wordsToKeep.add(key);
				}
			}
			System.out.println(nPruned + " of " + nOrigSize + " reverse index pruned.");
			nSearchStates *= resultsForArgs[k].size();
		}

		System.out.println("Expect " + nSearchStates + " search states.");
	}

	private static void pruneDuplicates(ArrayList<LexicalResult>[] resultsForArgs)
			throws CorpusException {
		
		// 2) prune duplicates (those that evaluate to "true" for "same")
		// from each arg's results AND prune incorrect variable usage
		
		for (int k = 0; k < resultsForArgs.length; k++) {

			boolean[] killList = new boolean[resultsForArgs[k].size()];
			for (int i = 0; i < killList.length; i++) {
				killList[i] = false;
			}

			for (int i = 0; i < resultsForArgs[k].size(); i++) {
				for (int j = i + 1; j < resultsForArgs[k].size(); j++) {
					if (isSame(resultsForArgs[k].get(i), resultsForArgs[k].get(j))) {
						killList[j] = true;
					}
				}
			}

			int nPruned = 0;
			for (int i = killList.length - 1; i >= 0; i--) {
				if (killList[i]) {
					resultsForArgs[k].remove(i);
					nPruned++;
				}
			}
			System.out.println(nPruned + " of " + killList.length + " dup pruned.");
		}
	}
	static long xx = 0;

	private static void evaluateCountDiffSynchronously(Rule rule, TreeNode condition,
			ArrayList<TreeNode> args, ArrayList<LexicalResult>[] resultsForArgs,
			List<NumericResult> diffResults) throws ParseException, CorpusException {
		LexicalResult[] proposedResult = new LexicalResult[resultsForArgs.length];
		CountJob job =
				new CountJob(rule, condition, args, resultsForArgs, diffResults, proposedResult, 0);
		evaluateCountDiffRecursively(job, 0);
	}

	// private static void evaluateCountDiffAsynchronously(Rule rule, TreeNode
	// condition,
	// ArrayList<TreeNode> args, ArrayList<LexicalResult>[] resultsForArgs,
	// List<NumericResult> diffResults) {
	//
	// diffResults = Collections.synchronizedList(diffResults);
	// assert countQueue.size() == 0 : "Queue not empty.";
	// assert nInProgress.get() == 0 : "Jobs already in progress.";
	//
	// System.out.println("Performing " + nThreads + " threaded evaluation of
	// count-diff.");
	// for (int i = 0; i < resultsForArgs[0].size(); i++) {
	// LexicalResult[] proposedResult = new
	// LexicalResult[resultsForArgs.length];
	// proposedResult[0] = resultsForArgs[0].get(i);
	// CountJob job =
	// new CountJob(rule, condition, args, resultsForArgs, diffResults,
	// proposedResult, 1);
	// try {
	// countQueue.put(job);
	// } catch (InterruptedException e) {
	// ;
	// }
	// }
	//
	// System.out.println("Waiting for " + countQueue.size() + " jobs to
	// complete...");
	//
	// // threads do their work here...
	//
	// synchronized (nInProgress) {
	// while (nInProgress.get() > 0 && countQueue.size() > 0) {
	// try {
	// nInProgress.wait();
	// } catch (InterruptedException e) {
	// ;
	// }
	//
	// System.out.println(nInProgress.get() + " in progress.");
	// System.out.println(countQueue.size() + " jobs remaining.");
	// }
	// }
	//
	// System.out.println("They're all done!");
	// System.out.println(nInProgress.get() + " in progress!");
	// System.out.println(countQueue.size() + " jobs remaining!");
	//
	// // diffResults should now contain the correct information
	// }

	private static void evaluateCountDiffRecursively(CountJob job, int nArg) throws ParseException,
			CorpusException {

		xx++;
		if (xx % 10000 == 0) {
			System.out.println("Evaluated " + xx + " calls to evaluateCountDiffRecursively");
		}

		if (nArg == job.resultsForArgs.length) {
			if (OverlapEvaluator.isProperPair(job.rule, true, job.proposedResult)) {
				// this is a keeper now count the differences
				job.diffResults.add(countDiffs(job.proposedResult, job.condition));
			}
		} else {

			// preemptively prune bad variable usage.
			// if (OverlapEvaluator.isProperPair(job.rule, false,
			// job.proposedResult)) {
			for (int i = 0; i < job.resultsForArgs[nArg].size(); i++) {
				job.proposedResult[nArg] = job.resultsForArgs[nArg].get(i);
				assert job.proposedResult[nArg] != null : "null element in proposedResult";
				evaluateCountDiffRecursively(job, nArg + 1);
			}
			job.proposedResult[nArg] = null;
			// }
		}

	}

	private static NumericResult countDiffs(LexicalResult[] proposedResult, TreeNode condition)
			throws CorpusException {

		// make a safe copy (clone) of proposedResult to work with
		LexicalResult[] safeCopy = new LexicalResult[proposedResult.length];
		System.arraycopy(proposedResult, 0, safeCopy, 0, proposedResult.length);

		// 1) prune duplicates (those that evaluate to "true" for "same")
		// BETWEEN args
		for (int i = 0; i < safeCopy.length; i++) {
			for (int j = i + 1; j < safeCopy.length; j++) {
				if (safeCopy[i] != null && safeCopy[j] != null && isSame(safeCopy[i], safeCopy[j])) {
					safeCopy[j] = null;
				}
			}
		}

		// 2) Count remaining combinations
		int nDiff = -1;
		NumericResult result = new NumericResult(10);
		result.addComment("count-diff");
		result.setConditionNode(condition);
		for (int i = 0; i < safeCopy.length; i++) {
			if (safeCopy[i] != null) {

				// TODO: perhaps we want to add the whole proposed result
				// instead of just the differences we found?
				result.addOperand(safeCopy[i]);
				nDiff++;
			}
		}
		assert nDiff > -1 : "Negative nDiff";
		result.setCountValue(nDiff);
		return result;
	}

	private static boolean isSame(LexicalResult resultA, LexicalResult resultB)
			throws CorpusException {

		assert resultA != null : "null resultA";
		assert resultB != null : "null resultB";

		ArrayList<TreeNode> lexiconsA = resultA.getCurrentResult();
		ArrayList<TreeNode> lexiconsB = resultB.getCurrentResult();
		if (lexiconsA.size() != lexiconsB.size()) {
			return false;
		} else {
			for (int i = 0; i < lexiconsA.size(); i++) {
				ArrayList<String> valuesA = lexiconsA.get(i).getValues();
				ArrayList<String> valuesB = lexiconsB.get(i).getValues();

				if (valuesA.size() != 2)
					throw new CorpusException("Expected 2 values: POS tag and lexical entry");
				if (valuesB.size() != 2)
					throw new CorpusException("Expected 2 values: POS tag and lexical entry");

				String lexA = valuesA.get(1);
				String lexB = valuesB.get(1);

				assert lexA == lexA.intern() : "lexA not interned: " + lexA;
				assert lexB == lexB.intern() : "lexB not interned: " + lexB;

				if (lexA != lexB) {
					return false;
				}
			}

			return true;
		}
	}

	/**
	 * Hack to get around Java's generic array deficiency.
	 */
	@SuppressWarnings("unchecked")
	private static ArrayList<LexicalResult>[] createArray(int nSize) {
		return (ArrayList<LexicalResult>[]) new ArrayList[nSize];
	}

}
