import torch
import numpy as np
import argparse
from datetime import datetime
from pretrain_omniglot import Embedder
import mpmath
from mpmath import mpf, fsub, fmul, fadd, fdiv, mp, power
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA, FastICA
precision = 1500
mp.dps = precision
max_prob = fsub(mpf(1), power(mpf(10), mpf(-precision)))


def compute_kl_label_robustness(p, K):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        p = max_prob
    elif p < 0.5:
        return 0
    num = fadd(fadd(mpmath.log(p), mpmath.log(fsub(mpf(1), p))), mpmath.log(mpf(4)))
    denom = fmul(mpf(2), fsub(mpf(1), fdiv(fmul(mpf(K), mpf(args.flip_prob)), mpf(K-1))))
    denom = fmul(denom, fsub(mpmath.log(mpf(args.flip_prob)), mpmath.log(fmul(fsub(mpf(1), mpf(args.flip_prob)), mpf(K-1)))))
    return int(mpmath.floor(fdiv(num, denom)))


def compute_discrete_label_robustness(p, rho_inv):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        p = max_prob
    if p <= 0.5:
        return 0
    return (p >= rho_inv).sum()


def deriv0(a, b, c, exp, negexp, alpha):
    a_negexp = a * negexp
    c_exp = c * exp
    return np.log(a_negexp + b + c_exp)


def deriv1(a, b, c, exp, negexp, alpha):
    a_negexp = a * negexp
    c_exp = c * exp
    numer = alpha * (c_exp - a_negexp)
    denom = a_negexp + b + c_exp
    return numer / denom


def deriv2(a, b, c, exp, negexp, alpha):
    a_negexp = a * negexp
    c_exp = c * exp
    z = a_negexp + b + c_exp
    dzdt = alpha * (c_exp - a_negexp)
    d2zdt2 = alpha ** 2 * (c_exp + a_negexp)
    return d2zdt2 / z - (dzdt / z) ** 2


def create_deriv_evaluator(q, alpha, K, Y, i, i_prime, deriv_func):
    q_div_kmin1 = q / (K - 1)

    def evaluator(t):
        exp = np.exp(alpha * t)
        negexp = np.exp(alpha * -t)
        vec_i = deriv_func(1-q, (K-2) * q_div_kmin1, q_div_kmin1, exp, negexp, alpha)
        vec_i_prime = deriv_func(q_div_kmin1, (K-2) * q_div_kmin1, 1-q, exp, negexp, alpha)
        vec_neither = deriv_func(q_div_kmin1, 1 - 2 * q_div_kmin1, q_div_kmin1, exp, negexp, alpha)
        prod = np.einsum('kj,ji->ki', np.vstack([vec_i, vec_i_prime, vec_neither]), Y)
        return prod[0, i] + prod[1, i_prime] + np.sum(np.delete(prod[2], [i, i_prime]))

    return evaluator

def newtons(d0, d1, d2, eps, beta):
    t = 1
    f_val = d0(t)
    f_prime = d1(t)
    count = 0
    while np.abs(f_prime) > eps and count < args.max_newton:
        count += 1
        f_prime2 = d2(t)
        if f_prime2 == 0:
            break
        direction = -f_prime / f_prime2
        step_size = 1
        f_next = d0(t + step_size * direction)
        f_prime_next = d1(t + step_size * direction)
        c1, c2 = 1e-4, 0.9
        # Check Wolfe condition, but also whether we've reached a small enough derivative or value
        while f_next > f_val + c1 * step_size * f_prime * direction and np.abs(f_prime_next) > eps:
            term1 = -direction * f_prime_next
            term2 = -c2 * direction * f_prime
            # curvature condition
            if term1 > term2:
                step_size /= beta
                break
            step_size *= beta
            f_next = d0(t + step_size * direction)
            f_prime_next = d1(t + step_size * direction)
        t += step_size * direction
        f_val = f_next
        f_prime = f_prime_next
    return t, f_val


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--flip-prob', type=float, default=0.3)
    parser.add_argument('--num-to-test', type=int, default=-1)
    parser.add_argument('--max-newton', type=int, default=3)
    parser.add_argument('--regularize', action='store_true', default=False)
    args = parser.parse_args()

    np.random.seed(args.seed)
    network = Embedder()

    with torch.no_grad():
        # print('Embedding train data')
        # mnist_train = datasets.MNIST("../mnist-data", train=True, download=True, transform=transforms.ToTensor())
        # y_train, num_train = mnist_train.targets, len(mnist_train)
        # train_loader = DataLoader(mnist_train, batch_size=512, shuffle=False)
        # embedding = None
        # for i, (X, _) in enumerate(train_loader):
        #     if i == 0:
        #         embedding = X.reshape(-1, 784)
        #     else:
        #         embedding = torch.cat([embedding, X.reshape(-1, 784)], dim=0)
        # ica = FastICA(n_components=20, whiten=True, max_iter=500).fit(embedding)
        # embedding = ica.transform(embedding)
        # np.save('data/multiclass_mnist/train/x_train_ica_embed.npy', embedding)
        # np.save('data/multiclass_mnist/train/y_train_ica.npy', y_train)
        embedding = torch.from_numpy(np.load('data/multiclass_mnist/train/x_train_ica_embed.npy')).float()
        y_train = np.load('data/multiclass_mnist/train/y_train_ica.npy')
        num_train = len(y_train)

        # print('Embedding test data')
        # mnist_test = datasets.MNIST("../mnist-data", train=False, download=True, transform=transforms.ToTensor())
        # y_test, num_test = mnist_test.targets, len(mnist_test)
        # test_loader = DataLoader(mnist_test, batch_size=100, shuffle=False)
        # test_embed = None
        # for i, (X, _) in enumerate(test_loader):
        #     if i == 0:
        #         test_embed = X.reshape(-1, 784)
        #     else:
        #         test_embed = torch.cat([test_embed, X.reshape(-1, 784)], dim=0)
        # test_embed = ica.transform(test_embed)
        # np.save('data/multiclass_mnist/test/x_test_ica_embed.npy', test_embed)
        # np.save('data/multiclass_mnist/test/y_test_ica.npy', y_test)
        test_embed = torch.from_numpy(np.load('data/multiclass_mnist/test/x_test_ica_embed.npy')).float()
        y_test = np.load('data/multiclass_mnist/test/y_test_ica.npy')
        num_test = len(y_test)
    K = int(y_train.max() + 1)
    Y = np.eye(K)[y_train]

    with open(f'rhoinv/rho_inv_multiclass_q.{int(args.flip_prob*1000)}_d.{num_train}.txt', 'r') as f:
        rho_inv = np.array(list(map(mpf, f.readline().strip().split(',')[:-1])))

    with torch.no_grad():
        filename = f'multiclass_flipprob{args.flip_prob}'
        if args.regularize:
            filename += '_regularized'
        data_file = open(f'logs/mnist/{filename}.txt_tmp', 'w')
        log_file = open(f'logs/mnist/{filename}_stats.txt_tmp', 'w')
        embedding = torch.cat([embedding, torch.ones((num_train, 1))], dim=1)
        XTX = torch.mm(embedding.t(), embedding)
        if args.regularize:
            dim = embedding.shape[1]
            condition = np.linalg.cond(embedding)
            XTXinvXT, _ = torch.solve(embedding.t(), XTX)
            residual = np.linalg.norm(Y - embedding.numpy() @ (XTXinvXT.numpy() @ Y))
            noise_estimate = residual ** 2 / (num_train - dim)
            reg_constant = (1 + args.flip_prob) / 2 * condition * dim * noise_estimate / num_train
            print(f'LAMBDA = {reg_constant}')
            print(f'LAMBDA = {reg_constant}', file=log_file, flush=True)
            XTX += torch.eye(dim) * reg_constant
        XTXinvXT, _ = torch.solve(embedding.t(), XTX)
        test_embed = torch.cat([test_embed, torch.ones((num_test, 1))], dim=1)
        hat_matrix = torch.mm(test_embed, XTXinvXT).numpy()

        # Noiseless least-squares accuracy
        # logits = np.empty((num_test, 10))
        # for cls in range(10):
        #     ytmp = np.where(y_train == cls, np.ones(num_train), np.zeros(num_train))
        #     logits[:, cls] = hat_matrix @ ytmp
        # predictions = logits.argmax(1)
        # print('Noiseless accuracy: ', (predictions == y_test).sum() / num_test)

        print(f'ID\tp\tflips_discrete\tflips_kl\tcorrect', file=data_file, flush=True)
        acc = 0
        if args.num_to_test == -1:
            args.num_to_test = num_test
        k = np.empty(args.num_to_test)
        z = []
        for sample in range(args.num_to_test):
            start = datetime.now()
            alpha = hat_matrix[sample]
            worst_bounds = np.array([-np.inf]*10)
            best_worst = -np.inf
            prediction = -1
            for class1 in range(K-1):
                for class2 in range(class1+1, K):
                    chernoff_calculator = create_deriv_evaluator(args.flip_prob, alpha, K, Y, class1, class2, deriv0)
                    deriv1_calculator = create_deriv_evaluator(args.flip_prob, alpha, K, Y, class1, class2, deriv1)
                    deriv2_calculator = create_deriv_evaluator(args.flip_prob, alpha, K, Y, class1, class2, deriv2)
                    # min_t = newton(deriv1_calculator, 1, tol=.5, disp=False)#, fprime=deriv2_calculator)
                    # val = chernoff_calculator(min_t)
                    min_t, val = newtons(chernoff_calculator, deriv1_calculator, deriv2_calculator, 5e-3, 0.5)
                    if val > worst_bounds[class1]:
                        worst_bounds[class1] = val
                    if val > worst_bounds[class2]:
                        worst_bounds[class2] = val
                    new_best_worst = worst_bounds.min()
                    if new_best_worst > best_worst:
                        best_worst = new_best_worst
                        prediction = class1 if min_t > 0 else class2
            chernoff_lower_bound = max(fsub(mpf(1), mpmath.exp(best_worst)), 0.5)
            discrete_robust = compute_discrete_label_robustness(chernoff_lower_bound, rho_inv)
            kl_robust = 0 if chernoff_lower_bound <= 0.5 else compute_kl_label_robustness(chernoff_lower_bound, K)
            k[sample] = discrete_robust
            elapsed = datetime.now() - start
            print(f'{sample} Chernoff: {chernoff_lower_bound} k (KL): {kl_robust} k: (discrete): {discrete_robust} '
                  f'({elapsed.seconds}.{elapsed.microseconds // 1000:03d} s)', file=log_file, flush=True)
            print(f'{sample} Chernoff: {chernoff_lower_bound} k (KL): {kl_robust} k: (discrete): {discrete_robust} '
                  f'({elapsed.seconds}.{elapsed.microseconds // 1000:03d} s)')
            print(f'{sample}\t{chernoff_lower_bound}\t{discrete_robust}\t{kl_robust}\t{int(prediction==y_test[sample])}',
                  file=data_file, flush=True)
            if prediction == y_test[sample]:
                acc += 1
                z.append(discrete_robust)
            else:
                z.append(0)
            # print(sample, acc/(sample+1), (np.array(z) >= 10).mean())
        print(f'Overall accuracy: {acc / args.num_to_test}')
        print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}')
        print(f'Overall accuracy: {acc / args.num_to_test}', file=log_file, flush=True)
        print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}', file=log_file, flush=True)
        data_file.close()
        log_file.close()