package edu.cmu.cs.lti.avenue.projection;

import info.jonclark.log.LogUtils;
import info.jonclark.util.ArrayUtils;
import info.jonclark.util.DebugUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;

import edu.cmu.cs.lti.avenue.corpus.Alignment;
import edu.cmu.cs.lti.avenue.corpus.CorpusException;
import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.projection.ordering.Ordering;
import edu.cmu.cs.lti.avenue.projection.ordering.OrderingModel;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree.LabelMode;

/**
 * Given a sentence pair with a source-side constituent structure and alignments
 * to a target-side sentence, projects (induces) a target-side constituent
 * structure.
 * 
 * @author jon
 */
// TODO: Current reordering models base their positions on the values of
// terminals. Should we instead look at the positions of only a
// constituent's direct children?
public class ConstituentStructureProjector {

	private OrderingModel orderingModel;
	private ProjectionConstraints projectionConstraints;
	protected static final String LINKS = "links";
	protected static final String FEATURES = "features";
	protected static final String TARGET_POSITION = "target_position";
	protected static final String LEX = "LEX";
	private static final String REORDERING_CACHE = "reordering_cache";

	private static final Logger log = LogUtils.getLogger();

	public ConstituentStructureProjector(OrderingModel orderingModel,
			ProjectionConstraints projectionConstraints) {

		this.orderingModel = orderingModel;
		this.projectionConstraints = projectionConstraints;
	}

	public SmartTree project(SentencePair pair) throws CorpusException {

		SmartTree source = pair.getSourceConstituentStructure();

		TreeNode topTargetNode = null;
		TreeNode[] targetNodeCache = new TreeNode[pair.getNormalizedTargetTokens().length];
		boolean[] targetTerminalCoverage = new boolean[pair.getNormalizedTargetTokens().length];

		ArrayList<TreeNode> nodesBottomUp = source.getLabeledNodes();
		for (int i = 0; i < nodesBottomUp.size(); i++) {
			final TreeNode sourceNode = nodesBottomUp.get(i);
			boolean isSecondToLastNode =
					(i == nodesBottomUp.size() - 2) || (nodesBottomUp.size() < 2);
			topTargetNode =
					processSourceNode(pair, topTargetNode, targetNodeCache, targetTerminalCoverage,
							sourceNode, isSecondToLastNode);

			// flush unused cache entries so that we don't set the POS to LEX
			// erroneously
			for (int j = 0; j < targetNodeCache.length; j++) {
				if (targetTerminalCoverage[j] == false) {
					targetNodeCache[j] = null;
				}
			}
		}

		// encapsulate in a SmartTree object and (re-)assign tree/terminal
		// indices
		assert topTargetNode != null;
		SmartTree target =
				SmartTree.createTreeFromRoot(SmartTree.TARGET_C_STRUCT_LABEL,
						LabelMode.LABEL_ALL_NODES, topTargetNode);
		checkProjectionSanity(pair, target);
		return target;
	}

	private TreeNode processSourceNode(SentencePair pair, TreeNode topTargetNode,
			TreeNode[] targetNodeCache, boolean[] targetTerminalCoverage,
			final TreeNode sourceNode, boolean isSecondToLastNode) throws CorpusException {

		Alignment alignment = pair.getNormalizedAlignment();
		String[] fSentence = pair.getNormalizedTargetTokens();

		ArrayList<TreeNode> targetChildren = new ArrayList<TreeNode>();
		getLinkedTargetChildren(sourceNode, pair, targetNodeCache, targetChildren);
		ensureUnique(targetChildren, targetNodeCache);

		// NOTE: Unaligned target words can't affect our reordering metric
		boolean[][] reorderingCache = getReorderingCacheUnion(pair, targetChildren);
		ProjectionFeatures features =
				ProjectionFeatureExtractor.extractProjectionFeatures(pair, sourceNode,
						targetNodeCache, targetTerminalCoverage, reorderingCache);
		sourceNode.putMetaData(FEATURES, features);

		if (projectionConstraints.includeUnalignedTargetWords) {
			ensureUnique(features.unalignedTargets, targetNodeCache);
			for (final TreeNode unalignedTarget : features.unalignedTargets) {
				if (!contains(targetChildren, unalignedTarget)) {
					targetChildren.add(unalignedTarget);
				}
			}
		}

		// if this is the top node, thrown in all the leftovers
		if (isSecondToLastNode) {
			ArrayList<TreeNode> leftovers =
					getLeftovers(fSentence, targetTerminalCoverage, targetNodeCache);
			for (final TreeNode leftover : leftovers) {
				if (!contains(targetChildren, leftover)) {
					targetChildren.add(leftover);
					features.lo++;
				}
			}
		}

		// NOTE: terminals should never have reorderings
		// but we do need to link them when projecting
		Ordering order = orderingModel.getOrder(targetChildren, alignment);
		features.ag = order.ambiguitiesByGrouping;
		log.fine("ORDER: " + Arrays.toString(order.relativePositions) + "; ambiguous="
				+ order.ambiguous);

		// don't process non-terminals that have no children
		if (sourceNode.isTerminal() || targetChildren.size() > 0) {

			// non-null order implies that we know how to deal with each of
			// the children here (i.e. the ordering model actually builds
			// the target tree)
			if (order.ambiguous || !projectionConstraints.accepts(features)) {
				log.fine("REJECT FROM SOURCE: " + sourceNode + " ||| " + features);
			} else {
				features.ed =
						ProjectionFeatureExtractor.getLevenshteinEditDistance(order.relativePositions);
				topTargetNode =
						connectTargetTree(topTargetNode, targetTerminalCoverage, sourceNode,
								features, order);
				if (topTargetNode != null) {
					topTargetNode.putMetaData(REORDERING_CACHE, reorderingCache);
				}
			} // end if order != null
		}
		return topTargetNode;
	}

	private ArrayList<TreeNode> getLeftovers(String[] fSentence, boolean[] targetTerminalCoverage,
			TreeNode[] targetNodeCache) {

		ArrayList<TreeNode> leftovers = new ArrayList<TreeNode>();

		for (int j = 0; j < targetTerminalCoverage.length; j++) {
			if (targetTerminalCoverage[j] == false) {
				TreeNode leftoverTarget =
						TreeNode.createOrphanNode(j, SmartTree.TARGET_C_STRUCT_LABEL);
				leftoverTarget.addValue(LEX);
				leftoverTarget.addValue(fSentence[j]);
				leftovers.add(leftoverTarget);
			}
		}
		ensureUnique(leftovers, targetNodeCache);

		return leftovers;
	}

	private static boolean[][] getReorderingCacheUnion(SentencePair pair,
			ArrayList<TreeNode> targetNodes) {

		boolean[][] union = ProjectionFeatureExtractor.getEmptyReorderingCache(pair);
		for (final TreeNode targetNode : targetNodes) {
			boolean[][] child = (boolean[][]) targetNode.getMetaData(REORDERING_CACHE);
			if (child != null) {
				for (int i = 0; i < child.length; i++) {
					for (int j = 0; j < child[i].length; j++) {
						union[i][j] |= child[i][j];
					}
				}
			}
		}
		return union;
	}

	private TreeNode connectTargetTree(TreeNode topTargetNode, boolean[] targetTerminalCoverage,
			final TreeNode sourceNode, ProjectionFeatures features, Ordering order) {

		log.fine("BUILDING FROM SOURCE: " + sourceNode + " ||| " + features);

		for (final TreeNode targetNode : order.orderedNodes) {
			for (final int n : targetNode.getTerminalIndices()) {
				targetTerminalCoverage[n] = true;
			}
		}

		if (sourceNode.isTerminal()) {
			if (order.orderedNodes.size() > 0) {
				sourceNode.putMetaData(LINKS, order.orderedNodes);
				topTargetNode = order.orderedNodes.get(0);
			}
		} else {

			topTargetNode = TreeNode.createOrphanNode(-1, SmartTree.TARGET_C_STRUCT_LABEL);
			topTargetNode.addValue(sourceNode.getValues().get(0));
			sourceNode.putMetaData(LINKS, ArrayUtils.toArrayList(topTargetNode));

			for (final TreeNode node : order.orderedNodes) {
				topTargetNode.assimilateChild(node);
			}
		}

		log.fine("BUILT TARGET: " + topTargetNode);
		return topTargetNode;
	}

	private static void checkProjectionSanity(SentencePair pair, SmartTree target) {
		// do some sanity checking
		if (DebugUtils.isAssertEnabled()) {
			String[] expected = pair.getNormalizedTargetTokens();
			ArrayList<TreeNode> targetTerminals = target.getTerminalNodes();
			assert expected.length == targetTerminals.size() : "Length mismatch between target tree and target terminals";

			String[] actual = new String[expected.length];
			for (int i = 0; i < actual.length; i++) {
				actual[i] = targetTerminals.get(i).getValues().get(1);
			}
			for (int i = 0; i < expected.length; i++) {
				assert actual[i].equals(expected[i]) : pair.serialize()
						+ "\nTerminal mismatch at position " + i + ": " + actual[i] + " vs "
						+ expected[i];
			}
		}
	}

	/**
	 * Ensures that target terminals are unique nodes since dual alignments
	 * cause ambiguity
	 * 
	 * @param nodes
	 * @param targetNodeCache
	 */
	protected static void ensureUnique(ArrayList<TreeNode> nodes, TreeNode[] targetNodeCache) {

		for (int i = 0; i < nodes.size(); i++) {
			TreeNode node = nodes.get(i);

			// non-terminals are guaranteed to be unique
			if (node.isTerminal()) {
				int terminalIndex = node.getTerminalIndices().get(0);
				if (targetNodeCache[terminalIndex] != null) {
					nodes.set(i, targetNodeCache[terminalIndex]);

					// if the POS's don't match
					if (!targetNodeCache[terminalIndex].getValues().get(0).equals(
							node.getValues().get(0))) {
						targetNodeCache[terminalIndex].getValues().set(0, LEX);
					}
				} else {
					targetNodeCache[terminalIndex] = node;
				}
			}
		}
	}

	/**
	 * Gets the linked children (those eligible to be connected to the source
	 * node on the target side).
	 * 
	 * @param sourceNode
	 * @return
	 */
	@SuppressWarnings("unchecked")
	private static void getLinkedTargetChildren(TreeNode sourceNode, SentencePair pair,
			TreeNode[] targetNodeCache, ArrayList<TreeNode> targetChildren) throws CorpusException {

		ArrayList<TreeNode> sourceToTargetLinks =
				(ArrayList<TreeNode>) sourceNode.getMetaData(LINKS);
		if (sourceToTargetLinks != null) {
			for (final TreeNode link : sourceToTargetLinks) {
				// if (!targetChildren.contains(link)) {
				targetChildren.add(link);
				// }
			}

		} else if (sourceNode.isTerminal()) {
			Alignment alignment = pair.getNormalizedAlignment();
			String[] fSentence = pair.getNormalizedTargetTokens();
			ArrayList<TreeNode> targetTerminals = alignment.getTargetNodes(fSentence, sourceNode);
			ensureUnique(targetTerminals, targetNodeCache);
			for (final TreeNode targetTerminal : targetTerminals) {
				if (!contains(targetChildren, targetTerminal)) {
					targetChildren.add(targetTerminal);
				}
			}
		} else {
			// we couldn't figure out the order for this node
			// so we have to dig deeper
			for (final TreeNode sourceChild : sourceNode.getChildren()) {
				getLinkedTargetChildren(sourceChild, pair, targetNodeCache, targetChildren);
			}
		}
	}

	private static boolean contains(Iterable<TreeNode> list, TreeNode a) {

		for (final TreeNode b : list) {

			if (a.getValues().equals(b.getValues())
					&& a.getTerminalIndices().equals(b.getTerminalIndices())) {
				return true;
			}
		}
		return false;
	}
}
