package edu.cmu.minorthird.classify.multi;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.multi.MultiDataset;
import edu.cmu.minorthird.classify.transform.PredictedClassTransform;
import edu.cmu.minorthird.classify.transform.TransformingMultiClassifier;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.text.DecimalFormat;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/multi/MultiCrossValidatedDataset.class */
public class MultiCrossValidatedDataset implements Visible {
    private static Logger log;
    private MultiClassifiedDataset[] cds;
    private MultiClassifiedDataset[] trainCds;
    private MultiEvaluation v;
    static Class class$edu$cmu$minorthird$classify$experiments$CrossValidatedDataset;

    public MultiCrossValidatedDataset(ClassifierLearner classifierLearner, MultiDataset multiDataset, Splitter splitter) {
        this(classifierLearner, multiDataset, splitter, false, false);
    }

    public MultiCrossValidatedDataset(ClassifierLearner classifierLearner, MultiDataset multiDataset, Splitter splitter, boolean z) {
        this(classifierLearner, multiDataset, splitter, z, false);
    }

    public MultiCrossValidatedDataset(ClassifierLearner classifierLearner, MultiDataset multiDataset, Splitter splitter, boolean z, boolean z2) {
        MultiDataset.MultiSplit MultiSplit = multiDataset.MultiSplit(splitter);
        this.cds = new MultiClassifiedDataset[MultiSplit.getNumPartitions()];
        this.trainCds = z ? new MultiClassifiedDataset[MultiSplit.getNumPartitions()] : null;
        this.v = new MultiEvaluation(multiDataset.getMultiSchema());
        ProgressCounter progressCounter = new ProgressCounter("train/test", "fold", MultiSplit.getNumPartitions());
        for (int i = 0; i < MultiSplit.getNumPartitions(); i++) {
            MultiDataset train = MultiSplit.getTrain(i);
            train = z2 ? train.annotateData() : train;
            MultiDataset test = MultiSplit.getTest(i);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", preparing to train on ").append(train.size()).append(" and test on ").append(test.size()).toString());
            MultiClassifier train2 = new MultiDatasetClassifierTeacher(train).train(classifierLearner);
            train2 = z2 ? new TransformingMultiClassifier(train2, new PredictedClassTransform(train2)) : train2;
            MultiDatasetIndex multiDatasetIndex = new MultiDatasetIndex(test);
            this.cds[i] = new MultiClassifiedDataset(train2, test, multiDatasetIndex);
            if (this.trainCds != null) {
                this.trainCds[i] = new MultiClassifiedDataset(train2, train, multiDatasetIndex);
            }
            this.v.extend(train2, test);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", stored classified dataset").toString());
            progressCounter.progress();
        }
        progressCounter.finished();
    }

    private String classDistributionString(MultiExampleSchema multiExampleSchema, MultiDatasetIndex multiDatasetIndex) {
        StringBuffer stringBuffer = new StringBuffer("");
        DecimalFormat decimalFormat = new DecimalFormat("#####");
        for (ExampleSchema exampleSchema : multiExampleSchema.getSchemas()) {
            for (int i = 0; i < exampleSchema.getNumberOfClasses(); i++) {
                if (stringBuffer.length() > 0) {
                    stringBuffer.append("; ");
                }
                stringBuffer.append(new StringBuffer().append(decimalFormat.format(multiDatasetIndex.size(r0))).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(exampleSchema.getClassName(i)).toString());
            }
        }
        return stringBuffer.toString();
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        ParallelViewer parallelViewer = new ParallelViewer();
        for (int i = 0; i < this.cds.length; i++) {
            System.out.println(i);
            parallelViewer.addSubView(new StringBuffer().append("Test Partition ").append(i + 1).toString(), new TransformedViewer(this, this.cds[0].toGUI(), i) { // from class: edu.cmu.minorthird.classify.multi.MultiCrossValidatedDataset.1
                private final int val$k;
                private final MultiCrossValidatedDataset this$0;

                {
                    this.this$0 = this;
                    this.val$k = i;
                }

                @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                public Object transform(Object obj) {
                    return this.this$0.cds[this.val$k];
                }
            });
        }
        if (this.trainCds != null) {
            for (int i2 = 0; i2 < this.trainCds.length; i2++) {
                parallelViewer.addSubView(new StringBuffer().append("Train Partition ").append(i2 + 1).toString(), new TransformedViewer(this, this.cds[0].toGUI(), i2) { // from class: edu.cmu.minorthird.classify.multi.MultiCrossValidatedDataset.2
                    private final int val$k;
                    private final MultiCrossValidatedDataset this$0;

                    {
                        this.this$0 = this;
                        this.val$k = i2;
                    }

                    @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                    public Object transform(Object obj) {
                        return this.this$0.trainCds[this.val$k];
                    }
                });
            }
        }
        parallelViewer.addSubView("Overall Evaluation", new TransformedViewer(this, this.v.toGUI()) { // from class: edu.cmu.minorthird.classify.multi.MultiCrossValidatedDataset.3
            private final MultiCrossValidatedDataset this$0;

            {
                this.this$0 = this;
            }

            @Override // edu.cmu.minorthird.util.gui.TransformedViewer
            public Object transform(Object obj) {
                return ((MultiCrossValidatedDataset) obj).v;
            }
        });
        parallelViewer.setContent(this);
        return parallelViewer;
    }

    public MultiEvaluation getEvaluation() {
        return this.v;
    }

    public static void main(String[] strArr) {
        System.out.println("CrossValidatedDataset");
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$experiments$CrossValidatedDataset == null) {
            cls = class$("edu.cmu.minorthird.classify.experiments.CrossValidatedDataset");
            class$edu$cmu$minorthird$classify$experiments$CrossValidatedDataset = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$experiments$CrossValidatedDataset;
        }
        log = Logger.getLogger(cls);
    }
}
