package iitb.CRF;

import cern.colt.function.DoubleDoubleFunction;
import cern.colt.function.IntIntDoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import java.util.TreeSet;
import org.apache.log4j.Priority;

/* loaded from: input_file:iitb/CRF/RobustMath.class */
public class RobustMath {
    static final double MINUS_LOG_EPSILON = 30.0d;
    public static double LOG0 = -1.7976931348623157E308d;
    public static double LOG2 = 0.69314718055d;
    public static LogSumExp logSumExpFunc = new LogSumExp();
    static LogMult logMult = new LogMult();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:iitb/CRF/RobustMath$LogExpCache.class */
    public static class LogExpCache {
        static int CUT_OFF = 6;
        static int NUM_FINE = Priority.DEBUG_INT;
        static int NUM_COARSE = 1000;
        static boolean useCache = true;
        static double[] vals = new double[((CUT_OFF * NUM_FINE) + ((30 - CUT_OFF) * NUM_COARSE)) + 1];

        LogExpCache() {
        }

        static double lookupAdd(double d) {
            if (!useCache) {
                return Math.log(Math.exp((-1.0d) * d) + 1.0d);
            }
            int rint = d < ((double) CUT_OFF) ? (int) Math.rint(d * NUM_FINE) : (NUM_FINE * CUT_OFF) + ((int) Math.rint((d - CUT_OFF) * NUM_COARSE));
            if (vals[rint] < 0.0d) {
                vals[rint] = Math.log(Math.exp((-1.0d) * d) + 1.0d);
            }
            return vals[rint];
        }

        static {
            int length = vals.length - 1;
            while (length >= 0) {
                int i = length;
                length = i - 1;
                vals[i] = -1.0d;
            }
        }
    }

    /* loaded from: input_file:iitb/CRF/RobustMath$LogMult.class */
    static class LogMult implements IntIntDoubleFunction {
        DoubleMatrix2D M;
        DoubleMatrix1D z;
        double lalpha;
        boolean transposeA;
        DoubleMatrix1D y;

        LogMult() {
        }

        @Override // cern.colt.function.IntIntDoubleFunction
        public double apply(int i, int i2, double d) {
            int i3 = i;
            int i4 = i2;
            if (this.transposeA) {
                i3 = i2;
                i4 = i;
            }
            this.z.set(i3, RobustMath.logSumExp(this.z.get(i3), this.M.get(i, i2) + this.y.get(i4) + this.lalpha));
            return d;
        }
    }

    /* loaded from: input_file:iitb/CRF/RobustMath$LogSumExp.class */
    static class LogSumExp implements DoubleDoubleFunction {
        LogSumExp() {
        }

        @Override // cern.colt.function.DoubleDoubleFunction
        public double apply(double d, double d2) {
            return RobustMath.logSumExp(d, d2);
        }
    }

    public static double logSumExp(double d, double d2) {
        if (Math.abs(d - d2) < Double.MIN_VALUE) {
            return d + LOG2;
        }
        double min = Math.min(d, d2);
        double max = Math.max(d, d2);
        return max > min + 30.0d ? max : max + LogExpCache.lookupAdd(max - min);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void addNoDups(TreeSet treeSet, double d) {
        Double d2 = new Double(d);
        if (treeSet.add(d2)) {
            return;
        }
        treeSet.remove(d2);
        addNoDups(treeSet, d2.doubleValue() + LOG2);
    }

    public static double logSumExp(TreeSet treeSet) {
        while (treeSet.size() > 1) {
            double doubleValue = ((Double) treeSet.first()).doubleValue();
            treeSet.remove(treeSet.first());
            double doubleValue2 = ((Double) treeSet.first()).doubleValue();
            treeSet.remove(treeSet.first());
            addNoDups(treeSet, logSumExp(doubleValue, doubleValue2));
        }
        return treeSet.size() > 0 ? ((Double) treeSet.first()).doubleValue() : LOG0;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double logSumExp(DoubleMatrix1D doubleMatrix1D) {
        TreeSet treeSet = new TreeSet();
        for (int i = 0; i < doubleMatrix1D.size(); i++) {
            if (doubleMatrix1D.getQuick(i) != LOG0) {
                addNoDups(treeSet, doubleMatrix1D.getQuick(i));
            }
        }
        return logSumExp(treeSet);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void logSumExp(DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2) {
        for (int i = 0; i < doubleMatrix1D.size(); i++) {
            doubleMatrix1D.set(i, logSumExp(doubleMatrix1D.get(i), doubleMatrix1D2.get(i)));
        }
    }

    public static DoubleMatrix1D logMult(DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2, double d, double d2, boolean z) {
        double log = d != 1.0d ? Math.log(d) : 0.0d;
        if (d2 == 0.0d) {
            doubleMatrix1D2.assign(LOG0);
        } else if (d2 != 1.0d) {
            double log2 = Math.log(d2);
            for (int i = 0; i < doubleMatrix1D2.size(); i++) {
                doubleMatrix1D2.set(i, doubleMatrix1D2.get(i) + log2);
            }
        }
        logMult.M = doubleMatrix2D;
        logMult.z = doubleMatrix1D2;
        logMult.lalpha = log;
        logMult.transposeA = z;
        logMult.y = doubleMatrix1D;
        doubleMatrix2D.forEachNonZero(logMult);
        return doubleMatrix1D2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DoubleMatrix1D logMult(DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2, double d, double d2, boolean z, EdgeGenerator edgeGenerator) {
        double log = d != 1.0d ? Math.log(d) : 0.0d;
        if (d2 == 0.0d) {
            doubleMatrix1D2.assign(LOG0);
        } else if (d2 != 1.0d) {
            for (int i = 0; i < doubleMatrix1D2.size(); i++) {
                doubleMatrix1D2.set(i, doubleMatrix1D2.get(i) + Math.log(d2));
            }
        }
        for (int i2 = 0; i2 < doubleMatrix2D.columns(); i2++) {
            int first = edgeGenerator.first(i2);
            while (true) {
                int i3 = first;
                if (i3 < doubleMatrix2D.rows()) {
                    int i4 = i3;
                    int i5 = i2;
                    if (z) {
                        i4 = i2;
                        i5 = i3;
                    }
                    doubleMatrix1D2.setQuick(i4, logSumExp(doubleMatrix1D2.getQuick(i4), doubleMatrix2D.getQuick(i3, i2) + doubleMatrix1D.get(i5) + log));
                    first = edgeGenerator.next(i2, i3);
                }
            }
        }
        return doubleMatrix1D2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DoubleMatrix1D Mult(DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2, double d, double d2, boolean z, EdgeGenerator edgeGenerator) {
        for (int i = 0; i < doubleMatrix1D2.size(); i++) {
            doubleMatrix1D2.set(i, doubleMatrix1D2.get(i) * d2);
        }
        for (int i2 = 0; i2 < doubleMatrix2D.columns(); i2++) {
            int first = edgeGenerator.first(i2);
            while (true) {
                int i3 = first;
                if (i3 < doubleMatrix2D.rows()) {
                    int i4 = i3;
                    int i5 = i2;
                    if (z) {
                        i4 = i2;
                        i5 = i3;
                    }
                    doubleMatrix1D2.set(i4, doubleMatrix1D2.getQuick(i4) + (doubleMatrix2D.getQuick(i3, i2) * doubleMatrix1D.getQuick(i5) * d));
                    first = edgeGenerator.next(i2, i3);
                }
            }
        }
        return doubleMatrix1D2;
    }

    public static void main(String[] strArr) {
        System.out.println(logSumExp(Double.parseDouble(strArr[0]), Double.parseDouble(strArr[1])));
    }

    public static double exp(double d) {
        if (Double.isInfinite(d)) {
            return 0.0d;
        }
        if (d >= 0.0d || Math.abs(d) <= 30.0d) {
            return Math.exp(d);
        }
        return 0.0d;
    }

    public static double log(float f) {
        if (Math.abs(f - 1.0f) < Double.MIN_VALUE) {
            return 0.0d;
        }
        return Math.log(f);
    }
}
