from math import factorial
import mpmath
from mpmath import mpf, fdiv, power, ceil, mp

def fadd(a, b):
    return mpmath.fadd(a, b, exact=True)

def fmul(a, b):
    return mpmath.fmul(a, b, exact=True)

def fsub(a, b):
    return mpmath.fsub(a, b, exact=True)

mpf0, mpf10000 = mpf(0), mpf(10000)
factorial_dict = {}
binom_dict = {}

def factorial_lookup(n):
    if n not in factorial_dict:
        factorial_dict[n] = factorial(n)
    return factorial_dict[n]

def binom(n, r):
    return factorial_lookup(n) // (factorial_lookup(r) * factorial_lookup(n-r))

def binom_lookup(n, r):
    if n == 0: return 1
    if (n, r) not in binom_dict:
        binom_dict[(n, r)] = binom(n, r)
    return binom_dict[(n, r)]

def probsum_noK(r, a, n):
    s = mpf0
    for i in range(n+1):
        prod = fmul(power(mpf(a)/mpf10000, mpf(str(r-i))), power(mpf(10000-a)/mpf10000, mpf(str(i))))
        s = fadd(s, fmul(prod, mpf(str(binom_lookup(r, i)))))
    return s

def probsum(r, a, n, k=None, s=None):
    if k is None:
        return probsum_noK(r, a, n)
    else:
        return probsumK(r, a, n, k, s)

def binary_search(r, a, left, right, k=None, s=None, p=mpf0):
    if fadd(p, probsum(r, a, right, k, s)) < 0.5:
        return r
    while left <= right:
        mid = left + (right - left) // 2
        if fadd(p, probsum(r, a, mid, k, s)) >= 0.5:
            if fadd(p, probsum(r, a, mid-1, k, s)) < 0.5:
                return mid-1
            right = mid - 1
        else:
            left = mid + 1
    return r

def minN(r, a, k=None, s=None, p=mpf0):
    return binary_search(r, a, 0, r, k, s, p)

def rho_inv(r, a, d):
    n = minN(r, a)
    prob = probsum(r, a, n)
    remainder = fsub(mpf(1)/mpf(2), prob)
    a_mpf, oneminusa_mpf = mpf(a)/mpf10000, mpf(10000-a)/mpf10000
    temp = fdiv(remainder, fmul(power(a_mpf, mpf(r-(n+1))), power(oneminusa_mpf, mpf(n+1))))
    count = ceil(fmul(power(mpf(2), mpf(d-r)), temp))
    prob2 = probsum(r, 10000-a, n)
    extra = fmul(fmul(count, power(oneminusa_mpf, mpf(r-n-1))), power(a_mpf, mpf(n+1)))
    extra = fdiv(extra, power(mpf(2), mpf(d-r)))
    return fadd(prob2, extra)


def probsumK(r, a, n, k, s):
    total = mpf0
    a_mpf, oneminusa_mpf = mpf(a)/mpf10000, mpf(10000-a)/mpf10000
    apowers = power(a_mpf, mpf(s))
    omadivk = fdiv(oneminusa_mpf, mpf(k))
    omadivkmulkm1 = fmul(omadivk, mpf(k-1))
    for i in range(n-s+1):
        val = fmul(power(omadivk, mpf(r-i-s)), power(omadivkmulkm1, mpf(i)))
        val = fmul(val, apowers)
        val = fmul(val, mpf(binom_lookup(r, i)))
        val = fmul(val, mpf(binom_lookup(r-i, s)))
        total = fadd(total, val)
    return total

def probsumK2(r, a, n, k, s):
    total = mpf0
    a_mpf, oneminusa_mpf = mpf(a)/mpf10000, mpf(10000-a)/mpf10000
    for i in range(n-s+1):
        val = fmul(power(a_mpf, mpf(r-i-s)), power(fdiv(fmul(oneminusa_mpf, mpf(k-1)), mpf(k)), mpf(i)))
        val = fmul(val, power(fdiv(oneminusa_mpf, mpf(k)), mpf(s)))
        val = fmul(val, mpf(binom_lookup(r, i)))
        val = fmul(val, mpf(binom_lookup(r-i, s)))
        total = fadd(total, val)
    return total

def rho_invK(r, a, d, k):
    s = 0
    p = mpf0
    n = minN(r, a, k, s, p)
    a_mpf, oneminusa_mpf = mpf(a)/mpf10000, mpf(10000-a)/mpf10000
    while n == r:
        p += probsumK(r, a, r, k, s)
        s += 1
        n = minN(r, a, k, s, p)
    n -= s
    prob = fadd(p, probsumK(r, a, n, k, s))
    remainder = fsub(mpf(1)/mpf(2), prob)
    numerator = fmul(power(mpf(k+1), mpf(d-r)), remainder)
    divisor = power(fdiv(oneminusa_mpf, mpf(k)), mpf(r-n-1-s))
    divisor = fmul(divisor, power(fdiv(fmul(oneminusa_mpf, mpf(k-1)), mpf(k)), mpf(n+1)))
    divisor = fmul(divisor, mpf(binom_lookup(r, n+1)))
    divisor = fmul(divisor, power(a_mpf, mpf(s)))
    divisor = fmul(divisor, mpf(binom_lookup(r-n-1, s)))
    count = ceil(fdiv(numerator, divisor))

    prob2 = mpf0
    for i in range(s):
        prob2 = fadd(prob2, probsumK2(r, a, r, k, i))
    prob2 = fadd(prob2, probsumK2(r, a, n, k, s))
    extra = fmul(count, power(a_mpf, mpf(r-n-1-s)))
    extra = fmul(extra, power(fdiv(fmul(oneminusa_mpf, mpf(k-1)), mpf(k)), mpf(n+1)))
    extra = fmul(extra, mpf(binom_lookup(r, n+1)))
    extra = fmul(extra, power(fdiv(oneminusa_mpf, mpf(k)), mpf(s)))
    extra = fmul(extra, mpf(binom_lookup(r-n-1, s)))
    extra = fdiv(extra, power(mpf(k+1), mpf(d-r)))
    return fadd(prob2, extra)

# q is flip probability * 1000
# d is size of training set
# k is number of labels (classes)
# r is number of labels they differ on
if __name__ == '__main__':
    k = 10
    for q, precision in [(10, 1900), (50, 1300), (100, 1200), (150, 1100), (200, 900),
                         (250, 600), (300, 500), (400, 400), (450, 400), (475, 300)]:
        a = 10000-q
        for d in [1800, 13007, 25000, 50000]:
            with open(f'rhoinv/rho_inv_multiclass_q.{q}_d.{d}.txt', 'w') as f:
                print(q, d)
                for r in range(1, 400 if d == 25000 else 100 if d == 1800 else 2001):
                    if r % 10 == 0:
                        print(r)
                    mp.dps = precision + r
                    f.write(str(rho_invK(r, a, d, k-1)) + '\n')