package edu.cmu.minorthird.classify.multi;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BatchVersion;
import edu.cmu.minorthird.classify.CascadingBinaryLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetLoader;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.FeatureFactory;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.multi.MultiExample;
import edu.cmu.minorthird.util.Saveable;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import edu.cmu.minorthird.util.gui.ZoomedViewer;
import java.awt.Component;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import javax.swing.JComponent;
import javax.swing.JList;
import javax.swing.JScrollPane;
import javax.swing.ListCellRenderer;

/* loaded from: input_file:edu/cmu/minorthird/classify/multi/MultiDataset.class */
public class MultiDataset implements Dataset, Visible, Saveable {
    private static final long serialVersionUID = 1;
    private final int CURRENT_SERIAL_VERSION = 1;
    protected ArrayList examples = new ArrayList();
    protected ArrayList unlabeledExamples = new ArrayList();
    protected Set[] classNameSets = null;
    protected FeatureFactory factory = new FeatureFactory();
    public int numPosExamples = 0;
    private static final String FORMAT_NAME = "Minorthird MultiDataset";

    /* loaded from: input_file:edu/cmu/minorthird/classify/multi/MultiDataset$MultiSplit.class */
    public class MultiSplit {
        Splitter splitter;
        private final MultiDataset this$0;

        public MultiSplit(MultiDataset multiDataset, Splitter splitter) {
            this.this$0 = multiDataset;
            this.splitter = splitter;
        }

        public int getNumPartitions() {
            return this.splitter.getNumPartitions();
        }

        public MultiDataset getTrain(int i) {
            return this.this$0.invertMultiIteration(this.splitter.getTrain(i));
        }

        public MultiDataset getTest(int i) {
            return this.this$0.invertMultiIteration(this.splitter.getTest(i));
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/multi/MultiDataset$SimpleDatasetViewer.class */
    public static class SimpleDatasetViewer extends ComponentViewer {
        @Override // edu.cmu.minorthird.util.gui.ComponentViewer, edu.cmu.minorthird.util.gui.Viewer
        public boolean canReceive(Object obj) {
            return obj instanceof Dataset;
        }

        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            MultiDataset multiDataset = (MultiDataset) obj;
            MultiExample[] multiExampleArr = new MultiExample[multiDataset.size()];
            int i = 0;
            MultiExample.Looper multiIterator = multiDataset.multiIterator();
            while (multiIterator.hasNext()) {
                int i2 = i;
                i++;
                multiExampleArr[i2] = multiIterator.nextMultiExample();
            }
            JList jList = new JList(multiExampleArr);
            jList.setCellRenderer(new ListCellRenderer(this, multiExampleArr) { // from class: edu.cmu.minorthird.classify.multi.MultiDataset.SimpleDatasetViewer.1
                private final MultiExample[] val$tmp;
                private final SimpleDatasetViewer this$0;

                {
                    this.this$0 = this;
                    this.val$tmp = multiExampleArr;
                }

                public Component getListCellRendererComponent(JList jList2, Object obj2, int i3, boolean z, boolean z2) {
                    return GUI.conciseMultiExampleRendererComponent(this.val$tmp[i3], 60, z);
                }
            });
            monitorSelections(jList);
            return new JScrollPane(jList);
        }
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public ExampleSchema getSchema() {
        ExampleSchema exampleSchema = new ExampleSchema((String[]) this.classNameSets[0].toArray(new String[this.classNameSets[0].size()]));
        return exampleSchema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA) ? ExampleSchema.BINARY_EXAMPLE_SCHEMA : exampleSchema;
    }

    public MultiExampleSchema getMultiSchema() {
        ExampleSchema[] exampleSchemaArr = new ExampleSchema[this.classNameSets.length];
        for (int i = 0; i < exampleSchemaArr.length; i++) {
            exampleSchemaArr[i] = new ExampleSchema((String[]) this.classNameSets[i].toArray(new String[this.classNameSets[i].size()]));
            if (exampleSchemaArr.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA)) {
                exampleSchemaArr[i] = ExampleSchema.BINARY_EXAMPLE_SCHEMA;
            }
        }
        return new MultiExampleSchema(exampleSchemaArr);
    }

    public void addUnlabeled(Instance instance) {
        this.unlabeledExamples.add(this.factory.compress(instance));
    }

    public Instance.Looper iteratorOverUnlabeled() {
        return new Instance.Looper(this.unlabeledExamples);
    }

    public int sizeUnlabeled() {
        return this.unlabeledExamples.size();
    }

    public boolean hasUnlabeled() {
        return this.unlabeledExamples.size() > 0;
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public void add(Example example) {
        throw new IllegalArgumentException("You must add a MultiExample to a MutiDataset");
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public void add(Example example, boolean z) {
        throw new IllegalArgumentException("You must add a MultiExample to a MutiDataset");
    }

    public void addMulti(MultiExample multiExample) {
        if (this.classNameSets == null) {
            this.classNameSets = new Set[multiExample.getMultiLabel().numDimensions()];
            for (int i = 0; i < this.classNameSets.length; i++) {
                this.classNameSets[i] = new TreeSet();
            }
        }
        if (this.classNameSets.length != multiExample.getMultiLabel().numDimensions()) {
            throw new IllegalArgumentException("This example does not have the same number of dimensions as previous examples");
        }
        this.examples.add(this.factory.compressMulti(multiExample));
        Set[] possibleLabels = multiExample.getMultiLabel().possibleLabels();
        for (int i2 = 0; i2 < this.classNameSets.length; i2++) {
            this.classNameSets[i2].addAll(possibleLabels[i2]);
        }
        if (multiExample.getLabel().isPositive()) {
            this.numPosExamples++;
        }
    }

    public Dataset[] separateDatasets() {
        BasicDataset[] basicDatasetArr = new BasicDataset[((MultiExample) this.examples.get(0)).getExamples().length];
        for (int i = 0; i < basicDatasetArr.length; i++) {
            basicDatasetArr[i] = new BasicDataset();
        }
        for (int i2 = 0; i2 < this.examples.size(); i2++) {
            Example[] examples = ((MultiExample) this.examples.get(i2)).getExamples();
            for (int i3 = 0; i3 < examples.length; i3++) {
                basicDatasetArr[i3].add(examples[i3]);
            }
        }
        return basicDatasetArr;
    }

    public int getNumPosExamples() {
        return this.numPosExamples;
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public Example.Looper iterator() {
        throw new IllegalArgumentException("Must use multiIterator to iterate through MultiExamples");
    }

    public MultiExample.Looper multiIterator() {
        return new MultiExample.Looper(this.examples);
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public int size() {
        return this.examples.size();
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public void shuffle(Random random) {
        Collections.shuffle(this.examples, random);
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public void shuffle() {
        shuffle(new Random(999L));
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public Dataset shallowCopy() {
        MultiDataset multiDataset = new MultiDataset();
        MultiExample.Looper multiIterator = multiIterator();
        while (multiIterator.hasNext()) {
            multiDataset.addMulti(multiIterator.nextMultiExample());
        }
        return multiDataset;
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public String[] getFormatNames() {
        return new String[]{FORMAT_NAME};
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public String getExtensionFor(String str) {
        return ".multidata";
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public void saveAs(File file, String str) throws IOException {
        if (!str.equals(FORMAT_NAME)) {
            throw new IllegalArgumentException(new StringBuffer().append("illegal format ").append(str).toString());
        }
        DatasetLoader.save(this, file);
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public Object restore(File file) throws IOException {
        try {
            return DatasetLoader.loadFile(file);
        } catch (NumberFormatException e) {
            throw new IllegalStateException(new StringBuffer().append("error loading from ").append(file).append(": ").append(e).toString());
        }
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer("");
        MultiExample.Looper multiIterator = multiIterator();
        while (multiIterator.hasNext()) {
            stringBuffer.append(multiIterator.nextMultiExample().toString());
            stringBuffer.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        }
        return stringBuffer.toString();
    }

    public MultiDataset annotateData() {
        MultiDataset multiDataset = new MultiDataset();
        MultiSplit MultiSplit2 = MultiSplit(new CrossValSplitter(9));
        for (int i = 0; i < 9; i++) {
            MultiClassifier train = new MultiDatasetClassifierTeacher(MultiSplit2.getTrain(i)).train(new CascadingBinaryLearner(new BatchVersion(new VotedPerceptron())));
            MultiExample.Looper multiIterator = MultiSplit2.getTest(i).multiIterator();
            while (multiIterator.hasNext()) {
                MultiExample nextMultiExample = multiIterator.nextMultiExample();
                Instance asInstance = nextMultiExample.asInstance();
                multiDataset.addMulti(new MultiExample(new InstanceFromPrediction(asInstance, train.multiLabelClassification(asInstance).bestClassName()), nextMultiExample.getMultiLabel(), nextMultiExample.getWeight()));
            }
        }
        return multiDataset;
    }

    public MultiDataset annotateData(MultiClassifier multiClassifier) {
        MultiDataset multiDataset = new MultiDataset();
        MultiExample.Looper multiIterator = multiIterator();
        while (multiIterator.hasNext()) {
            MultiExample nextMultiExample = multiIterator.nextMultiExample();
            Instance asInstance = nextMultiExample.asInstance();
            multiDataset.addMulti(new MultiExample(new InstanceFromPrediction(asInstance, multiClassifier.multiLabelClassification(asInstance).bestClassName()), nextMultiExample.getMultiLabel(), nextMultiExample.getWeight()));
        }
        return multiDataset;
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        SimpleDatasetViewer simpleDatasetViewer = new SimpleDatasetViewer();
        simpleDatasetViewer.setContent(this);
        return new ZoomedViewer(simpleDatasetViewer, GUI.newSourcedMultiExampleViewer());
    }

    @Override // edu.cmu.minorthird.classify.Dataset
    public Dataset.Split split(Splitter splitter) {
        splitter.split(this.examples.iterator());
        return new Dataset.Split(this, splitter) { // from class: edu.cmu.minorthird.classify.multi.MultiDataset.1
            private final Splitter val$splitter;
            private final MultiDataset this$0;

            {
                this.this$0 = this;
                this.val$splitter = splitter;
            }

            @Override // edu.cmu.minorthird.classify.Dataset.Split
            public int getNumPartitions() {
                return this.val$splitter.getNumPartitions();
            }

            @Override // edu.cmu.minorthird.classify.Dataset.Split
            public Dataset getTrain(int i) {
                return this.this$0.invertIteration(this.val$splitter.getTrain(i));
            }

            @Override // edu.cmu.minorthird.classify.Dataset.Split
            public Dataset getTest(int i) {
                return this.this$0.invertIteration(this.val$splitter.getTest(i));
            }
        };
    }

    public MultiSplit MultiSplit(Splitter splitter) {
        splitter.split(this.examples.iterator());
        return new MultiSplit(this, splitter);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Dataset invertIteration(Iterator it) {
        BasicDataset basicDataset = new BasicDataset();
        while (it.hasNext()) {
            basicDataset.add((Example) it.next());
        }
        return basicDataset;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public MultiDataset invertMultiIteration(Iterator it) {
        MultiDataset multiDataset = new MultiDataset();
        while (it.hasNext()) {
            multiDataset.addMulti((MultiExample) it.next());
        }
        return multiDataset;
    }

    public static void main(String[] strArr) {
        System.out.println("Not working yet");
    }
}
