/*
 * Decompiled with CFR 0.152.
 */
package sals.single;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;
import sals.single.ArrayMethods;
import sals.single.GreedyRowAssignment;
import sals.single.MultiThread;
import sals.single.Output;
import sals.single.Performance;
import sals.single.Tensor;
import sals.single.TensorMethods;

public class SALS {
    private static float epsilon = 1.0E-12f;
    private Tensor test;
    private float[][][] params;
    private float[][] bias;
    private float mu;

    public static void main(String[] args) throws Exception {
        SALS.run(args, false);
    }

    public static void run(String[] args, boolean useBias) throws Exception {
        boolean inputError = true;
        try {
            System.out.println("========================");
            System.out.println("Start parameter check...");
            System.out.println("========================");
            String training = args[0];
            System.out.println("-training: " + training);
            String outputDir = args[1];
            System.out.println("-output: " + outputDir);
            int M = Integer.valueOf(args[2]);
            System.out.println("-M: " + M);
            int Tout = Integer.valueOf(args[3]);
            System.out.println("-Tout: " + Tout);
            int Tin = Integer.valueOf(args[4]);
            System.out.println("-Tin: " + Tin);
            int N = Integer.valueOf(args[5]);
            System.out.println("-N: " + N);
            int K = Integer.valueOf(args[6]);
            System.out.println("-K: " + K);
            int C = Integer.valueOf(args[7]);
            System.out.println("-C: " + C);
            float lambda = Float.valueOf(args[8]).floatValue();
            System.out.println("-lambda: " + lambda);
            boolean useWeight = Integer.valueOf(args[9]) > 0;
            System.out.println("-useWeight: " + useWeight);
            int[] modeSizes = new int[N];
            int dim = 0;
            while (dim < N) {
                modeSizes[dim] = Integer.valueOf(args[10 + dim]);
                System.out.println("-I" + (dim + 1) + ": " + modeSizes[dim]);
                ++dim;
            }
            String test = null;
            if (args.length > 10 + N) {
                test = args[10 + N];
                System.out.println("-test: " + test);
            }
            String query = null;
            if (args.length > 11 + N) {
                query = args[11 + N];
                System.out.println("-query: " + query);
            }
            inputError = false;
            double[][] result = null;
            int[] modesIdx = ArrayMethods.createSequnce(N);
            System.out.println("==============================");
            System.out.println("Start greedy row assignment...");
            System.out.println("==============================");
            int[][] permutedIdx = GreedyRowAssignment.run(training, N, modesIdx, modeSizes, M);
            String name = useBias ? "Bias-SALS" : "SALS";
            System.out.println("=============");
            System.out.println("Start " + name + "...");
            System.out.println("=============");
            SALS method = new SALS();
            Tensor trainingTensor = TensorMethods.importSparseTensor(training, ",", modeSizes, modesIdx, N, permutedIdx);
            if (test != null) {
                Tensor testTensor = TensorMethods.importSparseTensor(test, ",", modeSizes, modesIdx, N, permutedIdx);
                method.setTest(testTensor);
            }
            result = method.run(trainingTensor, K, Tout, Tin, M, C, lambda, useWeight, useBias, true);
            System.out.println("=======================");
            System.out.println("Start writing output...");
            System.out.println("=======================");
            Output.writePerformance(outputDir, result, Tout);
            Output.writeFactorMatrices(outputDir, method.params, permutedIdx);
            if (useBias) {
                Output.writeBiases(outputDir, method.bias, permutedIdx);
                Output.writeMU(outputDir, method.mu);
            }
            if (query != null) {
                Tensor queryTensor = TensorMethods.importSparseTensor(query, ",", modeSizes, modesIdx, 0, permutedIdx);
                if (useBias) {
                    Output.calculateEstimate(queryTensor, method.mu, method.bias, method.params, N, K);
                } else {
                    Output.calculateEstimate(queryTensor, method.params, N, K);
                }
                Output.writeEstimate(outputDir, queryTensor, permutedIdx, N);
            }
            System.out.println("===========");
            System.out.println("Complete!!!");
            System.out.println("===========");
        }
        catch (Exception e) {
            if (inputError) {
                String fileName = useBias ? "run_single_bias_sals.sh" : "run_single_sals.sh";
                System.err.println("Usage: " + fileName + " [training] [output] [M] [Tout] [Tin] [N] [K] [C] [lambda] [useWeight] [I1] [I2] ... [IN] [test] [query]");
                e.printStackTrace();
            }
            throw e;
        }
    }

    private void setTest(Tensor test) {
        this.test = test;
    }

    public double[][] run(Tensor training, int K, int Tout, int Tin, final int M, int C, final float lambda, final boolean useWeight, boolean useBias, boolean printLog) {
        int dim;
        int bIndex;
        int n;
        Random random = new Random();
        final Tensor R = training.copy();
        final int nnzTraining = R.omega;
        boolean useTest = this.test != null;
        final int N = R.N;
        final int[] modeLengths = R.modeLengths;
        final int[][] nnzFiber = TensorMethods.cardinality(R);
        final int[][][] division = new int[M][N][];
        int[][] divisionCount = new int[M][N];
        int elemIdx = 0;
        while (elemIdx < R.omega) {
            n = 0;
            while (n < N) {
                bIndex = Math.min(R.indices[n][elemIdx] * M / modeLengths[n], M - 1);
                int[] nArray = divisionCount[bIndex];
                int n2 = n++;
                nArray[n2] = nArray[n2] + 1;
            }
            ++elemIdx;
        }
        int m = 0;
        while (m < M) {
            n = 0;
            while (n < N) {
                division[m][n] = new int[divisionCount[m][n]];
                divisionCount[m][n] = 0;
                ++n;
            }
            ++m;
        }
        elemIdx = 0;
        while (elemIdx < R.omega) {
            n = 0;
            while (n < N) {
                bIndex = Math.min(R.indices[n][elemIdx] * M / modeLengths[n], M - 1);
                int[] nArray = divisionCount[bIndex];
                int n3 = n;
                int n4 = nArray[n3];
                nArray[n3] = n4 + 1;
                division[bIndex][n][n4] = elemIdx;
                ++n;
            }
            ++elemIdx;
        }
        m = 0;
        while (m < M) {
            dim = 0;
            while (dim < N) {
                final int currentDim = dim;
                int[] arrayToSort = division[m][dim];
                Integer[] temp = new Integer[arrayToSort.length];
                int i = 0;
                while (i < arrayToSort.length) {
                    temp[i] = arrayToSort[i];
                    ++i;
                }
                Arrays.sort(temp, new Comparator<Integer>(){

                    @Override
                    public int compare(Integer elemIdx, Integer tElemIdx) {
                        int modeIdx = R.indices[currentDim][elemIdx];
                        int tModeIdx = R.indices[currentDim][tElemIdx];
                        return tModeIdx - modeIdx;
                    }
                });
                i = 0;
                while (i < arrayToSort.length) {
                    arrayToSort[i] = temp[i];
                    ++i;
                }
                ++dim;
            }
            ++m;
        }
        if (useBias) {
            this.mu = training.mu;
            new MultiThread<Object>(){

                @Override
                public Object runJob(int b, int threadIndex) {
                    int[] indicies = SALS.blockIndex(nnzTraining, M, b);
                    int rowStart = indicies[0];
                    int rowEnd = indicies[1];
                    int elemIdx = rowStart;
                    while (elemIdx <= rowEnd) {
                        int n = elemIdx++;
                        R.values[n] = R.values[n] - SALS.this.mu;
                    }
                    return null;
                }
            }.run(M, MultiThread.createJobList(M));
            this.bias = new float[N][];
            int n5 = 0;
            while (n5 < N) {
                this.bias[n5] = new float[modeLengths[n5]];
                ++n5;
            }
        }
        this.params = new float[N][][];
        int dim2 = 0;
        while (dim2 < N) {
            this.params[dim2] = ArrayMethods.createUniformRandomMatrix(modeLengths[dim2], K, dim2 != 0 ? 1 : 0, random);
            ++dim2;
        }
        final float[][][] currentParams = new float[N][][];
        dim = 0;
        while (dim < N) {
            currentParams[dim] = new float[modeLengths[dim]][C];
            ++dim;
        }
        double[][] result = new double[Tout][4];
        long start = System.currentTimeMillis();
        int outIter = 0;
        while (outIter < Tout) {
            final int[] permutedColumns = SALS.createRandomSequence(K, random);
            int splitIter = 0;
            while ((double)splitIter < Math.ceil(((double)K + 0.0) / (double)C)) {
                final int cStart = C * splitIter;
                final int cEnd = Math.min(C * (splitIter + 1) - 1, K - 1);
                final int cLength = cEnd - cStart + 1;
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int blockIndex, int threadIndex) {
                        int dim = 0;
                        while (dim < N) {
                            int[] indicies = SALS.blockIndex(modeLengths[dim], M, blockIndex);
                            int rowStart = indicies[0];
                            int rowEnd = indicies[1];
                            int row = rowStart;
                            while (row <= rowEnd) {
                                int column = cStart;
                                while (column <= cEnd) {
                                    currentParams[dim][row][column - cStart] = SALS.this.params[dim][row][permutedColumns[column]];
                                    ++column;
                                }
                                ++row;
                            }
                            ++dim;
                        }
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int blockIndex, int threadIndex) {
                        int[] indicies = SALS.blockIndex(nnzTraining, M, blockIndex);
                        int startIdx = indicies[0];
                        int endIdx = indicies[1];
                        SALS.this.updateR(R, currentParams, cLength, startIdx, endIdx, true);
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                int innerIter = 0;
                while (innerIter < Tin) {
                    int dim3 = 0;
                    while (dim3 < N) {
                        final int currentDim = dim3++;
                        new MultiThread<Object>(){

                            @Override
                            public Object runJob(int b, int threadIndex) {
                                int[] indicies = SALS.blockIndex(modeLengths[currentDim], M, b);
                                int startIdx = indicies[0];
                                int endIdx = indicies[1];
                                SALS.updateFactor(R, division[b][currentDim], currentDim, currentParams, cLength, lambda, useWeight, startIdx, endIdx, nnzFiber[currentDim]);
                                return null;
                            }
                        }.run(M, MultiThread.createJobList(M));
                    }
                    ++innerIter;
                }
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int b, int threadIndex) {
                        int[] indicies = SALS.blockIndex(nnzTraining, M, b);
                        int rowStart = indicies[0];
                        int rowEnd = indicies[1];
                        SALS.this.updateR(R, currentParams, cLength, rowStart, rowEnd, false);
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int b, int threadIndex) {
                        int dim = 0;
                        while (dim < N) {
                            int[] indicies = SALS.blockIndex(modeLengths[dim], M, b);
                            int startIdx = indicies[0];
                            int endIdx = indicies[1];
                            int idx = startIdx;
                            while (idx <= endIdx) {
                                int column = cStart;
                                while (column <= cEnd) {
                                    ((SALS)SALS.this).params[dim][idx][permutedColumns[column]] = currentParams[dim][idx][column - cStart];
                                    ++column;
                                }
                                ++idx;
                            }
                            ++dim;
                        }
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                ++splitIter;
            }
            if (useBias) {
                int n6 = 0;
                while (n6 < N) {
                    final int _n = n6++;
                    final float[] oldBias = (float[])this.bias[_n].clone();
                    new MultiThread<Object>(){

                        @Override
                        public Object runJob(int m, int threadIndex) {
                            int[] indicies = SALS.blockIndex(modeLengths[_n], M, m);
                            int startIdx = indicies[0];
                            int endIdx = indicies[1];
                            SALS.this.updateBiases(R, division[m][_n], _n, oldBias, SALS.this.bias[_n], lambda, useWeight, startIdx, endIdx, nnzFiber[_n]);
                            return null;
                        }
                    }.run(M, MultiThread.createJobList(M));
                    new MultiThread<Object>(){

                        @Override
                        public Object runJob(int b, int threadIndex) {
                            int[] indicies = SALS.blockIndex(nnzTraining, M, b);
                            int rowStart = indicies[0];
                            int rowEnd = indicies[1];
                            SALS.this.updateR(R, _n, SALS.this.bias[_n], oldBias, rowStart, rowEnd);
                            return null;
                        }
                    }.run(M, MultiThread.createJobList(M));
                }
            }
            final double[] loss = new double[1];
            new MultiThread<Object>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public Object runJob(int b, int threadIndex) {
                    double innerLoss = 0.0;
                    int[] indicies = SALS.blockIndex(nnzTraining, M, b);
                    int rowStart = indicies[0];
                    int rowEnd = indicies[1];
                    int elemIdx = rowStart;
                    while (elemIdx <= rowEnd) {
                        innerLoss += (double)(R.values[elemIdx] * R.values[elemIdx]);
                        ++elemIdx;
                    }
                    double[] dArray = loss;
                    synchronized (loss) {
                        loss[0] = loss[0] + innerLoss;
                        // ** MonitorExit[var5_4] (shouldn't be in output)
                        return null;
                    }
                }
            }.run(M, MultiThread.createJobList(M));
            double trainingRMSE = Math.sqrt(loss[0] / (double)nnzTraining);
            double testRMSE = 0.0;
            if (useTest) {
                testRMSE = useBias ? Performance.computeRMSE(this.test, this.params, this.bias, this.mu, N, K, M) : Performance.computeRMSE(this.test, this.params, N, K, M);
            }
            long elapsedTime = System.currentTimeMillis() - start;
            if (printLog) {
                System.out.printf("%d,%d,%f,%f\n", outIter + 1, elapsedTime, trainingRMSE, testRMSE);
            }
            result[outIter] = new double[]{outIter + 1, elapsedTime, trainingRMSE, testRMSE};
            ++outIter;
        }
        return result;
    }

    private static void updateFactor(Tensor R, int[] RIndex, int n, float[][][] currentParams, int cLength, float lambda, boolean useWeight, int firstRow, int lastRow, int[] nnzFiber) {
        int dimension = currentParams.length;
        double[][] B = new double[cLength][cLength];
        double[][] c = new double[cLength][1];
        int[] indices = new int[dimension];
        int oldResultIndex = -1;
        int idx = 0;
        while (idx < RIndex.length) {
            int r;
            int elemIdx = RIndex[idx];
            int resultIndex = R.indices[n][elemIdx];
            if (oldResultIndex >= 0 && oldResultIndex != resultIndex) {
                int r1 = 0;
                while (r1 < cLength) {
                    int r2 = r1 + 1;
                    while (r2 < cLength) {
                        B[r2][r1] = B[r1][r2];
                        ++r2;
                    }
                    ++r1;
                }
                int r2 = 0;
                while (r2 < cLength) {
                    double[] dArray = B[r2];
                    int n2 = r2;
                    dArray[n2] = dArray[n2] + (double)(lambda * (float)(useWeight ? nnzFiber[oldResultIndex] : 1));
                    ++r2;
                }
                double[][] newParam = new CholeskyDecomposition(new Matrix(B)).solve(new Matrix(c)).getArray();
                r = 0;
                while (r < cLength) {
                    float result = (float)newParam[r][0];
                    if (result > -epsilon && result < epsilon) {
                        result = 0.0f;
                    }
                    currentParams[n][oldResultIndex][r] = result;
                    ++r;
                }
                B = new double[cLength][cLength];
                c = new double[cLength][1];
            }
            oldResultIndex = resultIndex;
            double[] product = new double[cLength];
            r = 0;
            while (r < cLength) {
                product[r] = 1.0;
                ++r;
            }
            int dim = 0;
            while (dim < dimension) {
                indices[dim] = R.indices[dim][elemIdx];
                ++dim;
            }
            int i = 1;
            while (i < dimension) {
                int nextmode = (n + i) % dimension;
                int r3 = 0;
                while (r3 < cLength) {
                    int n3 = r3;
                    product[n3] = product[n3] * (double)currentParams[nextmode][indices[nextmode]][r3];
                    ++r3;
                }
                ++i;
            }
            float value = R.values[elemIdx];
            int r1 = 0;
            while (r1 < cLength) {
                int r2 = r1;
                while (r2 < cLength) {
                    double[] dArray = B[r1];
                    int n4 = r2;
                    dArray[n4] = dArray[n4] + product[r1] * product[r2];
                    ++r2;
                }
                double[] dArray = c[r1];
                dArray[0] = dArray[0] + product[r1] * (double)value;
                ++r1;
            }
            ++idx;
        }
        if (oldResultIndex >= 0) {
            int r1 = 0;
            while (r1 < cLength) {
                int r2 = r1 + 1;
                while (r2 < cLength) {
                    B[r2][r1] = B[r1][r2];
                    ++r2;
                }
                ++r1;
            }
            int r = 0;
            while (r < cLength) {
                double[] dArray = B[r];
                int n5 = r;
                dArray[n5] = dArray[n5] + (double)(lambda * (float)(useWeight ? nnzFiber[oldResultIndex] : 1));
                ++r;
            }
            double[][] newParam = new CholeskyDecomposition(new Matrix(B)).solve(new Matrix(c)).getArray();
            int r4 = 0;
            while (r4 < cLength) {
                float result = (float)newParam[r4][0];
                if (result > -epsilon && result < epsilon) {
                    result = 0.0f;
                }
                currentParams[n][oldResultIndex][r4] = result;
                ++r4;
            }
        }
    }

    private void updateBiases(Tensor R, int[] Rindex, int n, float[] oldBias, float[] currentBias, float lambda, boolean useWeight, int firstRow, int lastRow, int[] nnzFiber) {
        int numberOfRows = lastRow - firstRow + 1;
        float[] numerators = new float[numberOfRows];
        float[] denominators = new float[numberOfRows];
        int idx = 0;
        while (idx < Rindex.length) {
            int resultIndex;
            int elemIdx = Rindex[idx];
            int rowIndex = R.indices[n][elemIdx];
            int n2 = resultIndex = rowIndex - firstRow;
            numerators[n2] = numerators[n2] + (R.values[elemIdx] + oldBias[rowIndex]);
            int n3 = resultIndex;
            denominators[n3] = denominators[n3] + 1.0f;
            ++idx;
        }
        int i = 0;
        while (i < numberOfRows) {
            int idx2 = i + firstRow;
            int n4 = i;
            denominators[n4] = denominators[n4] + lambda * (float)(useWeight ? nnzFiber[idx2] : 1);
            if (denominators[i] != 0.0f) {
                float result = numerators[i] / denominators[i];
                if (result > -epsilon && result < epsilon) {
                    result = 0.0f;
                }
                currentBias[idx2] = result;
            }
            ++i;
        }
    }

    private void updateR(Tensor R, int n, float[] currentBias, float[] oldBias, int startIdx, int endIdx) {
        int idx = startIdx;
        while (idx <= endIdx) {
            int rowIdx = R.indices[n][idx];
            R.values[idx] = R.values[idx] + oldBias[rowIdx] - currentBias[rowIdx];
            ++idx;
        }
    }

    private double updateR(Tensor R, float[][][] currentParams, int columLength, int startIdx, int endIdx, boolean add) {
        double loss = 0.0;
        int dimension = currentParams.length;
        int[] indices = new int[dimension];
        if (add) {
            int idx = startIdx;
            while (idx <= endIdx) {
                int dim = 0;
                while (dim < dimension) {
                    indices[dim] = R.indices[dim][idx];
                    ++dim;
                }
                int columnIndex = 0;
                while (columnIndex < columLength) {
                    float product = 1.0f;
                    int dim2 = 0;
                    while (dim2 < currentParams.length) {
                        product *= currentParams[dim2][indices[dim2]][columnIndex];
                        ++dim2;
                    }
                    R.values[idx] = R.values[idx] + product;
                    ++columnIndex;
                }
                ++idx;
            }
        } else {
            int idx = startIdx;
            while (idx <= endIdx) {
                int dim = 0;
                while (dim < dimension) {
                    indices[dim] = R.indices[dim][idx];
                    ++dim;
                }
                int columnIndex = 0;
                while (columnIndex < columLength) {
                    float product = 1.0f;
                    int dim3 = 0;
                    while (dim3 < currentParams.length) {
                        product *= currentParams[dim3][indices[dim3]][columnIndex];
                        ++dim3;
                    }
                    R.values[idx] = R.values[idx] - product;
                    ++columnIndex;
                }
                ++idx;
            }
        }
        return loss;
    }

    private static int[] createRandomSequence(int n, Random random) {
        int[] result = new int[n];
        int i = 0;
        while (i < n) {
            result[i] = i;
            ++i;
        }
        SALS.shuffle(result, random);
        return result;
    }

    private static void shuffle(int[] vec, Random random) {
        int n = vec.length;
        int i = 0;
        while (i < n) {
            int randI = random.nextInt(n - i);
            int temp = vec[n - i - 1];
            vec[n - i - 1] = vec[randI];
            vec[randI] = temp;
            ++i;
        }
    }

    private static int[] blockIndex(int n, int m, int i) {
        int[] result = new int[]{(int)Math.ceil(((double)n + 0.0) * (double)i / (double)m), (int)Math.ceil(((double)n + 0.0) * (double)(i + 1) / (double)m) - 1};
        return result;
    }
}

