package edu.cmu.minorthird.classify.sequential;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CollinsPerceptronLearner.class */
public class CollinsPerceptronLearner implements BatchSequenceClassifierLearner, SequenceConstants {
    protected static Logger log;
    protected static final boolean DEBUG;
    protected int historySize;
    protected int numberOfEpochs;
    protected String[] history;
    static Class class$edu$cmu$minorthird$classify$sequential$CollinsPerceptronLearner;

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CollinsPerceptronLearner$MultiClassVPClassifier.class */
    public static class MultiClassVPClassifier implements Classifier, Visible, Serializable {
        private static final long serialVersionUID = 1;
        private ExampleSchema schema;
        private Hyperplane[] s_t;
        private Hyperplane[] w_t;
        private int numClasses;
        private boolean voteMode = false;

        public MultiClassVPClassifier(ExampleSchema exampleSchema) {
            this.schema = exampleSchema;
            this.numClasses = exampleSchema.getNumberOfClasses();
            reset();
        }

        public void setVoteMode(boolean z) {
            this.voteMode = z;
        }

        public void update(String str, Instance instance, double d) {
            this.w_t[this.schema.getClassIndex(str)].increment(instance, d);
        }

        public void completeUpdate() {
            for (int i = 0; i < this.numClasses; i++) {
                this.s_t[i].increment(this.w_t[i], 1.0d);
            }
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            Hyperplane[] hyperplaneArr = this.voteMode ? this.s_t : this.w_t;
            ClassLabel classLabel = new ClassLabel();
            for (int i = 0; i < this.numClasses; i++) {
                classLabel.add(this.schema.getClassName(i), hyperplaneArr[i].score(instance));
            }
            return classLabel;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            Hyperplane[] hyperplaneArr = this.voteMode ? this.s_t : this.w_t;
            StringBuffer stringBuffer = new StringBuffer("");
            for (int i = 0; i < this.numClasses; i++) {
                stringBuffer.append(new StringBuffer().append("Hyperplane for class ").append(this.schema.getClassName(i)).append(":\n").toString());
                stringBuffer.append(hyperplaneArr[i].explain(instance));
                stringBuffer.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
            return stringBuffer.toString();
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            ComponentViewer componentViewer = new ComponentViewer(this) { // from class: edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner.1
                private final MultiClassVPClassifier this$0;

                {
                    this.this$0 = this;
                }

                @Override // edu.cmu.minorthird.util.gui.ComponentViewer
                public JComponent componentFor(Object obj) {
                    MultiClassVPClassifier multiClassVPClassifier = (MultiClassVPClassifier) obj;
                    JPanel jPanel = new JPanel();
                    for (int i = 0; i < this.this$0.numClasses; i++) {
                        JPanel jPanel2 = new JPanel();
                        jPanel2.setBorder(new TitledBorder(new StringBuffer().append("Class ").append(multiClassVPClassifier.schema.getClassName(i)).toString()));
                        Viewer gui = this.this$0.voteMode ? this.this$0.s_t[i].toGUI() : this.this$0.w_t[i].toGUI();
                        gui.setSuperView(this);
                        jPanel2.add(gui);
                        jPanel.add(jPanel2);
                    }
                    return new JScrollPane(jPanel);
                }
            };
            componentViewer.setContent(this);
            return componentViewer;
        }

        public void reset() {
            this.s_t = new Hyperplane[this.numClasses];
            this.w_t = new Hyperplane[this.numClasses];
            for (int i = 0; i < this.numClasses; i++) {
                this.s_t[i] = new Hyperplane();
                this.w_t[i] = new Hyperplane();
            }
        }

        public String toString() {
            return new StringBuffer().append("[MultiClassVPClassifier:").append(StringUtil.toString(this.w_t, AbstractFormatter.DEFAULT_ROW_SEPARATOR, "\n]", "\n - ")).toString();
        }
    }

    public CollinsPerceptronLearner() {
        this(3, 5);
    }

    public CollinsPerceptronLearner(int i) {
        this(3, i);
    }

    public CollinsPerceptronLearner(int i, int i2) {
        this.historySize = i;
        this.numberOfEpochs = i2;
        this.history = new String[i];
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int i) {
        this.numberOfEpochs = i;
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public int getHistorySize() {
        return this.historySize;
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    @Override // edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner
    public SequenceClassifier batchTrain(SequenceDataset sequenceDataset) {
        ExampleSchema schema = sequenceDataset.getSchema();
        MultiClassVPClassifier multiClassVPClassifier = new MultiClassVPClassifier(schema);
        ProgressCounter progressCounter = new ProgressCounter("training sequence perceptron", "sequence", this.numberOfEpochs * sequenceDataset.numberOfSequences());
        for (int i = 0; i < this.numberOfEpochs; i++) {
            sequenceDataset.shuffle();
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            Iterator sequenceIterator = sequenceDataset.sequenceIterator();
            while (sequenceIterator.hasNext()) {
                Example[] exampleArr = (Example[]) sequenceIterator.next();
                ClassLabel[] bestLabelSequence = new BeamSearcher(multiClassVPClassifier, this.historySize, schema).bestLabelSequence(exampleArr);
                if (DEBUG) {
                    log.debug(new StringBuffer().append("classifier: ").append(multiClassVPClassifier).toString());
                }
                if (DEBUG) {
                    log.debug(new StringBuffer().append("viterbi:\n").append(StringUtil.toString(bestLabelSequence)).toString());
                }
                boolean z = false;
                for (int i5 = 0; i5 < exampleArr.length; i5++) {
                    boolean z2 = !bestLabelSequence[i5].isCorrect(exampleArr[i5].getLabel());
                    for (int i6 = 1; i5 - i6 >= 0 && !z2 && i6 <= this.historySize; i6++) {
                        if (!bestLabelSequence[i5 - i6].isCorrect(exampleArr[i5 - i6].getLabel())) {
                            z2 = true;
                        }
                    }
                    if (z2) {
                        i3++;
                        z = true;
                        InstanceFromSequence.fillHistory(this.history, exampleArr, i5);
                        InstanceFromSequence instanceFromSequence = new InstanceFromSequence(exampleArr[i5], this.history);
                        multiClassVPClassifier.update(exampleArr[i5].getLabel().bestClassName(), instanceFromSequence, 1.0d);
                        if (DEBUG) {
                            log.debug(new StringBuffer().append("+ update ").append(exampleArr[i5].getLabel().bestClassName()).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(instanceFromSequence.getSource()).toString());
                        }
                        InstanceFromSequence.fillHistory(this.history, bestLabelSequence, i5);
                        InstanceFromSequence instanceFromSequence2 = new InstanceFromSequence(exampleArr[i5], this.history);
                        multiClassVPClassifier.update(bestLabelSequence[i5].bestClassName(), instanceFromSequence2, -1.0d);
                        if (DEBUG) {
                            log.debug(new StringBuffer().append("- update ").append(bestLabelSequence[i5].bestClassName()).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(instanceFromSequence2.getSource()).toString());
                        }
                    }
                }
                multiClassVPClassifier.completeUpdate();
                if (z) {
                    i2++;
                }
                i4 += exampleArr.length;
                progressCounter.progress();
            }
            System.out.println(new StringBuffer().append("Epoch ").append(i).append(": sequenceErr=").append(i2).append(" transitionErrors=").append(i3).append("/").append(i4).toString());
            if (i3 == 0) {
                break;
            }
        }
        progressCounter.finished();
        multiClassVPClassifier.setVoteMode(true);
        return new CMM(multiClassVPClassifier, this.historySize, schema);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$sequential$CollinsPerceptronLearner == null) {
            cls = class$("edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner");
            class$edu$cmu$minorthird$classify$sequential$CollinsPerceptronLearner = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$sequential$CollinsPerceptronLearner;
        }
        log = Logger.getLogger(cls);
        DEBUG = log.isDebugEnabled();
    }
}
