package edu.cmu.minorthird.classify.algorithms.knn;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.DatasetIndex;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.MathUtil;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.TreeSet;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/knn/KnnClassifier.class */
class KnnClassifier implements Classifier, Serializable {
    private static final long serialVersionUID = 1;
    private final int CURRENT_VERSION_NUMBER = 1;
    private static Logger log;
    private static final boolean DEBUG;
    private DatasetIndex index;
    private ExampleSchema schema;
    private int k;
    static Class class$edu$cmu$minorthird$classify$algorithms$knn$KnnClassifier;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/knn/KnnClassifier$Neighbor.class */
    private static class Neighbor implements Comparable {
        Example e;
        double sim;

        public Neighbor(Example example, double d) {
            this.e = example;
            this.sim = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(Object obj) {
            return MathUtil.sign(((Neighbor) obj).sim - this.sim);
        }
    }

    public KnnClassifier(DatasetIndex datasetIndex, ExampleSchema exampleSchema, int i) {
        this.index = datasetIndex;
        this.schema = exampleSchema;
        this.k = i;
        if (DEBUG) {
            log.debug(new StringBuffer().append("knn classifier for index:\n").append(datasetIndex).toString());
        }
    }

    @Override // edu.cmu.minorthird.classify.Classifier
    public ClassLabel classification(Instance instance) {
        if (DEBUG) {
            log.debug(new StringBuffer().append("classifying: ").append(instance).toString());
        }
        TreeSet treeSet = new TreeSet();
        Example.Looper neighbors = this.index.getNeighbors(instance);
        while (neighbors.hasNext()) {
            Example nextExample = neighbors.nextExample();
            treeSet.add(new Neighbor(nextExample, computeSimilarity(instance, nextExample)));
        }
        double d = 0.0d;
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator it = treeSet.iterator();
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= this.k || !it.hasNext()) {
                break;
            }
            Neighbor neighbor = (Neighbor) it.next();
            String bestClassName = neighbor.e.getLabel().bestClassName();
            double weight = neighbor.e.getWeight() * neighbor.sim;
            Double d2 = (Double) hashMap.get(bestClassName);
            if (d2 == null) {
                Double d3 = new Double(0.0d);
                d2 = d3;
                hashMap.put(bestClassName, d3);
            }
            hashMap.put(bestClassName, new Double(d2.doubleValue() + weight));
            d += weight;
            if (DEBUG) {
                log.debug(new StringBuffer().append("neighbor: ").append(neighbor.e).append(" distance: ").append(neighbor.sim).append(" weight: ").append(weight).append(" count[").append(bestClassName).append("]: ").append(hashMap.get(bestClassName)).toString());
            }
        }
        ClassLabel classLabel = new ClassLabel();
        for (String str : hashMap.keySet()) {
            double doubleValue = ((Double) hashMap.get(str)).doubleValue();
            classLabel.add(str, Math.log((doubleValue / d) + 0.001d) - Math.log(((d - doubleValue) / d) + 0.001d));
        }
        return classLabel;
    }

    @Override // edu.cmu.minorthird.classify.Classifier
    public String explain(Instance instance) {
        return "not implemented";
    }

    private double computeSimilarity(Instance instance, Instance instance2) {
        double d = 0.0d;
        double d2 = 0.0d;
        Feature.Looper featureIterator = instance.featureIterator();
        while (featureIterator.hasNext()) {
            Feature nextFeature = featureIterator.nextFeature();
            double weight = instance.getWeight(nextFeature);
            d += weight * weight;
            d2 += weight * instance2.getWeight(nextFeature);
        }
        double d3 = 0.0d;
        Feature.Looper featureIterator2 = instance2.featureIterator();
        while (featureIterator2.hasNext()) {
            double weight2 = instance2.getWeight(featureIterator2.nextFeature());
            d3 += weight2 * weight2;
        }
        return d2 / (Math.sqrt(d) * Math.sqrt(d3));
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$algorithms$knn$KnnClassifier == null) {
            cls = class$("edu.cmu.minorthird.classify.algorithms.knn.KnnClassifier");
            class$edu$cmu$minorthird$classify$algorithms$knn$KnnClassifier = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$algorithms$knn$KnnClassifier;
        }
        log = Logger.getLogger(cls);
        DEBUG = log.isDebugEnabled();
    }
}
