/*
 * Decompiled with CFR 0.152.
 */
package multeval;

import com.google.common.base.Charsets;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import jannopts.ConfigurationException;
import jannopts.Configurator;
import jannopts.Option;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import multeval.HypothesisManager;
import multeval.Module;
import multeval.MultEval;
import multeval.NbestEntry;
import multeval.metrics.BLEU;
import multeval.metrics.METEOR;
import multeval.metrics.Metric;
import multeval.metrics.SuffStats;
import multeval.metrics.TER;
import multeval.parallel.MetricWorkerPool;
import multeval.parallel.SynchronizedBufferedReader;
import multeval.parallel.SynchronizedPrintStream;
import multeval.util.FileUtils;
import multeval.util.StringUtils;
import multeval.util.SuffStatUtils;

public class NbestModule
implements Module {
    @Option(shortName="v", longName="verbosity", usage="Verbosity level", defaultValue="0")
    public int verbosity;
    @Option(shortName="o", longName="metrics", usage="Space-delimited list of metrics to use. Any of: bleu, meteor, ter, length", defaultValue="bleu meteor ter", arrayDelim=" ")
    public String[] metricNames;
    @Option(shortName="N", longName="nbest", usage="File containing tokenized, fullform hypotheses, one per line")
    public String nbestList;
    @Option(shortName="R", longName="refs", usage="Space-delimited list of files containing tokenized, fullform references, one per line", arrayDelim=" ")
    public String[] refFiles;
    @Option(shortName="r", longName="rankDir", usage="Rank hypotheses of median optimization run of each system with regard to improvement/decline over median baseline system and output to the specified directory for analysis", required=false)
    private String rankDir;
    @Option(shortName="t", longName="threads", usage="Number of threads to use. This will be reset to 1 thread if you choose to use any thread-unsafe metrics such as TER (Zero means use all available cores)", defaultValue="0")
    private int threads;

    @Override
    public Iterable<Class<?>> getDynamicConfigurables() {
        return ImmutableList.of(BLEU.class, METEOR.class, TER.class);
    }

    @Override
    public void run(Configurator opts) throws ConfigurationException, IOException, InterruptedException {
        String line;
        SynchronizedPrintStream[] metricRankFiles;
        final List<Metric<?>> metrics = MultEval.loadMetrics(this.metricNames, opts);
        final String[] submetricNames = NbestModule.getSubmetricNames(metrics);
        this.threads = MultEval.initThreads(metrics, this.threads);
        String lastLine = FileUtils.getLastLine(this.nbestList);
        NbestEntry lastEntry = NbestEntry.parse(lastLine, -1, 0);
        int numHyps = lastEntry.sentId + 1;
        List<List<String>> allRefs = HypothesisManager.loadRefs(this.refFiles, numHyps);
        System.err.println("Found " + numHyps + " hypotheses with " + allRefs.get(0).size() + " references");
        final SynchronizedPrintStream out = new SynchronizedPrintStream(System.out);
        SynchronizedPrintStream[] synchronizedPrintStreamArray = metricRankFiles = this.rankDir == null ? null : new SynchronizedPrintStream[metrics.size()];
        if (this.rankDir != null) {
            new File(this.rankDir).mkdirs();
            for (int iMetric = 0; iMetric < metrics.size(); ++iMetric) {
                metricRankFiles[iMetric] = new SynchronizedPrintStream(new PrintStream(new File(this.rankDir, this.metricNames[iMetric] + ".sorted"), "UTF-8"));
            }
        }
        SynchronizedBufferedReader in = new SynchronizedBufferedReader(new BufferedReader(new InputStreamReader((InputStream)new FileInputStream(this.nbestList), Charsets.UTF_8)));
        int DEFAULT_NUM_HYPS = 1000;
        ArrayList<NbestEntry> hyps = new ArrayList<NbestEntry>(1000);
        final ArrayList oracleStatsByMetric = new ArrayList(metrics.size());
        final ArrayList woracleStatsByMetric = new ArrayList(metrics.size());
        final ArrayList topbestStatsByMetric = new ArrayList(metrics.size());
        for (int i = 0; i < metrics.size(); ++i) {
            oracleStatsByMetric.add(new ArrayList());
            woracleStatsByMetric.add(new ArrayList());
            topbestStatsByMetric.add(new ArrayList());
        }
        MetricWorkerPool work = new MetricWorkerPool<NbestTask, List<Metric<?>>>(this.threads, new Supplier<List<Metric<?>>>(){

            @Override
            public List<Metric<?>> get() {
                ArrayList copy = new ArrayList(metrics.size());
                for (Metric metric : metrics) {
                    copy.add(metric.threadClone());
                }
                return copy;
            }
        }){

            @Override
            public void doWork(List<Metric<?>> localMetrics, NbestTask t) {
                try {
                    NbestModule.this.processHyp(localMetrics, submetricNames, t.myHyps, t.sentRefs, out, metricRankFiles, oracleStatsByMetric, woracleStatsByMetric, topbestStatsByMetric);
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                    System.exit(1);
                }
            }
        };
        work.start();
        int curHyp = 0;
        int iLine = 0;
        while ((line = in.readLine()) != null) {
            ++iLine;
            NbestEntry entry = NbestEntry.parse(line, hyps.size(), metrics.size());
            if (curHyp != entry.sentId) {
                List<String> sentRefs = allRefs.get(curHyp);
                work.addTask(new NbestTask(hyps, sentRefs));
                if (iLine % 10000 == 0) {
                    System.err.println("Processed " + iLine + " lines (" + curHyp + " hypotheses) so far...");
                }
                int prevNumHyps = hyps.size();
                hyps = new ArrayList(prevNumHyps);
                entry.origRank = 0;
                curHyp = entry.sentId;
            }
            hyps.add(entry);
        }
        List<String> sentRefs = allRefs.get(curHyp);
        work.addTask(new NbestTask(hyps, sentRefs));
        work.waitForCompletion();
        out.close();
        if (this.rankDir != null) {
            System.err.println("Wrote n-best list ranked by metrics to: " + this.rankDir);
            for (int iMetric = 0; iMetric < metrics.size(); ++iMetric) {
                metricRankFiles[iMetric].close();
            }
        }
        for (int i = 0; i < metrics.size(); ++i) {
            Metric<?> metric = metrics.get(i);
            SuffStats<?> topbestStats = SuffStatUtils.sumStats((List)topbestStatsByMetric.get(i));
            double topbestScore = metric.scoreStats(topbestStats);
            String topbestSub = metric.scoreSubmetricsString(topbestStats);
            System.err.println(String.format("%s topbest score: %.2f (%s)", metric.toString(), topbestScore, topbestSub));
            SuffStats<?> oracleStats = SuffStatUtils.sumStats((List)oracleStatsByMetric.get(i));
            double oracleScore = metric.scoreStats(oracleStats);
            String oracleSub = metric.scoreSubmetricsString(oracleStats);
            System.err.println(String.format("%s oracle score: %.2f (%s)", metric.toString(), oracleScore, oracleSub));
            SuffStats<?> woracleStats = SuffStatUtils.sumStats((List)woracleStatsByMetric.get(i));
            double woracleScore = metric.scoreStats(woracleStats);
            String woracleSub = metric.scoreSubmetricsString(woracleStats);
            System.err.println(String.format("%s worst-oracle score: %.2f (%s)", metric.toString(), woracleScore, woracleSub));
        }
    }

    public static String[] getSubmetricNames(List<Metric<?>> metrics) {
        int numSubmetrics = 0;
        for (Metric<?> metric : metrics) {
            numSubmetrics += metric.getSubmetricNames().length;
        }
        String[] submetricNames = new String[numSubmetrics];
        int i = 0;
        for (Metric<?> metric : metrics) {
            String[] arr$ = metric.getSubmetricNames();
            int len$ = arr$.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                String name;
                submetricNames[i] = name = arr$[i$];
                ++i;
            }
        }
        return submetricNames;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void processHyp(List<Metric<?>> metricCopies, String[] submetricNames, List<NbestEntry> hyps, List<String> sentRefs, SynchronizedPrintStream out, SynchronizedPrintStream[] metricRankFiles, List<List<SuffStats<?>>> oracleStatsByMetric, List<List<SuffStats<?>>> woracleStatsByMetric, List<List<SuffStats<?>>> topbestStatsByMetric) throws InterruptedException {
        for (int iRank = 0; iRank < hyps.size(); ++iRank) {
            ArrayList metricStats = new ArrayList(metricCopies.size());
            double[] metricScores = new double[metricCopies.size()];
            double[] submetricScores = new double[submetricNames.length];
            NbestEntry entry = hyps.get(iRank);
            entry.hyp = StringUtils.normalizeWhitespace(entry.hyp);
            int iSubmetric = 0;
            for (int iMetric = 0; iMetric < metricCopies.size(); ++iMetric) {
                Metric<?> metric = metricCopies.get(iMetric);
                Object stats = metric.stats(entry.hyp, sentRefs);
                metricStats.add((SuffStats<?>)stats);
                metricScores[iMetric] = metric.scoreStats((SuffStats<?>)stats);
                double[] arr$ = metric.scoreSubmetricsStats((SuffStats<?>)stats);
                int len$ = arr$.length;
                for (int i$ = 0; i$ < len$; ++i$) {
                    double sub;
                    submetricScores[iSubmetric] = sub = arr$[i$];
                    ++iSubmetric;
                }
            }
            entry.metricStats = metricStats;
            entry.metricScores = metricScores;
            entry.submetricScores = submetricScores;
        }
        for (int iMetric = 0; iMetric < metricCopies.size(); ++iMetric) {
            List<List<SuffStats<?>>> metricStats = topbestStatsByMetric;
            synchronized (metricStats) {
                topbestStatsByMetric.get(iMetric).add(hyps.get((int)0).metricStats.get(iMetric));
            }
            this.sortByMetricScore(hyps, iMetric, metricCopies.get(iMetric).isBiggerBetter());
            metricStats = oracleStatsByMetric;
            synchronized (metricStats) {
                oracleStatsByMetric.get(iMetric).add(hyps.get((int)0).metricStats.get(iMetric));
            }
            metricStats = woracleStatsByMetric;
            synchronized (metricStats) {
                woracleStatsByMetric.get(iMetric).add(hyps.get((int)(hyps.size() - 1)).metricStats.get(iMetric));
            }
            for (int iRank = 0; iRank < hyps.size(); ++iRank) {
                hyps.get((int)iRank).metricRank[iMetric] = iRank;
            }
        }
        Collections.sort(hyps, new Comparator<NbestEntry>(){

            @Override
            public int compare(NbestEntry a, NbestEntry b) {
                int ra = a.origRank;
                int rb = b.origRank;
                return ra < rb ? -1 : 1;
            }
        });
        int sentId = hyps.get((int)0).sentId;
        for (NbestEntry entry : hyps) {
            out.println(sentId, entry.toString(this.metricNames, submetricNames));
        }
        out.finishUnit(sentId);
        if (metricRankFiles != null) {
            for (int iMetric = 0; iMetric < metricCopies.size(); ++iMetric) {
                this.sortByMetricScore(hyps, iMetric, metricCopies.get(iMetric).isBiggerBetter());
                for (NbestEntry entry : hyps) {
                    metricRankFiles[iMetric].println(sentId, entry.toString(this.metricNames, submetricNames));
                }
                metricRankFiles[iMetric].finishUnit(sentId);
            }
        }
    }

    private void sortByMetricScore(List<NbestEntry> hyps, final int i, final boolean isBiggerBetter) {
        Collections.sort(hyps, new Comparator<NbestEntry>(){

            @Override
            public int compare(NbestEntry a, NbestEntry b) {
                double da = a.metricScores[i];
                double db = b.metricScores[i];
                if (isBiggerBetter) {
                    return da == db ? 0 : (da > db ? -1 : 1);
                }
                return da == db ? 0 : (da < db ? -1 : 1);
            }
        });
    }

    public static class NbestTask {
        public final List<NbestEntry> myHyps;
        public final List<String> sentRefs;

        public NbestTask(List<NbestEntry> myHyps, List<String> sentRefs) {
            this.myHyps = myHyps;
            this.sentRefs = sentRefs;
        }
    }
}

