package edu.cmu.minorthird.classify.sequential;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.util.ProgressCounter;
import java.util.Iterator;
import java.util.Vector;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/MarginPerceptronLearner.class */
public class MarginPerceptronLearner extends CollinsPerceptronLearner {
    float beta;
    int topK;

    public MarginPerceptronLearner() {
        this(3, 5, 0.05f);
    }

    public MarginPerceptronLearner(int i) {
        this(3, i, 0.05f);
    }

    public MarginPerceptronLearner(int i, int i2, float f) {
        this(i, i2, f, 10);
    }

    public MarginPerceptronLearner(int i, int i2, float f, int i3) {
        super(i, i2);
        this.beta = 0.05f;
        this.topK = 10;
        this.beta = f;
        this.topK = i3;
    }

    @Override // edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner, edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner
    public SequenceClassifier batchTrain(SequenceDataset sequenceDataset) {
        ExampleSchema schema = sequenceDataset.getSchema();
        CollinsPerceptronLearner.MultiClassVPClassifier multiClassVPClassifier = new CollinsPerceptronLearner.MultiClassVPClassifier(schema);
        ProgressCounter progressCounter = new ProgressCounter("training sequence perceptron", "sequence", getNumberOfEpochs() * sequenceDataset.numberOfSequences());
        Vector vector = new Vector();
        for (int i = 0; i < getNumberOfEpochs(); i++) {
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            Iterator sequenceIterator = sequenceDataset.sequenceIterator();
            while (sequenceIterator.hasNext()) {
                Example[] exampleArr = (Example[]) sequenceIterator.next();
                BeamSearcher beamSearcher = new BeamSearcher(multiClassVPClassifier, getHistorySize(), schema);
                beamSearcher.doSearch(exampleArr);
                float score = getScore(exampleArr, multiClassVPClassifier);
                if (DEBUG) {
                    log.debug(new StringBuffer().append("corrScore: ").append(score).toString());
                }
                vector.clear();
                int min = Math.min(beamSearcher.getNumberOfSolutionsFound(), this.topK);
                for (int i5 = 0; i5 < min; i5++) {
                    ClassLabel[] viterbi = beamSearcher.viterbi(i5);
                    float score2 = beamSearcher.score(i5);
                    if (DEBUG) {
                        log.debug(new StringBuffer().append("viterbi: ").append(i5).append(" score ").append(score2).toString());
                    }
                    if (DEBUG) {
                        log.debug(sequenceToString(viterbi));
                    }
                    if (score2 < score * (1.0f - this.beta)) {
                        break;
                    }
                    if (!isCorrect(viterbi, exampleArr)) {
                        vector.add(viterbi);
                    }
                }
                if (DEBUG) {
                    log.debug(new StringBuffer().append("added: ").append(vector.size()).toString());
                }
                boolean z = false;
                if (vector.size() > 0) {
                    for (int i6 = 0; i6 < exampleArr.length; i6++) {
                        boolean z2 = false;
                        for (int i7 = 0; i7 < vector.size(); i7++) {
                            ClassLabel[] classLabelArr = (ClassLabel[]) vector.elementAt(i7);
                            z2 = !classLabelArr[i6].isCorrect(exampleArr[i6].getLabel());
                            for (int i8 = 1; i6 - i8 >= 0 && !z2 && i8 <= getHistorySize(); i8++) {
                                if (!classLabelArr[i6 - i8].isCorrect(exampleArr[i6 - i8].getLabel())) {
                                    z2 = true;
                                }
                            }
                            if (z2) {
                                break;
                            }
                        }
                        if (z2) {
                            i3++;
                            z = true;
                            InstanceFromSequence.fillHistory(this.history, exampleArr, i6);
                            multiClassVPClassifier.update(exampleArr[i6].getLabel().bestClassName(), new InstanceFromSequence(exampleArr[i6], this.history), 1.0d);
                            for (int i9 = 0; i9 < vector.size(); i9++) {
                                ClassLabel[] classLabelArr2 = (ClassLabel[]) vector.elementAt(i9);
                                InstanceFromSequence.fillHistory(this.history, classLabelArr2, i6);
                                multiClassVPClassifier.update(classLabelArr2[i6].bestClassName(), new InstanceFromSequence(exampleArr[i6], this.history), (-1.0d) / vector.size());
                            }
                        }
                    }
                }
                multiClassVPClassifier.completeUpdate();
                if (z) {
                    i2++;
                }
                i4 += exampleArr.length;
                progressCounter.progress();
            }
            System.out.println(new StringBuffer().append("Epoch ").append(i).append(": sequenceErr=").append(i2).append(" transitionErrors=").append(i3).append("/").append(i4).toString());
            if (i3 == 0) {
                break;
            }
        }
        progressCounter.finished();
        multiClassVPClassifier.setVoteMode(true);
        return new CMM(multiClassVPClassifier, getHistorySize(), schema);
    }

    float getScore(Example[] exampleArr, Classifier classifier) {
        float f = 0.0f;
        for (int i = 0; i < exampleArr.length; i++) {
            InstanceFromSequence.fillHistory(this.history, exampleArr, i);
            f = (float) (f + classifier.classification(new InstanceFromSequence(exampleArr[i], this.history)).getWeight(exampleArr[i].getLabel().bestClassName()));
        }
        return f;
    }

    boolean isCorrect(ClassLabel[] classLabelArr, Example[] exampleArr) {
        for (int i = 0; i < exampleArr.length; i++) {
            if (!classLabelArr[i].isCorrect(exampleArr[i].getLabel())) {
                return false;
            }
        }
        return true;
    }

    String sequenceToString(ClassLabel[] classLabelArr) {
        String str = "";
        for (ClassLabel classLabel : classLabelArr) {
            str = new StringBuffer().append(str).append(classLabel.bestClassName()).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString();
        }
        return str;
    }
}
