package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.classify.experiments.Tester;
import java.util.ArrayList;

/* loaded from: input_file:edu/cmu/minorthird/classify/CascadingBinaryLearner.class */
public class CascadingBinaryLearner extends OneVsAllLearner {
    public String[] sortedClassNames;
    private ArrayList data;
    private ArrayList eval;

    public CascadingBinaryLearner() {
        this.data = null;
        this.eval = null;
    }

    public CascadingBinaryLearner(ClassifierLearnerFactory classifierLearnerFactory) {
        super(classifierLearnerFactory);
        this.data = null;
        this.eval = null;
    }

    public CascadingBinaryLearner(String str) {
        super(str);
        this.data = null;
        this.eval = null;
    }

    public CascadingBinaryLearner(BatchClassifierLearner batchClassifierLearner) {
        this.data = null;
        this.eval = null;
        this.learner = batchClassifierLearner;
        this.learnerName = batchClassifierLearner.toString();
        this.learnerFactory = new ClassifierLearnerFactory(this.learnerName);
    }

    @Override // edu.cmu.minorthird.classify.OneVsAllLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
        this.schema = exampleSchema;
        this.innerLearner = new ArrayList();
        this.data = new ArrayList();
        for (int i = 0; i < exampleSchema.getNumberOfClasses(); i++) {
            this.innerLearner.add(((BatchClassifierLearner) this.learner).copy());
            ((ClassifierLearner) this.innerLearner.get(i)).setSchema(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
            this.data.add(new BasicDataset());
        }
    }

    private void createRankings() {
        CrossValSplitter crossValSplitter = new CrossValSplitter(9);
        this.eval = new ArrayList();
        for (int i = 0; i < this.innerLearner.size(); i++) {
            this.eval.add(Tester.evaluate((ClassifierLearner) this.innerLearner.get(i), (Dataset) this.data.get(i), crossValSplitter));
        }
    }

    private void sortLearners() {
        ArrayList arrayList = new ArrayList();
        String[] validClassNames = this.schema.validClassNames();
        ArrayList arrayList2 = new ArrayList();
        this.sortedClassNames = new String[this.schema.getNumberOfClasses()];
        for (int i = 0; i < this.innerLearner.size(); i++) {
            arrayList.add((BatchClassifierLearner) this.innerLearner.get(i));
            arrayList2.add(validClassNames[i]);
        }
        this.innerLearner.clear();
        int i2 = 0;
        while (!arrayList.isEmpty()) {
            double d = -10.0d;
            int i3 = -1;
            for (int i4 = 0; i4 < arrayList.size(); i4++) {
                try {
                    double kappa = ((Evaluation) this.eval.get(i4)).kappa();
                    if (kappa >= d) {
                        d = kappa;
                        i3 = i4;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            this.innerLearner.add((ClassifierLearner) arrayList.remove(i3));
            this.sortedClassNames[i2] = (String) arrayList2.remove(i3);
            i2++;
        }
    }

    @Override // edu.cmu.minorthird.classify.OneVsAllLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        int classIndex = this.schema.getClassIndex(example.getLabel().bestClassName());
        int i = 0;
        while (i < this.innerLearner.size()) {
            Example example2 = new Example(example.asInstance(), classIndex == i ? ClassLabel.positiveLabel(1.0d) : ClassLabel.negativeLabel(-1.0d));
            ((ClassifierLearner) this.innerLearner.get(i)).addExample(example2);
            ((Dataset) this.data.get(i)).add(example2);
            i++;
        }
    }

    @Override // edu.cmu.minorthird.classify.OneVsAllLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void completeTraining() {
        for (int i = 0; i < this.innerLearner.size(); i++) {
            ((ClassifierLearner) this.innerLearner.get(i)).completeTraining();
        }
        createRankings();
        sortLearners();
    }

    @Override // edu.cmu.minorthird.classify.OneVsAllLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        Classifier[] classifierArr = new Classifier[this.innerLearner.size()];
        for (int i = 0; i < this.innerLearner.size(); i++) {
            classifierArr[i] = ((ClassifierLearner) this.innerLearner.get(i)).getClassifier();
        }
        return new OneVsAllClassifier(this.sortedClassNames, classifierArr);
    }
}
