/*
 * Decompiled with CFR 0.152.
 */
package sals.mr;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Reducer;
import sals.mr.CommonReducer;
import sals.mr.ElementWritable;
import sals.mr.TripleWritable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SALSReducer
extends CommonReducer {
    private boolean useBias = false;
    private float[][][] curCols;
    private int C;
    private float mu = 0.0f;

    @Override
    public void setup(Reducer.Context context) {
        super.setup(context);
        this.C = this.conf.getInt("C", 0);
        this.curCols = new float[this.N][][];
        int mode = 0;
        while (mode < this.N) {
            this.curCols[mode] = new float[this.modeLengths[mode] + 1][this.C];
            ++mode;
        }
        this.useBias = this.conf.getBoolean("use_bias", false);
        if (this.useBias) {
            this.mu = this.conf.getFloat("average", 0.0f);
        }
    }

    public void reduce(TripleWritable key, Iterable<ElementWritable> values, Reducer.Context context) {
        if (this.machineId < 0) {
            this.machineId = key.left;
            String userHome = System.getProperty("user.home");
            this.baseLocalPath = String.valueOf(userHome) + "/SALS" + this.machineId;
            this.tempLocalFile = String.valueOf(this.baseLocalPath) + "/TEMP";
            File baseDir = new File(this.baseLocalPath);
            if (baseDir.exists()) {
                try {
                    FileUtil.fullyDelete((File)baseDir);
                }
                catch (IOException e1) {
                    e1.printStackTrace();
                }
            }
            baseDir.mkdir();
            int dim = 0;
            while (dim < this.N) {
                File f = new File(this.getLocalParamPath(dim));
                if (f.exists()) {
                    f.delete();
                }
                f.mkdir();
                ++dim;
            }
            try {
                int mode = 0;
                while (mode < this.N) {
                    this.outIndexR[mode] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(mode, true, 0, false))));
                    this.outValueR[mode] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(mode, false, 0, false))));
                    ++mode;
                }
                this.outIndexRTest[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, true, 1, false))));
                this.outValueRTest[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, false, 1, false))));
                this.outIndexRQuery[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, true, 2, false))));
                this.outValueRQuery[0] = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(0, false, 2, false))));
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            int mode = 0;
            while (mode < this.N) {
                this.startIndex[mode] = this.getStartIndex(mode, this.machineId);
                this.endIndex[mode] = this.getStartIndex(mode, this.machineId + 1);
                if (this.machineId == this.M) {
                    int n = mode;
                    this.endIndex[n] = this.endIndex[n] + 1;
                }
                this.nnzFiber[mode] = new int[this.endIndex[mode] - this.startIndex[mode]];
                ++mode;
            }
        }
        try {
            int fileMode = key.mid;
            for (ElementWritable value : values) {
                int dim;
                int[] index;
                if (value.isTraining) {
                    int n = fileMode;
                    this.nnzTraining[n] = this.nnzTraining[n] + 1;
                    index = value.index;
                    dim = 0;
                    while (dim < this.N) {
                        this.outIndexR[fileMode].writeInt(index[dim]);
                        ++dim;
                    }
                    this.outValueR[fileMode].writeFloat(value.value - this.mu);
                    int[] nArray = this.nnzFiber[fileMode];
                    int n2 = index[fileMode] - this.startIndex[fileMode];
                    nArray[n2] = nArray[n2] + 1;
                    continue;
                }
                if (Float.isNaN(value.value)) {
                    int n = fileMode;
                    this.nnzQuery[n] = this.nnzQuery[n] + 1;
                    index = value.index;
                    dim = 0;
                    while (dim < this.N) {
                        this.outIndexRQuery[fileMode].writeInt(index[dim]);
                        ++dim;
                    }
                    this.outValueRQuery[fileMode].writeFloat(-this.mu);
                    continue;
                }
                int n = fileMode;
                this.nnzTest[n] = this.nnzTest[n] + 1;
                index = value.index;
                dim = 0;
                while (dim < this.N) {
                    this.outIndexRTest[fileMode].writeInt(index[dim]);
                    ++dim;
                }
                this.outValueRTest[fileMode].writeFloat(value.value - this.mu);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void cleanup(Reducer.Context context) throws IOException, InterruptedException {
        File file;
        try {
            int i = 0;
            while (i < this.outIndexR.length) {
                this.outIndexR[i].close();
                this.outValueR[i].close();
                ++i;
            }
            this.outIndexRTest[0].close();
            this.outValueRTest[0].close();
            this.outIndexRQuery[0].close();
            this.outValueRQuery[0].close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        Random rand = new Random(this.conf.getInt("seed", 0));
        int dim = 0;
        while (dim < this.N) {
            context.progress();
            this.createParamMatrix(dim, this.K, dim != 0 ? 1 : 0, rand, context);
            ++dim;
        }
        if (this.useBias) {
            int n = 0;
            while (n < this.N) {
                context.progress();
                this.createBiasTerms(n, context);
                ++n;
            }
        }
        double[][] result = new double[this.Tout][6];
        long startTime = System.currentTimeMillis();
        int outIter = 0;
        while (outIter < this.Tout) {
            FileSystem fs = FileSystem.get((Configuration)this.conf);
            int[] permutedColumns = SALSReducer.createRandomSequence(this.K, rand);
            int splitIter = 0;
            while ((double)splitIter < Math.ceil(((double)this.K + 0.0) / (double)this.C)) {
                long time;
                int cStart = this.C * splitIter;
                int cEnd = Math.min(this.C * (splitIter + 1) - 1, this.K - 1);
                int cLength = cEnd - cStart + 1;
                System.out.printf("Iter : %d, column sequence : %d\n", outIter, splitIter);
                int n = 0;
                while (n < this.N) {
                    int column = cStart;
                    while (column <= cEnd) {
                        time = System.currentTimeMillis();
                        context.progress();
                        this.loadFromLocal(n, permutedColumns[column], column - cStart);
                        context.getCounter("Speed", "Initialize").increment(System.currentTimeMillis() - time);
                        ++column;
                    }
                    ++n;
                }
                n = 0;
                while (n < this.N) {
                    context.progress();
                    long time2 = System.currentTimeMillis();
                    if (n == 0) {
                        this.updateR(n, 0, true, false);
                        this.updateR(n, 1, true, false);
                        this.updateR(n, 2, true, false);
                    } else {
                        this.updateR(n, 0, true, false);
                    }
                    context.getCounter("Speed", "Update R_hat").increment(System.currentTimeMillis() - time2);
                    ++n;
                }
                int innerIter = 0;
                while (innerIter < this.Tin) {
                    int n2 = 0;
                    while (n2 < this.N) {
                        context.progress();
                        time = System.currentTimeMillis();
                        this.updateFactors(n2, cLength, context);
                        context.getCounter("Speed", "Optimize").increment(System.currentTimeMillis() - time);
                        time = System.currentTimeMillis();
                        this.communicate(outIter, splitIter, cLength, innerIter, n2, context, fs);
                        context.getCounter("Speed", "Broadcast").increment(System.currentTimeMillis() - time);
                        ++n2;
                    }
                    ++innerIter;
                }
                float errorTrainingSum = 0.0f;
                float errorTestSum = 0.0f;
                int nnzTrainingSum = 0;
                int nnzTestSum = 0;
                int n3 = 0;
                while (n3 < this.N) {
                    context.progress();
                    long time3 = System.currentTimeMillis();
                    if (n3 == 0) {
                        errorTrainingSum = (float)((double)errorTrainingSum + this.updateR(n3, 0, false, !this.useBias && cEnd == this.K - 1));
                        errorTestSum = (float)((double)errorTestSum + this.updateR(n3, 1, false, !this.useBias && cEnd == this.K - 1));
                        this.updateR(n3, 2, false, false);
                        nnzTrainingSum += this.nnzTraining[n3];
                        nnzTestSum += this.nnzTest[n3];
                    } else {
                        this.updateR(n3, 0, false, false);
                    }
                    context.getCounter("Speed", "Update R").increment(System.currentTimeMillis() - time3);
                    time3 = System.currentTimeMillis();
                    int column = cStart;
                    while (column <= cEnd) {
                        context.progress();
                        this.writeFactors(n3, permutedColumns[column], column - cStart);
                        ++column;
                    }
                    context.getCounter("Speed", "Update Param").increment(System.currentTimeMillis() - time3);
                    ++n3;
                }
                if (!this.useBias && cEnd == this.K - 1) {
                    System.out.println(Math.sqrt(errorTrainingSum / (float)nnzTrainingSum));
                    System.out.println(Math.sqrt(errorTestSum / (float)nnzTestSum));
                    result[outIter] = new double[]{outIter, System.currentTimeMillis() - startTime, errorTrainingSum, nnzTrainingSum, errorTestSum, nnzTestSum};
                }
                ++splitIter;
            }
            if (this.useBias) {
                int n = 0;
                while (n < this.N) {
                    long time = System.currentTimeMillis();
                    context.progress();
                    this.loadBiasFromLocal(n);
                    context.getCounter("Speed", "Initialize").increment(System.currentTimeMillis() - time);
                    time = System.currentTimeMillis();
                    this.oldBias = (float[])this.curBias.clone();
                    context.progress();
                    this.updateBias(n);
                    context.getCounter("Speed", "Optimize").increment(System.currentTimeMillis() - time);
                    time = System.currentTimeMillis();
                    this.communicateBias(outIter, n, context, fs);
                    context.getCounter("Speed", "Broadcast").increment(System.currentTimeMillis() - time);
                    float trainErrorSum = 0.0f;
                    float testErrorSum = 0.0f;
                    int NNZSum = 0;
                    int NNZTestSum = 0;
                    int nr = 0;
                    while (nr < this.N) {
                        context.progress();
                        time = System.currentTimeMillis();
                        if (nr == 0) {
                            trainErrorSum = (float)((double)trainErrorSum + this.updateRWithBias(nr, n, 0, n == this.N - 1));
                            testErrorSum = (float)((double)testErrorSum + this.updateRWithBias(nr, n, 1, n == this.N - 1));
                            this.updateRWithBias(nr, n, 2, false);
                            NNZSum += this.nnzTraining[nr];
                            NNZTestSum += this.nnzTest[nr];
                        } else {
                            this.updateRWithBias(nr, n, 0, false);
                        }
                        context.getCounter("Speed", "Update R").increment(System.currentTimeMillis() - time);
                        time = System.currentTimeMillis();
                        ++nr;
                    }
                    this.writeBiasParams(n);
                    context.getCounter("Speed", "Update Param").increment(System.currentTimeMillis() - time);
                    if (n == this.N - 1) {
                        System.out.println(Math.sqrt(trainErrorSum / (float)NNZSum));
                        System.out.println(Math.sqrt(testErrorSum / (float)NNZTestSum));
                        result[outIter] = new double[]{outIter, System.currentTimeMillis() - startTime, trainErrorSum, NNZSum, testErrorSum, NNZTestSum};
                    }
                    ++n;
                }
            }
            if (this.machineId == 0) {
                context.getCounter("Time ", "" + outIter).increment(System.currentTimeMillis() - startTime);
            }
            try {
                fs.close();
            }
            catch (Exception exception) {
                // empty catch block
            }
            ++outIter;
        }
        this.nnzFiber = null;
        this.curCols = null;
        if (this.nnzQuery[0] > 0) {
            this.writeEstimate(context);
        }
        this.writePerformance(context, result);
        if (this.machineId == 0) {
            this.writeFactormatricesResult(context);
            if (this.useBias) {
                this.writeBiasesResults(context);
            }
        }
        if ((file = new File(this.baseLocalPath)).exists()) {
            try {
                FileUtil.fullyDelete((File)file);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void updateFactors(int n, int C, Reducer.Context context) throws IOException {
        double[][] B = new double[C][C];
        double[][] c = new double[C][1];
        int oldResultIndex = -1;
        ObjectInputStream inIndex = null;
        ObjectInputStream inValue = null;
        inIndex = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, true, 0, false))));
        inValue = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, false, 0, false))));
        int elem = 0;
        while (elem < this.nnzTraining[n]) {
            int column;
            int[] index = new int[this.N];
            float r = 0.0f;
            int _mode = 0;
            while (_mode < this.N) {
                index[_mode] = inIndex.readInt();
                ++_mode;
            }
            r = inValue.readFloat();
            int resultIndex = index[n];
            if (oldResultIndex >= 0 && oldResultIndex != resultIndex) {
                if (oldResultIndex % 1000000 == 0) {
                    context.progress();
                }
                int column1 = 0;
                while (column1 < C) {
                    int column2 = column1 + 1;
                    while (column2 < C) {
                        B[column2][column1] = B[column1][column2];
                        ++column2;
                    }
                    ++column1;
                }
                int column2 = 0;
                while (column2 < C) {
                    double[] dArray = B[column2];
                    int n2 = column2;
                    dArray[n2] = dArray[n2] + (double)(this.lambda * (float)(this.useWeight ? this.nnzFiber[n][oldResultIndex - this.startIndex[n]] : 1));
                    ++column2;
                }
                try {
                    double[][] newParam = new CholeskyDecomposition(new Matrix(B)).solve(new Matrix(c)).getArray();
                    column = 0;
                    while (column < C) {
                        float result = (float)newParam[column][0];
                        if (result > -this.epsilon && result < this.epsilon) {
                            result = 0.0f;
                        }
                        this.curCols[n][oldResultIndex][column] = result;
                        ++column;
                    }
                }
                catch (Exception e) {
                    System.out.println("Singular matrix");
                }
                B = new double[C][C];
                c = new double[C][1];
            }
            oldResultIndex = resultIndex;
            double[] product = new double[C];
            column = 0;
            while (column < C) {
                product[column] = 1.0;
                ++column;
            }
            int i = 1;
            while (i < this.N) {
                int nextmode = (n + i) % this.N;
                int column3 = 0;
                while (column3 < C) {
                    int n3 = column3;
                    product[n3] = product[n3] * (double)this.curCols[nextmode][index[nextmode]][column3];
                    ++column3;
                }
                ++i;
            }
            int column1 = 0;
            while (column1 < C) {
                int column2 = column1;
                while (column2 < C) {
                    double[] dArray = B[column1];
                    int n4 = column2;
                    dArray[n4] = dArray[n4] + product[column1] * product[column2];
                    ++column2;
                }
                double[] dArray = c[column1];
                dArray[0] = dArray[0] + product[column1] * (double)r;
                ++column1;
            }
            ++elem;
        }
        if (oldResultIndex >= 0) {
            int column1 = 0;
            while (column1 < C) {
                int column2 = column1 + 1;
                while (column2 < C) {
                    B[column2][column1] = B[column1][column2];
                    ++column2;
                }
                ++column1;
            }
            int column = 0;
            while (column < C) {
                double[] dArray = B[column];
                int n5 = column;
                dArray[n5] = dArray[n5] + (double)(this.lambda * (float)(this.useWeight ? this.nnzFiber[n][oldResultIndex - this.startIndex[n]] : 1));
                ++column;
            }
            double[][] newParam = new CholeskyDecomposition(new Matrix(B)).solve(new Matrix(c)).getArray();
            int column4 = 0;
            while (column4 < C) {
                float result = (float)newParam[column4][0];
                if (result > -this.epsilon && result < this.epsilon) {
                    result = 0.0f;
                }
                this.curCols[n][oldResultIndex][column4] = result;
                ++column4;
            }
            B = new double[C][C];
            c = new double[C][1];
        }
        inIndex.close();
        inValue.close();
    }

    public double updateR(int n, int type, boolean add, boolean measureCost) throws IOException {
        double errorSum = 0.0;
        ObjectInputStream inIndex = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, true, type, false))));
        ObjectInputStream inValue = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalRPath(n, false, type, false))));
        ObjectOutputStream outValue = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalRPath(n, false, type, true))));
        int nnz = 0;
        nnz = type == 0 ? this.nnzTraining[n] : (type == 1 ? this.nnzTest[n] : this.nnzQuery[n]);
        if (add) {
            int elem = 0;
            while (elem < nnz) {
                int[] index = new int[this.N];
                int _mode = 0;
                while (_mode < this.N) {
                    index[_mode] = inIndex.readInt();
                    ++_mode;
                }
                float r = inValue.readFloat();
                float[] product = new float[this.C];
                int column = 0;
                while (column < this.C) {
                    product[column] = 1.0f;
                    ++column;
                }
                int dim = 0;
                while (dim < this.N) {
                    int column2 = 0;
                    while (column2 < this.C) {
                        int n2 = column2;
                        product[n2] = product[n2] * this.curCols[dim][index[dim]][column2];
                        ++column2;
                    }
                    ++dim;
                }
                column = 0;
                while (column < this.C) {
                    r += product[column];
                    ++column;
                }
                outValue.writeFloat(r);
                ++elem;
            }
        } else {
            int elem = 0;
            while (elem < nnz) {
                int[] index = new int[this.N];
                int _mode = 0;
                while (_mode < this.N) {
                    index[_mode] = inIndex.readInt();
                    ++_mode;
                }
                float r = inValue.readFloat();
                float[] product = new float[this.C];
                int column = 0;
                while (column < this.C) {
                    product[column] = 1.0f;
                    ++column;
                }
                int dim = 0;
                while (dim < this.N) {
                    int column3 = 0;
                    while (column3 < this.C) {
                        int n3 = column3;
                        product[n3] = product[n3] * this.curCols[dim][index[dim]][column3];
                        ++column3;
                    }
                    ++dim;
                }
                column = 0;
                while (column < this.C) {
                    r -= product[column];
                    ++column;
                }
                outValue.writeFloat(r);
                if (measureCost) {
                    errorSum += (double)(r * r);
                }
                ++elem;
            }
        }
        inIndex.close();
        inValue.close();
        outValue.close();
        this.replace(this.getLocalRPath(n, false, type, false));
        return errorSum;
    }

    private void communicate(int outIter, int splitIter, int C, int inIter, int n, Reducer.Context context, FileSystem fs) throws IOException {
        FSDataOutputStream out = null;
        Path outPath = new Path(this.getHDFSParamPath(outIter, splitIter, inIter, n, this.machineId, false));
        out = fs.create(outPath);
        int i = this.startIndex[n];
        while (i < this.endIndex[n]) {
            int column = 0;
            while (column < C) {
                out.writeFloat(this.curCols[n][i][column]);
                ++column;
            }
            ++i;
        }
        out.close();
        this.markWrite(outIter, splitIter, inIter, n, fs);
        boolean[] markReadComplete = new boolean[this.M];
        markReadComplete[this.machineId] = true;
        while (true) {
            long requestTime = System.currentTimeMillis();
            FileStatus[] statusList = fs.listStatus(new Path(this.getHDFSParamPath(outIter, splitIter, inIter, n, true)));
            SALSReducer.shuffle(statusList);
            FileStatus[] fileStatusArray = statusList;
            int n2 = statusList.length;
            int n3 = 0;
            while (n3 < n2) {
                block22: {
                    FileStatus status = fileStatusArray[n3];
                    int _taskId = Integer.valueOf(status.getPath().getName());
                    if (!markReadComplete[_taskId]) {
                        FSDataInputStream in = null;
                        try {
                            try {
                                in = fs.open(new Path(this.getHDFSParamPath(outIter, splitIter, inIter, n, _taskId, false)));
                                int i2 = this.getStartIndex(n, _taskId);
                                while (i2 < this.getStartIndex(n, _taskId + 1)) {
                                    int column = 0;
                                    while (column < C) {
                                        this.curCols[n][i2][column] = in.readFloat();
                                        ++column;
                                    }
                                    ++i2;
                                }
                            }
                            catch (Exception e) {
                                System.out.println(e.getMessage());
                                context.getCounter("Error", "err").increment(1L);
                                try {
                                    in.close();
                                }
                                catch (Exception exception) {}
                                break block22;
                            }
                        }
                        catch (Throwable throwable) {
                            try {
                                in.close();
                            }
                            catch (Exception exception) {
                                // empty catch block
                            }
                            throw throwable;
                        }
                        try {
                            in.close();
                        }
                        catch (Exception exception) {
                            // empty catch block
                        }
                        markReadComplete[_taskId] = true;
                    }
                }
                ++n3;
            }
            boolean markAll = true;
            int _taskId = 0;
            while (_taskId < this.M) {
                if (!markReadComplete[_taskId]) {
                    markAll = false;
                    break;
                }
                ++_taskId;
            }
            if (markAll) break;
            context.progress();
            long timeToWait = (long)((double)this.waiting * Math.random()) - (System.currentTimeMillis() - requestTime);
            if (timeToWait <= 0L) continue;
            try {
                Thread.sleep(timeToWait);
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    private void loadFromLocal(int n, int k, int c) throws IOException {
        ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream(this.getLocalParamPath(n, k, false))));
        int i = 0;
        while (i < this.modeLengths[n]) {
            this.curCols[n][i][c] = in.readFloat();
            ++i;
        }
        in.close();
    }

    private static int[] createRandomSequence(int n, Random rand) {
        int[] result = new int[n];
        int i = 0;
        while (i < n) {
            result[i] = i;
            ++i;
        }
        SALSReducer.shuffle(result, rand);
        return result;
    }

    private static void shuffle(int[] vec, Random rand) {
        int n = vec.length;
        int i = 0;
        while (i < n) {
            int randI = rand.nextInt(n - i);
            int temp = vec[n - i - 1];
            vec[n - i - 1] = vec[randI];
            vec[randI] = temp;
            ++i;
        }
    }

    private void writeFactors(int n, int k, int c) throws FileNotFoundException, IOException {
        ObjectOutputStream os = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(this.getLocalParamPath(n, k, false))));
        int row = 0;
        while (row < this.modeLengths[n]) {
            os.writeFloat(this.curCols[n][row][c]);
            ++row;
        }
        os.close();
    }
}

