package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/PoissonLearner.class */
public class PoissonLearner extends BatchBinaryClassifierLearner {
    private static Logger log;
    private static final boolean LOG = true;
    private double SCALE;
    static Class class$edu$cmu$minorthird$classify$algorithms$linear$PoissonLearner;

    public PoissonLearner() {
        this.SCALE = 10.0d;
        reset();
    }

    public PoissonLearner(double d) {
        this.SCALE = d;
        reset();
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        BasicFeatureIndex basicFeatureIndex = new BasicFeatureIndex(dataset);
        PoissonClassifier poissonClassifier = new PoissonClassifier();
        poissonClassifier.setScale(this.SCALE);
        double d = 0.0d;
        double d2 = 0.0d;
        Feature.Looper featureIterator = basicFeatureIndex.featureIterator();
        while (featureIterator.hasNext()) {
            Feature nextFeature = featureIterator.nextFeature();
            for (int i = 0; i < basicFeatureIndex.size(nextFeature); i++) {
                Example example = basicFeatureIndex.getExample(nextFeature, i);
                if (example.getLabel().isPositive()) {
                    d += example.getWeight(nextFeature);
                } else {
                    d2 += example.getWeight(nextFeature);
                }
            }
        }
        double numberOfFeatures = 1.0d / basicFeatureIndex.numberOfFeatures();
        Feature.Looper featureIterator2 = basicFeatureIndex.featureIterator();
        while (featureIterator2.hasNext()) {
            Feature nextFeature2 = featureIterator2.nextFeature();
            double counts = basicFeatureIndex.getCounts(nextFeature2, ExampleSchema.POS_CLASS_NAME);
            double counts2 = basicFeatureIndex.getCounts(nextFeature2, ExampleSchema.NEG_CLASS_NAME);
            poissonClassifier.increment(nextFeature2, (-estimatedProb(counts, d / poissonClassifier.getScale(), numberOfFeatures, 1.0d / poissonClassifier.getScale())) + estimatedProb(counts2, d2 / poissonClassifier.getScale(), numberOfFeatures, 1.0d / poissonClassifier.getScale()));
            poissonClassifier.increment(nextFeature2, estimatedProb(counts, d / poissonClassifier.getScale(), numberOfFeatures, 1.0d / poissonClassifier.getScale(), true) - estimatedProb(counts2, d2 / poissonClassifier.getScale(), numberOfFeatures, 1.0d / poissonClassifier.getScale(), true), true);
        }
        poissonClassifier.incrementBias(estimatedProb(d, d + d2, 0.5d, 1.0d, true));
        poissonClassifier.incrementBias(-estimatedProb(d2, d + d2, 0.5d, 1.0d, true));
        return poissonClassifier;
    }

    private double estimatedProb(double d, double d2, double d3, double d4) {
        return (d + (d3 * d4)) / (d2 + d4);
    }

    private double estimatedProb(double d, double d2, double d3, double d4, boolean z) {
        return Math.log((d + (d3 * d4)) / (d2 + d4));
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$algorithms$linear$PoissonLearner == null) {
            cls = class$("edu.cmu.minorthird.classify.algorithms.linear.PoissonLearner");
            class$edu$cmu$minorthird$classify$algorithms$linear$PoissonLearner = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$algorithms$linear$PoissonLearner;
        }
        log = Logger.getLogger(cls);
    }
}
