/*
 * Decompiled with CFR 0.152.
 */
package sals.mr;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Reducer;
import sals.mr.CommonReducer;
import sals.mr.ElementWritable;
import sals.mr.TripleWritable;
import sals.single.ArrayMethods;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class CDTFReducer
extends CommonReducer {
    private boolean useBias = false;
    private float[][] oldCols;
    private float[][] curCols;
    private float mu = 0.0f;

    @Override
    public void setup(Reducer.Context context) {
        super.setup(context);
        this.oldCols = new float[this.N][];
        this.curCols = new float[this.N][];
        int mode = 0;
        while (mode < this.N) {
            this.oldCols[mode] = new float[this.modeLengths[mode]];
            this.curCols[mode] = new float[this.modeLengths[mode]];
            ++mode;
        }
        this.useBias = this.conf.getBoolean("use_bias", false);
        if (this.useBias) {
            this.mu = this.conf.getFloat("average", 0.0f);
        }
    }

    public void reduce(TripleWritable key, Iterable<ElementWritable> values, Reducer.Context context) throws FileNotFoundException, IOException {
        if (this.machineId < 0) {
            this.machineId = key.left;
            String userHome = System.getProperty("user.home");
            this.baseLocalPath = String.valueOf(userHome) + "/CDTF_" + this.machineId;
            this.tempLocalFile = String.valueOf(this.baseLocalPath) + "/TEMP";
            File baseDir = new File(this.baseLocalPath);
            if (baseDir.exists()) {
                try {
                    FileUtil.fullyDelete((File)baseDir);
                }
                catch (IOException e1) {
                    e1.printStackTrace();
                }
            }
            baseDir.mkdir();
            int dim = 0;
            while (dim < this.N) {
                File f = new File(this.getLocalParamPath(dim));
                if (f.exists()) {
                    f.delete();
                }
                f.mkdir();
                ++dim;
            }
            int mode = 0;
            while (mode < this.N) {
                this.outIndexR[mode] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(mode, true, 0, false))));
                this.outValueR[mode] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(mode, false, 0, false))));
                ++mode;
            }
            this.outIndexRTest[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, true, 1, false))));
            this.outValueRTest[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, false, 1, false))));
            this.outIndexRQuery[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, true, 2, false))));
            this.outValueRQuery[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, false, 2, false))));
            mode = 0;
            while (mode < this.N) {
                this.startIndex[mode] = this.getStartIndex(mode, this.machineId);
                this.endIndex[mode] = this.getStartIndex(mode, this.machineId + 1);
                if (this.machineId == this.M) {
                    int n = mode;
                    this.endIndex[n] = this.endIndex[n] + 1;
                }
                this.nnzFiber[mode] = new int[this.endIndex[mode] - this.startIndex[mode]];
                ++mode;
            }
        }
        int fileMode = key.mid;
        for (ElementWritable value : values) {
            int dim;
            int[] index;
            if (value.isTraining) {
                int n = fileMode;
                this.nnzTraining[n] = this.nnzTraining[n] + 1;
                index = value.index;
                dim = 0;
                while (dim < this.N) {
                    this.outIndexR[fileMode].writeInt(index[dim]);
                    ++dim;
                }
                this.outValueR[fileMode].writeFloat(value.value - this.mu);
                int[] nArray = this.nnzFiber[fileMode];
                int n2 = index[fileMode] - this.startIndex[fileMode];
                nArray[n2] = nArray[n2] + 1;
                continue;
            }
            if (Float.isNaN(value.value)) {
                int n = fileMode;
                this.nnzQuery[n] = this.nnzQuery[n] + 1;
                index = value.index;
                dim = 0;
                while (dim < this.N) {
                    this.outIndexRQuery[fileMode].writeInt(index[dim]);
                    ++dim;
                }
                this.outValueRQuery[fileMode].writeFloat(-this.mu);
                continue;
            }
            int n = fileMode;
            this.nnzTest[n] = this.nnzTest[n] + 1;
            index = value.index;
            dim = 0;
            while (dim < this.N) {
                this.outIndexRTest[fileMode].writeInt(index[dim]);
                ++dim;
            }
            this.outValueRTest[fileMode].writeFloat(value.value - this.mu);
        }
    }

    public void cleanup(Reducer.Context context) throws IOException, InterruptedException {
        File file;
        int i = 0;
        while (i < this.outIndexR.length) {
            this.outIndexR[i].close();
            this.outValueR[i].close();
            ++i;
        }
        this.outIndexRTest[0].close();
        this.outValueRTest[0].close();
        this.outIndexRQuery[0].close();
        this.outValueRQuery[0].close();
        Random rand = new Random(this.conf.getInt("seed", 0));
        int n = 0;
        while (n < this.N) {
            context.progress();
            this.createParamMatrix(n, this.K, n != 0 ? 1 : 0, rand, context);
            ++n;
        }
        if (this.useBias) {
            n = 0;
            while (n < this.N) {
                context.progress();
                this.createBiasTerms(n, context);
                ++n;
            }
        }
        double[][] result = new double[this.Tout][6];
        long startTime = System.currentTimeMillis();
        int outIter = 0;
        while (outIter < this.Tout) {
            int NNZTestSum;
            int NNZSum;
            FileSystem fs = FileSystem.get((Configuration)this.conf);
            int k = 0;
            while (k < this.K) {
                System.out.printf("Iter : %d, column : %d\n", outIter, k);
                int n2 = 0;
                while (n2 < this.N) {
                    long time = System.currentTimeMillis();
                    context.progress();
                    this.loadFromLocal(n2, k);
                    context.getCounter("Speed", "Initialize").increment(System.currentTimeMillis() - time);
                    ++n2;
                }
                this.oldCols = ArrayMethods.copy(this.curCols);
                int innerIter = 0;
                while (innerIter < this.Tin) {
                    int n3 = 0;
                    while (n3 < this.N) {
                        context.progress();
                        long time = System.currentTimeMillis();
                        this.updateFactors(n3);
                        context.getCounter("Speed", "Optimize").increment(System.currentTimeMillis() - time);
                        time = System.currentTimeMillis();
                        this.communicate(outIter, k, innerIter, n3, context, fs);
                        context.getCounter("Speed", "Broadcast").increment(System.currentTimeMillis() - time);
                        ++n3;
                    }
                    ++innerIter;
                }
                double trainErrorSum = 0.0;
                double testErrorSum = 0.0;
                NNZSum = 0;
                NNZTestSum = 0;
                int n4 = 0;
                while (n4 < this.N) {
                    context.progress();
                    long time = System.currentTimeMillis();
                    if (n4 == 0) {
                        trainErrorSum += this.updateR(n4, 0, !this.useBias && k == this.K - 1);
                        testErrorSum += this.updateR(n4, 1, !this.useBias && k == this.K - 1);
                        this.updateR(n4, 2, false);
                        NNZSum += this.nnzTraining[n4];
                        NNZTestSum += this.nnzTest[n4];
                    } else {
                        this.updateR(n4, 0, false);
                    }
                    context.getCounter("Speed", "Update R").increment(System.currentTimeMillis() - time);
                    time = System.currentTimeMillis();
                    this.writeFactors(n4, k);
                    context.getCounter("Speed", "Update Param").increment(System.currentTimeMillis() - time);
                    ++n4;
                }
                if (!this.useBias && k == this.K - 1) {
                    System.out.println(Math.sqrt(trainErrorSum / (double)NNZSum));
                    System.out.println(Math.sqrt(testErrorSum / (double)NNZTestSum));
                    result[outIter] = new double[]{outIter, System.currentTimeMillis() - startTime, trainErrorSum, NNZSum, testErrorSum, NNZTestSum};
                }
                ++k;
            }
            if (this.useBias) {
                int n5 = 0;
                while (n5 < this.N) {
                    long time = System.currentTimeMillis();
                    context.progress();
                    this.loadBiasFromLocal(n5);
                    context.getCounter("Speed", "Initialize").increment(System.currentTimeMillis() - time);
                    time = System.currentTimeMillis();
                    this.oldBias = (float[])this.curBias.clone();
                    context.progress();
                    this.updateBias(n5);
                    context.getCounter("Speed", "Optimize").increment(System.currentTimeMillis() - time);
                    time = System.currentTimeMillis();
                    this.communicateBias(outIter, n5, context, fs);
                    context.getCounter("Speed", "Broadcast").increment(System.currentTimeMillis() - time);
                    float trainErrorSum = 0.0f;
                    float testErrorSum = 0.0f;
                    NNZSum = 0;
                    NNZTestSum = 0;
                    int nr = 0;
                    while (nr < this.N) {
                        context.progress();
                        time = System.currentTimeMillis();
                        if (nr == 0) {
                            trainErrorSum = (float)((double)trainErrorSum + this.updateRWithBias(nr, n5, 0, n5 == this.N - 1));
                            testErrorSum = (float)((double)testErrorSum + this.updateRWithBias(nr, n5, 1, n5 == this.N - 1));
                            this.updateRWithBias(nr, n5, 2, false);
                            NNZSum += this.nnzTraining[nr];
                            NNZTestSum += this.nnzTest[nr];
                        } else {
                            this.updateRWithBias(nr, n5, 0, false);
                        }
                        context.getCounter("Speed", "Update R").increment(System.currentTimeMillis() - time);
                        time = System.currentTimeMillis();
                        ++nr;
                    }
                    this.writeBiasParams(n5);
                    context.getCounter("Speed", "Update Param").increment(System.currentTimeMillis() - time);
                    if (n5 == this.N - 1) {
                        System.out.println(Math.sqrt(trainErrorSum / (float)NNZSum));
                        System.out.println(Math.sqrt(testErrorSum / (float)NNZTestSum));
                        result[outIter] = new double[]{outIter, System.currentTimeMillis() - startTime, trainErrorSum, NNZSum, testErrorSum, NNZTestSum};
                    }
                    ++n5;
                }
            }
            if (this.machineId == 0) {
                context.getCounter("Time ", "" + outIter).increment(System.currentTimeMillis() - startTime);
            }
            try {
                fs.close();
            }
            catch (Exception exception) {
                // empty catch block
            }
            ++outIter;
        }
        this.nnzFiber = null;
        this.curCols = null;
        this.oldCols = null;
        if (this.nnzQuery[0] > 0) {
            this.writeEstimate(context);
        }
        this.writePerformance(context, result);
        if (this.machineId == 0) {
            this.writeFactormatricesResult(context);
            if (this.useBias) {
                this.writeBiasesResults(context);
            }
        }
        if ((file = new File(this.baseLocalPath)).exists()) {
            try {
                FileUtil.fullyDelete((File)file);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void updateFactors(int n) throws FileNotFoundException, IOException {
        int blockLength = this.endIndex[n] - this.startIndex[n];
        float[] numerators = new float[blockLength];
        float[] denominators = new float[blockLength];
        ObjectInputStream inIndex = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, true, 0, false))));
        ObjectInputStream inValue = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, false, 0, false))));
        int elem = 0;
        while (elem < this.nnzTraining[n]) {
            int resultIndex;
            int[] index = new int[this.N];
            float r = 0.0f;
            int _mode = 0;
            while (_mode < this.N) {
                index[_mode] = inIndex.readInt();
                ++_mode;
            }
            r = inValue.readFloat();
            float oldProduct = 1.0f;
            float numerator = 1.0f;
            float denominator = 1.0f;
            int dim = 0;
            while (dim < this.N) {
                oldProduct *= this.oldCols[dim][index[dim]];
                if (dim != n) {
                    numerator *= this.curCols[dim][index[dim]];
                }
                ++dim;
            }
            denominator = numerator * numerator;
            int n2 = resultIndex = index[n] - this.startIndex[n];
            numerators[n2] = numerators[n2] + (numerator *= r + oldProduct);
            int n3 = resultIndex;
            denominators[n3] = denominators[n3] + denominator;
            ++elem;
        }
        inIndex.close();
        inValue.close();
        int i = 0;
        while (i < blockLength) {
            if (denominators[i] != 0.0f) {
                int n4 = i;
                denominators[n4] = denominators[n4] + this.lambda * (float)(this.useWeight ? this.nnzFiber[n][i] : 1);
                int rowIndex = i + this.startIndex[n];
                float result = numerators[i] / denominators[i];
                if (result > -this.epsilon && result < this.epsilon) {
                    result = 0.0f;
                }
                this.curCols[n][rowIndex] = result;
            }
            ++i;
        }
    }

    private double updateR(int n, int type, boolean measureCost) throws FileNotFoundException, IOException {
        double errorSum = 0.0;
        ObjectInputStream inIndex = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, true, type, false))));
        ObjectInputStream inValue = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, false, type, false))));
        ObjectOutputStream outValue = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(n, false, type, true))));
        int nnz = 0;
        nnz = type == 0 ? this.nnzTraining[n] : (type == 1 ? this.nnzTest[n] : this.nnzQuery[n]);
        int elem = 0;
        while (elem < nnz) {
            int[] index = new int[this.N];
            float r = 0.0f;
            int _mode = 0;
            while (_mode < this.N) {
                index[_mode] = inIndex.readInt();
                ++_mode;
            }
            r = inValue.readFloat();
            float oldProduct = 1.0f;
            float newProduct = 1.0f;
            int _mode2 = 0;
            while (_mode2 < this.N) {
                oldProduct *= this.oldCols[_mode2][index[_mode2]];
                newProduct *= this.curCols[_mode2][index[_mode2]];
                ++_mode2;
            }
            r = r + oldProduct - newProduct;
            outValue.writeFloat(r);
            if (measureCost) {
                errorSum += (double)(r * r);
            }
            ++elem;
        }
        inIndex.close();
        inValue.close();
        outValue.close();
        this.replace(this.getLocalRPath(n, false, type, false));
        return errorSum;
    }

    private void communicate(int outIter, int k, int inIter, int n, Reducer.Context context, FileSystem fs) throws IOException {
        FSDataOutputStream out = null;
        Path outPath = new Path(this.getHDFSParamPath(outIter, k, inIter, n, this.machineId, false));
        out = fs.create(outPath);
        int i = this.startIndex[n];
        while (i < this.endIndex[n]) {
            out.writeFloat(this.curCols[n][i]);
            ++i;
        }
        out.close();
        this.markWrite(outIter, k, inIter, n, fs);
        boolean[] markReadComplete = new boolean[this.M];
        markReadComplete[this.machineId] = true;
        while (true) {
            long requestTime = System.currentTimeMillis();
            FileStatus[] statusList = fs.listStatus(new Path(this.getHDFSParamPath(outIter, k, inIter, n, true)));
            CDTFReducer.shuffle(statusList);
            FileStatus[] fileStatusArray = statusList;
            int n2 = statusList.length;
            int n3 = 0;
            while (n3 < n2) {
                block20: {
                    FileStatus status = fileStatusArray[n3];
                    int _machineId = Integer.valueOf(status.getPath().getName());
                    if (!markReadComplete[_machineId]) {
                        FSDataInputStream in = null;
                        try {
                            try {
                                in = fs.open(new Path(this.getHDFSParamPath(outIter, k, inIter, n, _machineId, false)));
                                int i2 = this.getStartIndex(n, _machineId);
                                while (i2 < this.getStartIndex(n, _machineId + 1)) {
                                    this.curCols[n][i2] = in.readFloat();
                                    ++i2;
                                }
                            }
                            catch (Exception e) {
                                System.out.println(e.getMessage());
                                context.getCounter("Error", "err").increment(1L);
                                try {
                                    in.close();
                                }
                                catch (Exception exception) {}
                                break block20;
                            }
                        }
                        catch (Throwable throwable) {
                            try {
                                in.close();
                            }
                            catch (Exception exception) {
                                // empty catch block
                            }
                            throw throwable;
                        }
                        try {
                            in.close();
                        }
                        catch (Exception exception) {
                            // empty catch block
                        }
                        markReadComplete[_machineId] = true;
                    }
                }
                ++n3;
            }
            boolean markAll = true;
            int _machineId = 0;
            while (_machineId < this.M) {
                if (!markReadComplete[_machineId]) {
                    markAll = false;
                    break;
                }
                ++_machineId;
            }
            if (markAll) break;
            context.progress();
            long timeToWait = (long)((double)this.waiting * Math.random()) - (System.currentTimeMillis() - requestTime);
            if (timeToWait <= 0L) continue;
            try {
                Thread.sleep(timeToWait);
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    private void writeFactors(int n, int k) throws FileNotFoundException, IOException {
        ObjectOutputStream os = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalParamPath(n, k, false))));
        int row = 0;
        while (row < this.modeLengths[n]) {
            os.writeFloat(this.curCols[n][row]);
            ++row;
        }
        os.close();
    }

    private void loadFromLocal(int n, int k) throws FileNotFoundException, IOException {
        ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalParamPath(n, k, false))));
        int i = 0;
        while (i < this.modeLengths[n]) {
            this.curCols[n][i] = in.readFloat();
            ++i;
        }
        in.close();
    }
}

