/*
 * Decompiled with CFR 0.152.
 */
package kr.ac.kaist.itcknow.bigml.algo.multicore.psgd;

import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.CountDownLatch;
import kr.ac.kaist.itcknow.bigml.algo.multicore.sgd.LRSGD;
import kr.ac.kaist.itcknow.bigml.util.MersenneTwisterFast;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.PosixParser;

public class LRTrainPSGD2
extends LRSGD {
    public Slave[] slaves;
    public Thread[] threads;
    public int num_threads;
    public int num_processes;
    public CountDownLatch doneSignal;
    public static String opt_num_procs = "num_procs";
    public static String opt_num_threads = "num_threads";

    public static Options InitOptions() {
        LRSGD.InitOptions();
        options.addOption("np", opt_num_procs, true, "Number of Processes");
        options.addOption("nt", opt_num_threads, true, "Number of Threads");
        return options;
    }

    public LRTrainPSGD2(CommandLine cli) {
        super(cli);
        this.num_threads = cli.hasOption(opt_num_threads) ? Integer.parseInt(cli.getOptionValue(opt_num_threads)) : 1;
        this.num_processes = cli.hasOption(opt_num_procs) ? Integer.parseInt(cli.getOptionValue(opt_num_procs)) : 1;
        this.threads = new Thread[this.num_threads];
        this.slaves = new Slave[this.num_threads];
        this.doneSignal = new CountDownLatch(this.num_threads);
        for (int p = 0; p < this.num_threads; ++p) {
            this.slaves[p] = new Slave(this, p);
            this.threads[p] = new Thread(this.slaves[p]);
        }
    }

    public static void train(HashSet<LRSGD.Data> data_set, int num_label, int dim, int iterations, double epsilon, double[][] etas, HashMap<String, Integer> label_index, HashMap<String, Integer> feature_index, HashSet<Integer> sub_features, int process_id, int thread_id, long start_time, String prefix, double chance, int interval) throws Exception {
        MersenneTwisterFast rand = null;
        if (chance < 1.0) {
            rand = new MersenneTwisterFast(System.currentTimeMillis());
        }
        System.out.println(String.format("Process:%d\tThread:%d\tIterations:%d/%d\tTime:%e", process_id, thread_id, 0, iterations, (double)(System.currentTimeMillis() - start_time)));
        long overhead_end = 0L;
        long overhead_start = 0L;
        long duration = 0L;
        long current = 0L;
        if (thread_id == 0) {
            overhead_start = System.currentTimeMillis();
            current = overhead_start - start_time;
            LRTrainPSGD2.printParams(prefix, 0, num_label, dim, etas, label_index);
            overhead_end = System.currentTimeMillis();
        }
        double[] label_cache = new double[num_label - 1];
        for (int iter = 1; iter <= iterations; ++iter) {
            for (LRSGD.Data data : data_set) {
                int c;
                if (chance < 1.0 && chance < rand.nextDouble()) continue;
                int y = data.y;
                double[] feature_vector = data.feature_vector;
                double denominator = 1.0;
                for (c = 0; c < num_label - 1; ++c) {
                    double denominator_exp = 0.0;
                    for (int d = 0; d < dim; ++d) {
                        denominator_exp += etas[c][d] * feature_vector[d];
                    }
                    label_cache[c] = denominator_exp;
                    denominator += Math.exp(denominator_exp);
                }
                for (c = 0; c < num_label - 1; ++c) {
                    double numerator_exp = label_cache[c];
                    double numerator = Math.exp(numerator_exp);
                    double p = numerator / denominator;
                    double delta = c == y ? 1.0 : 0.0;
                    for (int d : sub_features) {
                        double x = feature_vector[d];
                        if (x == 0.0) continue;
                        double gradient = x * (delta - p);
                        denominator -= Math.exp(numerator_exp);
                        numerator_exp -= etas[c][d] * x;
                        double[] dArray = etas[c];
                        int n = d;
                        dArray[n] = dArray[n] + epsilon * gradient;
                        label_cache[c] = numerator_exp += etas[c][d] * x;
                        numerator = Math.exp(numerator_exp);
                        p = numerator / (denominator += numerator);
                    }
                }
            }
            if (iter % interval != 0) continue;
            if (thread_id == 0 && iter != iterations) {
                overhead_start = System.currentTimeMillis();
                duration = overhead_start - overhead_end;
                current = duration + current;
                System.out.println(String.format("Process:%d\tThread:%d\tIterations:%d/%d\tTime:%e", process_id, thread_id, iter, iterations, (double)current));
                LRTrainPSGD2.printParams(prefix, iter, num_label, dim, etas, label_index);
                overhead_end = System.currentTimeMillis();
                continue;
            }
            System.out.println(String.format("Process:%d\tThread:%d\tIterations:%d/%d\tTime:%e", process_id, thread_id, iter, iterations, (double)(System.currentTimeMillis() - start_time)));
        }
    }

    public void startParallel() {
        for (int p = 1; p < this.num_threads; ++p) {
            this.threads[p].start();
        }
        this.threads[0].start();
    }

    public void initParallel() {
        int p = 0;
        for (int d = 0; d < this.dim; ++d) {
            this.slaves[p].addFeature(d);
            p = (p + 1) % this.num_threads;
        }
    }

    public static void main(String[] args) throws Exception {
        args = "--x testPSGD --f data/training_D_100_N_10000.txt --i 100 --np 2 --nt 2".split(" ");
        PosixParser parser = new PosixParser();
        LRTrainPSGD2.InitOptions();
        CommandLine cli = parser.parse(options, args);
        LRTrainPSGD2 log_reg = new LRTrainPSGD2(cli);
        log_reg.readData();
        log_reg.init();
        System.out.println("Dim: " + log_reg.dim);
        System.out.println("Num labels: " + log_reg.num_label);
        System.out.println("Iterations: " + log_reg.iterations);
        System.out.println(String.format("Epsilon: %f", log_reg.epsilon));
        System.out.println("Threads: " + log_reg.num_threads);
        System.out.println("Process ID:" + log_reg.process_id);
        System.out.println();
        log_reg.initParallel();
        log_reg.startParallel();
        log_reg.doneSignal.await();
        LRTrainPSGD2.printParams(String.format("%s_psgd2_sgd_p%d", log_reg.expt, log_reg.process_id), log_reg.iterations, log_reg.num_label, log_reg.dim, log_reg.etas, log_reg.label_index);
    }

    public static class Slave
    implements Runnable {
        public int thread_id;
        public LRTrainPSGD2 master;
        public HashSet<Integer> sub_features;

        public Slave(LRTrainPSGD2 master, int thread_id) {
            this.master = master;
            this.thread_id = thread_id;
            this.sub_features = new HashSet();
        }

        public void addFeature(int feature) {
            this.sub_features.add(feature);
        }

        @Override
        public void run() {
            try {
                LRTrainPSGD2.train(this.master.data_set, this.master.num_label, this.master.dim, this.master.iterations, this.master.epsilon, this.master.etas, this.master.label_index, this.master.feature_index, this.sub_features, this.master.process_id, this.thread_id, this.master.start_time, String.format("%s_psgd2_sgd_p%d", this.master.expt, this.master.process_id), 1.0 / (double)this.master.num_processes, this.master.interval);
                this.master.doneSignal.countDown();
            }
            catch (Exception e) {
                e.printStackTrace();
                this.master.doneSignal.countDown();
            }
        }
    }
}

