/*
 * Decompiled with CFR 0.152.
 */
package kr.ac.kaist.itcknow.bigml.algo.multicore.cd;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import kr.ac.kaist.itcknow.bigml.algo.multicore.sgd.LRSGD;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.PosixParser;

public class LRCD
extends LRSGD {
    public void addParams(InputStreamReader inputStreamReader) throws Exception {
        BufferedReader br = new BufferedReader(inputStreamReader);
        this.num_label = Integer.parseInt(br.readLine());
        this.dim = Integer.parseInt(br.readLine());
        for (int i = 0; i < this.num_label; ++i) {
            String line = br.readLine();
            String[] tokens = line.split("\t");
            if (tokens.length <= 1) continue;
            int c = this.getIndex(this.label_index, tokens[0]);
            for (int j = 1; j < tokens.length; ++j) {
                String[] toks = tokens[j].split(":");
                int d = this.getFeatureIndex(toks[0]);
                double[] dArray = this.etas[c];
                int n = d;
                dArray[n] = dArray[n] + Double.parseDouble(toks[1]);
            }
        }
        br.close();
    }

    public void addParams(String filename) throws Exception {
        BufferedReader br = new BufferedReader(new FileReader(filename));
        this.num_label = Integer.parseInt(br.readLine());
        this.dim = Integer.parseInt(br.readLine());
        for (int i = 0; i < this.num_label; ++i) {
            String line = br.readLine();
            String[] tokens = line.split("\t");
            if (tokens.length <= 1) continue;
            int c = this.getIndex(this.label_index, tokens[0]);
            for (int j = 1; j < tokens.length; ++j) {
                String[] toks = tokens[j].split(":");
                int d = this.getFeatureIndex(toks[0]);
                double[] dArray = this.etas[c];
                int n = d;
                dArray[n] = dArray[n] + Double.parseDouble(toks[1]);
            }
        }
        br.close();
    }

    public void readParams(InputStreamReader inputStreamReader, HashSet<String> dSet) throws Exception {
        BufferedReader br = new BufferedReader(inputStreamReader);
        this.num_label = Integer.parseInt(br.readLine());
        this.dim = Integer.parseInt(br.readLine());
        this.etas = new double[this.num_label - 1][this.dim];
        for (int i = 0; i < this.num_label; ++i) {
            String line = br.readLine();
            String[] tokens = line.split("\t");
            if (tokens.length <= 1) continue;
            int c = this.getIndex(this.label_index, tokens[0]);
            for (int j = 1; j < tokens.length; ++j) {
                String[] toks = tokens[j].split(":");
                if (!dSet.contains(toks[0])) continue;
                int d = this.getFeatureIndex(toks[0]);
                this.etas[c][d] = Double.parseDouble(toks[1]);
            }
        }
        br.close();
    }

    public void mergeParamsDPSCD(String params_path) throws Exception {
        ArrayList<String> partFilenameList = new ArrayList<String>();
        File folder = new File(params_path);
        File[] listOfFiles = folder.listFiles();
        for (int i = 0; i < listOfFiles.length; ++i) {
            String files;
            if (!listOfFiles[i].isFile() || !(files = listOfFiles[i].getName()).startsWith("part-")) continue;
            partFilenameList.add(files);
        }
        int c = 0;
        this.etas = new double[this.num_label - 1][this.dim];
        for (String partFilename : partFilenameList) {
            String line;
            BufferedReader br = new BufferedReader(new FileReader(params_path + partFilename));
            while ((line = br.readLine()) != null) {
                String[] tokens = line.split("\t");
                if (tokens.length <= 1) continue;
                String dIndex = tokens[0];
                String mapInfo = tokens[1];
                for (int j = 1; j < tokens.length; ++j) {
                    String[] toks = tokens[j].split(":");
                    if (toks.length != 2) continue;
                    int d = this.getFeatureIndex(toks[0]);
                    double[] dArray = this.etas[c];
                    int n = d;
                    dArray[n] = dArray[n] + Double.parseDouble(toks[1]);
                }
            }
            br.close();
        }
    }

    public void mergeParamsDSCD(String params_path) throws Exception {
        ArrayList<String> partFilenameList = new ArrayList<String>();
        File folder = new File(params_path);
        File[] listOfFiles = folder.listFiles();
        for (int i = 0; i < listOfFiles.length; ++i) {
            String files;
            if (!listOfFiles[i].isFile() || !(files = listOfFiles[i].getName()).startsWith("part-")) continue;
            partFilenameList.add(files);
        }
        int c = 0;
        this.etas = new double[this.num_label - 1][this.dim];
        for (String partFilename : partFilenameList) {
            String line;
            BufferedReader br = new BufferedReader(new FileReader(params_path + partFilename));
            while ((line = br.readLine()) != null) {
                String[] tokens = line.split("\t");
                if (tokens.length <= 1) continue;
                String dIndex = tokens[0];
                for (int j = 1; j < tokens.length; ++j) {
                    String[] toks = tokens[j].split(":");
                    if (toks.length != 2) continue;
                    int d = this.getFeatureIndex(toks[0]);
                    this.etas[c][d] = Double.parseDouble(toks[1]);
                }
            }
            br.close();
        }
    }

    public String getParamFormatString(int dIndex) {
        String paramStr = "";
        for (Map.Entry entry : this.label_index.entrySet()) {
            if ((Integer)entry.getValue() == this.num_label - 1) continue;
            for (int d = 0; d < this.dim; ++d) {
                if (dIndex != d) continue;
                paramStr = paramStr + String.format("\t%s:%f", "" + (d + 1), this.etas[(Integer)entry.getValue()][d]);
            }
        }
        return paramStr;
    }

    public String getParamFormatString(HashSet<String> dSet) {
        String paramStr = "";
        for (Map.Entry entry : this.label_index.entrySet()) {
            if ((Integer)entry.getValue() == this.num_label - 1) continue;
            for (int d = 0; d < this.dim; ++d) {
                if (!dSet.contains("" + (d + 1)) || this.etas[(Integer)entry.getValue()][d] == 0.0) continue;
                paramStr = paramStr + String.format("\t%s:%f", "" + (d + 1), this.etas[(Integer)entry.getValue()][d]);
            }
        }
        return paramStr;
    }

    public void trainOneInstance(HashSet<String> dSet, int y, double[] feature_vector) throws Exception {
        this.trainOneInstance(dSet, y, feature_vector, this.num_label, this.dim, this.iterations, this.etas, this.label_index);
    }

    private void trainOneInstance(HashSet<String> dSet, int y, double[] feature_vector, int num_label, int dim, int iterations, double[][] etas, HashMap<String, Integer> label_index) {
        int c;
        double[] label_cache = new double[num_label - 1];
        double denominator = 1.0;
        for (c = 0; c < num_label - 1; ++c) {
            double denominator_exp = 0.0;
            for (int d = 0; d < dim; ++d) {
                denominator_exp += etas[c][d] * feature_vector[d];
            }
            label_cache[c] = denominator_exp;
            denominator += Math.exp(denominator_exp);
        }
        for (c = 0; c < num_label - 1; ++c) {
            double numerator_exp = label_cache[c];
            double numerator = Math.exp(numerator_exp);
            double p = numerator / denominator;
            double delta = c == y ? 1.0 : 0.0;
            for (String d : dSet) {
                int di = this.getFeatureIndex(d);
                double x = feature_vector[di];
                if (x == 0.0) continue;
                double gradient = x * (delta - p);
                denominator -= Math.exp(numerator_exp);
                numerator_exp -= etas[c][di] * x;
                double[] dArray = etas[c];
                int n = di;
                dArray[n] = dArray[n] + this.epsilon * gradient;
                label_cache[c] = numerator_exp += etas[c][di] * x;
                numerator = Math.exp(numerator_exp);
                p = numerator / (denominator += numerator);
            }
        }
    }

    public void trainOneInstance(HashSet<String> dSet, LRSGD.Data data) throws Exception {
        this.trainOneInstance(dSet, data, this.num_label, this.dim, this.iterations, this.etas, (HashMap<String, Integer>)this.label_index);
    }

    private void trainOneInstance(HashSet<String> dSet, LRSGD.Data data, int num_label, int dim, int iterations, double[][] etas, HashMap<String, Integer> label_index) {
        int c;
        double[] label_cache = new double[num_label - 1];
        int y = data.y;
        double[] feature_vector = data.feature_vector;
        double denominator = 1.0;
        for (c = 0; c < num_label - 1; ++c) {
            double denominator_exp = 0.0;
            for (int d = 0; d < dim; ++d) {
                denominator_exp += etas[c][d] * feature_vector[d];
            }
            label_cache[c] = denominator_exp;
            denominator += Math.exp(denominator_exp);
        }
        for (c = 0; c < num_label - 1; ++c) {
            double numerator_exp = label_cache[c];
            double numerator = Math.exp(numerator_exp);
            double p = numerator / denominator;
            double delta = c == y ? 1.0 : 0.0;
            for (String d : dSet) {
                int di = this.getFeatureIndex(d);
                double x = feature_vector[di];
                if (x == 0.0) continue;
                double gradient = x * (delta - p);
                denominator -= Math.exp(numerator_exp);
                numerator_exp -= etas[c][di] * x;
                double[] dArray = etas[c];
                int n = di;
                dArray[n] = dArray[n] + this.epsilon * gradient;
                label_cache[c] = numerator_exp += etas[c][di] * x;
                numerator = Math.exp(numerator_exp);
                p = numerator / (denominator += numerator);
            }
        }
    }

    public void trainOneInstance(int dIndex, LRSGD.Data data) throws Exception {
        this.trainOneInstance(dIndex, data, this.num_label, this.dim, this.iterations, this.etas, (HashMap<String, Integer>)this.label_index);
    }

    private void trainOneInstance(int dIndex, LRSGD.Data data, int num_label, int dim, int iterations, double[][] etas, HashMap<String, Integer> label_index) {
        int c;
        double[] label_cache = new double[num_label - 1];
        int y = data.y;
        double[] feature_vector = data.feature_vector;
        double denominator = 1.0;
        for (c = 0; c < num_label - 1; ++c) {
            double denominator_exp = 0.0;
            for (int d = 0; d < dim; ++d) {
                denominator_exp += etas[c][d] * feature_vector[d];
            }
            label_cache[c] = denominator_exp;
            denominator += Math.exp(denominator_exp);
        }
        for (c = 0; c < num_label - 1; ++c) {
            double numerator_exp = label_cache[c];
            double numerator = Math.exp(numerator_exp);
            double p = numerator / denominator;
            double delta = c == y ? 1.0 : 0.0;
            double x = feature_vector[dIndex];
            if (x == 0.0) continue;
            double gradient = x * (delta - p);
            double[] dArray = etas[c];
            int n = dIndex;
            dArray[n] = dArray[n] + this.epsilon * gradient;
        }
    }

    public LRCD() {
        this.dim = 0;
        this.num_label = 0;
        this.label_index = new HashMap();
        this.feature_index = new HashMap();
        this.data_set = new HashSet();
    }

    public LRCD(CommandLine cli) {
        this();
        this.expt = cli.getOptionValue(opt_expt);
        this.input_file = cli.getOptionValue(opt_input_file);
        this.test_file = cli.getOptionValue(opt_test_file);
        this.iterations = Integer.parseInt(cli.getOptionValue(opt_iterations));
        this.epsilon = cli.hasOption(opt_epsilon) ? Double.parseDouble(cli.getOptionValue(opt_epsilon)) : 0.1;
        this.start_time = cli.hasOption(opt_start_time) ? Long.parseLong(cli.getOptionValue(opt_start_time)) : System.currentTimeMillis();
        this.process_id = cli.hasOption(opt_process_id) ? Integer.parseInt(cli.getOptionValue(opt_process_id)) : 0;
        this.thread_id = cli.hasOption(opt_thread_id) ? Integer.parseInt(cli.getOptionValue(opt_thread_id)) : 0;
        this.chance = cli.hasOption(opt_chance) ? Double.parseDouble(cli.getOptionValue(opt_chance)) : 1.0;
        this.interval = cli.hasOption(opt_interval) ? Integer.parseInt(cli.getOptionValue(opt_interval)) : 1;
    }

    public static void main(String[] args) throws Exception {
        args = "--x cd --f data/training_D_100_N_10000.txt --i 100".split(" ");
        PosixParser parser = new PosixParser();
        LRCD.InitOptions();
        CommandLine cli = parser.parse(options, args);
        LRCD log_reg = new LRCD(cli);
        log_reg.readData();
        log_reg.init();
        System.out.println("Dim: " + log_reg.dim);
        System.out.println("Num labels: " + log_reg.num_label);
        System.out.println("Iterations: " + log_reg.iterations);
        System.out.println(String.format("Epsilon: %f", log_reg.epsilon));
        System.out.println();
        log_reg.train();
    }
}

