package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;
import junit.framework.Test;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_train;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/LibsvmTest.class */
public class LibsvmTest extends AbstractClassificationChecks {
    Logger log;
    private static final String trainFile = "testData/a1a.dat";
    private static final String model = "modelFile.dat";
    private static final String testFile = "testData/a1a.t.dat";
    static Class class$edu$cmu$minorthird$classify$LibsvmTest;

    public LibsvmTest(String str) {
        super(str);
        this.log = Logger.getLogger(getClass());
    }

    public LibsvmTest() {
        super("LibsvmTest");
        this.log = Logger.getLogger(getClass());
    }

    @Override // junit.framework.TestCase
    protected void setUp() {
        Logger.getRootLogger().removeAllAppenders();
        BasicConfigurator.configure();
        this.log.setLevel(Level.DEBUG);
        super.setCheckStandards(false);
    }

    @Override // junit.framework.TestCase
    protected void tearDown() {
    }

    public void testDirectCode() {
        this.log.debug("start");
        try {
            svm_train.main(new String[]{"-t", "0", trainFile, model});
            this.log.debug("trained, sent model to: modelFile.dat");
            double[] prediction = prediction(new String[]{testFile, model, "results.dat"});
            this.log.debug("ran predict on testfile");
            checkStats(prediction, new double[]{0.8352766230693838d, 0.6588935077224645d, 0.28752077092970524d});
        } catch (Exception e) {
            this.log.error(e, e);
            fail("exception");
        }
    }

    public void testWrapper() {
        try {
            super.setCheckStandards(true);
            super.checkClassify(new SVMLearner(), DatasetLoader.loadSVMStyle(new File(trainFile)), DatasetLoader.loadSVMStyle(new File(testFile)), new double[]{0.16472337693061612d, 0.5532531341004251d, 0.6413123436810357d, 1.3132616875183545d});
        } catch (Exception e) {
            this.log.error(e, e);
        }
    }

    public void testSampleData() {
        super.checkClassify(new SVMLearner(), SampleDatasets.toyTrain(), SampleDatasets.toyTest(), new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 1.0d, 1.0d, 1.3132616875182228d, 1.0d, 1.0d, 1.0d, 1.0d});
    }

    public static Test suite() {
        Class cls;
        if (class$edu$cmu$minorthird$classify$LibsvmTest == null) {
            cls = class$("edu.cmu.minorthird.classify.LibsvmTest");
            class$edu$cmu$minorthird$classify$LibsvmTest = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$LibsvmTest;
        }
        return new TestSuite(cls);
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }

    private double[] predict(BufferedReader bufferedReader, DataOutputStream dataOutputStream, svm_model svm_modelVar) throws IOException {
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                this.log.debug(new StringBuffer().append("Accuracy = ").append((i / i2) * 100.0d).append("% (").append(i).append("/").append(i2).append(") (classification)\n").toString());
                this.log.debug(new StringBuffer().append("Mean squared error = ").append(d / i2).append(" (regression)\n").toString());
                this.log.debug(new StringBuffer().append("Squared correlation coefficient = ").append((((i2 * d6) - (d2 * d3)) * ((i2 * d6) - (d2 * d3))) / (((i2 * d4) - (d2 * d2)) * ((i2 * d5) - (d3 * d3)))).append(" (regression)\n").toString());
                return new double[]{i / i2, d / i2, (((i2 * d6) - (d2 * d3)) * ((i2 * d6) - (d2 * d3))) / (((i2 * d4) - (d2 * d2)) * ((i2 * d5) - (d3 * d3)))};
            }
            StringTokenizer stringTokenizer = new StringTokenizer(readLine, " \t\n\r\f:");
            double atof = atof(stringTokenizer.nextToken());
            int countTokens = stringTokenizer.countTokens() / 2;
            svm_node[] svm_nodeVarArr = new svm_node[countTokens];
            for (int i3 = 0; i3 < countTokens; i3++) {
                svm_nodeVarArr[i3] = new svm_node();
                svm_nodeVarArr[i3].index = atoi(stringTokenizer.nextToken());
                svm_nodeVarArr[i3].value = atof(stringTokenizer.nextToken());
            }
            double svm_predict = svm.svm_predict(svm_modelVar, svm_nodeVarArr);
            if (svm_predict == atof) {
                i++;
            }
            d += (svm_predict - atof) * (svm_predict - atof);
            d2 += svm_predict;
            d3 += atof;
            d4 += svm_predict * svm_predict;
            d5 += atof * atof;
            d6 += svm_predict * atof;
            i2++;
        }
    }

    private double[] prediction(String[] strArr) throws IOException {
        if (strArr.length != 3) {
            System.err.print("usage: svm-predict test_file model_file output_file\n");
            System.exit(1);
        }
        return predict(new BufferedReader(new FileReader(strArr[0])), new DataOutputStream(new FileOutputStream(strArr[2])), svm.svm_load_model(strArr[1]));
    }

    private static double atof(String str) {
        return Double.valueOf(str).doubleValue();
    }

    private static int atoi(String str) {
        return Integer.parseInt(str);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }
}
