/*
 * Created on Jun 9, 2007
 */
package info.jonclark.stat;

import info.jonclark.io.StringTable;
import info.jonclark.util.StringUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;

/**
 * Given a vector of outputs from a classifier system, calculats the F1 score
 * (and therefore the precision and recall as well).
 */
public class F1HierarchicalCalculator {

	private static class Counts {
		public int nFired = 0;
		public int nCorrect = 0;
		public int nExpected = 0;
		public final HashMap<String, Counts> children = new HashMap<String, Counts>();

		public String getTotalRecallAsFraction() {
			return nCorrect + "/" + nExpected;
		}

		public String getTotalPrecisionAsFraction() {
			return nCorrect + "/" + nFired;
		}

		public double getTotalRecall() {
			double r = (double) nCorrect / (double) nExpected;
			return r;
		}

		public double getTotalPrecision() {
			double p = (double) nCorrect / (double) nFired;
			return p;
		}

		public double getTotalF1Score() {
			double p = getTotalPrecision();
			double r = getTotalRecall();
			double f1 = 2 * p * r / (p + r);
			return f1;
		}
	}

	private final Counts counts = new Counts();

	public void addCorrectOutcome(String... contexts) {
		counts.nCorrect++;
		addCorrectOutcome(counts, 0, contexts);
	}

	public void addExpectedOutcome(String... contexts) {
		counts.nExpected++;
		addExpectedOutcome(counts, 0, contexts);
	}

	public void addObservedOutcome(String... contexts) {
		counts.nFired++;
		addObservedOutcome(counts, 0, contexts);
	}

	protected void addCorrectOutcome(Counts parent, int offset, String[] contexts) {
		String context = contexts[offset];
		Counts child = parent.children.get(context);
		if (child == null) {
			child = new Counts();
			parent.children.put(context, child);
		}
		child.nCorrect++;
		if (offset < contexts.length - 1)
			addCorrectOutcome(child, offset + 1, contexts);
	}

	protected void addExpectedOutcome(Counts parent, int offset, String[] contexts) {
		String context = contexts[offset];
		Counts child = parent.children.get(context);
		if (child == null) {
			child = new Counts();
			parent.children.put(context, child);
		}
		child.nExpected++;
		if (offset < contexts.length - 1)
			addExpectedOutcome(child, offset + 1, contexts);
	}

	protected void addObservedOutcome(Counts parent, int offset, String[] contexts) {
		String context = contexts[offset];
		Counts child = parent.children.get(context);
		if (child == null) {
			child = new Counts();
			parent.children.put(context, child);
		}
		child.nFired++;
		if (offset < contexts.length - 1)
			addObservedOutcome(child, offset + 1, contexts);
	}

	public String getTotalRecallAsFraction() {
		return counts.getTotalRecallAsFraction();
	}

	public String getTotalPrecisionAsFraction() {
		return counts.getTotalPrecisionAsFraction();
	}

	public double getTotalRecall() {
		return counts.getTotalRecall();
	}

	public double getTotalPrecision() {
		return counts.getTotalPrecision();
	}

	public double getTotalF1Score() {
		return counts.getTotalF1Score();
	}

	public StringTable getF1Report() {
		return getF1Report(Integer.MAX_VALUE);
	}

	/**
	 * Generates a full report in CSV format for later use by a program such as
	 * excel.
	 */
	public StringTable getF1Report(int maxDepth) {
		StringTable table = new StringTable();
		table.addRow("Context", "F1 (%)", "Precision (%)", "Precision (correct/fired)",
				"Recall (%)", "Recall (correct/expected)");

		ArrayList<String> context = new ArrayList<String>();
		context.add("root");

		getF1Report(table, context, counts, 1, maxDepth);

		return table;
	}

	private void getF1Report(StringTable table, ArrayList<String> context, Counts me, int depth,
			int maxDepth) {
		table.addRow(StringUtils.untokenize(context, "/"), me.getTotalF1Score() + "",
				me.getTotalPrecision() + "", me.getTotalPrecisionAsFraction(), me.getTotalRecall()
						+ "", me.getTotalRecallAsFraction());

		if (depth < maxDepth) {
			for (final Entry<String, Counts> c : me.children.entrySet()) {
				context.add(c.getKey());
				getF1Report(table, context, c.getValue(), depth + 1, maxDepth);
				context.remove(context.size() - 1);
			}
		}
	}
}
