package edu.cmu.cs.lti.avenue.trees.smart;

import info.jonclark.util.ArrayUtils;
import info.jonclark.util.DebugUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.Vector;

import edu.cmu.cs.lti.avenue.trees.smart.SmartTree.LabelDisplay;

/**
 * A wrapper around a SmartTree that allows for certain nodes to be designated
 * as wildcards.
 * <p>
 * This is a time critical part of the software since it gets called millions of
 * times during the creation of minimal pairs.
 */
public class SmartTreeDelta {
	private SmartTree tree;
	private int[] wildcardNodes;
	private boolean useExactMinimalPairs;
	private static final Random rand = new Random();

	/**
	 * @param tree
	 * @param wildcardNodes
	 *            An array containing the absoluteTreeLabeling indices of the
	 *            nodes whose values should always be considered matches.
	 */
	public SmartTreeDelta(SmartTree tree, int[] wildcardNodes, boolean useExactMinimalPairs) {

		this.useExactMinimalPairs = useExactMinimalPairs;

		this.tree = tree;
		tree.sort();

		this.wildcardNodes = new int[wildcardNodes.length];
		System.arraycopy(wildcardNodes, 0, this.wildcardNodes, 0, wildcardNodes.length);
		Arrays.sort(this.wildcardNodes);

		if (DebugUtils.isAssertEnabled()) {
			for (int n : wildcardNodes) {
				assert n >= 0;
				Vector<TreeNode> allNodes = tree.getAllNodes();
				assert n < allNodes.size();
			}
		}
	}

	public int hashCode() {
		int hash = 0;

		// don't hash based on skeleton structure if we're not going to compare
		// against it
		if (useExactMinimalPairs) {
			for (final TreeNode node : tree.getAllNodes()) {

				if (ArrayUtils.unsortedArrayContains(wildcardNodes, node.getAbsoluteTreeIndex()))
					continue;

				for (final String value : node.getValues()) {
					hash += value.hashCode();
				}
			}

			for (int i = 0; i < wildcardNodes.length; i++) {
				hash +=
						wildcardNodes[i] * wildcardNodes[i] * wildcardNodes[i] * wildcardNodes[i]
								* wildcardNodes[i];
				// Math.pow(wildcardNodes[i], i + 10);
			}
		} else {
			for (int i = 0; i < wildcardNodes.length; i++) {
				TreeNode wildcardNode = tree.getByAbsoluteTreeIndex(wildcardNodes[i]);
				hash += wildcardNode.getValues().get(0).hashCode();
			}
		}

		return hash;
	}

	/**
	 * Check if the wildcard portion of the two trees match.
	 * 
	 * @param myWildcard
	 * @param otherWildcard
	 * @param myDelta
	 * @param otherDelta
	 * @return
	 */
	private boolean wildcardsMatch(int myWildcard, int otherWildcard, SmartTreeDelta myDelta,
			SmartTreeDelta otherDelta) {

		TreeNode myNode = myDelta.tree.getByAbsoluteTreeIndex(myWildcard);
		TreeNode otherNode = otherDelta.tree.getByAbsoluteTreeIndex(otherWildcard);

		if (myNode.getValues().size() < 1 || otherNode.getValues().size() < 1)
			return false;

		String myValue = myNode.getValues().get(0);
		String otherValue = otherNode.getValues().get(0);
		if (myValue != otherValue) // interned strings
			return false;

		return parentContextEquals(myDelta, otherDelta, myNode.getParentNode(),
				otherNode.getParentNode());
	}

	private static boolean compareValues(SmartTreeDelta myDelta, SmartTreeDelta otherDelta,
			TreeNode myTree, TreeNode otherTree) {

		ArrayList<String> myValues = myTree.getValues();
		ArrayList<String> otherValues = otherTree.getValues();
		if (myValues.size() != otherValues.size()) {
			return false;
		}

		for (int i = 0; i < myValues.size(); i++) {

			// skip wildcards, but make sure it occurs in the same enclosing
			// context
			if (i == 1
					&& Arrays.binarySearch(myDelta.wildcardNodes, myTree.getAbsoluteTreeIndex()) >= 0) {

				if (parentContextEquals(myDelta, otherDelta, myTree.getParentNode(),
						otherTree.getParentNode())) {
					continue;
				}
			}

			if (myValues.get(i) != otherValues.get(i)) { // interned strings
				// otherwise require all value to match
				return false;
			}
		}
		return true;
	}

	/**
	 * Check if the context in which this feature value occurred is the same as
	 * the context in which the candidate feature value occurred.
	 * 
	 * @param myDelta
	 * @param otherDelta
	 * @param myTree
	 * @param otherTree
	 * @return
	 */
	private static boolean parentContextEquals(SmartTreeDelta myDelta, SmartTreeDelta otherDelta,
			TreeNode myTree, TreeNode otherTree) {

		if (myTree.getParentNode() == null && otherTree.getParentNode() == null)
			return true;
		else if (myTree.getParentNode() == null || otherTree.getParentNode() == null)
			return false;
		else if (compareValues(myDelta, otherDelta, myTree, otherTree) == false)
			return false;
		else
			return parentContextEquals(myDelta, otherDelta, myTree.getParentNode(),
					otherTree.getParentNode());
	}

	/**
	 * Check if all parts of the trees other than the wildcard nodes match.
	 * 
	 * @param myDelta
	 * @param otherDelta
	 * @param myTree
	 * @param otherTree
	 * @return
	 */
	private static boolean subtreeEquals(SmartTreeDelta myDelta, SmartTreeDelta otherDelta,
			TreeNode myTree, TreeNode otherTree) {

		ArrayList<TreeNode> myChildren = myTree.getChildren();
		ArrayList<TreeNode> otherChildren = otherTree.getChildren();
		if (myChildren.size() != otherChildren.size()) {
			return false;
		}

		if (compareValues(myDelta, otherDelta, myTree, otherTree) == false)
			return false;

		// check features and their values "unification style"
		// we know they're sorted (constructor)
		int i = 0;
		int j = 0;
		while (i < myChildren.size() && j < otherChildren.size()) {
			TreeNode myChild = myChildren.get(i);
			TreeNode otherChild = otherChildren.get(j);
			boolean hasValues = (myChild.values.size() > 0 && otherChild.values.size() > 0);
			int compareResult =
					hasValues ? myChild.getValues().get(0).compareTo(otherChild.getValues().get(0))
							: 0;
			if (hasValues && compareResult > 0) {
				i++;
			} else if (hasValues && compareResult < 0) {
				j++;
			} else if (subtreeEquals(myDelta, otherDelta, myChildren.get(i), otherChildren.get(j))) {
				i++;
				j++;
			} else {
				return false;
			}
		}
		return true;
	}

	public boolean equals(Object obj) {
		if (obj instanceof SmartTreeDelta) {
			SmartTreeDelta other = (SmartTreeDelta) obj;

			// first, check to see if the wildcards match AND if they're in the
			// same context
			if (this.wildcardNodes.length != other.wildcardNodes.length)
				return false;

			boolean[] otherCoveredWildcards = new boolean[this.wildcardNodes.length];
			for (int myWildcard : this.wildcardNodes) {
				boolean matched = false;
				int j = 0;
				for (int otherWildcard : other.wildcardNodes) {
					if (otherCoveredWildcards[j] == false) {
						if (wildcardsMatch(myWildcard, otherWildcard, this, other)) {
							matched = true;
							otherCoveredWildcards[j] = true;
							break;
						}
					}
					j++;
				}

				if (!matched) {
					return false;
				}
			}

			if (useExactMinimalPairs == false) {
				return true;
			} else {
				// now, check to make sure the rest of the structure matches
				boolean result =
						subtreeEquals(this, other, tree.getRootNode(), other.tree.getRootNode());

				// randomly check the sanity of about 5% of the pairs (it makes
				// debugging difficult, but errors are at least detectable
				// within reasonable time limits this way)
				if (DebugUtils.isAssertEnabled() && rand.nextFloat() < 0.05) {
					String wildcard1 = this.toString(LabelDisplay.NONE);
					String wildcard2 = other.toString(LabelDisplay.NONE);
					assert !result || wildcard1.equals(wildcard2) : "Fired incorrect equality (STOCHASTIC ASSERT): "
							+ wildcard1 + "\n" + wildcard2;
				}
				return result;
			}
		} else {
			return false;
		}
	}
	
	public int[] getWildcardNodes() {
		return wildcardNodes;
	}

	public String toString() {
		return toString(LabelDisplay.NONE);
	}

	public String toString(LabelDisplay disp) {
		String[] temp = new String[wildcardNodes.length];
		for (int i = 0; i < wildcardNodes.length; i++) {
			int n = wildcardNodes[i];
			TreeNode node = tree.getByAbsoluteTreeIndex(n);
			if (node.getValues().size() >= 2) {
				temp[i] = node.getValues().get(1);
				node.getValues().set(1, "*");
			} else {
				// this is an invalid wildcard node, ignore it
			}
		}

		String str = tree.toString(disp);

		// put things back like we found them
		for (int i = 0; i < wildcardNodes.length; i++) {
			int n = wildcardNodes[i];
			TreeNode node = tree.getByAbsoluteTreeIndex(n);
			if (node.getValues().size() >= 2) {
				node.getValues().set(1, temp[i]);
			} else {
				// this is an invalid wildcard node, ignore it
			}
		}

		return str;
	}
}
