package defpackage;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.CMMLearner;
import edu.cmu.minorthird.classify.sequential.DatasetSequenceClassifierTeacher;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.classify.transform.AugmentedInstance;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.util.Iterator;
import org.apache.log4j.Logger;
import org.apache.log4j.helpers.DateLayout;

/* loaded from: input_file:StackedSequenceLearner.class */
public class StackedSequenceLearner implements BatchSequenceClassifierLearner {
    private static Logger log;
    private SequenceClassifierLearner baseLearner;
    private StackingParams params;
    static Class class$StackedSequenceLearner;

    /* loaded from: input_file:StackedSequenceLearner$StackedSequenceClassifier.class */
    private class StackedSequenceClassifier implements SequenceClassifier, Visible {
        private SequenceClassifier[] m;
        private ExampleSchema schema;
        private StackingParams params;
        private final StackedSequenceLearner this$0;

        public StackedSequenceClassifier(StackedSequenceLearner stackedSequenceLearner, SequenceClassifier[] sequenceClassifierArr, StackingParams stackingParams) {
            this.this$0 = stackedSequenceLearner;
            this.m = sequenceClassifierArr;
            this.params = stackingParams;
        }

        @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
        public ClassLabel[] classification(Instance[] instanceArr) {
            String[] strArr = new String[this.params.historySize];
            ClassLabel[] classification = this.m[0].classification(instanceArr);
            Instance[] instanceArr2 = new Instance[instanceArr.length];
            for (int i = 1; i < this.m.length; i++) {
                for (int i2 = 0; i2 < instanceArr.length; i2++) {
                    instanceArr2[i2] = StackedSequenceLearner.stackInstance(i2, instanceArr[i2], classification, this.params);
                }
                classification = this.m[i].classification(instanceArr2);
            }
            return classification;
        }

        @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
        public String explain(Instance[] instanceArr) {
            return "not implemented";
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            ParallelViewer parallelViewer = new ParallelViewer();
            for (int i = 0; i < this.m.length; i++) {
                int i2 = i;
                parallelViewer.addSubView(new StringBuffer().append("Level ").append(i2).append(" classifier").toString(), new TransformedViewer(this, new SmartVanillaViewer(this.m[i2]), i2) { // from class: StackedSequenceLearner.1
                    private final int val$k;
                    private final StackedSequenceClassifier this$1;

                    {
                        this.this$1 = this;
                        this.val$k = i2;
                    }

                    @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                    public Object transform(Object obj) {
                        return ((StackedSequenceClassifier) obj).m[this.val$k];
                    }
                });
            }
            parallelViewer.setContent(this);
            return parallelViewer;
        }
    }

    /* loaded from: input_file:StackedSequenceLearner$StackingParams.class */
    public static class StackingParams {
        public int historySize = 5;
        public int futureSize = 5;
        public int stackingDepth = 1;
        public boolean useLogistic = true;
        public boolean useTargetPrediction = true;
        public boolean useConfidence = true;
        public Splitter splitter = new CrossValSplitter(5);
        int crossValSplits = 5;

        public int getHistorySize() {
            return this.historySize;
        }

        public void setHistorySize(int i) {
            this.historySize = i;
        }

        public int getFutureSize() {
            return this.futureSize;
        }

        public void setFutureSize(int i) {
            this.futureSize = i;
        }

        public boolean getUseLogisticOnConfidences() {
            return this.useLogistic;
        }

        public void setUseLogisticOnConfidences(boolean z) {
            this.useLogistic = z;
        }

        public boolean getUseConfidences() {
            return this.useConfidence;
        }

        public void setUseConfidences(boolean z) {
            this.useConfidence = z;
        }

        public boolean getUseTargetPrediction() {
            return this.useTargetPrediction;
        }

        public void setUseTargetPrediction(boolean z) {
            this.useTargetPrediction = z;
        }

        public int getStackingDepth() {
            return this.stackingDepth;
        }

        public void setStackingDepth(int i) {
            this.stackingDepth = i;
        }

        public int getCrossValSplits() {
            return this.crossValSplits;
        }

        public void setCrossValSplits(int i) {
            this.splitter = new CrossValSplitter(i);
            this.crossValSplits = i;
        }
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public int getHistorySize() {
        return this.params.historySize;
    }

    public void setHistorySize(int i) {
        this.params.setHistorySize(i);
    }

    public StackingParams getParams() {
        return this.params;
    }

    public StackedSequenceLearner() {
        this.baseLearner = new CMMLearner(new VotedPerceptron(), 0);
        this.params = new StackingParams();
    }

    public StackedSequenceLearner(SequenceClassifierLearner sequenceClassifierLearner, int i) {
        this();
        this.baseLearner = sequenceClassifierLearner;
        this.params.setStackingDepth(i);
    }

    public StackedSequenceLearner(ClassifierLearner classifierLearner, int i) {
        this();
        this.baseLearner = new CMMLearner(classifierLearner, 0);
        this.params.setStackingDepth(i);
    }

    public StackedSequenceLearner(SequenceClassifierLearner sequenceClassifierLearner, int i, int i2) {
        this();
        this.baseLearner = sequenceClassifierLearner;
        this.params.setStackingDepth(i);
        this.params.setHistorySize(i2);
        this.params.setFutureSize(i2);
    }

    public StackedSequenceLearner(ClassifierLearner classifierLearner, int i, int i2) {
        this();
        this.baseLearner = new CMMLearner(classifierLearner, 0);
        this.params.setStackingDepth(i);
        this.params.setHistorySize(i2);
        this.params.setFutureSize(i2);
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    @Override // edu.cmu.minorthird.classify.sequential.BatchSequenceClassifierLearner
    public SequenceClassifier batchTrain(SequenceDataset sequenceDataset) {
        SequenceClassifier[] sequenceClassifierArr = new SequenceClassifier[this.params.stackingDepth + 1];
        SequenceDataset sequenceDataset2 = sequenceDataset;
        sequenceDataset2.setHistorySize(0);
        ProgressCounter progressCounter = new ProgressCounter("training stacked learner", "stacking level", this.params.stackingDepth + 1);
        for (int i = 0; i <= this.params.stackingDepth; i++) {
            sequenceClassifierArr[i] = new DatasetSequenceClassifierTeacher(sequenceDataset2).train(this.baseLearner);
            if (i + 1 <= this.params.stackingDepth) {
                sequenceDataset2 = stackDataset(sequenceDataset2);
            }
            progressCounter.progress();
        }
        progressCounter.finished();
        return new StackedSequenceClassifier(this, sequenceClassifierArr, this.params);
    }

    public SequenceDataset stackDataset(SequenceDataset sequenceDataset) {
        String[] strArr = new String[this.params.historySize];
        SequenceDataset sequenceDataset2 = new SequenceDataset();
        Dataset.Split split = sequenceDataset.split(this.params.splitter);
        sequenceDataset.getSchema();
        ProgressCounter progressCounter = new ProgressCounter("labeling for stacking", "fold", split.getNumPartitions());
        for (int i = 0; i < split.getNumPartitions(); i++) {
            SequenceDataset sequenceDataset3 = (SequenceDataset) split.getTrain(i);
            SequenceDataset sequenceDataset4 = (SequenceDataset) split.getTest(i);
            log.info(new StringBuffer().append("splitting with ").append(this.params.splitter).append(", preparing to train on ").append(sequenceDataset3.size()).append(" and test on ").append(sequenceDataset4.size()).toString());
            SequenceClassifier train = new DatasetSequenceClassifierTeacher(sequenceDataset3).train(this.baseLearner);
            Iterator sequenceIterator = sequenceDataset4.sequenceIterator();
            while (sequenceIterator.hasNext()) {
                Example[] exampleArr = (Example[]) sequenceIterator.next();
                ClassLabel[] classification = train.classification(exampleArr);
                Example[] exampleArr2 = new Example[exampleArr.length];
                for (int i2 = 0; i2 < exampleArr.length; i2++) {
                    exampleArr2[i2] = new Example(stackInstance(i2, exampleArr[i2].asInstance(), classification, this.params), exampleArr[i2].getLabel());
                }
                sequenceDataset2.addSequence(exampleArr2);
            }
            log.info(new StringBuffer().append("splitting with ").append(this.params.splitter).append(", stored classified dataset").toString());
            progressCounter.progress();
        }
        progressCounter.finished();
        sequenceDataset2.setHistorySize(0);
        return sequenceDataset2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Instance stackInstance(int i, Instance instance, ClassLabel[] classLabelArr, StackingParams stackingParams) {
        int i2 = stackingParams.historySize + stackingParams.futureSize + (stackingParams.useTargetPrediction ? 1 : 0);
        String[] strArr = new String[i2];
        double[] dArr = new double[i2];
        int i3 = 0;
        for (int i4 = i - stackingParams.historySize; i4 <= i + stackingParams.futureSize; i4++) {
            if (i4 != i || stackingParams.useTargetPrediction) {
                if (i4 < 0 || i4 >= classLabelArr.length) {
                    strArr[i3] = stackFeatureName(i4 - i, DateLayout.NULL_DATE_FORMAT);
                    dArr[i3] = 1.0d;
                } else {
                    strArr[i3] = stackFeatureName(i4 - i, classLabelArr[i4].bestClassName());
                    dArr[i3] = 1.0d;
                    if (stackingParams.useConfidence) {
                        double bestWeight = classLabelArr[i4].bestWeight();
                        dArr[i3] = stackingParams.useLogistic ? MathUtil.logistic(bestWeight) : bestWeight;
                    }
                }
                i3++;
            }
        }
        return new AugmentedInstance(instance, strArr, dArr);
    }

    private static String stackFeatureName(int i, String str) {
        return i < 0 ? new StringBuffer().append("pred.prev.").append(-i).append(".").append(str).toString() : i > 0 ? new StringBuffer().append("pred.next.").append(i).append(".").append(str).toString() : new StringBuffer().append("pred.here.").append(str).toString();
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$StackedSequenceLearner == null) {
            cls = class$("StackedSequenceLearner");
            class$StackedSequenceLearner = cls;
        } else {
            cls = class$StackedSequenceLearner;
        }
        log = Logger.getLogger(cls);
    }
}
