package edu.cmu.minorthird.classify.ranking;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jfree.chart.axis.ValueAxis;

/* loaded from: input_file:edu/cmu/minorthird/classify/ranking/RankingBoosted.class */
public class RankingBoosted extends BatchRankingLearner {
    private int numEpochs;
    private int exampleSize;
    private Map A_pos;
    private Map A_neg;
    private Set features;
    private double SMOOTH_PARAM;
    private double[][] margins;
    private Feature score;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/ranking/RankingBoosted$Index.class */
    public class Index {
        int i;
        int j;
        private final RankingBoosted this$0;

        public Index(RankingBoosted rankingBoosted, int i, int i2) {
            this.this$0 = rankingBoosted;
            this.i = i;
            this.j = i2;
        }
    }

    public RankingBoosted() {
        this(ValueAxis.MAXIMUM_TICK_COUNT, 20);
    }

    public RankingBoosted(int i, int i2) {
        this.exampleSize = 20;
        this.A_pos = new HashMap();
        this.A_neg = new HashMap();
        this.features = new HashSet();
        this.SMOOTH_PARAM = 0.005d;
        this.score = new Feature("walkerScore");
        this.numEpochs = i;
        this.exampleSize = i2;
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        Map splitIntoRankings = splitIntoRankings(dataset);
        Example[][] exampleArr = new Example[splitIntoRankings.size()][this.exampleSize];
        int i = 0;
        Iterator it = splitIntoRankings.keySet().iterator();
        while (it.hasNext()) {
            List orderExamplesList = orderExamplesList((List) splitIntoRankings.get((String) it.next()));
            for (int i2 = 0; i2 < this.exampleSize; i2++) {
                exampleArr[i][i2] = (Example) orderExamplesList.get(i2);
            }
            i++;
        }
        Hyperplane populate_A = populate_A(exampleArr, new Hyperplane());
        populate_A.increment(this.score, best_w0(exampleArr));
        this.margins = initializeMargins(exampleArr, populate_A);
        ProgressCounter progressCounter = new ProgressCounter("boosted perceptron training", "epoch", this.numEpochs);
        for (int i3 = 0; i3 < this.numEpochs; i3++) {
            populate_A = batchTrain(populate_A);
            progressCounter.progress();
        }
        progressCounter.finished();
        new ViewerFrame("hyperplane", populate_A.toGUI());
        return populate_A;
    }

    private Hyperplane populate_A(Example[][] exampleArr, Hyperplane hyperplane) {
        for (int i = 0; i < exampleArr.length; i++) {
            Example example = exampleArr[i][0];
            HashSet hashSet = new HashSet();
            Feature.Looper binaryFeatureIterator = example.binaryFeatureIterator();
            while (binaryFeatureIterator.hasNext()) {
                hashSet.add(binaryFeatureIterator.next());
            }
            for (int i2 = 1; i2 < this.exampleSize; i2++) {
                Example example2 = exampleArr[i][i2];
                HashSet hashSet2 = new HashSet();
                Feature.Looper binaryFeatureIterator2 = example2.binaryFeatureIterator();
                while (binaryFeatureIterator2.hasNext()) {
                    Feature feature = (Feature) binaryFeatureIterator2.next();
                    if (!hashSet.contains(feature)) {
                        update_A(this.A_neg, feature, i, i2);
                    }
                    hashSet2.add(feature);
                    this.features.add(feature);
                }
                Feature.Looper binaryFeatureIterator3 = example.binaryFeatureIterator();
                while (binaryFeatureIterator3.hasNext()) {
                    Feature feature2 = (Feature) binaryFeatureIterator3.next();
                    if (!hashSet2.contains(feature2)) {
                        update_A(this.A_pos, feature2, i, i2);
                    }
                    this.features.add(feature2);
                }
            }
        }
        hyperplane.multiply(0.0d);
        return hyperplane;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v10, types: [java.util.Set] */
    private Map update_A(Map map, Feature feature, int i, int i2) {
        HashSet hashSet = new HashSet();
        if (map.containsKey(feature)) {
            hashSet = (Set) map.get(feature);
        }
        hashSet.add(new Index(this, i, i2));
        map.put(feature, hashSet);
        return map;
    }

    private double best_w0(Example[][] exampleArr) {
        double d = 0.001d;
        double d2 = 1.0E8d;
        double d3 = 0.001d;
        while (true) {
            double d4 = d3;
            if (d4 >= 10.0d) {
                return d;
            }
            double initialExpLoss = initialExpLoss(d4, exampleArr);
            if (initialExpLoss < d2) {
                d = d4;
                d2 = initialExpLoss;
            }
            d3 = d4 + 0.001d;
        }
    }

    public double initialExpLoss(double d, Example[][] exampleArr) {
        double d2 = 0.0d;
        for (int i = 0; i < exampleArr.length; i++) {
            for (int i2 = 0; i2 < this.exampleSize; i2++) {
                if (exampleArr[i][i2].getLabel().toString().endsWith("NEG 1.0]")) {
                    d2 += Math.exp((-d) * (Math.log(exampleArr[i][0].getWeight(this.score)) - Math.log(exampleArr[i][i2].getWeight(this.score))));
                }
            }
        }
        return d2;
    }

    private double expLoss(double[][] dArr) {
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            for (int i = 0; i < this.exampleSize; i++) {
                d += Math.exp((-1.0d) * dArr2[i]);
            }
        }
        return d;
    }

    private double[][] initializeMargins(Example[][] exampleArr, Hyperplane hyperplane) {
        double[][] dArr = new double[exampleArr.length][this.exampleSize];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < this.exampleSize; i2++) {
                dArr[i][i2] = hyperplane.featureScore(this.score) * (Math.log(exampleArr[i][0].getWeight(this.score)) - Math.log(exampleArr[i][i2].getWeight(this.score)));
                System.out.println(new StringBuffer().append("margins: ").append(i).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(i2).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(dArr[i][i2]).toString());
            }
        }
        return dArr;
    }

    private Hyperplane batchTrain(Hyperplane hyperplane) {
        Feature feature = null;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (Feature feature2 : this.features) {
            double d4 = 0.0d;
            double d5 = 0.0d;
            if (this.A_pos.containsKey(feature2)) {
                for (Index index : (Set) this.A_pos.get(feature2)) {
                    d4 += Math.exp((-1.0d) * this.margins[index.i][index.j]);
                }
            }
            if (this.A_neg.containsKey(feature2)) {
                for (Index index2 : (Set) this.A_neg.get(feature2)) {
                    d5 += Math.exp((-1.0d) * this.margins[index2.i][index2.j]);
                }
            }
            double abs = Math.abs(Math.sqrt(d4) - Math.sqrt(d5));
            if (abs > d) {
                d = abs;
                feature = feature2;
                d2 = d4;
                d3 = d5;
            }
        }
        if (feature != null) {
            double expLoss = expLoss(this.margins);
            double log = 0.5d * Math.log((d2 + (this.SMOOTH_PARAM * expLoss)) / (d3 + (this.SMOOTH_PARAM * expLoss)));
            updateMargins(feature, log);
            hyperplane.increment(feature, log);
        }
        return hyperplane;
    }

    private void updateMargins(Feature feature, double d) {
        Set<Index> set = (Set) this.A_pos.get(feature);
        Set<Index> set2 = (Set) this.A_neg.get(feature);
        if (set != null) {
            for (Index index : set) {
                double[] dArr = this.margins[index.i];
                int i = index.j;
                dArr[i] = dArr[i] + d;
            }
        }
        if (set2 != null) {
            for (Index index2 : set2) {
                double[] dArr2 = this.margins[index2.i];
                int i2 = index2.j;
                dArr2[i2] = dArr2[i2] - d;
            }
        }
    }

    private List orderExamplesList(List list) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (int i = 0; i < list.size(); i++) {
            Example example = (Example) list.get(i);
            if (example.getLabel().toString().endsWith("POS 1.0]")) {
                hashSet.add(example);
            } else {
                hashSet2.add(example);
            }
        }
        LinkedList linkedList = new LinkedList();
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            linkedList.add(it.next());
        }
        Iterator it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            linkedList.add(it2.next());
        }
        return linkedList;
    }
}
