package iitb.CRF;

import cern.colt.function.DoubleDoubleFunction;
import cern.colt.function.DoubleFunction;
import cern.colt.matrix.impl.AbstractFormatter;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import riso.numerical.LBFGS;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:iitb/CRF/Trainer.class */
public class Trainer {
    int numF;
    int numY;
    double[] gradLogli;
    double[] diag;
    double[] lambda;
    DenseDoubleMatrix2D Mi_YY;
    DenseDoubleMatrix1D Ri_Y;
    DenseDoubleMatrix1D alpha_Y;
    DenseDoubleMatrix1D newAlpha_Y;
    DenseDoubleMatrix1D[] beta_Y;
    DenseDoubleMatrix1D tmp_Y;
    double[] ExpF;
    double[] scale;
    double[] rLogScale;
    DataIter diter;
    FeatureGenerator featureGenerator;
    CrfParams params;
    EdgeGenerator edgeGen;
    int icall;
    MultFunc multFunc = new MultFunc(this);
    SumFunc sumFunc = new SumFunc(this);
    MultSingle constMultiplier = new MultSingle(this);
    Evaluator evaluator = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/Trainer$MultFunc.class */
    public class MultFunc implements DoubleDoubleFunction {
        private final Trainer this$0;

        MultFunc(Trainer trainer) {
            this.this$0 = trainer;
        }

        @Override // cern.colt.function.DoubleDoubleFunction
        public double apply(double d, double d2) {
            return d * d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/Trainer$MultSingle.class */
    public class MultSingle implements DoubleFunction {
        public double multiplicator = 1.0d;
        private final Trainer this$0;

        MultSingle(Trainer trainer) {
            this.this$0 = trainer;
        }

        @Override // cern.colt.function.DoubleFunction
        public double apply(double d) {
            return d * this.multiplicator;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/Trainer$SumFunc.class */
    public class SumFunc implements DoubleDoubleFunction {
        private final Trainer this$0;

        SumFunc(Trainer trainer) {
            this.this$0 = trainer;
        }

        @Override // cern.colt.function.DoubleDoubleFunction
        public double apply(double d, double d2) {
            return d + d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double norm(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    public Trainer(CrfParams crfParams) {
        this.params = crfParams;
    }

    public void train(CRF crf, DataIter dataIter, double[] dArr, Evaluator evaluator) {
        init(crf, dataIter, dArr);
        this.evaluator = evaluator;
        if (this.params.debugLvl > 0) {
            Util.printDbg(new StringBuffer().append("Number of features :").append(this.lambda.length).toString());
        }
        doTrain();
    }

    double getInitValue() {
        return this.params.initValue;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void init(CRF crf, DataIter dataIter, double[] dArr) {
        this.edgeGen = crf.edgeGen;
        this.lambda = dArr;
        this.numY = crf.numY;
        this.diter = dataIter;
        this.featureGenerator = crf.featureGenerator;
        this.numF = this.featureGenerator.numFeatures();
        this.gradLogli = new double[this.numF];
        this.diag = new double[this.numF];
        this.Mi_YY = new DenseDoubleMatrix2D(this.numY, this.numY);
        this.Ri_Y = new DenseDoubleMatrix1D(this.numY);
        this.alpha_Y = new DenseDoubleMatrix1D(this.numY);
        this.newAlpha_Y = new DenseDoubleMatrix1D(this.numY);
        this.tmp_Y = new DenseDoubleMatrix1D(this.numY);
        this.ExpF = new double[this.lambda.length];
    }

    void doTrain() {
        this.icall = 0;
        int[] iArr = {this.params.debugLvl - 2, this.params.debugLvl - 1};
        int[] iArr2 = {0};
        for (int i = 0; i < this.lambda.length; i++) {
            this.lambda[i] = getInitValue();
        }
        do {
            double computeFunctionGradient = (-1.0d) * computeFunctionGradient(this.lambda, this.gradLogli);
            for (int i2 = 0; i2 < this.lambda.length; i2++) {
                double[] dArr = this.gradLogli;
                int i3 = i2;
                dArr[i3] = dArr[i3] * (-1.0d);
            }
            if (this.evaluator != null && !this.evaluator.evaluate()) {
                return;
            }
            try {
                LBFGS.lbfgs(this.numF, this.params.mForHessian, this.lambda, computeFunctionGradient, this.gradLogli, false, this.diag, iArr, this.params.epsForConvergence, 1.0E-16d, iArr2);
                this.icall++;
                if (iArr2[0] == 0) {
                    return;
                }
            } catch (LBFGS.ExceptionWithIflag e) {
                System.err.println(new StringBuffer().append("CRF: lbfgs failed.\n").append(e).toString());
                return;
            }
        } while (this.icall <= this.params.maxIters);
    }

    protected double computeFunctionGradient(double[] dArr, double[] dArr2) {
        if (this.params.trainerType.equals("ll")) {
            return computeFunctionGradientLL(dArr, dArr2);
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            } catch (Exception e) {
                System.out.println(new StringBuffer().append("Alpha-i ").append(this.alpha_Y.toString()).toString());
                System.out.println(new StringBuffer().append("Ri ").append(this.Ri_Y.toString()).toString());
                System.out.println(new StringBuffer().append("Mi ").append(this.Mi_YY.toString()).toString());
                e.printStackTrace();
                System.exit(0);
            }
        }
        boolean z = this.params.doScaling;
        this.diter.startScan();
        int i2 = 0;
        while (this.diter.hasNext()) {
            DataSequence next = this.diter.next();
            if (this.params.debugLvl > 1) {
                Util.printDbg(new StringBuffer().append("Read next seq: ").append(i2).append(" logli ").append(d).toString());
            }
            this.alpha_Y.assign(1.0d);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                this.ExpF[i3] = 0.0d;
            }
            if (this.beta_Y == null || this.beta_Y.length < next.length()) {
                this.beta_Y = new DenseDoubleMatrix1D[2 * next.length()];
                for (int i4 = 0; i4 < this.beta_Y.length; i4++) {
                    this.beta_Y[i4] = new DenseDoubleMatrix1D(this.numY);
                }
                this.scale = new double[2 * next.length()];
            }
            this.scale[next.length() - 1] = z ? this.numY : 1.0d;
            this.beta_Y[next.length() - 1].assign(1.0d / this.scale[next.length() - 1]);
            for (int length = next.length() - 1; length > 0; length--) {
                if (this.params.debugLvl > 2) {
                    Util.printDbg("Features fired");
                }
                computeLogMi(this.featureGenerator, dArr, next, length, this.Mi_YY, this.Ri_Y, true);
                this.tmp_Y.assign(this.beta_Y[length]);
                this.tmp_Y.assign(this.Ri_Y, this.multFunc);
                RobustMath.Mult(this.Mi_YY, this.tmp_Y, this.beta_Y[length - 1], 1.0d, 0.0d, false, this.edgeGen);
                this.scale[length - 1] = z ? this.beta_Y[length - 1].zSum() : 1.0d;
                if (this.scale[length - 1] < 1.0d && this.scale[length - 1] > -1.0d) {
                    this.scale[length - 1] = 1.0d;
                }
                this.constMultiplier.multiplicator = 1.0d / this.scale[length - 1];
                this.beta_Y[length - 1].assign(this.constMultiplier);
            }
            double d2 = 0.0d;
            for (int i5 = 0; i5 < next.length(); i5++) {
                computeLogMi(this.featureGenerator, dArr, next, i5, this.Mi_YY, this.Ri_Y, true);
                this.featureGenerator.startScanFeaturesAt(next, i5);
                this.tmp_Y.assign(this.alpha_Y);
                RobustMath.Mult(this.Mi_YY, this.tmp_Y, this.newAlpha_Y, 1.0d, 0.0d, true, this.edgeGen);
                this.newAlpha_Y.assign(this.Ri_Y, this.multFunc);
                while (this.featureGenerator.hasNext()) {
                    Feature next2 = this.featureGenerator.next();
                    int index = next2.index();
                    int y = next2.y();
                    int yprev = next2.yprev();
                    float value = next2.value();
                    if (next.y(i5) == y && ((i5 - 1 >= 0 && yprev == next.y(i5 - 1)) || yprev < 0)) {
                        dArr2[index] = dArr2[index] + value;
                        d2 += value * dArr[index];
                    }
                    if (yprev < 0) {
                        double[] dArr3 = this.ExpF;
                        dArr3[index] = dArr3[index] + (this.newAlpha_Y.get(y) * value * this.beta_Y[i5].get(y));
                    } else {
                        double[] dArr4 = this.ExpF;
                        dArr4[index] = dArr4[index] + (this.alpha_Y.get(yprev) * this.Ri_Y.get(y) * this.Mi_YY.get(yprev, y) * value * this.beta_Y[i5].get(y));
                    }
                }
                this.alpha_Y.assign(this.newAlpha_Y);
                this.constMultiplier.multiplicator = 1.0d / this.scale[i5];
                this.alpha_Y.assign(this.constMultiplier);
                if (this.params.debugLvl > 2) {
                    System.out.println(new StringBuffer().append("Alpha-i ").append(this.alpha_Y.toString()).toString());
                    System.out.println(new StringBuffer().append("Ri ").append(this.Ri_Y.toString()).toString());
                    System.out.println(new StringBuffer().append("Mi ").append(this.Mi_YY.toString()).toString());
                    System.out.println(new StringBuffer().append("Beta-i ").append(this.beta_Y[i5].toString()).toString());
                }
            }
            double zSum = this.alpha_Y.zSum();
            if (zSum == 0.0d) {
                zSum = 4.94065646E-316d;
            }
            double log = d2 - log(zSum);
            for (int i6 = 0; i6 < next.length(); i6++) {
                log -= log(this.scale[i6]);
            }
            d += log;
            for (int i7 = 0; i7 < dArr2.length; i7++) {
                int i8 = i7;
                dArr2[i8] = dArr2[i8] - (this.ExpF[i7] / zSum);
            }
            if (this.params.debugLvl > 1) {
                System.out.println(new StringBuffer().append("Sequence ").append(log).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(d).toString());
            }
            i2++;
        }
        if (this.params.debugLvl > 2) {
            for (double d3 : dArr) {
                System.out.print(new StringBuffer().append(d3).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString());
            }
            System.out.println(" :x");
            for (int i9 = 0; i9 < dArr.length; i9++) {
                System.out.print(new StringBuffer().append(dArr2[i9]).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString());
            }
            System.out.println(" :g");
        }
        if (this.params.debugLvl > 0) {
            Util.printDbg(new StringBuffer().append("Iter ").append(this.icall).append(" log likelihood ").append(d).append(" norm(grad logli) ").append(norm(dArr2)).append(" norm(x) ").append(norm(dArr)).toString());
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeLogMi(FeatureGenerator featureGenerator, double[] dArr, DenseDoubleMatrix2D denseDoubleMatrix2D, DenseDoubleMatrix1D denseDoubleMatrix1D, boolean z) {
        denseDoubleMatrix2D.assign(0.0d);
        denseDoubleMatrix1D.assign(0.0d);
        while (featureGenerator.hasNext()) {
            Feature next = featureGenerator.next();
            int index = next.index();
            int y = next.y();
            int yprev = next.yprev();
            float value = next.value();
            if (yprev < 0) {
                denseDoubleMatrix1D.set(y, denseDoubleMatrix1D.get(y) + (dArr[index] * value));
            } else {
                denseDoubleMatrix2D.set(yprev, y, denseDoubleMatrix2D.get(yprev, y) + (dArr[index] * value));
            }
        }
        if (z) {
            for (int i = 0; i < denseDoubleMatrix2D.rows(); i++) {
                denseDoubleMatrix1D.set(i, exp(denseDoubleMatrix1D.get(i)));
                for (int i2 = 0; i2 < denseDoubleMatrix2D.columns(); i2++) {
                    denseDoubleMatrix2D.set(i, i2, exp(denseDoubleMatrix2D.get(i, i2)));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeLogMi(FeatureGenerator featureGenerator, double[] dArr, DataSequence dataSequence, int i, DenseDoubleMatrix2D denseDoubleMatrix2D, DenseDoubleMatrix1D denseDoubleMatrix1D, boolean z) {
        featureGenerator.startScanFeaturesAt(dataSequence, i);
        computeLogMi(featureGenerator, dArr, denseDoubleMatrix2D, denseDoubleMatrix1D, z);
    }

    protected double computeFunctionGradientLL(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            } catch (Exception e) {
                System.out.println(new StringBuffer().append("Alpha-i ").append(this.alpha_Y.toString()).toString());
                System.out.println(new StringBuffer().append("Ri ").append(this.Ri_Y.toString()).toString());
                System.out.println(new StringBuffer().append("Mi ").append(this.Mi_YY.toString()).toString());
                e.printStackTrace();
                System.exit(0);
            }
        }
        this.diter.startScan();
        int i2 = 0;
        while (this.diter.hasNext()) {
            DataSequence next = this.diter.next();
            if (this.params.debugLvl > 1) {
                Util.printDbg(new StringBuffer().append("Read next seq: ").append(i2).append(" logli ").append(d).toString());
            }
            this.alpha_Y.assign(0.0d);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                this.ExpF[i3] = RobustMath.LOG0;
            }
            if (this.beta_Y == null || this.beta_Y.length < next.length()) {
                this.beta_Y = new DenseDoubleMatrix1D[2 * next.length()];
                for (int i4 = 0; i4 < this.beta_Y.length; i4++) {
                    this.beta_Y[i4] = new DenseDoubleMatrix1D(this.numY);
                }
            }
            this.beta_Y[next.length() - 1].assign(0.0d);
            for (int length = next.length() - 1; length > 0; length--) {
                if (this.params.debugLvl > 2) {
                    Util.printDbg("Features fired");
                    this.featureGenerator.startScanFeaturesAt(next, length);
                    while (this.featureGenerator.hasNext()) {
                        Util.printDbg(this.featureGenerator.next().toString());
                    }
                }
                computeLogMi(this.featureGenerator, dArr, next, length, this.Mi_YY, this.Ri_Y, false);
                this.tmp_Y.assign(this.beta_Y[length]);
                this.tmp_Y.assign(this.Ri_Y, this.sumFunc);
                RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.beta_Y[length - 1], 1.0d, 0.0d, false);
            }
            double d2 = 0.0d;
            for (int i5 = 0; i5 < next.length(); i5++) {
                computeLogMi(this.featureGenerator, dArr, next, i5, this.Mi_YY, this.Ri_Y, false);
                this.featureGenerator.startScanFeaturesAt(next, i5);
                this.tmp_Y.assign(this.alpha_Y);
                RobustMath.logMult(this.Mi_YY, this.tmp_Y, this.newAlpha_Y, 1.0d, 0.0d, true);
                this.newAlpha_Y.assign(this.Ri_Y, this.sumFunc);
                while (this.featureGenerator.hasNext()) {
                    Feature next2 = this.featureGenerator.next();
                    int index = next2.index();
                    int y = next2.y();
                    int yprev = next2.yprev();
                    float value = next2.value();
                    if (next.y(i5) == y && ((i5 - 1 >= 0 && yprev == next.y(i5 - 1)) || yprev < 0)) {
                        dArr2[index] = dArr2[index] + value;
                        d2 += value * dArr[index];
                    }
                    if (yprev < 0) {
                        this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.newAlpha_Y.get(y) + Math.log(value) + this.beta_Y[i5].get(y));
                    } else {
                        this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.alpha_Y.get(yprev) + this.Ri_Y.get(y) + this.Mi_YY.get(yprev, y) + Math.log(value) + this.beta_Y[i5].get(y));
                    }
                }
                this.alpha_Y.assign(this.newAlpha_Y);
                if (this.params.debugLvl > 2) {
                    System.out.println(new StringBuffer().append("Alpha-i ").append(this.alpha_Y.toString()).toString());
                    System.out.println(new StringBuffer().append("Ri ").append(this.Ri_Y.toString()).toString());
                    System.out.println(new StringBuffer().append("Mi ").append(this.Mi_YY.toString()).toString());
                    System.out.println(new StringBuffer().append("Beta-i ").append(this.beta_Y[i5].toString()).toString());
                }
            }
            double logSumExp = RobustMath.logSumExp(this.alpha_Y);
            double d3 = d2 - logSumExp;
            d += d3;
            for (int i6 = 0; i6 < dArr2.length; i6++) {
                int i7 = i6;
                dArr2[i7] = dArr2[i7] - Math.exp(this.ExpF[i6] - logSumExp);
            }
            if (this.params.debugLvl > 1) {
                System.out.println(new StringBuffer().append("Sequence ").append(d3).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(d).toString());
            }
            i2++;
        }
        if (this.params.debugLvl > 2) {
            for (double d4 : dArr) {
                System.out.print(new StringBuffer().append(d4).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString());
            }
            System.out.println(" :x");
            for (int i8 = 0; i8 < dArr.length; i8++) {
                System.out.print(new StringBuffer().append(dArr2[i8]).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString());
            }
            System.out.println(" :g");
        }
        if (this.params.debugLvl > 0) {
            Util.printDbg(new StringBuffer().append("Iteration ").append(this.icall).append(" log likelihood ").append(d).append(" norm(grad logli) ").append(norm(dArr2)).append(" norm(x) ").append(norm(dArr)).toString());
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double log(double d) {
        try {
            return logE(d);
        } catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
            return -1.7976931348623157E308d;
        }
    }

    static double exp(double d) {
        try {
            return expE(d);
        } catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
            return Double.MAX_VALUE;
        }
    }

    static double logE(double d) throws Exception {
        double log = Math.log(d);
        if (Double.isNaN(log) || Double.isInfinite(log)) {
            throw new Exception(new StringBuffer().append("Overflow error when taking log of ").append(d).toString());
        }
        return log;
    }

    static double expE(double d) throws Exception {
        double exp = Math.exp(d);
        if (Double.isNaN(exp) || Double.isInfinite(exp)) {
            throw new Exception(new StringBuffer().append("Overflow error when taking exp of ").append(d).toString());
        }
        return exp;
    }
}
