package edu.cmu.minorthird.classify.sequential;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.FeatureIdFactory;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.classify.sequential.SequenceUtils;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import iitb.CRF.CRF;
import iitb.CRF.DataIter;
import iitb.CRF.DataSequence;
import iitb.Model.EdgeFeatures;
import iitb.Model.EdgeLinearHistFeatures;
import iitb.Model.FeatureGenImpl;
import iitb.Model.FeatureImpl;
import iitb.Model.FeatureTypes;
import iitb.Model.StartFeatures;
import java.awt.BorderLayout;
import java.io.Serializable;
import java.util.Iterator;
import java.util.Properties;
import java.util.StringTokenizer;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner.class */
public class CRFLearner implements BatchSequenceClassifierLearner, SequenceConstants, SequenceClassifier, Visible, Serializable {
    private static final long serialVersionUID = 1;
    private final int CURRENT_SERIAL_VERSION = 1;
    int histsize;
    ExampleSchema schema;
    CRF crfModel;
    Properties defaults;
    Properties options;
    private FeatureIdFactory idFactory;
    private static final boolean CONVERT_TO_MINORTHIRD_HYPERPLANE = true;
    public String maxItersHelp;
    public String useHighPrecisionArithmeticHelp;
    FeatureGenImpl featureGen;
    SequenceClassifier cmmClassifier;
    double[] crfWs;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner$CRFDataIter.class */
    public class CRFDataIter implements DataIter {
        Iterator iter;
        SequenceDataset dataset;
        TrainDataSequenceC sequence;
        int dataSize;
        private final CRFLearner this$0;

        CRFDataIter(CRFLearner cRFLearner, SequenceDataset sequenceDataset) {
            this.this$0 = cRFLearner;
            this.dataset = sequenceDataset;
            this.dataSize = sequenceDataset.size();
            this.sequence = new TrainDataSequenceC(cRFLearner);
        }

        @Override // iitb.CRF.DataIter
        public void startScan() {
            this.iter = this.dataset.sequenceIterator();
        }

        @Override // iitb.CRF.DataIter
        public boolean hasNext() {
            return this.iter.hasNext();
        }

        @Override // iitb.CRF.DataIter
        public DataSequence next() {
            this.sequence.init((Example[]) this.iter.next());
            return this.sequence;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner$DataSequenceC.class */
    public class DataSequenceC implements DataSequence {
        Instance[] sequence;
        int[] labels;
        private final CRFLearner this$0;

        DataSequenceC(CRFLearner cRFLearner) {
            this.this$0 = cRFLearner;
        }

        void init(Instance[] instanceArr) {
            this.sequence = instanceArr;
            if (instanceArr != null) {
                if (this.labels == null || instanceArr.length > this.labels.length) {
                    this.labels = new int[instanceArr.length];
                }
            }
        }

        @Override // iitb.CRF.DataSequence
        public int length() {
            return this.sequence.length;
        }

        @Override // iitb.CRF.DataSequence
        public int y(int i) {
            return this.labels[i];
        }

        @Override // iitb.CRF.DataSequence
        public Object x(int i) {
            return this.sequence[i];
        }

        @Override // iitb.CRF.DataSequence
        public void set_y(int i, int i2) {
            this.labels[i] = i2;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner$MTFeatureGenImpl.class */
    public class MTFeatureGenImpl extends FeatureGenImpl {
        private final CRFLearner this$0;

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        public MTFeatureGenImpl(CRFLearner cRFLearner, String str, int i, String[] strArr) throws Exception {
            super(str, i, false);
            this.this$0 = cRFLearner;
            Feature[] featureArr = new Feature[strArr.length];
            for (int i2 = 0; i2 < strArr.length; i2++) {
                featureArr[i2] = new Feature(new String[]{SequenceConstants.HISTORY_FEATURE, "1", strArr[i2]});
            }
            addFeature(new EdgeFeatures(this, featureArr));
            addFeature(new StartFeatures(this, new Feature(new String[]{SequenceConstants.HISTORY_FEATURE, "1", "null"})));
            if (cRFLearner.histsize > 1) {
                Feature[][] featureArr2 = new Feature[cRFLearner.histsize][strArr.length];
                for (int i3 = 1; i3 < cRFLearner.histsize; i3++) {
                    for (int i4 = 0; i4 < strArr.length; i4++) {
                        featureArr2[i3][i4] = new Feature(new String[]{SequenceConstants.HISTORY_FEATURE, Integer.toString(i3 + 1), strArr[i4]});
                    }
                }
                addFeature(new EdgeLinearHistFeatures(this, featureArr2, cRFLearner.histsize));
            }
            addFeature(new MTFeatureTypes(cRFLearner, this));
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner$MTFeatureTypes.class */
    class MTFeatureTypes extends FeatureTypes {
        Feature.Looper featureLooper;
        Feature feature;
        int numStates;
        Instance example;
        int stateId;
        private final CRFLearner this$0;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        public MTFeatureTypes(CRFLearner cRFLearner, FeatureGenImpl featureGenImpl) {
            super(featureGenImpl);
            this.this$0 = cRFLearner;
            this.numStates = this.model.numStates();
        }

        void advance() {
            this.stateId++;
            if (this.stateId < this.numStates) {
                return;
            }
            if (this.featureLooper.hasNext()) {
                this.feature = this.featureLooper.nextFeature();
                this.stateId = 0;
            } else {
                this.feature = null;
                this.featureLooper = null;
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public boolean startScan() {
            this.stateId = -1;
            if (!this.featureLooper.hasNext()) {
                this.feature = null;
                return false;
            }
            this.feature = this.featureLooper.nextFeature();
            advance();
            return true;
        }

        @Override // iitb.Model.FeatureTypes
        public boolean startScanFeaturesAt(DataSequence dataSequence, int i, int i2) {
            this.example = (Instance) dataSequence.x(i2);
            this.featureLooper = this.example.featureIterator();
            return startScan();
        }

        @Override // iitb.Model.FeatureTypes
        public boolean hasNext() {
            return this.stateId < this.numStates && this.feature != null;
        }

        @Override // iitb.Model.FeatureTypes
        public void next(FeatureImpl featureImpl) {
            featureImpl.yend = this.stateId;
            featureImpl.ystart = -1;
            featureImpl.val = (float) this.example.getWeight(this.feature);
            setFeatureIdentifier((this.this$0.idFactory.getID(this.feature) * this.numStates) + this.stateId, this.stateId, this.feature, featureImpl);
            advance();
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner$TestDataSequenceC.class */
    class TestDataSequenceC extends DataSequenceC {
        private final CRFLearner this$0;

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        TestDataSequenceC(CRFLearner cRFLearner, Instance[] instanceArr) {
            super(cRFLearner);
            this.this$0 = cRFLearner;
            init(instanceArr);
        }

        ClassLabel[] getLabels() {
            ClassLabel[] classLabelArr = new ClassLabel[this.sequence.length];
            for (int i = 0; i < this.sequence.length; i++) {
                classLabelArr[i] = new ClassLabel(this.this$0.schema.getClassName(this.labels[i]));
            }
            return classLabelArr;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CRFLearner$TrainDataSequenceC.class */
    class TrainDataSequenceC extends DataSequenceC {
        private final CRFLearner this$0;

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        TrainDataSequenceC(CRFLearner cRFLearner) {
            super(cRFLearner);
            this.this$0 = cRFLearner;
        }

        void init(Example[] exampleArr) {
            super.init((Instance[]) exampleArr);
            if (exampleArr != null) {
                for (int i = 0; i < this.sequence.length; i++) {
                    this.labels[i] = this.this$0.schema.getClassIndex(exampleArr[i].getLabel().bestClassName());
                }
            }
        }
    }

    public CRFLearner() {
        this.CURRENT_SERIAL_VERSION = 1;
        this.histsize = 1;
        this.maxItersHelp = new String("Number of training iterations over the training set; default set to 100");
        this.useHighPrecisionArithmeticHelp = new String("Make the learner use high precision arithmetic.");
        this.cmmClassifier = null;
        this.defaults = new Properties();
        this.defaults.setProperty("modelGraph", "naive");
        this.defaults.setProperty("debugLvl", "1");
        this.defaults.setProperty("trainer", "ll");
        this.options = this.defaults;
    }

    public CRFLearner(String str) {
        this(str, 1);
    }

    public CRFLearner(String str, int i) {
        this();
        this.histsize = i;
        StringTokenizer stringTokenizer = new StringTokenizer(str, AbstractFormatter.DEFAULT_COLUMN_SEPARATOR);
        this.options = new Properties(this.defaults);
        while (stringTokenizer.hasMoreTokens()) {
            this.options.setProperty(stringTokenizer.nextToken(), stringTokenizer.nextToken());
        }
    }

    public CRFLearner(String[] strArr) {
        this();
        this.options = new Properties(this.defaults);
        for (int i = 0; i < strArr.length - 1; i += 2) {
            this.options.setProperty(strArr[i], strArr[i + 1]);
        }
    }

    public void setLogSpaceOption() {
        this.options.setProperty("trainer", "ll");
    }

    public void removeLogSpaceOption() {
        this.options.remove("trainer");
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public int getHistorySize() {
        return this.histsize;
    }

    public void setMaxIters(int i) {
        this.defaults.setProperty("maxIters", Integer.toString(i));
    }

    public int getMaxIters() {
        String property = this.defaults.getProperty("maxIters");
        if (property != null) {
            return Integer.parseInt(property);
        }
        return 100;
    }

    public String getMaxItersHelp() {
        return this.maxItersHelp;
    }

    public boolean getUseHighPrecisionArithmetic() {
        String property = this.defaults.getProperty("trainer");
        return property != null && property.equals("ll");
    }

    public void setUseHighPrecisionArithmetic(boolean z) {
        if (z) {
            setLogSpaceOption();
        } else {
            removeLogSpaceOption();
        }
    }

    public String getUseHighPrecisionArithmeticHelp() {
        return this.useHighPrecisionArithmeticHelp;
    }

    DataIter allocModel(SequenceDataset sequenceDataset) throws Exception {
        this.featureGen = new MTFeatureGenImpl(this, this.options.getProperty("modelGraph"), this.schema.getNumberOfClasses(), this.schema.validClassNames());
        System.out.println(new StringBuffer().append("Property: ").append(this.options.getProperty("trainer")).toString());
        this.crfModel = new CRF(this.featureGen.numStates(), this.histsize, this.featureGen, this.options);
        return new CRFDataIter(this, sequenceDataset);
    }

    @Override // edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner
    public SequenceClassifier batchTrain(SequenceDataset sequenceDataset) {
        try {
            this.idFactory = new FeatureIdFactory(sequenceDataset);
            this.schema = sequenceDataset.getSchema();
            return doTrain(allocModel(sequenceDataset));
        } catch (Exception e) {
            e.printStackTrace();
            throw new IllegalStateException(new StringBuffer().append("error in CRF: ").append(e).toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SequenceClassifier doTrain(DataIter dataIter) throws Exception {
        this.featureGen.train(dataIter);
        ProgressCounter progressCounter = new ProgressCounter("training CRF", "iteration");
        this.crfWs = this.crfModel.train(dataIter);
        progressCounter.finished();
        return toMinorthirdClassifier();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public SequenceClassifier toMinorthirdClassifier() {
        int numberOfClasses = this.schema.getNumberOfClasses();
        Hyperplane[] hyperplaneArr = new Hyperplane[numberOfClasses];
        for (int i = 0; i < numberOfClasses; i++) {
            hyperplaneArr[i] = new Hyperplane();
            hyperplaneArr[i].setBias(0.0d);
        }
        for (int i2 = 0; i2 < this.crfWs.length; i2++) {
            hyperplaneArr[this.featureGen.featureIdentifier(i2).stateId].increment((Feature) this.featureGen.featureIdentifier(i2).name, this.crfWs[i2]);
        }
        return new CMM(new SequenceUtils.MultiClassClassifier(this.schema, hyperplaneArr), this.histsize, this.schema);
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
    public ClassLabel[] classification(Instance[] instanceArr) {
        TestDataSequenceC testDataSequenceC = new TestDataSequenceC(this, instanceArr);
        this.crfModel.apply(testDataSequenceC);
        this.featureGen.mapStatesToLabels(testDataSequenceC);
        return testDataSequenceC.getLabels();
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
    public String explain(Instance[] instanceArr) {
        if (this.cmmClassifier == null) {
            this.cmmClassifier = toMinorthirdClassifier();
        }
        return this.cmmClassifier.explain(instanceArr);
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
    public Explanation getExplanation(Instance[] instanceArr) {
        if (this.cmmClassifier == null) {
            this.cmmClassifier = toMinorthirdClassifier();
        }
        Explanation.Node node = new Explanation.Node("CRF Explanation");
        Explanation.Node topNode = this.cmmClassifier.getExplanation(instanceArr).getTopNode();
        if (topNode == null) {
            topNode = new Explanation.Node(this.cmmClassifier.explain(instanceArr));
        }
        node.add(topNode);
        return new Explanation(node);
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        ComponentViewer componentViewer = new ComponentViewer(this) { // from class: edu.cmu.minorthird.classify.sequential.CRFLearner.1
            private final CRFLearner this$0;

            {
                this.this$0 = this;
            }

            @Override // edu.cmu.minorthird.util.gui.ComponentViewer
            public JComponent componentFor(Object obj) {
                JPanel jPanel = new JPanel();
                jPanel.setLayout(new BorderLayout());
                jPanel.add(new JLabel("CRFLearner: historySize=1"), "North");
                SmartVanillaViewer smartVanillaViewer = new SmartVanillaViewer(this.this$0.toMinorthirdClassifier());
                smartVanillaViewer.setSuperView(this);
                jPanel.add(smartVanillaViewer, "South");
                jPanel.setBorder(new TitledBorder("CRFLearner"));
                return new JScrollPane(jPanel);
            }
        };
        componentViewer.setContent(this);
        return componentViewer;
    }

    public FeatureIdFactory getIdFactory() {
        return this.idFactory;
    }

    public void setIdFactory(FeatureIdFactory featureIdFactory) {
        this.idFactory = featureIdFactory;
    }
}
