package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/BalancedWinnow.class */
public class BalancedWinnow extends OnlineBinaryClassifierLearner implements Serializable {
    private Hyperplane pos_t;
    private Hyperplane neg_t;
    private Hyperplane vpos_t;
    private Hyperplane vneg_t;
    private double theta;
    private double alpha;
    private double beta;
    private int excount;
    private int votedCount;
    private double margin;
    private boolean voted;
    private double W_MAX;
    private double W_MIN;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/BalancedWinnow$MyClassifier.class */
    public class MyClassifier implements Classifier, Serializable, Visible {
        private static final long serialVersionUID = 1;
        private final int CURRENT_SERIAL_VERSION = 1;
        private Hyperplane pos_h;
        private Hyperplane neg_h;
        private ExampleSchema schema;
        private double mytheta;
        private final BalancedWinnow this$0;

        public MyClassifier(BalancedWinnow balancedWinnow, Hyperplane hyperplane, Hyperplane hyperplane2, double d) {
            this.this$0 = balancedWinnow;
            this.pos_h = hyperplane;
            this.neg_h = hyperplane2;
            this.mytheta = d;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            Instance asInstance = Winnow.normalizeWeights(filterFeat(new Example(instance, new ClassLabel(ExampleSchema.POS_CLASS_NAME))), true).asInstance();
            double score = (this.pos_h.score(asInstance) - this.neg_h.score(asInstance)) - this.mytheta;
            return score >= 0.0d ? ClassLabel.positiveLabel(score) : ClassLabel.negativeLabel(score);
        }

        public Example filterFeat(Example example) {
            MutableInstance mutableInstance = new MutableInstance();
            Feature.Looper featureIterator = example.asInstance().featureIterator();
            while (featureIterator.hasNext()) {
                Feature nextFeature = featureIterator.nextFeature();
                if (this.pos_h.hasFeature(nextFeature)) {
                    mutableInstance.addNumeric(nextFeature, example.getWeight(nextFeature));
                }
            }
            return new Example(mutableInstance, example.getLabel());
        }

        public String toString() {
            return new StringBuffer().append("POS = ").append(this.pos_h.toString()).append("\nNEG = ").append(this.neg_h.toString()).toString();
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            return "BalancedWinnow: Not implemented yet";
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            return new Explanation(new Explanation.Node("BalancedWinnow Explanation"));
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            TransformedViewer transformedViewer = new TransformedViewer(this, new SmartVanillaViewer()) { // from class: edu.cmu.minorthird.classify.algorithms.linear.BalancedWinnow.MyClassifier.1
                private final MyClassifier this$1;

                {
                    this.this$1 = this;
                }

                @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                public Object transform(Object obj) {
                    return ((MyClassifier) obj).pos_h;
                }
            };
            transformedViewer.setContent(this);
            return transformedViewer;
        }
    }

    public BalancedWinnow() {
        this(1.5d, 0.5d, false);
    }

    public BalancedWinnow(double d, double d2, boolean z) {
        this.theta = 1.0d;
        this.margin = 0.0d;
        this.voted = false;
        this.W_MAX = Math.pow(2.0d, 200.0d);
        this.W_MIN = 1.0d / Math.pow(2.0d, 200.0d);
        if (d < 1.0d || d2 < 0.0d || d2 > 1.0d) {
            System.out.println("Error in BalancedWinnow initial parameters");
            System.out.println("This should never happen: (theta<0)||(alpha < 1)||(beta<0)||(beta>1)");
            System.exit(0);
        }
        this.alpha = d;
        this.beta = d2;
        this.voted = z;
        reset();
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void reset() {
        this.pos_t = new Hyperplane();
        this.neg_t = new Hyperplane();
        this.excount = 0;
        if (this.voted) {
            this.votedCount = 0;
            this.vpos_t = new Hyperplane();
            this.vneg_t = new Hyperplane();
        }
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        this.excount++;
        Example normalizeWeights = Winnow.normalizeWeights(example, true);
        Feature.Looper featureIterator = normalizeWeights.asInstance().featureIterator();
        while (featureIterator.hasNext()) {
            Feature nextFeature = featureIterator.nextFeature();
            if (!this.pos_t.hasFeature(nextFeature)) {
                this.pos_t.increment(nextFeature, 2.0d);
                this.neg_t.increment(nextFeature, 1.0d);
            }
        }
        if (normalizeWeights.getLabel().numericLabel() * localscore(normalizeWeights.asInstance()) > this.margin) {
            this.votedCount++;
            return;
        }
        if (this.voted) {
            if (this.votedCount == 0) {
                updateVotedHyperplane(1.0d);
            } else {
                updateVotedHyperplane(this.votedCount);
            }
            this.votedCount = 1;
        }
        if (normalizeWeights.getLabel().isPositive()) {
            Feature.Looper featureIterator2 = normalizeWeights.featureIterator();
            while (featureIterator2.hasNext()) {
                Feature nextFeature2 = featureIterator2.nextFeature();
                if (this.pos_t.featureScore(nextFeature2) < this.W_MAX) {
                    this.pos_t.multiply(nextFeature2, this.alpha);
                }
                if (this.neg_t.featureScore(nextFeature2) > this.W_MIN) {
                    this.neg_t.multiply(nextFeature2, this.beta);
                }
            }
            return;
        }
        Feature.Looper featureIterator3 = normalizeWeights.featureIterator();
        while (featureIterator3.hasNext()) {
            Feature nextFeature3 = featureIterator3.nextFeature();
            if (this.pos_t.featureScore(nextFeature3) > this.W_MIN) {
                this.pos_t.multiply(nextFeature3, this.beta);
            }
            if (this.neg_t.featureScore(nextFeature3) < this.W_MAX) {
                this.neg_t.multiply(nextFeature3, this.alpha);
            }
        }
    }

    public void updateVotedHyperplane(double d) {
        this.vpos_t.increment(this.pos_t, d);
        this.vneg_t.increment(this.neg_t, d);
        this.votedCount = 0;
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        if (!this.voted) {
            return new MyClassifier(this, this.pos_t, this.neg_t, this.theta);
        }
        updateVotedHyperplane(this.votedCount);
        Hyperplane hyperplane = new Hyperplane();
        Hyperplane hyperplane2 = new Hyperplane();
        hyperplane.increment(this.vpos_t, 1.0d / this.excount);
        hyperplane2.increment(this.vneg_t, 1.0d / this.excount);
        return new MyClassifier(this, hyperplane, hyperplane2, this.theta);
    }

    public double localscore(Instance instance) {
        return (this.pos_t.score(instance) - this.neg_t.score(instance)) - this.theta;
    }

    public String toString() {
        return new StringBuffer().append("BalancedWinnow, voted=").append(this.voted).toString();
    }

    public static void main(String[] strArr) {
        BalancedWinnow balancedWinnow = new BalancedWinnow();
        ClassLabel positiveLabel = ClassLabel.positiveLabel(1.0d);
        MutableInstance mutableInstance = new MutableInstance();
        mutableInstance.addNumeric(new Feature("f2"), 2.0d);
        mutableInstance.addNumeric(new Feature("f3"), 3.0d);
        mutableInstance.addNumeric(new Feature("f4"), 4.0d);
        balancedWinnow.addExample(new Example(mutableInstance, positiveLabel));
        System.out.println(new StringBuffer().append("BWinnow Hyperplane = ").append(balancedWinnow.getClassifier().toString()).toString());
        ClassLabel negativeLabel = ClassLabel.negativeLabel(-1.0d);
        MutableInstance mutableInstance2 = new MutableInstance();
        mutableInstance2.addNumeric(new Feature("f3"), 1.0d);
        mutableInstance2.addNumeric(new Feature("f4"), 2.0d);
        mutableInstance2.addNumeric(new Feature("f5"), 3.0d);
        balancedWinnow.addExample(new Example(mutableInstance2, negativeLabel));
        System.out.println(new StringBuffer().append("BalancedWinnow Hyperplane = ").append(balancedWinnow.getClassifier().toString()).toString());
        ClassLabel positiveLabel2 = ClassLabel.positiveLabel(1.0d);
        MutableInstance mutableInstance3 = new MutableInstance();
        mutableInstance3.addNumeric(new Feature("f3"), -5.0d);
        mutableInstance3.addNumeric(new Feature("f4"), -12.0d);
        mutableInstance3.addNumeric(new Feature("f5"), -34.0d);
        balancedWinnow.addExample(new Example(mutableInstance3, positiveLabel2));
        System.out.println(new StringBuffer().append("BalancedWinnow Hyperplane = ").append(balancedWinnow.getClassifier().toString()).toString());
        ClassLabel positiveLabel3 = ClassLabel.positiveLabel(1.0d);
        MutableInstance mutableInstance4 = new MutableInstance();
        mutableInstance4.addNumeric(new Feature("f3"), -5.0d);
        mutableInstance4.addNumeric(new Feature("f4"), -12.0d);
        mutableInstance4.addNumeric(new Feature("f5"), -34.0d);
        mutableInstance.addNumeric(new Feature("f2"), -2.0d);
        balancedWinnow.addExample(new Example(mutableInstance4, positiveLabel3));
        System.out.println(new StringBuffer().append("BWinnow Hyperplane = ").append(balancedWinnow.getClassifier().toString()).toString());
    }
}
