package edu.cmu.minorthird.classify.sequential;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/GenericCollinsLearnerV1.class */
public class GenericCollinsLearnerV1 implements BatchSequenceClassifierLearner, SequenceConstants {
    private OnlineBinaryClassifierLearner innerLearnerPrototype;
    private OnlineBinaryClassifierLearner[] innerLearner;
    private int historySize;
    private int numberOfEpochs;
    private String[] history;

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/GenericCollinsLearnerV1$MultiClassClassifier.class */
    public static class MultiClassClassifier implements Classifier, Visible, Serializable {
        private static final long serialVersionUID = 1;
        private ExampleSchema schema;
        private BinaryClassifier[] innerClassifier;
        private int numClasses;

        public MultiClassClassifier(ExampleSchema exampleSchema, BinaryClassifier[] binaryClassifierArr) {
            this.schema = exampleSchema;
            this.numClasses = exampleSchema.getNumberOfClasses();
            this.innerClassifier = binaryClassifierArr;
        }

        public MultiClassClassifier(ExampleSchema exampleSchema, OnlineBinaryClassifierLearner[] onlineBinaryClassifierLearnerArr) {
            this.schema = exampleSchema;
            this.numClasses = exampleSchema.getNumberOfClasses();
            this.innerClassifier = new BinaryClassifier[this.numClasses];
            for (int i = 0; i < this.numClasses; i++) {
                this.innerClassifier[i] = onlineBinaryClassifierLearnerArr[i].getBinaryClassifier();
            }
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            ClassLabel classLabel = new ClassLabel();
            for (int i = 0; i < this.numClasses; i++) {
                classLabel.add(this.schema.getClassName(i), this.innerClassifier[i].score(instance));
            }
            return classLabel;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            StringBuffer stringBuffer = new StringBuffer("");
            for (int i = 0; i < this.numClasses; i++) {
                stringBuffer.append(new StringBuffer().append("Classifier for class ").append(this.schema.getClassName(i)).append(":\n").toString());
                stringBuffer.append(this.innerClassifier[i].explain(instance));
                stringBuffer.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
            return stringBuffer.toString();
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            ComponentViewer componentViewer = new ComponentViewer(this) { // from class: edu.cmu.minorthird.classify.sequential.GenericCollinsLearnerV1.1
                private final MultiClassClassifier this$0;

                {
                    this.this$0 = this;
                }

                @Override // edu.cmu.minorthird.util.gui.ComponentViewer
                public JComponent componentFor(Object obj) {
                    MultiClassClassifier multiClassClassifier = (MultiClassClassifier) obj;
                    JPanel jPanel = new JPanel();
                    for (int i = 0; i < this.this$0.numClasses; i++) {
                        JPanel jPanel2 = new JPanel();
                        jPanel2.setBorder(new TitledBorder(new StringBuffer().append("Class ").append(multiClassClassifier.schema.getClassName(i)).toString()));
                        SmartVanillaViewer smartVanillaViewer = new SmartVanillaViewer(multiClassClassifier.innerClassifier[i]);
                        smartVanillaViewer.setSuperView(this);
                        jPanel2.add(smartVanillaViewer);
                        jPanel.add(jPanel2);
                    }
                    return new JScrollPane(jPanel);
                }
            };
            componentViewer.setContent(this);
            return componentViewer;
        }
    }

    public GenericCollinsLearnerV1() {
        this(3, 5);
    }

    public GenericCollinsLearnerV1(OnlineBinaryClassifierLearner onlineBinaryClassifierLearner, int i) {
        this(onlineBinaryClassifierLearner, i, 5);
    }

    public GenericCollinsLearnerV1(int i, int i2) {
        this(new VotedPerceptron(), i, i2);
    }

    public GenericCollinsLearnerV1(OnlineBinaryClassifierLearner onlineBinaryClassifierLearner, int i, int i2) {
        this.historySize = i;
        this.innerLearnerPrototype = onlineBinaryClassifierLearner;
        this.numberOfEpochs = i2;
        this.history = new String[i];
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    public OnlineBinaryClassifierLearner getInnerLearner() {
        return this.innerLearnerPrototype;
    }

    public void setInnerLearner(OnlineBinaryClassifierLearner onlineBinaryClassifierLearner) {
        this.innerLearnerPrototype = onlineBinaryClassifierLearner;
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public int getHistorySize() {
        return this.historySize;
    }

    public void setHistorySize(int i) {
        this.historySize = i;
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int i) {
        this.numberOfEpochs = i;
    }

    @Override // edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner
    public SequenceClassifier batchTrain(SequenceDataset sequenceDataset) {
        ExampleSchema schema = sequenceDataset.getSchema();
        try {
            this.innerLearner = new OnlineBinaryClassifierLearner[schema.getNumberOfClasses()];
            for (int i = 0; i < schema.getNumberOfClasses(); i++) {
                this.innerLearner[i] = (OnlineBinaryClassifierLearner) this.innerLearnerPrototype.copy();
                this.innerLearner[i].reset();
            }
            ProgressCounter progressCounter = new ProgressCounter(new StringBuffer().append("training sequential ").append(this.innerLearnerPrototype.toString()).toString(), "sequence", this.numberOfEpochs * sequenceDataset.numberOfSequences());
            for (int i2 = 0; i2 < this.numberOfEpochs; i2++) {
                int i3 = 0;
                int i4 = 0;
                int i5 = 0;
                Iterator sequenceIterator = sequenceDataset.sequenceIterator();
                while (sequenceIterator.hasNext()) {
                    Example[] exampleArr = (Example[]) sequenceIterator.next();
                    ClassLabel[] bestLabelSequence = new BeamSearcher(new MultiClassClassifier(schema, this.innerLearner), this.historySize, schema).bestLabelSequence(exampleArr);
                    boolean z = false;
                    for (int i6 = 0; i6 < exampleArr.length; i6++) {
                        boolean z2 = !bestLabelSequence[i6].isCorrect(exampleArr[i6].getLabel());
                        for (int i7 = 1; i6 - i7 >= 0 && !z2 && i7 <= this.historySize; i7++) {
                            if (!bestLabelSequence[i6 - i7].isCorrect(exampleArr[i6 - i7].getLabel())) {
                                z2 = true;
                            }
                        }
                        if (z2) {
                            i4++;
                            z = true;
                            InstanceFromSequence.fillHistory(this.history, exampleArr, i6);
                            this.innerLearner[schema.getClassIndex(exampleArr[i6].getLabel().bestClassName())].addExample(new Example(new InstanceFromSequence(exampleArr[i6], this.history), ClassLabel.binaryLabel(1.0d)));
                            InstanceFromSequence.fillHistory(this.history, bestLabelSequence, i6);
                            this.innerLearner[schema.getClassIndex(bestLabelSequence[i6].bestClassName())].addExample(new Example(new InstanceFromSequence(exampleArr[i6], this.history), ClassLabel.binaryLabel(-1.0d)));
                        }
                    }
                    if (z) {
                        i3++;
                    }
                    i5 += exampleArr.length;
                    progressCounter.progress();
                }
                System.out.println(new StringBuffer().append("Epoch ").append(i2).append(": sequenceErr=").append(i3).append(" transitionErrors=").append(i4).append("/").append(i5).toString());
                if (i4 == 0) {
                    break;
                }
            }
            progressCounter.finished();
            return new CMM(new MultiClassClassifier(schema, this.innerLearner), this.historySize, schema);
        } catch (CloneNotSupportedException e) {
            throw new IllegalArgumentException("innerLearner must be cloneable");
        }
    }
}
