/*
 * Decompiled with CFR 0.152.
 */
package kr.ac.kaist.itcknow.bigml.algo.multicore.sgd;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import kr.ac.kaist.itcknow.bigml.algo.hadoop.DSGD;
import kr.ac.kaist.itcknow.bigml.util.MersenneTwisterFast;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.PosixParser;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;

public class LRSGD {
    public static Logger logger = Logger.getLogger(DSGD.class);
    public boolean isHadoopGeneratedData = false;
    public String dataDeliminator = " ";
    public boolean isPrintParams = false;
    public int dim = 0;
    public int num_label = 0;
    public int num_instance;
    public HashMap<String, Integer> label_index = new HashMap();
    public HashMap<String, Integer> feature_index = new HashMap();
    public HashSet<Data> data_set = new HashSet();
    public double[][] etas;
    public String expt;
    public String input_file;
    public int iterations;
    public double epsilon;
    public double lamda;
    public long start_time;
    public int process_id;
    public int thread_id;
    public double chance;
    public int interval;
    public String test_file;
    public static Options options = new Options();
    public static String opt_expt = "experiment_name";
    public static String opt_input_file = "input_file";
    public static String opt_iterations = "iterations";
    public static String opt_epsilon = "epsilon";
    public static String opt_process_id = "process_id";
    public static String opt_thread_id = "thread_id";
    public static String opt_chance = "chance";
    public static String opt_interval = "interval";
    public static String opt_start_time = "start_time";
    public static String opt_test_file = "test_file";
    public static String opt_lamda = "lamda";

    public int getIndex(HashMap<String, Integer> map, String key) {
        if (!map.containsKey(key)) {
            map.put(key, map.size());
        }
        return map.get(key);
    }

    public int getFeatureIndex(String key) {
        return Integer.parseInt(key) - 1;
    }

    public void readData() throws Exception {
        String[] tokens;
        int startIndex;
        String line;
        int nInstance = 0;
        BufferedReader br = new BufferedReader(new FileReader(this.input_file));
        while ((line = br.readLine()) != null) {
            ++nInstance;
            startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            tokens = line.split(this.dataDeliminator);
            this.getIndex(this.label_index, tokens[startIndex]);
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                String[] toks = tokens[i].split(":");
                this.getIndex(this.feature_index, toks[0]);
            }
        }
        br.close();
        this.num_label = this.label_index.size();
        this.dim = this.feature_index.size();
        this.num_instance = nInstance;
        br = new BufferedReader(new FileReader(this.input_file));
        while ((line = br.readLine()) != null) {
            startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            tokens = line.split(this.dataDeliminator);
            int y = this.getIndex(this.label_index, tokens[startIndex]);
            double[] feature_vector = new double[this.dim];
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                double x;
                String[] toks = tokens[i].split(":");
                int feature = this.getIndex(this.feature_index, toks[0]);
                feature_vector[feature] = x = Double.parseDouble(toks[1]);
            }
            this.data_set.add(new Data(y, feature_vector));
        }
        br.close();
    }

    public void readData(int scale) throws Exception {
        String line;
        int nInstance = 0;
        BufferedReader br = new BufferedReader(new FileReader(this.input_file));
        while ((line = br.readLine()) != null) {
            ++nInstance;
            int startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            String[] tokens = line.split(this.dataDeliminator);
            this.getIndex(this.label_index, tokens[startIndex]);
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                String[] toks = tokens[i].split(":");
                this.getIndex(this.feature_index, toks[0]);
            }
        }
        br.close();
        this.num_label = this.label_index.size();
        this.dim = this.feature_index.size();
        this.num_instance = nInstance;
        br = new BufferedReader(new FileReader(this.input_file));
        int lineCnt = 0;
        while ((line = br.readLine()) != null && lineCnt++ <= scale) {
            int startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            String[] tokens = line.split(this.dataDeliminator);
            int y = this.getIndex(this.label_index, tokens[startIndex]);
            double[] feature_vector = new double[this.dim];
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                double x;
                String[] toks = tokens[i].split(":");
                int feature = this.getIndex(this.feature_index, toks[0]);
                feature_vector[feature] = x = Double.parseDouble(toks[1]);
            }
            this.data_set.add(new Data(y, feature_vector));
        }
        br.close();
    }

    public void readParams(String params_file) throws Exception {
        BufferedReader br = new BufferedReader(new FileReader(params_file));
        this.num_label = Integer.parseInt(br.readLine());
        this.dim = Integer.parseInt(br.readLine());
        this.etas = new double[this.num_label - 1][this.dim];
        for (int i = 0; i < this.num_label; ++i) {
            String line = br.readLine();
            String[] tokens = line.split("\t");
            if (tokens.length <= 1) continue;
            int c = this.getIndex(this.label_index, tokens[0]);
            for (int j = 1; j < tokens.length; ++j) {
                String[] toks = tokens[j].split(":");
                int d = this.getFeatureIndex(toks[0]);
                this.etas[c][d] = Double.parseDouble(toks[1]);
            }
        }
        br.close();
    }

    public void readParams(InputStreamReader inputStreamReader) throws Exception {
        BufferedReader br = new BufferedReader(inputStreamReader);
        this.num_label = Integer.parseInt(br.readLine());
        this.dim = Integer.parseInt(br.readLine());
        this.etas = new double[this.num_label - 1][this.dim];
        for (int i = 0; i < this.num_label; ++i) {
            String line = br.readLine();
            String[] tokens = line.split("\t");
            if (tokens.length <= 1) continue;
            int c = this.getIndex(this.label_index, tokens[0]);
            for (int j = 1; j < tokens.length; ++j) {
                String[] toks = tokens[j].split(":");
                int d = this.getFeatureIndex(toks[0]);
                this.etas[c][d] = Double.parseDouble(toks[1]);
            }
        }
        br.close();
    }

    public void mergeParams(String params_path) throws Exception {
        ArrayList<String> partFilenameList = new ArrayList<String>();
        File folder = new File(params_path);
        File[] listOfFiles = folder.listFiles();
        for (int i = 0; i < listOfFiles.length; ++i) {
            String files;
            if (!listOfFiles[i].isFile() || !(files = listOfFiles[i].getName()).startsWith("part-")) continue;
            partFilenameList.add(files);
        }
        int c = 0;
        int mapCnt = 0;
        this.etas = new double[this.num_label - 1][this.dim];
        for (String partFilename : partFilenameList) {
            String line;
            BufferedReader br = new BufferedReader(new FileReader(params_path + partFilename));
            while ((line = br.readLine()) != null) {
                String[] tokens = line.split("\t");
                if (tokens.length <= 1) continue;
                ++mapCnt;
                String dIndex = tokens[0];
                String mapInfo = tokens[1];
                for (int j = 1; j < tokens.length; ++j) {
                    String[] toks = tokens[j].split(":");
                    if (toks.length != 2) continue;
                    int d = this.getFeatureIndex(toks[0]);
                    double[] dArray = this.etas[c];
                    int n = d;
                    dArray[n] = dArray[n] + Double.parseDouble(toks[1]);
                }
            }
            br.close();
        }
        for (int d = 0; d < this.dim; ++d) {
            this.etas[c][d] = this.etas[c][d] / (double)mapCnt;
        }
    }

    protected static double updateFeature(HashMap<String, Double> eta, String feature_string, double value) {
        double feature_value = LRSGD.getFeature(eta, feature_string);
        if ((feature_value += value) == 0.0) {
            eta.remove(feature_string);
        } else {
            eta.put(feature_string, feature_value);
        }
        return feature_value;
    }

    public static double getFeature(HashMap<String, Double> eta, String feature_string) {
        if (!eta.containsKey(feature_string)) {
            return 0.0;
        }
        return eta.get(feature_string);
    }

    public double[][] getParamFormatDouble() {
        return this.etas;
    }

    public String getParamFormatString() {
        String paramStr = "";
        for (Map.Entry<String, Integer> entry : this.label_index.entrySet()) {
            if (entry.getValue() == this.num_label - 1) continue;
            for (int d = 0; d < this.dim; ++d) {
                paramStr = paramStr + String.format("\t%s:%f", "" + (d + 1), this.etas[entry.getValue()][d]);
            }
        }
        return paramStr;
    }

    public void init() {
        if (this.expt.equals("1") || this.expt.equals("2") || this.expt.equals("3")) {
            this.getIndex(this.label_index, "0");
            this.getIndex(this.label_index, "1");
            this.isHadoopGeneratedData = true;
            this.dataDeliminator = "\t";
        } else if (this.expt.equals("rcv1") || this.expt.equals("arcene") || this.expt.equals("news20") || this.expt.equals("zeta") || this.expt.equals("a9")) {
            this.getIndex(this.label_index, "-1");
            this.getIndex(this.label_index, "1");
            this.isHadoopGeneratedData = false;
            this.dataDeliminator = " ";
        } else if (this.expt.equals("kdda") || this.expt.equals("kddb")) {
            this.getIndex(this.label_index, "0");
            this.getIndex(this.label_index, "1");
            this.isHadoopGeneratedData = false;
            this.dataDeliminator = " ";
        }
        this.num_label = this.label_index.size();
        this.etas = new double[this.num_label - 1][this.dim];
    }

    public void init(double[][] etas_previous) {
        for (int c = 0; c < this.num_label - 1; ++c) {
            for (int d = 0; d < this.dim; ++d) {
                this.etas[c][d] = etas_previous[c][d];
            }
        }
    }

    protected static void printParams(String filename, int num_label, int dim, double[][] etas, HashMap<String, Integer> label_index) throws Exception {
        BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
        bw.write(String.format("%d\n", num_label));
        bw.write(String.format("%d\n", dim));
        for (Map.Entry<String, Integer> entry : label_index.entrySet()) {
            bw.write(String.format("%s", entry.getKey()));
            if (entry.getValue() != num_label - 1) {
                for (int d = 0; d < dim; ++d) {
                    bw.write(String.format("\t%s:%f", "" + (d + 1), etas[entry.getValue()][d]));
                }
            }
            bw.newLine();
        }
        bw.close();
    }

    public static void printParams(String expt, int iter, int num_label, int dim, double[][] etas, HashMap<String, Integer> label_index) throws Exception {
        LRSGD.printParams(String.format("./params/%s_%d.params", expt, iter), num_label, dim, etas, label_index);
    }

    public void printParams(String filename) throws Exception {
        BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
        bw.write(String.format("%d\n", this.num_label));
        bw.write(String.format("%d\n", this.dim));
        for (Map.Entry<String, Integer> entry : this.label_index.entrySet()) {
            bw.write(String.format("%s", entry.getKey()));
            if (entry.getValue() != this.num_label - 1) {
                for (int d = 0; d < this.dim; ++d) {
                    if (this.etas[entry.getValue()][d] == 0.0) continue;
                    bw.write(String.format("\t%s:%f", "" + (d + 1), this.etas[entry.getValue()][d]));
                }
            }
            bw.newLine();
        }
        bw.close();
    }

    public void train() throws Exception {
        this.train(this.data_set, this.num_label, this.dim, this.iterations, this.epsilon, this.etas, this.label_index, this.feature_index, this.process_id, this.thread_id, this.start_time, String.format("%s_sgd_p%d", this.expt, this.process_id), this.chance, this.interval);
    }

    public void trainOneInstance(int y, double[] feature_vector) throws Exception {
        this.trainOneInstance(y, feature_vector, this.num_label, this.dim, this.etas, this.label_index);
    }

    private void trainOneInstance(int y, double[] feature_vector, int num_label, int dim, double[][] etas, HashMap<String, Integer> label_index) throws Exception {
        int c;
        double[] label_cache = new double[num_label - 1];
        double denominator = 1.0;
        for (c = 0; c < num_label - 1; ++c) {
            double denominator_exp = 0.0;
            for (int d = 0; d < dim; ++d) {
                denominator_exp += etas[c][d] * feature_vector[d];
            }
            label_cache[c] = denominator_exp;
            denominator += Math.exp(denominator_exp);
        }
        for (c = 0; c < num_label - 1; ++c) {
            double numerator_exp = label_cache[c];
            double numerator = Math.exp(numerator_exp);
            double p = numerator / denominator;
            double delta = c == y ? 1.0 : 0.0;
            for (int d = 0; d < dim; ++d) {
                double x = feature_vector[d];
                if (x == 0.0) continue;
                double gradient = x * (delta - p);
                denominator -= Math.exp(numerator_exp);
                numerator_exp -= etas[c][d] * x;
                double[] dArray = etas[c];
                int n = d;
                dArray[n] = dArray[n] + this.epsilon * gradient;
                label_cache[c] = numerator_exp += etas[c][d] * x;
                numerator = Math.exp(numerator_exp);
                p = numerator / (denominator += numerator);
            }
        }
    }

    public void trainOneInstance(Data data) throws Exception {
        this.trainOneInstance(data, this.num_label, this.dim, this.etas, this.label_index);
    }

    private void trainOneInstance(Data data, int num_label, int dim, double[][] etas, HashMap<String, Integer> label_index) throws Exception {
        int c;
        double[] label_cache = new double[num_label - 1];
        int y = data.y;
        double[] feature_vector = data.feature_vector;
        double denominator = 1.0;
        for (c = 0; c < num_label - 1; ++c) {
            double denominator_exp = 0.0;
            for (int d = 0; d < dim; ++d) {
                denominator_exp += etas[c][d] * feature_vector[d];
            }
            label_cache[c] = denominator_exp;
            denominator += Math.exp(denominator_exp);
        }
        for (c = 0; c < num_label - 1; ++c) {
            double numerator_exp = label_cache[c];
            double numerator = Math.exp(numerator_exp);
            double p = numerator / denominator;
            double delta = c == y ? 1.0 : 0.0;
            for (int d = 0; d < dim; ++d) {
                double x = feature_vector[d];
                if (x == 0.0) continue;
                double gradient = x * (delta - p);
                denominator -= Math.exp(numerator_exp);
                numerator_exp -= etas[c][d] * x;
                double[] dArray = etas[c];
                int n = d;
                dArray[n] = dArray[n] + this.epsilon * gradient;
                label_cache[c] = numerator_exp += etas[c][d] * x;
                numerator = Math.exp(numerator_exp);
                p = numerator / (denominator += numerator);
            }
        }
    }

    public void train(HashSet<Data> data_set, int num_label, int dim, int iterations, double epsilon, double[][] etas, HashMap<String, Integer> label_index, HashMap<String, Integer> feature_index, int process_id, int thread_id, long start_time, String prefix, double chance, int interval) throws Exception {
        MersenneTwisterFast rand = null;
        if (chance < 1.0) {
            rand = new MersenneTwisterFast(System.currentTimeMillis());
        }
        System.out.println(String.format("Process:%d\tThread:%d\tIterations:%d/%d\tTime:%e", process_id, thread_id, 0, iterations, (double)(System.currentTimeMillis() - start_time)));
        long overhead_end = 0L;
        long overhead_start = 0L;
        long duration = 0L;
        long current = 0L;
        if (thread_id == 0) {
            overhead_start = System.currentTimeMillis();
            current = overhead_start - start_time;
            if (this.isPrintParams) {
                LRSGD.printParams(prefix, 0, num_label, dim, etas, label_index);
            }
            overhead_end = System.currentTimeMillis();
        }
        double[] label_cache = new double[num_label - 1];
        for (int iter = 1; iter <= iterations; ++iter) {
            for (Data data : data_set) {
                int c;
                if (chance < 1.0 && chance < rand.nextDouble()) continue;
                int y = data.y;
                double[] feature_vector = data.feature_vector;
                double denominator = 1.0;
                for (c = 0; c < num_label - 1; ++c) {
                    double denominator_exp = 0.0;
                    for (int d = 0; d < dim; ++d) {
                        denominator_exp += etas[c][d] * feature_vector[d];
                    }
                    label_cache[c] = denominator_exp;
                    denominator += Math.exp(denominator_exp);
                }
                for (c = 0; c < num_label - 1; ++c) {
                    double numerator_exp = label_cache[c];
                    double numerator = Math.exp(numerator_exp);
                    double p = numerator / denominator;
                    double delta = c == y ? 1.0 : 0.0;
                    for (int d = 0; d < dim; ++d) {
                        double x = feature_vector[d];
                        if (x == 0.0) continue;
                        double gradient = x * (delta - p);
                        denominator -= Math.exp(numerator_exp);
                        numerator_exp -= etas[c][d] * x;
                        double[] dArray = etas[c];
                        int n = d;
                        dArray[n] = dArray[n] + epsilon * gradient;
                        label_cache[c] = numerator_exp += etas[c][d] * x;
                        numerator = Math.exp(numerator_exp);
                        p = numerator / (denominator += numerator);
                    }
                }
            }
            if (iter % interval != 0) continue;
            if (thread_id == 0) {
                overhead_start = System.currentTimeMillis();
                duration = overhead_start - overhead_end;
                current = duration + current;
                System.out.println(String.format("Process:%d\tThread:%d\tIterations:%d/%d\tTime:%e", process_id, thread_id, iter, iterations, (double)current));
                if (this.isPrintParams) {
                    LRSGD.printParams(prefix, iter, num_label, dim, etas, label_index);
                }
                overhead_end = System.currentTimeMillis();
                continue;
            }
            System.out.println(String.format("Process:%d\tThread:%d\tIterations:%d/%d\tTime:%e", process_id, thread_id, iter, iterations, (double)(System.currentTimeMillis() - start_time)));
        }
    }

    public static Options InitOptions() {
        options.addOption("x", opt_expt, true, "Experiment Name");
        options.addOption("f", opt_input_file, true, "Input File");
        options.addOption("i", opt_iterations, true, "Iterations");
        options.addOption("e", opt_epsilon, true, "Epsilon");
        options.addOption("p", opt_process_id, true, "Process ID");
        options.addOption("t", opt_thread_id, true, "Thread ID");
        options.addOption("c", opt_chance, true, "Chance");
        options.addOption("v", opt_interval, true, "Output Interval");
        options.addOption("s", opt_start_time, true, "Start Time");
        options.addOption("tt", opt_test_file, true, "Testing File");
        options.addOption("l", opt_lamda, true, "Lamda");
        return options;
    }

    public LRSGD() {
    }

    public LRSGD(CommandLine cli) {
        this();
        this.expt = cli.getOptionValue(opt_expt);
        this.input_file = cli.getOptionValue(opt_input_file);
        this.test_file = cli.getOptionValue(opt_test_file);
        this.iterations = Integer.parseInt(cli.getOptionValue(opt_iterations));
        this.epsilon = cli.hasOption(opt_epsilon) ? Double.parseDouble(cli.getOptionValue(opt_epsilon)) : 0.1;
        this.start_time = cli.hasOption(opt_start_time) ? Long.parseLong(cli.getOptionValue(opt_start_time)) : System.currentTimeMillis();
        this.process_id = cli.hasOption(opt_process_id) ? Integer.parseInt(cli.getOptionValue(opt_process_id)) : 0;
        this.thread_id = cli.hasOption(opt_thread_id) ? Integer.parseInt(cli.getOptionValue(opt_thread_id)) : 0;
        this.chance = cli.hasOption(opt_chance) ? Double.parseDouble(cli.getOptionValue(opt_chance)) : 1.0;
        this.interval = cli.hasOption(opt_interval) ? Integer.parseInt(cli.getOptionValue(opt_interval)) : 1;
        this.lamda = cli.hasOption(opt_lamda) ? (double)Integer.parseInt(cli.getOptionValue(opt_lamda)) : 0.01;
    }

    public double calculateLikelihoodFromFile(String test_file) throws IOException {
        String line;
        double likelihood = 0.0;
        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(test_file)));
        int cnt = 0;
        while ((line = br.readLine()) != null) {
            int startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            String[] tokens = line.split(this.dataDeliminator);
            int y = this.getIndex(this.label_index, tokens[startIndex]);
            double[] feature_vector = new double[this.dim];
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                double x;
                String[] toks = tokens[i].split(":");
                int feature = this.getFeatureIndex(toks[0]);
                feature_vector[feature] = x = Double.parseDouble(toks[1]);
            }
            likelihood += this.calculateLikelihoodOneInstance(new Data(y, feature_vector));
            if (cnt++ % 1000 != 0) continue;
            System.out.println("count: " + cnt);
        }
        br.close();
        return likelihood;
    }

    public double calculateLikelihoodSequential(FileSystem fs, String test_file) throws IOException {
        String line;
        double likelihood = 0.0;
        BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)fs.open(new Path(test_file))));
        while ((line = br.readLine()) != null) {
            int startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            String[] tokens = line.split(this.dataDeliminator);
            int y = this.getIndex(this.label_index, tokens[startIndex]);
            double[] feature_vector = new double[this.dim];
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                double x;
                String[] toks = tokens[i].split(":");
                int feature = this.getFeatureIndex(toks[0]);
                feature_vector[feature] = x = Double.parseDouble(toks[1]);
            }
            likelihood += this.calculateLikelihoodOneInstance(new Data(y, feature_vector));
        }
        br.close();
        return likelihood;
    }

    public double calculateLikelihoodSequential(FileSystem fs) throws Exception {
        String line;
        double likelihood = 0.0;
        BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)fs.open(new Path(this.input_file))));
        while ((line = br.readLine()) != null) {
            int startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            String[] tokens = line.split(this.dataDeliminator);
            int y = this.getIndex(this.label_index, tokens[startIndex]);
            double[] feature_vector = new double[this.dim];
            for (int i = startIndex + 1; i < tokens.length; ++i) {
                double x;
                String[] toks = tokens[i].split(":");
                int feature = this.getFeatureIndex(toks[0]);
                feature_vector[feature] = x = Double.parseDouble(toks[1]);
            }
            likelihood += this.calculateLikelihoodOneInstance(new Data(y, feature_vector));
        }
        br.close();
        return likelihood;
    }

    public double calculateLikelihoodOneInstance(Data data) {
        int y = data.y;
        double[] feature_vector = data.feature_vector;
        double numerator = 1.0;
        double denominator = 1.0;
        for (int c = 0; c < this.num_label - 1; ++c) {
            double denominator_exp = 0.0;
            for (int d = 0; d < this.dim; ++d) {
                denominator_exp += this.etas[c][d] * feature_vector[d];
            }
            denominator_exp = Math.exp(denominator_exp);
            if (y == c) {
                numerator = denominator_exp;
            }
            denominator += denominator_exp;
        }
        return Math.log(numerator) - Math.log(denominator);
    }

    public double calculateLikelihood() {
        double likelihood = 0.0;
        for (Data data : this.data_set) {
            int y = data.y;
            double[] feature_vector = data.feature_vector;
            double numerator = 1.0;
            double denominator = 1.0;
            for (int c = 0; c < this.num_label - 1; ++c) {
                double denominator_exp = 0.0;
                for (int d = 0; d < this.dim; ++d) {
                    denominator_exp += this.etas[c][d] * feature_vector[d];
                }
                denominator_exp = Math.exp(denominator_exp);
                if (y == c) {
                    numerator = denominator_exp;
                }
                denominator += denominator_exp;
            }
            likelihood += Math.log(numerator) - Math.log(denominator);
        }
        return likelihood;
    }

    public Result getAccuracy(InputStreamReader in, int num_label, int dim, double[][] etas, HashMap<String, Integer> label_index) throws Exception {
        String line;
        double accuracy = 0.0;
        BufferedReader br = new BufferedReader(in);
        int num_correct = 0;
        int num_test_cases = 0;
        while ((line = br.readLine()) != null) {
            int startIndex = 0;
            if (this.isHadoopGeneratedData) {
                startIndex = 1;
            }
            String[] tokens = line.split(this.dataDeliminator);
            int y = label_index.get(tokens[startIndex]);
            int best_label = num_label - 1;
            double best_logp = 0.0;
            for (int c = 0; c < num_label - 1; ++c) {
                double logp = 0.0;
                for (int i = startIndex + 1; i < tokens.length; ++i) {
                    String[] toks = tokens[i].split(":");
                    int d = this.getFeatureIndex(toks[0]);
                    logp += etas[c][d] * Double.parseDouble(toks[1]);
                }
                if (!(logp > best_logp)) continue;
                best_label = c;
                best_logp = logp;
            }
            if (best_label == y) {
                ++num_correct;
            }
            ++num_test_cases;
        }
        br.close();
        accuracy = (double)num_correct / (double)num_test_cases;
        return new Result(accuracy, num_correct, num_test_cases);
    }

    public Result getAccuracy(InputStreamReader in) throws Exception {
        return this.getAccuracy(in, this.num_label, this.dim, this.etas, this.label_index);
    }

    public static void main(String[] args) throws Exception {
        args = "--x sgd --f input/training_D_4500_N_120000_S_0.01.txt --i 100".split(" ");
        PosixParser parser = new PosixParser();
        LRSGD.InitOptions();
        CommandLine cli = parser.parse(options, args);
        LRSGD log_reg = new LRSGD(cli);
        log_reg.readData();
        log_reg.init();
        System.out.println("Dim: " + log_reg.dim);
        System.out.println("Num labels: " + log_reg.num_label);
        System.out.println("Iterations: " + log_reg.iterations);
        System.out.println(String.format("Epsilon: %f", log_reg.epsilon));
        System.out.println();
        log_reg.train();
    }

    public class Result {
        public double accuracy;
        public int num_correct;
        public int num_test_cases;

        public Result(double accuracy, int num_correct, int num_test_cases) {
            this.accuracy = accuracy;
            this.num_correct = num_correct;
            this.num_test_cases = num_test_cases;
        }
    }

    public static class Data {
        public int y;
        public double[] feature_vector;

        public Data(int y, double[] feature_vector) {
            this.y = y;
            this.feature_vector = feature_vector;
        }
    }
}

