package edu.cmu.minorthird.classify.algorithms.trees;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/AdaBoost.class */
public class AdaBoost extends BatchBinaryClassifierLearner {
    private static Logger log;
    private BatchClassifierLearner baseLearner;
    private int maxRounds;
    static Class class$edu$cmu$minorthird$classify$algorithms$trees$AdaBoost;

    /* renamed from: edu.cmu.minorthird.classify.algorithms.trees.AdaBoost$1, reason: invalid class name */
    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/AdaBoost$1.class */
    static class AnonymousClass1 {
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/AdaBoost$BoostedClassifier.class */
    private static class BoostedClassifier extends BinaryClassifier implements Serializable, Visible {
        private List classifiers;

        public BoostedClassifier(List list) {
            this.classifiers = list;
        }

        @Override // edu.cmu.minorthird.classify.BinaryClassifier
        public double score(Instance instance) {
            double d = 0.0d;
            Iterator it = this.classifiers.iterator();
            while (it.hasNext()) {
                d += ((BinaryClassifier) it.next()).score(instance);
            }
            return d;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            StringBuffer stringBuffer = new StringBuffer("");
            double d = 0.0d;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                d += binaryClassifier.score(instance);
                stringBuffer.append(new StringBuffer().append("score of ").append(binaryClassifier).append(": ").append(binaryClassifier.score(instance)).append(AbstractFormatter.DEFAULT_ROW_SEPARATOR).toString());
                stringBuffer.append(new StringBuffer().append(StringUtil.indent(1, binaryClassifier.explain(instance))).append(AbstractFormatter.DEFAULT_ROW_SEPARATOR).toString());
            }
            stringBuffer.append(new StringBuffer().append("total score: ").append(d).toString());
            return stringBuffer.toString();
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer("[boosted classifier:\n");
            Iterator it = this.classifiers.iterator();
            while (it.hasNext()) {
                stringBuffer.append(new StringBuffer().append(((BinaryClassifier) it.next()).toString()).append(AbstractFormatter.DEFAULT_ROW_SEPARATOR).toString());
            }
            stringBuffer.append("]");
            return stringBuffer.toString();
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            BoostedClassifierViewer boostedClassifierViewer = new BoostedClassifierViewer(null);
            boostedClassifierViewer.setContent(this);
            return boostedClassifierViewer;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/AdaBoost$BoostedClassifierViewer.class */
    private static class BoostedClassifierViewer extends ComponentViewer {
        private BoostedClassifierViewer() {
        }

        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            JPanel jPanel = new JPanel();
            jPanel.setLayout(new GridBagLayout());
            int i = 0;
            for (Classifier classifier : ((BoostedClassifier) obj).classifiers) {
                GridBagConstraints gridBagConstraints = new GridBagConstraints();
                gridBagConstraints.fill = 2;
                gridBagConstraints.weighty = 0.0d;
                gridBagConstraints.weightx = 0.0d;
                gridBagConstraints.gridx = 0;
                int i2 = i;
                i++;
                gridBagConstraints.gridy = i2;
                Viewer gui = classifier instanceof Visible ? ((Visible) classifier).toGUI() : new VanillaViewer(classifier);
                gui.setSuperView(this);
                jPanel.add(gui, gridBagConstraints);
            }
            JScrollPane jScrollPane = new JScrollPane(jPanel);
            jScrollPane.setHorizontalScrollBarPolicy(30);
            return jScrollPane;
        }

        BoostedClassifierViewer(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/AdaBoost$L.class */
    public static class L extends AdaBoost {
        public L() {
        }

        public L(BatchClassifierLearner batchClassifierLearner, int i) {
            super(batchClassifierLearner, i);
        }

        @Override // edu.cmu.minorthird.classify.algorithms.trees.AdaBoost
        protected double discountFactor(double d, double d2) {
            return 1.0d + Math.exp(d * d2);
        }
    }

    public AdaBoost() {
        this(new DecisionTreeLearner(), 10);
    }

    public AdaBoost(BatchClassifierLearner batchClassifierLearner, int i) {
        this.maxRounds = 100;
        this.baseLearner = batchClassifierLearner;
        this.maxRounds = i;
    }

    public int getMaxRounds() {
        return this.maxRounds;
    }

    public void setMaxRounds(int i) {
        this.maxRounds = i;
    }

    public BatchClassifierLearner getBaseLearner() {
        return this.baseLearner;
    }

    public void setBaseLearner(BatchClassifierLearner batchClassifierLearner) {
        this.baseLearner = batchClassifierLearner;
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        BasicDataset basicDataset = new BasicDataset();
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            basicDataset.add(new Example(nextExample.asInstance(), nextExample.getLabel()));
        }
        ArrayList arrayList = new ArrayList(this.maxRounds);
        ProgressCounter progressCounter = new ProgressCounter("boosting", "round", this.maxRounds);
        for (int i = 0; i < this.maxRounds; i++) {
            log.info(new StringBuffer().append("Adaboost is starting round ").append(i + 1).append("/").append(this.maxRounds).toString());
            log.info(new StringBuffer().append("Learning classifier with ").append(this.baseLearner).toString());
            BinaryClassifier binaryClassifier = (BinaryClassifier) this.baseLearner.batchTrain(basicDataset);
            arrayList.add(binaryClassifier);
            if (log.isDebugEnabled()) {
                log.debug(new StringBuffer().append("classifier is ").append(binaryClassifier).toString());
            }
            log.info("Generating new distribution");
            double d = 0.0d;
            Example.Looper it2 = basicDataset.iterator();
            while (it2.hasNext()) {
                Example nextExample2 = it2.nextExample();
                nextExample2.setWeight(nextExample2.getWeight() / discountFactor(nextExample2.getLabel().numericLabel(), binaryClassifier.score(nextExample2)));
                d += nextExample2.getWeight();
            }
            Example.Looper it3 = basicDataset.iterator();
            while (it3.hasNext()) {
                Example nextExample3 = it3.nextExample();
                nextExample3.setWeight(nextExample3.getWeight() / d);
            }
            progressCounter.progress();
        }
        progressCounter.finished();
        return new BoostedClassifier(arrayList);
    }

    protected double discountFactor(double d, double d2) {
        return Math.exp(d * d2);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$algorithms$trees$AdaBoost == null) {
            cls = class$("edu.cmu.minorthird.classify.algorithms.trees.AdaBoost");
            class$edu$cmu$minorthird$classify$algorithms$trees$AdaBoost = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$algorithms$trees$AdaBoost;
        }
        log = Logger.getLogger(cls);
    }
}
