package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.Instance;
import java.util.ArrayList;

/* loaded from: input_file:edu/cmu/minorthird/classify/OneVsAllLearner.class */
public class OneVsAllLearner implements ClassifierLearner {
    private ClassifierLearnerFactory learnerFactory;
    private ClassifierLearner[] innerLearner;
    private ExampleSchema schema;

    public OneVsAllLearner() {
        this(new ClassifierLearnerFactory("new VotedPerceptron()"));
    }

    public OneVsAllLearner(ClassifierLearnerFactory classifierLearnerFactory) {
        this.innerLearner = null;
        this.learnerFactory = classifierLearnerFactory;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
        this.schema = exampleSchema;
        this.innerLearner = new ClassifierLearner[exampleSchema.getNumberOfClasses()];
        for (int i = 0; i < this.innerLearner.length; i++) {
            this.innerLearner[i] = this.learnerFactory.getLearner();
            this.innerLearner[i].setSchema(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
        }
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void reset() {
        if (this.innerLearner != null) {
            for (int i = 0; i < this.innerLearner.length; i++) {
                this.innerLearner[i].reset();
            }
        }
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void setInstancePool(Instance.Looper looper) {
        ArrayList arrayList = new ArrayList();
        while (looper.hasNext()) {
            arrayList.add(looper.next());
        }
        for (int i = 0; i < this.innerLearner.length; i++) {
            this.innerLearner[i].setInstancePool(new Instance.Looper(arrayList));
        }
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public boolean hasNextQuery() {
        for (int i = 0; i < this.innerLearner.length; i++) {
            if (this.innerLearner[i].hasNextQuery()) {
                return true;
            }
        }
        return false;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public Instance nextQuery() {
        for (int i = 0; i < this.innerLearner.length; i++) {
            if (this.innerLearner[i].hasNextQuery()) {
                return this.innerLearner[i].nextQuery();
            }
        }
        return null;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        int classIndex = this.schema.getClassIndex(example.getLabel().bestClassName());
        int i = 0;
        while (i < this.innerLearner.length) {
            this.innerLearner[i].addExample(new Example(example.asInstance(), classIndex == i ? ClassLabel.positiveLabel(1.0d) : ClassLabel.negativeLabel(-1.0d)));
            i++;
        }
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void completeTraining() {
        for (int i = 0; i < this.innerLearner.length; i++) {
            this.innerLearner[i].completeTraining();
        }
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        Classifier[] classifierArr = new Classifier[this.innerLearner.length];
        for (int i = 0; i < this.innerLearner.length; i++) {
            classifierArr[i] = this.innerLearner[i].getClassifier();
        }
        return new OneVsAllClassifier(this.schema.validClassNames(), classifierArr);
    }
}
