/*
 * Decompiled with CFR 0.152.
 */
package multeval.significance;

import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import java.util.List;
import java.util.Random;
import multeval.metrics.Metric;
import multeval.metrics.SuffStats;
import multeval.parallel.MetricWorkerPool;
import multeval.util.SuffStatUtils;

public class StratifiedApproximateRandomizationTest {
    private static final Random random = new Random();
    private final List<Metric<?>> masterMetrics;
    private final int threads;
    private final List<List<SuffStats<?>>> suffStatsA;
    private final List<List<SuffStats<?>>> suffStatsB;
    private int totalDataPoints;
    private final int numHyps;
    private final int numOptRuns;
    private final boolean debug;

    public StratifiedApproximateRandomizationTest(int threads, List<Metric<?>> metrics, List<List<SuffStats<?>>> suffStatsA, List<List<SuffStats<?>>> suffStatsB, int numHyps, int numOptRuns, boolean debug) {
        Preconditions.checkArgument(metrics.size() > 0, "Must have at least one metric.");
        Preconditions.checkArgument(suffStatsA.size() > 0, "Must have at least one data point.");
        this.threads = threads;
        this.masterMetrics = metrics;
        this.suffStatsA = suffStatsA;
        this.suffStatsB = suffStatsB;
        this.totalDataPoints = suffStatsA.get(0).size();
        this.numHyps = numHyps;
        this.numOptRuns = numOptRuns;
        this.debug = debug;
        Preconditions.checkArgument(suffStatsA.get(0).size() == suffStatsB.get(0).size(), "System A and System B must have the same number of data points.");
        Preconditions.checkArgument(this.totalDataPoints > 0, "Need more than zero data points.");
        Preconditions.checkArgument(this.totalDataPoints == numHyps * numOptRuns, String.format("totalDataPoints (%d) in second list must == numHyps (%d) * numOptRuns (%d)", this.totalDataPoints, numHyps, numOptRuns));
    }

    public double[] getTwoSidedP(int numShuffles) throws InterruptedException {
        final double[] overallDiffs = new double[this.masterMetrics.size()];
        double[] scoresA = new double[this.masterMetrics.size()];
        double[] scoresB = new double[this.masterMetrics.size()];
        for (int iMetric = 0; iMetric < this.masterMetrics.size(); ++iMetric) {
            Metric<?> metric = this.masterMetrics.get(iMetric);
            scoresA[iMetric] = metric.scoreStats(SuffStatUtils.sumStats(this.suffStatsA.get(iMetric)));
            scoresB[iMetric] = metric.scoreStats(SuffStatUtils.sumStats(this.suffStatsB.get(iMetric)));
            overallDiffs[iMetric] = Math.abs(scoresA[iMetric] - scoresB[iMetric]);
        }
        final int[] diffsByChance = new int[this.masterMetrics.size()];
        MetricWorkerPool<Integer, Shuffling> workers = new MetricWorkerPool<Integer, Shuffling>(this.threads, (Supplier)new Supplier<Shuffling>(){

            @Override
            public Shuffling get() {
                return new Shuffling(StratifiedApproximateRandomizationTest.this.numHyps, StratifiedApproximateRandomizationTest.this.numOptRuns);
            }
        }){

            @Override
            public void doWork(Shuffling shuffling, Integer i) {
                shuffling.shuffle();
                for (int iMetric = 0; iMetric < StratifiedApproximateRandomizationTest.this.masterMetrics.size(); ++iMetric) {
                    double scoreY;
                    Metric metric = (Metric)StratifiedApproximateRandomizationTest.this.masterMetrics.get(iMetric);
                    double scoreX = metric.scoreStats(StratifiedApproximateRandomizationTest.sumStats(shuffling, iMetric, StratifiedApproximateRandomizationTest.this.suffStatsA, StratifiedApproximateRandomizationTest.this.suffStatsB, false));
                    double sampleDiff = Math.abs(scoreX - (scoreY = metric.scoreStats(StratifiedApproximateRandomizationTest.sumStats(shuffling, iMetric, StratifiedApproximateRandomizationTest.this.suffStatsA, StratifiedApproximateRandomizationTest.this.suffStatsB, true))));
                    if (sampleDiff > overallDiffs[iMetric]) {
                        int n = iMetric;
                        diffsByChance[n] = diffsByChance[n] + 1;
                    }
                    if (!StratifiedApproximateRandomizationTest.this.debug) continue;
                    System.err.println("DIFF metric " + iMetric + ": " + scoreX + " - " + scoreY + " --> " + sampleDiff + " >? " + overallDiffs[iMetric] + "; diffsByChance: " + diffsByChance[iMetric]);
                }
            }
        };
        workers.start();
        for (int i = 0; i < numShuffles; ++i) {
            workers.addTask(i);
        }
        workers.waitForCompletion();
        double[] p = new double[this.masterMetrics.size()];
        for (int iMetric = 0; iMetric < this.masterMetrics.size(); ++iMetric) {
            p[iMetric] = ((double)diffsByChance[iMetric] + 1.0) / ((double)numShuffles + 1.0);
        }
        return p;
    }

    private static SuffStats<?> sumStats(Shuffling shuffling, int iMetric, List<List<SuffStats<?>>> suffStatsA, List<List<SuffStats<?>>> suffStatsB, boolean invert) {
        SuffStats<?> summedStats = suffStatsA.get(iMetric).get(0).create();
        List<SuffStats<?>> metricStatsA = suffStatsA.get(iMetric);
        List<SuffStats<?>> metricStatsB = suffStatsB.get(iMetric);
        for (int iRow = 0; iRow < metricStatsA.size(); ++iRow) {
            SuffStats<?> row = shuffling.at(iRow, metricStatsA, metricStatsB, invert);
            summedStats.add(row);
        }
        return summedStats;
    }

    static class Shuffling {
        private final boolean[] swap;
        private final int[] optRunPermutation;
        private final int[] optRunPermutationInv;
        private final int optRuns;
        private final int hyps;
        private static final Random rnd = new Random();

        public Shuffling(int hyps, int optRuns) {
            this.swap = new boolean[hyps * optRuns];
            this.optRunPermutation = new int[hyps * optRuns];
            this.optRunPermutationInv = new int[hyps * optRuns];
            this.hyps = hyps;
            this.optRuns = optRuns;
        }

        public <T> T at(int iRow, List<T> a, List<T> b, boolean invert) {
            boolean shouldSwap;
            int idx;
            if (invert) {
                idx = this.optRunPermutationInv[iRow];
                shouldSwap = !this.swap[iRow];
            } else {
                idx = this.optRunPermutation[iRow];
                shouldSwap = this.swap[iRow];
            }
            List<T> list = shouldSwap ? b : a;
            return list.get(idx);
        }

        public void shuffle() {
            int i;
            for (i = 0; i < this.swap.length; ++i) {
                this.swap[i] = random.nextBoolean();
            }
            for (i = 0; i < this.optRunPermutation.length; ++i) {
                this.optRunPermutation[i] = i;
            }
            for (int iHyp = 0; iHyp < this.hyps; ++iHyp) {
                for (int iRun = this.optRuns; iRun > 1; --iRun) {
                    int swapRun1 = iRun - 1;
                    int swapRun2 = rnd.nextInt(iRun);
                    this.swap(this.optRunPermutation, iHyp + this.hyps * swapRun1, iHyp + this.hyps * swapRun2);
                }
            }
            int origIdx = 0;
            while (origIdx < this.optRunPermutation.length) {
                int mappedIdx = this.optRunPermutation[origIdx];
                this.optRunPermutationInv[mappedIdx] = origIdx++;
            }
        }

        private void swap(int[] arr, int i, int j) {
            int tmp = arr[i];
            arr[i] = arr[j];
            arr[j] = tmp;
        }
    }
}

