package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.sequential.CollinsPerceptronLearner;
import edu.cmu.minorthird.classify.sequential.SegmentDataset;
import edu.cmu.minorthird.classify.sequential.Segmentation;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/SegmentCollinsPerceptronLearner.class */
public class SegmentCollinsPerceptronLearner implements BatchSegmenterLearner, SequenceConstants {
    private static Logger log;
    private static final boolean DEBUG;
    private int numberOfEpochs;
    private boolean updatedViterbi;
    static Class class$edu$cmu$minorthird$classify$sequential$SegmentCollinsPerceptronLearner;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/SegmentCollinsPerceptronLearner$BackPointer.class */
    public static class BackPointer {
        public int lastT;
        public int t;
        public int lastY;
        public boolean onBestPath = false;

        public BackPointer(int i, int i2, int i3) {
            this.lastT = i;
            this.t = i2;
            this.lastY = i3;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/SegmentCollinsPerceptronLearner$ViterbiSearcher.class */
    public static class ViterbiSearcher {
        private Classifier classifier;
        private ExampleSchema schema;
        private int maxSegmentSize;

        public ViterbiSearcher(Classifier classifier, ExampleSchema exampleSchema, int i) {
            this.classifier = classifier;
            this.schema = exampleSchema;
            this.maxSegmentSize = i;
        }

        public Segmentation bestSegments(CandidateSegmentGroup candidateSegmentGroup) {
            String[] strArr = new String[1];
            int sequenceLength = candidateSegmentGroup.getSequenceLength();
            int numberOfClasses = this.schema.getNumberOfClasses();
            int classIndex = this.schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME);
            double[][] dArr = new double[sequenceLength + 1][numberOfClasses];
            BackPointer[][] backPointerArr = new BackPointer[sequenceLength + 1][numberOfClasses];
            for (int i = 0; i < sequenceLength + 1; i++) {
                for (int i2 = 0; i2 < numberOfClasses; i2++) {
                    dArr[i][i2] = -99999.0d;
                    backPointerArr[i][i2] = null;
                }
            }
            for (int i3 = 0; i3 < numberOfClasses; i3++) {
                dArr[0][i3] = 0.0d;
            }
            for (int i4 = 0; i4 < sequenceLength + 1; i4++) {
                int i5 = 0;
                while (i5 < numberOfClasses) {
                    for (int i6 = 0; i6 < numberOfClasses; i6++) {
                        for (int max = Math.max(0, i4 - (i5 == classIndex ? 1 : this.maxSegmentSize)); max < i4; max++) {
                            Instance subsequenceInstance = candidateSegmentGroup.getSubsequenceInstance(max, i4);
                            if (subsequenceInstance != null) {
                                strArr[0] = this.schema.getClassName(i6);
                                double weight = this.classifier.classification(new InstanceFromSequence(subsequenceInstance, strArr)).getWeight(this.schema.getClassName(i5));
                                if (weight + dArr[max][i6] > dArr[i4][i5]) {
                                    dArr[i4][i5] = weight + dArr[max][i6];
                                    backPointerArr[i4][i5] = new BackPointer(max, i4, i6);
                                }
                            }
                        }
                    }
                    i5++;
                }
            }
            int i7 = -1;
            double d = -1.7976931348623157E308d;
            for (int i8 = 0; i8 < numberOfClasses; i8++) {
                if (dArr[sequenceLength][i8] > d) {
                    d = dArr[sequenceLength][i8];
                    i7 = i8;
                }
            }
            Segmentation segmentation = new Segmentation(this.schema);
            int i9 = i7;
            BackPointer backPointer = backPointerArr[sequenceLength][i9];
            while (true) {
                BackPointer backPointer2 = backPointer;
                if (backPointer2 == null) {
                    break;
                }
                backPointer2.onBestPath = true;
                segmentation.add(new Segmentation.Segment(backPointer2.lastT, backPointer2.t, i9));
                i9 = backPointer2.lastY;
                backPointer = backPointerArr[backPointer2.lastT][backPointer2.lastY];
            }
            if (SegmentCollinsPerceptronLearner.DEBUG) {
                SegmentCollinsPerceptronLearner.dumpStuff(candidateSegmentGroup, dArr, backPointerArr);
            }
            return segmentation;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/SegmentCollinsPerceptronLearner$ViterbiSegmenter.class */
    public static class ViterbiSegmenter implements Segmenter, Visible, Serializable {
        private static final long serialVersionUID = 1;
        private final int CURRENT_VERSION_NUMBER = 1;
        private Classifier c;
        private ExampleSchema schema;
        private int maxSegSize;

        public ViterbiSegmenter(Classifier classifier, ExampleSchema exampleSchema, int i) {
            this.c = classifier;
            this.schema = exampleSchema;
            this.maxSegSize = i;
        }

        @Override // edu.cmu.minorthird.classify.sequential.Segmenter
        public Segmentation segmentation(CandidateSegmentGroup candidateSegmentGroup) {
            return new ViterbiSearcher(this.c, this.schema, this.maxSegSize).bestSegments(candidateSegmentGroup);
        }

        @Override // edu.cmu.minorthird.classify.sequential.Segmenter
        public String explain(CandidateSegmentGroup candidateSegmentGroup) {
            return "not implemented yet";
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            ComponentViewer componentViewer = new ComponentViewer(this) { // from class: edu.cmu.minorthird.classify.sequential.SegmentCollinsPerceptronLearner.1
                private final ViterbiSegmenter this$0;

                {
                    this.this$0 = this;
                }

                @Override // edu.cmu.minorthird.util.gui.ComponentViewer
                public JComponent componentFor(Object obj) {
                    ViterbiSegmenter viterbiSegmenter = (ViterbiSegmenter) obj;
                    JPanel jPanel = new JPanel();
                    jPanel.setLayout(new BorderLayout());
                    jPanel.add(new JLabel(new StringBuffer().append("ViterbiSegmenter: maxSegSize=").append(viterbiSegmenter.maxSegSize).toString()), "North");
                    SmartVanillaViewer smartVanillaViewer = new SmartVanillaViewer(viterbiSegmenter.c);
                    smartVanillaViewer.setSuperView(this);
                    jPanel.add(smartVanillaViewer, "South");
                    jPanel.setBorder(new TitledBorder("ViterbiSegmenter"));
                    return new JScrollPane(jPanel);
                }
            };
            componentViewer.setContent(this);
            return componentViewer;
        }
    }

    public SegmentCollinsPerceptronLearner(int i) {
        this.updatedViterbi = false;
        this.numberOfEpochs = i;
    }

    public SegmentCollinsPerceptronLearner(int i, boolean z) {
        this(i);
        this.updatedViterbi = z;
    }

    public SegmentCollinsPerceptronLearner() {
        this.updatedViterbi = false;
        this.numberOfEpochs = 5;
    }

    @Override // edu.cmu.minorthird.classify.sequential.BatchSegmenterLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int i) {
        this.numberOfEpochs = i;
    }

    public int getHistorySize() {
        return 1;
    }

    @Override // edu.cmu.minorthird.classify.sequential.BatchSegmenterLearner
    public Segmenter batchTrain(SegmentDataset segmentDataset) {
        int maxWindowSize = segmentDataset.getMaxWindowSize();
        ExampleSchema schema = segmentDataset.getSchema();
        if (DEBUG) {
            log.debug(new StringBuffer().append("schema: ").append(schema).toString());
        }
        CollinsPerceptronLearner.MultiClassVPClassifier multiClassVPClassifier = new CollinsPerceptronLearner.MultiClassVPClassifier(schema);
        ProgressCounter progressCounter = new ProgressCounter("training semi-markov voted-perceptron", "sequence", this.numberOfEpochs * segmentDataset.getNumberOfSegmentGroups());
        if (this.updatedViterbi) {
            multiClassVPClassifier.setVoteMode(true);
        }
        for (int i = 0; i < this.numberOfEpochs; i++) {
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            SegmentDataset.Looper candidateSegmentGroupIterator = segmentDataset.candidateSegmentGroupIterator();
            while (candidateSegmentGroupIterator.hasNext()) {
                CandidateSegmentGroup nextCandidateSegmentGroup = candidateSegmentGroupIterator.nextCandidateSegmentGroup();
                if (DEBUG) {
                    log.debug(new StringBuffer().append("classifier is: ").append(multiClassVPClassifier).toString());
                }
                Segmentation bestSegments = new ViterbiSearcher(multiClassVPClassifier, schema, maxWindowSize).bestSegments(nextCandidateSegmentGroup);
                if (DEBUG) {
                    log.debug(new StringBuffer().append("viterbi:\n").append(bestSegments).toString());
                }
                Segmentation correctSegments = correctSegments(nextCandidateSegmentGroup, schema, maxWindowSize);
                if (DEBUG) {
                    log.debug(new StringBuffer().append("correct segments:\n").append(correctSegments).toString());
                }
                int compareSegmentsAndRevise = compareSegmentsAndRevise(multiClassVPClassifier, schema, bestSegments, correctSegments, -1.0d, nextCandidateSegmentGroup);
                boolean z = compareSegmentsAndRevise > 0;
                int compareSegmentsAndRevise2 = compareSegmentsAndRevise(multiClassVPClassifier, schema, correctSegments, bestSegments, 1.0d, nextCandidateSegmentGroup);
                if (compareSegmentsAndRevise2 > 0) {
                    z = true;
                }
                if (z) {
                    i2++;
                }
                i3 += compareSegmentsAndRevise + compareSegmentsAndRevise2;
                i4 += correctSegments.size();
                multiClassVPClassifier.completeUpdate();
                progressCounter.progress();
            }
            System.out.println(new StringBuffer().append("Epoch ").append(i).append(": sequenceErr=").append(i2).append(" transitionErrors=").append(i3).append("/").append(i4).toString());
            if (i3 == 0) {
                break;
            }
        }
        progressCounter.finished();
        multiClassVPClassifier.setVoteMode(true);
        return new ViterbiSegmenter(multiClassVPClassifier, schema, maxWindowSize);
    }

    private int compareSegmentsAndRevise(CollinsPerceptronLearner.MultiClassVPClassifier multiClassVPClassifier, ExampleSchema exampleSchema, Segmentation segmentation, Segmentation segmentation2, double d, CandidateSegmentGroup candidateSegmentGroup) {
        int i = 0;
        Map previousClassMap = previousClassMap(segmentation, exampleSchema);
        Map previousClassMap2 = previousClassMap(segmentation2, exampleSchema);
        String[] strArr = new String[1];
        Iterator it = segmentation.iterator();
        while (it.hasNext()) {
            Segmentation.Segment segment = (Segmentation.Segment) it.next();
            String str = (String) previousClassMap.get(segment);
            if (segment.lo >= 0 && (!segmentation2.contains(segment) || !previousClassMap2.get(segment).equals(str))) {
                i++;
                strArr[0] = str;
                InstanceFromSequence instanceFromSequence = new InstanceFromSequence(candidateSegmentGroup.getSubsequenceExample(segment.lo, segment.hi), strArr);
                if (DEBUG) {
                    log.debug(new StringBuffer().append("update ").append(d).append(" for: ").append(instanceFromSequence.getSource()).toString());
                }
                multiClassVPClassifier.update(exampleSchema.getClassName(segment.y), instanceFromSequence, d);
            }
        }
        return i;
    }

    private Map previousClassMap(Segmentation segmentation, ExampleSchema exampleSchema) {
        TreeMap treeMap = new TreeMap();
        Segmentation.Segment segment = null;
        Iterator it = segmentation.iterator();
        while (it.hasNext()) {
            Segmentation.Segment segment2 = (Segmentation.Segment) it.next();
            treeMap.put(segment2, segment == null ? "null" : exampleSchema.getClassName(segment.y));
            segment = segment2;
        }
        return treeMap;
    }

    private Segmentation correctSegments(CandidateSegmentGroup candidateSegmentGroup, ExampleSchema exampleSchema, int i) {
        Segmentation segmentation = new Segmentation(exampleSchema);
        int i2 = 0;
        while (i2 < candidateSegmentGroup.getSequenceLength()) {
            boolean z = false;
            for (int i3 = 1; !z && i3 <= i; i3++) {
                Instance subsequenceInstance = candidateSegmentGroup.getSubsequenceInstance(i2, i2 + i3);
                ClassLabel subsequenceLabel = candidateSegmentGroup.getSubsequenceLabel(i2, i2 + i3);
                if (subsequenceInstance != null && !subsequenceLabel.isNegative()) {
                    segmentation.add(new Segmentation.Segment(i2, i2 + i3, exampleSchema.getClassIndex(subsequenceLabel.bestClassName())));
                    z = true;
                    i2 += i3;
                }
            }
            if (!z) {
                candidateSegmentGroup.getSubsequenceInstance(i2, i2 + 1);
                candidateSegmentGroup.getSubsequenceLabel(i2, i2 + 1);
                segmentation.add(new Segmentation.Segment(i2, i2 + 1, exampleSchema.getClassIndex(ExampleSchema.NEG_CLASS_NAME)));
                i2++;
            }
        }
        return segmentation;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void dumpStuff(CandidateSegmentGroup candidateSegmentGroup, double[][] dArr, BackPointer[][] backPointerArr) {
        Example example = new Example(new MutableInstance("*NULL*"), new ClassLabel("*NULL*"));
        DecimalFormat decimalFormat = new DecimalFormat("####.###");
        System.out.println("t.y\tf(t,y)\tt'.y'\tspan");
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                BackPointer backPointer = backPointerArr[i][i2];
                Example subsequenceExample = backPointer == null ? example : candidateSegmentGroup.getSubsequenceExample(backPointer.lastT, backPointer.t);
                if (backPointer == null) {
                    backPointer = new BackPointer(-1, -1, -1);
                }
                System.out.println(new StringBuffer().append(i).append(".").append(i2).append("\t").append(decimalFormat.format(dArr[i][i2])).append("\t").append(backPointer.lastT).append(".").append(backPointer.lastY).append("\t'").append(subsequenceExample.getSource()).append("' ").append(backPointer.onBestPath ? "<==" : "").toString());
            }
        }
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$sequential$SegmentCollinsPerceptronLearner == null) {
            cls = class$("edu.cmu.minorthird.classify.sequential.SegmentCollinsPerceptronLearner");
            class$edu$cmu$minorthird$classify$sequential$SegmentCollinsPerceptronLearner = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$sequential$SegmentCollinsPerceptronLearner;
        }
        log = Logger.getLogger(cls);
        DEBUG = log.isDebugEnabled();
    }
}
