/*
 * Decompiled with CFR 0.152.
 */
package sals.single;

import java.util.Random;
import sals.single.ArrayMethods;
import sals.single.GreedyRowAssignment;
import sals.single.MultiThread;
import sals.single.Output;
import sals.single.Performance;
import sals.single.Tensor;
import sals.single.TensorMethods;

public class CDTF {
    private static float epsilon = 1.0E-12f;
    private Tensor test;
    private float[][][] params;
    private float[][] bias;
    private float mu;

    public static void main(String[] args) throws Exception {
        CDTF.run(args, false);
    }

    public static void run(String[] args, boolean useBias) throws Exception {
        boolean inputError = true;
        try {
            System.out.println("========================");
            System.out.println("Start parameter check...");
            System.out.println("========================");
            String training = args[0];
            System.out.println("-training: " + training);
            String outputDir = args[1];
            System.out.println("-output: " + outputDir);
            int M = Integer.valueOf(args[2]);
            System.out.println("-M: " + M);
            int Tout = Integer.valueOf(args[3]);
            System.out.println("-Tout: " + Tout);
            int Tin = Integer.valueOf(args[4]);
            System.out.println("-Tin: " + Tin);
            int N = Integer.valueOf(args[5]);
            System.out.println("-N: " + N);
            int K = Integer.valueOf(args[6]);
            System.out.println("-K: " + K);
            float lambda = Float.valueOf(args[7]).floatValue();
            System.out.println("-lambda: " + lambda);
            boolean useWeight = Integer.valueOf(args[8]) > 0;
            System.out.println("-useWeight: " + useWeight);
            int[] modeSizes = new int[N];
            int n = 0;
            while (n < N) {
                modeSizes[n] = Integer.valueOf(args[9 + n]);
                System.out.println("-I" + (n + 1) + ": " + modeSizes[n]);
                ++n;
            }
            String test = null;
            if (args.length > 9 + N) {
                test = args[9 + N];
                System.out.println("-test: " + test);
            }
            String query = null;
            if (args.length > 10 + N) {
                query = args[10 + N];
                System.out.println("-query: " + query);
            }
            inputError = false;
            double[][] result = null;
            int[] modesIdx = ArrayMethods.createSequnce(N);
            System.out.println("==============================");
            System.out.println("Start greedy row assignment...");
            System.out.println("==============================");
            int[][] permutedIdx = GreedyRowAssignment.run(training, N, modesIdx, modeSizes, M);
            Tensor testTensor = null;
            if (test != null) {
                testTensor = TensorMethods.importSparseTensor(test, ",", modeSizes, modesIdx, N, permutedIdx);
            }
            String name = useBias ? "Bias-CDTF" : "CDTF";
            System.out.println("=============");
            System.out.println("Start " + name + "...");
            System.out.println("=============");
            CDTF method = new CDTF();
            Tensor trainingTensor = TensorMethods.importSparseTensor(training, ",", modeSizes, modesIdx, N, permutedIdx);
            if (test != null) {
                method.setTest(testTensor);
            }
            result = method.run(trainingTensor, K, Tout, Tin, M, lambda, useWeight, useBias, true);
            System.out.println("=======================");
            System.out.println("Start writing output...");
            System.out.println("=======================");
            Output.writePerformance(outputDir, result, Tout);
            Output.writeFactorMatrices(outputDir, method.params, permutedIdx);
            if (useBias) {
                Output.writeBiases(outputDir, method.bias, permutedIdx);
                Output.writeMU(outputDir, method.mu);
            }
            if (query != null) {
                Tensor queryTensor = TensorMethods.importSparseTensor(query, ",", modeSizes, modesIdx, 0, permutedIdx);
                if (useBias) {
                    Output.calculateEstimate(queryTensor, method.mu, method.bias, method.params, N, K);
                } else {
                    Output.calculateEstimate(queryTensor, method.params, N, K);
                }
                Output.writeEstimate(outputDir, queryTensor, permutedIdx, N);
            }
            System.out.println("===========");
            System.out.println("Complete!!!");
            System.out.println("===========");
        }
        catch (Exception e) {
            if (inputError) {
                String fileName = useBias ? "run_single_bias_cdtf.sh" : "run_single_cdtf.sh";
                System.err.println("Usage: " + fileName + " [training] [output] [M] [Tout] [Tin] [N] [K] [lambda] [useWeight] [I1] [I2] ... [IN] [test] [query]");
                e.printStackTrace();
            }
            throw e;
        }
    }

    private void setTest(Tensor test) {
        this.test = test;
    }

    public double[][] run(Tensor training, int K, int Tout, int Tin, final int M, final float lambda, final boolean useWeight, boolean useBias, boolean printLog) {
        int n;
        int bIndex;
        int n2;
        Random random = new Random();
        if (printLog) {
            System.out.println("iteration,elapsed_time,training_rmse,test_rmse");
        }
        final Tensor R = training.copy();
        final int nnzTraining = R.omega;
        boolean useTest = this.test != null;
        final int N = R.N;
        final int[] modeLength = R.modeLengths;
        final int[][] nnzFiber = TensorMethods.cardinality(R);
        final int[][][] division = new int[M][N][];
        int[][] divisionCount = new int[M][N];
        int elemIdx = 0;
        while (elemIdx < R.omega) {
            n2 = 0;
            while (n2 < N) {
                bIndex = Math.min(R.indices[n2][elemIdx] * M / modeLength[n2], M - 1);
                int[] nArray = divisionCount[bIndex];
                int n3 = n2++;
                nArray[n3] = nArray[n3] + 1;
            }
            ++elemIdx;
        }
        int m = 0;
        while (m < M) {
            n2 = 0;
            while (n2 < N) {
                division[m][n2] = new int[divisionCount[m][n2]];
                divisionCount[m][n2] = 0;
                ++n2;
            }
            ++m;
        }
        elemIdx = 0;
        while (elemIdx < R.omega) {
            n2 = 0;
            while (n2 < N) {
                bIndex = Math.min(R.indices[n2][elemIdx] * M / modeLength[n2], M - 1);
                int[] nArray = divisionCount[bIndex];
                int n4 = n2;
                int n5 = nArray[n4];
                nArray[n4] = n5 + 1;
                division[bIndex][n2][n5] = elemIdx;
                ++n2;
            }
            ++elemIdx;
        }
        if (useBias) {
            this.mu = training.mu;
            new MultiThread<Object>(){

                @Override
                public Object runJob(int m, int threadIndex) {
                    int[] indicies = CDTF.blockIndex(nnzTraining, M, m);
                    int rowStart = indicies[0];
                    int rowEnd = indicies[1];
                    int elemIdx = rowStart;
                    while (elemIdx <= rowEnd) {
                        int n = elemIdx++;
                        R.values[n] = R.values[n] - CDTF.this.mu;
                    }
                    return null;
                }
            }.run(M, MultiThread.createJobList(M));
            this.bias = new float[N][];
            n = 0;
            while (n < N) {
                this.bias[n] = new float[modeLength[n]];
                ++n;
            }
        }
        this.params = new float[N][][];
        n = 0;
        while (n < N) {
            this.params[n] = ArrayMethods.createUniformRandomMatrix(modeLength[n], K, n != 0 ? 1 : 0, random);
            ++n;
        }
        final float[][] updatedColumn = new float[N][];
        n2 = 0;
        while (n2 < N) {
            updatedColumn[n2] = new float[modeLength[n2]];
            ++n2;
        }
        double[][] result = new double[Tout][4];
        long start = System.currentTimeMillis();
        int outIter = 0;
        while (outIter < Tout) {
            int k = 0;
            while (k < K) {
                final int _k = k;
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int blockIndex, int threadIndex) {
                        int n = 0;
                        while (n < N) {
                            int[] indicies = CDTF.blockIndex(modeLength[n], M, blockIndex);
                            int rowStart = indicies[0];
                            int rowEnd = indicies[1];
                            int row = rowStart;
                            while (row <= rowEnd) {
                                updatedColumn[n][row] = CDTF.this.params[n][row][_k];
                                ++row;
                            }
                            ++n;
                        }
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int blockIndex, int threadIndex) {
                        int[] indicies = CDTF.blockIndex(nnzTraining, M, blockIndex);
                        int startIdx = indicies[0];
                        int endIdx = indicies[1];
                        CDTF.this.updateR(R, updatedColumn, startIdx, endIdx, true);
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                int innerIter = 0;
                while (innerIter < Tin) {
                    int n6 = 0;
                    while (n6 < N) {
                        final int _n = n6++;
                        new MultiThread<Object>(){

                            @Override
                            public Object runJob(int m, int threadIndex) {
                                int[] indicies = CDTF.blockIndex(modeLength[_n], M, m);
                                int startIdx = indicies[0];
                                int endIdx = indicies[1];
                                CDTF.this.updateFactors(R, division[m][_n], _n, updatedColumn, lambda, useWeight, startIdx, endIdx, nnzFiber[_n]);
                                return null;
                            }
                        }.run(M, MultiThread.createJobList(M));
                    }
                    ++innerIter;
                }
                new MultiThread<Object>(){

                    @Override
                    public Object runJob(int b, int threadIndex) {
                        int[] indicies = CDTF.blockIndex(nnzTraining, M, b);
                        int rowStart = indicies[0];
                        int rowEnd = indicies[1];
                        CDTF.this.updateR(R, updatedColumn, rowStart, rowEnd, false);
                        int n = 0;
                        while (n < N) {
                            int[] indicies2 = CDTF.blockIndex(modeLength[n], M, b);
                            int startIdx = indicies2[0];
                            int endIdx = indicies2[1];
                            int idx = startIdx;
                            while (idx <= endIdx) {
                                ((CDTF)CDTF.this).params[n][idx][_k] = updatedColumn[n][idx];
                                ++idx;
                            }
                            ++n;
                        }
                        return null;
                    }
                }.run(M, MultiThread.createJobList(M));
                ++k;
            }
            if (useBias) {
                int n7 = 0;
                while (n7 < N) {
                    final int _n = n7++;
                    final float[] oldBias = (float[])this.bias[_n].clone();
                    new MultiThread<Object>(){

                        @Override
                        public Object runJob(int m, int threadIndex) {
                            int[] indicies = CDTF.blockIndex(modeLength[_n], M, m);
                            int startIdx = indicies[0];
                            int endIdx = indicies[1];
                            CDTF.this.updateBiases(R, division[m][_n], _n, oldBias, CDTF.this.bias[_n], lambda, useWeight, startIdx, endIdx, nnzFiber[_n]);
                            return null;
                        }
                    }.run(M, MultiThread.createJobList(M));
                    new MultiThread<Object>(){

                        @Override
                        public Object runJob(int b, int threadIndex) {
                            int[] indicies = CDTF.blockIndex(nnzTraining, M, b);
                            int rowStart = indicies[0];
                            int rowEnd = indicies[1];
                            CDTF.this.updateR(R, _n, CDTF.this.bias[_n], oldBias, rowStart, rowEnd);
                            return null;
                        }
                    }.run(M, MultiThread.createJobList(M));
                }
            }
            final double[] loss = new double[1];
            new MultiThread<Object>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public Object runJob(int b, int threadIndex) {
                    double innerLoss = 0.0;
                    int[] indicies = CDTF.blockIndex(nnzTraining, M, b);
                    int rowStart = indicies[0];
                    int rowEnd = indicies[1];
                    int elemIdx = rowStart;
                    while (elemIdx <= rowEnd) {
                        innerLoss += (double)(R.values[elemIdx] * R.values[elemIdx]);
                        ++elemIdx;
                    }
                    double[] dArray = loss;
                    synchronized (loss) {
                        loss[0] = loss[0] + innerLoss;
                        // ** MonitorExit[var5_4] (shouldn't be in output)
                        return null;
                    }
                }
            }.run(M, MultiThread.createJobList(M));
            double trainingRMSE = Math.sqrt(loss[0] / (double)nnzTraining);
            double testRMSE = 0.0;
            if (useTest) {
                testRMSE = useBias ? Performance.computeRMSE(this.test, this.params, this.bias, this.mu, N, K, M) : Performance.computeRMSE(this.test, this.params, N, K, M);
            }
            long elapsedTime = System.currentTimeMillis() - start;
            if (printLog) {
                System.out.printf("%d,%d,%f,%f\n", outIter + 1, elapsedTime, trainingRMSE, testRMSE);
            }
            result[outIter] = new double[]{outIter + 1, elapsedTime, trainingRMSE, testRMSE};
            ++outIter;
        }
        return result;
    }

    public void updateFactors(Tensor R, int[] Rindex, int n, float[][] updatedColumn, float lambda, boolean useWeight, int firstRow, int lastRow, int[] nnzFiber) {
        int dimension = updatedColumn.length;
        int numberOfRows = lastRow - firstRow + 1;
        float[] numerators = new float[numberOfRows];
        float[] denominators = new float[numberOfRows];
        int idx = 0;
        while (idx < Rindex.length) {
            int resultIndex;
            int elemIdx = Rindex[idx];
            float numerator = 1.0f;
            float denominator = 1.0f;
            int i = 0;
            while (i < dimension) {
                if (i != n) {
                    numerator *= updatedColumn[i][R.indices[i][elemIdx]];
                }
                ++i;
            }
            denominator = numerator * numerator;
            int n2 = resultIndex = R.indices[n][elemIdx] - firstRow;
            numerators[n2] = numerators[n2] + (numerator *= R.values[elemIdx]);
            int n3 = resultIndex;
            denominators[n3] = denominators[n3] + denominator;
            ++idx;
        }
        int i = 0;
        while (i < numberOfRows) {
            int idx2 = i + firstRow;
            int n4 = i;
            denominators[n4] = denominators[n4] + lambda * (float)(useWeight ? nnzFiber[idx2] : 1);
            if (denominators[i] != 0.0f) {
                float result = numerators[i] / denominators[i];
                if (result > -epsilon && result < epsilon) {
                    result = 0.0f;
                }
                updatedColumn[n][idx2] = result;
            }
            ++i;
        }
    }

    public void updateBiases(Tensor R, int[] Rindex, int n, float[] oldBias, float[] updatedBias, float lambda, boolean useWeight, int firstRow, int lastRow, int[] nnzFiber) {
        int numberOfRows = lastRow - firstRow + 1;
        float[] numerators = new float[numberOfRows];
        float[] denominators = new float[numberOfRows];
        int idx = 0;
        while (idx < Rindex.length) {
            int resultIndex;
            int elemIdx = Rindex[idx];
            int rowIndex = R.indices[n][elemIdx];
            int n2 = resultIndex = rowIndex - firstRow;
            numerators[n2] = numerators[n2] + (R.values[elemIdx] + oldBias[rowIndex]);
            int n3 = resultIndex;
            denominators[n3] = denominators[n3] + 1.0f;
            ++idx;
        }
        int i = 0;
        while (i < numberOfRows) {
            int idx2 = i + firstRow;
            int n4 = i;
            denominators[n4] = denominators[n4] + lambda * (float)(useWeight ? nnzFiber[idx2] : 1);
            if (denominators[i] != 0.0f) {
                float result = numerators[i] / denominators[i];
                if (result > -epsilon && result < epsilon) {
                    result = 0.0f;
                }
                updatedBias[idx2] = result;
            }
            ++i;
        }
    }

    private void updateR(Tensor R, float[][] updatedColumn, int startIdx, int endIdx, boolean add) {
        if (add) {
            int idx = startIdx;
            while (idx <= endIdx) {
                float product = 1.0f;
                int n = 0;
                while (n < updatedColumn.length) {
                    product *= updatedColumn[n][R.indices[n][idx]];
                    ++n;
                }
                R.values[idx] = R.values[idx] + product;
                ++idx;
            }
        } else {
            int idx = startIdx;
            while (idx <= endIdx) {
                float product = 1.0f;
                int n = 0;
                while (n < updatedColumn.length) {
                    product *= updatedColumn[n][R.indices[n][idx]];
                    ++n;
                }
                R.values[idx] = R.values[idx] - product;
                ++idx;
            }
        }
    }

    private void updateR(Tensor R, int n, float[] updatedBias, float[] oldBias, int startIdx, int endIdx) {
        int idx = startIdx;
        while (idx <= endIdx) {
            int rowIdx = R.indices[n][idx];
            R.values[idx] = R.values[idx] + oldBias[rowIdx] - updatedBias[rowIdx];
            ++idx;
        }
    }

    private static int[] blockIndex(int n, int m, int i) {
        int[] result = new int[]{(int)Math.ceil(((double)n + 0.0) * (double)i / (double)m), (int)Math.ceil(((double)n + 0.0) * (double)(i + 1) / (double)m) - 1};
        return result;
    }
}

