package edu.cmu.minorthird.classify.sequential;

import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.experiments.Evaluation;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.util.TreeSet;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/CrossValidatedSequenceDataset.class */
public class CrossValidatedSequenceDataset implements Visible {
    private static Logger log;
    private ClassifiedSequenceDataset[] cds;
    private ClassifiedSequenceDataset[] trainCds;
    private Evaluation v;
    static Class class$edu$cmu$minorthird$classify$sequential$CrossValidatedSequenceDataset;

    public CrossValidatedSequenceDataset(SequenceClassifierLearner sequenceClassifierLearner, SequenceDataset sequenceDataset, Splitter splitter) {
        this(sequenceClassifierLearner, sequenceDataset, splitter, false);
    }

    public CrossValidatedSequenceDataset(SequenceClassifierLearner sequenceClassifierLearner, SequenceDataset sequenceDataset, Splitter splitter, boolean z) {
        Dataset.Split split = sequenceDataset.split(splitter);
        this.cds = new ClassifiedSequenceDataset[split.getNumPartitions()];
        this.trainCds = z ? new ClassifiedSequenceDataset[split.getNumPartitions()] : null;
        this.v = new Evaluation(sequenceDataset.getSchema());
        ProgressCounter progressCounter = new ProgressCounter("train/test", "fold", split.getNumPartitions());
        for (int i = 0; i < split.getNumPartitions(); i++) {
            SequenceDataset sequenceDataset2 = (SequenceDataset) split.getTrain(i);
            SequenceDataset sequenceDataset3 = (SequenceDataset) split.getTest(i);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", preparing to train on ").append(sequenceDataset2.size()).append(" and test on ").append(sequenceDataset3.size()).toString());
            SequenceClassifier train = new DatasetSequenceClassifierTeacher(sequenceDataset2).train(sequenceClassifierLearner);
            this.cds[i] = new ClassifiedSequenceDataset(train, sequenceDataset3);
            if (this.trainCds != null) {
                this.trainCds[i] = new ClassifiedSequenceDataset(train, sequenceDataset2);
            }
            this.v.extend(this.cds[i].getClassifier(), sequenceDataset3, 0);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", stored classified dataset").toString());
            progressCounter.progress();
        }
        progressCounter.finished();
    }

    public Evaluation getEvaluation() {
        return this.v;
    }

    private void showSubpops(String str, SequenceDataset sequenceDataset) {
        TreeSet treeSet = new TreeSet();
        Example.Looper it = sequenceDataset.iterator();
        while (it.hasNext()) {
            treeSet.add(it.nextExample().getSubpopulationId());
        }
        log.debug(new StringBuffer().append(str).append(treeSet.toString()).toString());
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        ParallelViewer parallelViewer = new ParallelViewer();
        for (int i = 0; i < this.cds.length; i++) {
            parallelViewer.addSubView(new StringBuffer().append("Test Partition ").append(i + 1).toString(), new TransformedViewer(this, this.cds[0].toGUI(), i) { // from class: edu.cmu.minorthird.classify.sequential.CrossValidatedSequenceDataset.1
                private final int val$k;
                private final CrossValidatedSequenceDataset this$0;

                {
                    this.this$0 = this;
                    this.val$k = i;
                }

                @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                public Object transform(Object obj) {
                    return this.this$0.cds[this.val$k];
                }
            });
        }
        if (this.trainCds != null) {
            for (int i2 = 0; i2 < this.trainCds.length; i2++) {
                parallelViewer.addSubView(new StringBuffer().append("Train Partition ").append(i2 + 1).toString(), new TransformedViewer(this, this.cds[0].toGUI(), i2) { // from class: edu.cmu.minorthird.classify.sequential.CrossValidatedSequenceDataset.2
                    private final int val$k;
                    private final CrossValidatedSequenceDataset this$0;

                    {
                        this.this$0 = this;
                        this.val$k = i2;
                    }

                    @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                    public Object transform(Object obj) {
                        return this.this$0.trainCds[this.val$k];
                    }
                });
            }
        }
        parallelViewer.addSubView("Overall Evaluation", new TransformedViewer(this, this.v.toGUI()) { // from class: edu.cmu.minorthird.classify.sequential.CrossValidatedSequenceDataset.3
            private final CrossValidatedSequenceDataset this$0;

            {
                this.this$0 = this;
            }

            @Override // edu.cmu.minorthird.util.gui.TransformedViewer
            public Object transform(Object obj) {
                return ((CrossValidatedSequenceDataset) obj).v;
            }
        });
        parallelViewer.setContent(this);
        return parallelViewer;
    }

    public static void main(String[] strArr) {
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$sequential$CrossValidatedSequenceDataset == null) {
            cls = class$("edu.cmu.minorthird.classify.sequential.CrossValidatedSequenceDataset");
            class$edu$cmu$minorthird$classify$sequential$CrossValidatedSequenceDataset = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$sequential$CrossValidatedSequenceDataset;
        }
        log = Logger.getLogger(cls);
    }
}
