package edu.cmu.cs.lti.avenue.navigation.tools;

import info.jonclark.util.ArrayUtils;
import info.jonclark.util.FileUtils;
import info.jonclark.util.FormatUtils;
import info.jonclark.util.HashUtils;
import info.jonclark.util.StringUtils;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.TreeMap;
import java.util.Map.Entry;

import edu.cmu.cs.lti.avenue.corpus.Serializer;
import edu.cmu.cs.lti.avenue.navigation.search.oracle.cfglm.CfgLanguageModelEntry;
import edu.cmu.cs.lti.avenue.trees.cfg.CfgRule;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree;
import edu.cmu.cs.lti.avenue.trees.smart.TreeNode;
import edu.cmu.cs.lti.avenue.trees.smart.SmartTree.LabelMode;

/**
 * Extracts EN (source side) CFG rules and their relative frequencies
 * 
 * @author jon
 */
public class PennTreeBankAnalyzer {

	public static void main(String[] args) throws Throwable {

		if (args.length < 1) {
			System.err.println("Usage: program <parsed_dir> [--include-terminals] [--include-x-rules] [--sort-by-freq]");
			System.exit(1);
		}

		boolean includeTerminals = ArrayUtils.unsortedArrayContains(args, "--include-terminals");
		boolean includeXRules = ArrayUtils.unsortedArrayContains(args, "--include-x-rules");
		boolean sortByFreq = ArrayUtils.unsortedArrayContains(args, "--sort-by-freq");

		// get treebank files
		ArrayList<File> subdirs = FileUtils.getSubdirectoriesRecursively(new File(args[0]));
		ArrayList<File> files = new ArrayList<File>(2000);
		for (final File subdir : subdirs) {
			File[] subdirFiles = FileUtils.getFilesWithExt(subdir, ".MRG");
			for (final File subdirFile : subdirFiles) {
				files.add(subdirFile);
			}
		}

		int nTerminals = 0;
		int nSentences = 0;
		int nFiles = 0;
		int nErrors = 0;
		int nRules = 0;

		// parse sentences in each file and extract statistics
		TreeMap<String, Integer> stats = new TreeMap<String, Integer>();
		for (final File file : files) {
			System.err.println(file.getAbsolutePath());
			nFiles++;

			String[] commentsStartWith = new String[] { "( END_OF_TEXT_UNIT )", "( @", "*x*" };
			ArrayList<String> strTrees =
					Serializer.getTreeStringsFromFile(file, "(", commentsStartWith);

			for (String strTree : strTrees) {
				// try {
				strTree = invertPos(strTree);
				SmartTree tree =
						SmartTree.parse(strTree, SmartTree.SOURCE_C_STRUCT_LABEL,
								LabelMode.LABEL_ALL_NODES);
				removeExtraJunk(tree, includeXRules);

				for (final TreeNode node : tree.getLabeledNodes()) {
					// skip empty outer brackets
					if (node.getValues().size() < 1)
						continue;

					if (includeTerminals || !node.isTerminal()) {
						CfgRule cfgRule = node.toCfgRule();
						HashUtils.increment(stats, cfgRule.toString());
						nRules++;
					}

					if (node.isTerminal()) {
						nTerminals++;
//						if(nTerminals % 10000 == 0) {
//							float percent = (float) nTerminals / (float) 2782169 * 100;
//							System.err.println(FormatUtils.formatDouble2(percent) + "% done...");
//						}
					}
				}
				nSentences++;
				// } catch (Throwable t) {
				// System.err.println("Error in file: " +
				// file.getAbsolutePath());
				// System.err.println("Error in tree: " + strTree);
				// System.err.println("Ignoring error...");
				// nErrors++;
				// throw t;
				// }
			}
		}

		// determine how many rules we have for each LHS
		TreeMap<String, ArrayList<CfgLanguageModelEntry>> rulesByLhs = new TreeMap<String, ArrayList<CfgLanguageModelEntry>>();
		TreeMap<String, CfgLanguageModelEntry> rulesByCfg = new TreeMap<String, CfgLanguageModelEntry>();

		for (final Entry<String, Integer> entry : stats.entrySet()) {
			String lhs = StringUtils.substringBetween(entry.getKey(), "[", "->").trim();
			CfgLanguageModelEntry rule = new CfgLanguageModelEntry(entry.getKey(), entry.getValue());
			rule.globalProb = (double) rule.globalFreq / (double) nRules;

			HashUtils.append(rulesByLhs, lhs, rule);
			rulesByCfg.put(entry.getKey(), rule);
		}

		// calculate LHS probs for each rule
		for (final Entry<String, ArrayList<CfgLanguageModelEntry>> entry : rulesByLhs.entrySet()) {
			ArrayList<CfgLanguageModelEntry> lhsRules = entry.getValue();
			int nLhsFreq = 0;
			for (final CfgLanguageModelEntry lhsRule : lhsRules) {
				nLhsFreq += lhsRule.globalFreq;
			}
			for (final CfgLanguageModelEntry lhsRule : lhsRules) {
				lhsRule.lhsFreq = nLhsFreq;
				lhsRule.lhsProb = (double) lhsRule.globalFreq / (double) nLhsFreq;
			}
		}

		// output counts
		System.out.println("# CFG_RULE\tn(rule)\tn(lhs)\tP(rule | lhs)\tP(rule)");
		if (sortByFreq) {

			Collection<CfgLanguageModelEntry> ruleCollection = rulesByCfg.values();
			CfgLanguageModelEntry[] sorted = ruleCollection.toArray(new CfgLanguageModelEntry[ruleCollection.size()]);
			Arrays.sort(sorted, new Comparator<CfgLanguageModelEntry>() {
				public int compare(CfgLanguageModelEntry o1, CfgLanguageModelEntry o2) {
					return o1.globalFreq - o2.globalFreq;
				}
			});
			for (int i = 0; i < sorted.length; i++) {
				System.out.println(sorted[i].toString());
			}
		} else {
			Collection<CfgLanguageModelEntry> ruleCollection = rulesByCfg.values();
			for (final CfgLanguageModelEntry rule : ruleCollection) {
				System.out.println(rule.toString());
			}
		}
		System.out.println("# FILES: " + nFiles);
		System.out.println("# TERMINALS: " + FormatUtils.formatLong(nTerminals));
		System.out.println("# SENTENCES: " + FormatUtils.formatLong(nSentences));
		System.out.println("# RULES: " + FormatUtils.formatLong(nRules));

//		assert terminalsCreated == nTerminals : "terminal count mismatch";
	}

	static int terminalsCreated = 0;

	public static String invertPos(String strTree) {

		int nBefore = StringUtils.countOccurances(strTree, '(');
		String noNewlines = StringUtils.replaceFast(strTree, "\n", " ");
		assert hasMatchedParens(noNewlines) : "paren mismatch";
		String invertedPos = noNewlines.replaceAll("(\\S+?)/(\\S+)", "( $2 $1 )");
		assert hasMatchedParens(invertedPos) : "paren mismatch";
		int nAfter = StringUtils.countOccurances(strTree, ')');

		assert nBefore <= nAfter : "oops, we killed a constituent late (" + nBefore + " vs "
				+ nAfter + ")";

		return invertedPos.trim();
	}

	public static void removeExtraJunk(SmartTree tree, boolean includeXRules) {

		for (final TreeNode node : tree.getAllNodes()) {

			if (node.getValues().size() == 0) {
				continue;
			}

			String val0 = node.getValues().get(0);
			CfgRule cfgRule = node.toCfgRule();

			boolean remove = false;
			if (!includeXRules && cfgRule.toString().contains(" X ")) {
				remove = true;
			} else if (val0.startsWith("-") && val0.endsWith("-")) {
				remove = true;
			} else if(val0.contains("-")) {
				node.getValues().set(0, StringUtils.substringBefore(val0, "-"));
			} else if(val0.equals("EDITED")) {
				remove = true;
			}

			if (remove) {
				node.getParentNode().getChildren().remove(node);
				removeEmpty(node.getParentNode());
			}
		}

		tree.updateTreeStructure();
	}

	private static void removeEmpty(TreeNode node) {
		if (node != null && node.getChildren().size() == 0) {
			TreeNode parent = node.getParentNode();
			if (parent != null) {
				parent.getChildren().remove(node);
				removeEmpty(node.getParentNode());
			}
		}
	}

	public static boolean hasMatchedParens(String strTree) {
		int nOpen = StringUtils.countOccurances(strTree, '(');
		int nClose = StringUtils.countOccurances(strTree, ')');
		return (nOpen == nClose);
	}
}
