/*
 * Decompiled with CFR 0.152.
 */
package kr.ac.kaist.itcknow.bigml.algo.multicore.psgd;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.concurrent.CountDownLatch;
import kr.ac.kaist.itcknow.bigml.algo.multicore.sgd.LRSGD;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.PosixParser;

public class LRTrainPSGD
extends LRSGD {
    public Slave[] slaves;
    public Thread[] threads;
    public int num_processes;
    public CountDownLatch doneSignal;
    public static String opt_num_procs = "num_procs";

    public static Options InitOptions() {
        LRSGD.InitOptions();
        options.addOption("np", opt_num_procs, true, "Number of Processes");
        return options;
    }

    public LRTrainPSGD(CommandLine cli) {
        super(cli);
        this.num_processes = Integer.parseInt(cli.getOptionValue(opt_num_procs));
        this.slaves = new Slave[this.num_processes];
        this.threads = new Thread[this.num_processes];
        this.doneSignal = new CountDownLatch(this.num_processes);
        for (int p = 0; p < this.num_processes; ++p) {
            this.slaves[p] = new Slave(this, p);
            this.threads[p] = new Thread(this.slaves[p]);
        }
    }

    public void initParallel() throws Exception {
    }

    public void startParallel() {
        for (int p = 0; p < this.num_processes; ++p) {
            this.threads[p].start();
        }
    }

    public void mergeParams(String sub_method, int iteration) throws Exception {
        int c;
        this.etas = new double[this.num_label - 1][this.dim];
        for (c = 0; c < this.num_label - 1; ++c) {
            for (int d = 0; d < this.dim; ++d) {
                this.etas[c][d] = 0.0;
            }
        }
        for (int p = 0; p < this.num_processes; ++p) {
            String param_file = String.format("params/%s_%s_sgd_p%d_%d.params", this.expt, sub_method, p, iteration);
            BufferedReader br = new BufferedReader(new FileReader(param_file));
            int num_label_p = Integer.parseInt(br.readLine());
            br.readLine();
            for (int i = 0; i < num_label_p; ++i) {
                String line = br.readLine();
                String[] tokens = line.split("\t");
                if (tokens.length <= 1) continue;
                int c2 = this.getIndex(this.label_index, tokens[0]);
                for (int j = 1; j < tokens.length; ++j) {
                    String[] toks = tokens[j].split(":");
                    int d = this.getIndex(this.feature_index, toks[0]);
                    double[] dArray = this.etas[c2];
                    int n = d;
                    dArray[n] = dArray[n] + Double.parseDouble(toks[1]);
                }
            }
            br.close();
        }
        for (c = 0; c < this.num_label - 1; ++c) {
            int d = 0;
            while (d < this.dim) {
                double[] dArray = this.etas[c];
                int n = d++;
                dArray[n] = dArray[n] / (double)this.num_processes;
            }
        }
    }

    public static void main(String[] args) throws Exception {
        args = "--x testPSGD --f data/training_D_100_N_10000.txt --i 100 --np 3".split(" ");
        PosixParser parser = new PosixParser();
        LRTrainPSGD.InitOptions();
        CommandLine cli = parser.parse(options, args);
        LRTrainPSGD log_reg = new LRTrainPSGD(cli);
        log_reg.startParallel();
        log_reg.doneSignal.await();
    }

    public static class Slave
    implements Runnable {
        public int process_id;
        public LRTrainPSGD master;

        public Slave(LRTrainPSGD master, int process_id) {
            this.master = master;
            this.process_id = process_id;
        }

        @Override
        public void run() {
            try {
                String line;
                ProcessBuilder pb = new ProcessBuilder("java", "-cp", "target/PSCD-0.0.1-SNAPSHOT.jar;./lib/commons-cli-1.2.jar", "-Xmx1g", "kr.ac.kaist.itcknow.bigml.sgd.LRTrainSGD", String.format("--experiment_name=%s_psgd", this.master.expt), String.format("--input_file=%s", this.master.input_file), String.format("--iterations=%d", this.master.iterations), String.format("--epsilon=%f", this.master.epsilon), String.format("--interval=%d", this.master.interval), String.format("--process_id=%d", this.process_id), String.format("--thread_id=%d", 0), String.format("--start_time=%d", this.master.start_time), String.format("--chance=%f", 1.0 / (double)this.master.num_processes));
                pb.directory(new File(System.getProperty("user.dir")));
                System.out.println(pb.directory());
                System.out.println(pb.command());
                pb.redirectErrorStream(true);
                Process p = pb.start();
                InputStream stdout = p.getInputStream();
                BufferedWriter bw = new BufferedWriter(new FileWriter(String.format("./log/%s_psgd_%d.log", this.master.expt, this.process_id)));
                BufferedReader br = new BufferedReader(new InputStreamReader(stdout));
                while ((line = br.readLine()) != null) {
                    bw.write(line);
                    bw.newLine();
                }
                br.close();
                bw.close();
                p.waitFor();
                this.master.doneSignal.countDown();
            }
            catch (Exception e) {
                e.printStackTrace();
                this.master.doneSignal.countDown();
            }
        }
    }
}

