package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import java.io.Serializable;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/PassiveAggressiveLearner.class */
public class PassiveAggressiveLearner extends OnlineBinaryClassifierLearner implements Serializable {
    private Hyperplane pos_t;
    private Hyperplane vpos_t;
    private double eta;
    private double gamma;
    private int excount;
    private int numActiveFeatures;
    private boolean voted;

    public PassiveAggressiveLearner() {
        this(1.0d, 0.1d, true);
    }

    public PassiveAggressiveLearner(double d, double d2, boolean z) {
        this.eta = d;
        this.gamma = d2;
        this.voted = z;
        reset();
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void reset() {
        this.pos_t = new Hyperplane();
        if (this.voted) {
            this.vpos_t = new Hyperplane();
        }
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        this.excount++;
        double numericLabel = example.getLabel().numericLabel();
        double score = this.pos_t.score(example.asInstance());
        if (numericLabel * score < this.eta) {
            this.pos_t.increment(example.asInstance(), numericLabel * ((this.eta - (numericLabel * score)) / (getNormSquared(example.asInstance()) + this.gamma)));
        }
        if (this.voted) {
            this.vpos_t.increment(this.pos_t, 1.0d);
        }
    }

    public double getNormSquared(Instance instance) {
        double d = 0.0d;
        Feature.Looper featureIterator = instance.featureIterator();
        while (featureIterator.hasNext()) {
            double weight = instance.getWeight(featureIterator.nextFeature());
            d += weight * weight;
        }
        return d;
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        Hyperplane hyperplane = new Hyperplane();
        if (this.voted) {
            hyperplane.increment(this.vpos_t);
        } else {
            hyperplane.increment(this.pos_t);
        }
        return hyperplane;
    }

    public String toString() {
        return "PassiveAggressive Algorithm";
    }

    public static void main(String[] strArr) {
        PassiveAggressiveLearner passiveAggressiveLearner = new PassiveAggressiveLearner();
        ClassLabel positiveLabel = ClassLabel.positiveLabel(1.0d);
        MutableInstance mutableInstance = new MutableInstance();
        mutableInstance.addNumeric(new Feature("f2"), 2.0d);
        mutableInstance.addNumeric(new Feature("f3"), 3.0d);
        mutableInstance.addNumeric(new Feature("f4"), 4.0d);
        passiveAggressiveLearner.addExample(new Example(mutableInstance, positiveLabel));
        System.out.println(new StringBuffer().append("Winnow Hyperplane = ").append(passiveAggressiveLearner.getClassifier().toString()).toString());
        ClassLabel negativeLabel = ClassLabel.negativeLabel(-1.0d);
        MutableInstance mutableInstance2 = new MutableInstance();
        mutableInstance2.addNumeric(new Feature("f3"), 1.0d);
        mutableInstance2.addNumeric(new Feature("f4"), 2.0d);
        mutableInstance2.addNumeric(new Feature("f5"), 3.0d);
        passiveAggressiveLearner.addExample(new Example(mutableInstance2, negativeLabel));
        System.out.println(new StringBuffer().append("Winnow Hyperplane = ").append(passiveAggressiveLearner.getClassifier().toString()).toString());
        ClassLabel positiveLabel2 = ClassLabel.positiveLabel(1.0d);
        MutableInstance mutableInstance3 = new MutableInstance();
        mutableInstance3.addNumeric(new Feature("f3"), -5.0d);
        mutableInstance3.addNumeric(new Feature("f4"), -12.0d);
        mutableInstance3.addNumeric(new Feature("f5"), -34.0d);
        passiveAggressiveLearner.addExample(new Example(mutableInstance3, positiveLabel2));
        System.out.println(new StringBuffer().append("Winnow Hyperplane = ").append(passiveAggressiveLearner.getClassifier().toString()).toString());
        ClassLabel positiveLabel3 = ClassLabel.positiveLabel(1.0d);
        MutableInstance mutableInstance4 = new MutableInstance();
        mutableInstance4.addNumeric(new Feature("f3"), -5.0d);
        mutableInstance4.addNumeric(new Feature("f4"), -12.0d);
        mutableInstance4.addNumeric(new Feature("f5"), -34.0d);
        mutableInstance.addNumeric(new Feature("f2"), -2.0d);
        passiveAggressiveLearner.addExample(new Example(mutableInstance4, positiveLabel3));
        System.out.println(new StringBuffer().append("Winnow Hyperplane = ").append(passiveAggressiveLearner.getClassifier().toString()).toString());
    }
}
