package edu.cmu.cs.lti.avenue.navigation.search.generation1;

import info.jonclark.util.HashUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.TreeMap;
import java.util.Map.Entry;

import edu.cmu.cs.lti.avenue.corpus.SentencePair;
import edu.cmu.cs.lti.avenue.trees.cfg.CfgRule;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;

public class DiversityScorer {
	// provide data on count of different types of subtrees
	// use string representation of CFG rules to count
	// rank based on TFIDF: CFG rule frequency * inverse CFG rule frequency
	private final HashMap<String, Integer> corpusCounts;

	public enum SmoothingMode {
		ADD1
	}

	private ArithmeticBlender blender;
	private SmoothingMode smoothing;

	public DiversityScorer(ArrayList<SentencePair> corpus, ArithmeticBlender blender,
			SmoothingMode smoothing) {

		this.corpusCounts = getGlobalSourceTreeTypes(corpus);
		this.blender = blender;
		this.smoothing = smoothing;
	}

	public static HashMap<String, Integer> getSourceTreeTypes(SentencePair pair) {
		HashMap<String, Integer> types = new HashMap<String, Integer>();

		ArrayList<TreeNode> allNodes = pair.getSourceConstituentStructure().getLabeledNodes();
		for (final TreeNode node : allNodes) {
			CfgRule strType = node.toCfgRule();
			HashUtils.increment(types, strType.toString());
		}

		return types;
	}

	public static HashMap<String, Integer> getGlobalSourceTreeTypes(ArrayList<SentencePair> corpus) {
		HashMap<String, Integer> types = new HashMap<String, Integer>();

		for (final SentencePair pair : corpus) {
			ArrayList<TreeNode> allNodes = pair.getSourceConstituentStructure().getLabeledNodes();
			for (final TreeNode node : allNodes) {
				CfgRule cfgRule = node.toCfgRule();
				HashUtils.increment(types, cfgRule.toString());
			}
		}

		return types;
	}

	public TreeMap<Double, String> getIndividualConstituentRanking(SentencePair pair) {
		HashMap<String, Integer> pairCounts = getSourceTreeTypes(pair);
		TreeMap<Double, String> ranking = new TreeMap<Double, String>();

		for (final Entry<String, Integer> entry : pairCounts.entrySet()) {

			// TFIDF-style
			int pairCount = entry.getValue();
			int globalCount = corpusCounts.get(entry.getKey());

			if (smoothing == SmoothingMode.ADD1)
				globalCount += 1.0;

			double ruleScore = (double) pairCount / (double) globalCount;
			ranking.put(ruleScore, entry.getKey() + "(" + pairCount + "/" + globalCount + ")");
		}

		return ranking;
	}

	public double getDiversityScore(SentencePair pair) {
		HashMap<String, Integer> pairCounts = getSourceTreeTypes(pair);

		double average = 0.0;
		double sum = 0.0;
		double product = 0.0;
		double max = 0.0;

		// TODO: Weight by how frequently this structure occurs in real text
		for (final Entry<String, Integer> entry : pairCounts.entrySet()) {

			// TFIDF-style
			double pairCount = entry.getValue();
			double globalCount = corpusCounts.get(entry.getKey());

			if (smoothing == SmoothingMode.ADD1)
				globalCount += 1.0;

			double ruleScore = pairCount / globalCount;

			max = Math.max(max, ruleScore);
			sum += ruleScore;
			product *= ruleScore;
		}

		average = sum / pairCounts.size();

		return blender.blend(average, max, product, sum);
	}
}
