import torch
import numpy as np
import argparse
from datetime import datetime
from pretrain import Embedder
import mpmath
from mpmath import mpf, fsub, fmul, fadd, fdiv, mp
from scipy.special import logsumexp
from scipy.optimize import newton
mp.dps = 600


def compute_kl_label_robustness(p):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        return len(rho_inv)
    if p < 0.5:
        return 0
    num = fadd(fadd(mpmath.log(p), mpmath.log(fsub(mpf(1), p))), mpmath.log(mpf(4)))
    denom = fmul(mpf(2), fsub(mpf(1), fmul(mpf(2), mpf(args.flip_prob))))
    denom = fmul(denom, fsub(mpmath.log(mpf(args.flip_prob)), mpmath.log(fsub(mpf(1), mpf(args.flip_prob)))))
    return int(mpmath.floor(fdiv(num, denom)))

def compute_discrete_label_robustness(p, rho_inv):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        return len(rho_inv)
    if p <= 0.5:
        return 0
    return (p >= rho_inv).sum()

# Compute element-wise log of a + b * exp(c)
def log_abc(a, b, c):
    terms = np.vstack([np.log(a), np.log(b) + c])
    terms = logsumexp(terms, axis=0)
    return terms

def chernoff_bound(t, b, c, w):
    terms = log_abc(b, c, -t * w)
    sum = t / 2 + terms.sum()
    if not np.isfinite(sum):
        return 0.5
    return mpmath.exp(sum)

def mp_log_abc(a, b, c):
    return mpmath.log(a + b * mpmath.exp(c))

mp_vec_abc = np.vectorize(mp_log_abc)

def deriv1(t, b, c, w):
    log_terms = np.log(c)
    log_terms -= log_abc(c, b, t * w)
    terms = np.exp(log_terms) * w
    return 0.5 - terms.sum()

def deriv2(t, b, c, w):
    log_terms = t * w + np.log(b) + np.log(c) + 2 * np.log(np.abs(w))
    log_terms -= 2 * log_abc(c, b, t * w)
    return np.exp(log_terms).sum()

# Turns out for this case a single newton step is all we need.
def newtons(probs, alpha):
    t = 1
    inv_probs = 1 - probs
    f_prime = deriv1(t, probs, inv_probs, alpha)
    f_prime2 = deriv2(t, probs, inv_probs, alpha)
    direction = -f_prime / f_prime2
    t += direction
    f_val = chernoff_bound(t, probs, inv_probs, alpha)
    return t, f_val


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('modeldir')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--flip-prob', type=float, default=0.3)
    parser.add_argument('--num-to-test', type=int, default=-1)
    parser.add_argument('--max-newton', type=int, default=300)
    parser.add_argument('--regularize', action='store_true', default=False)
    args = parser.parse_args()

    np.random.seed(args.seed)
    network = Embedder()
    filename = f'{args.modeldir}/embedder.pth'
    network.load_state_dict(torch.load(filename))
    network.eval()

    x_train = torch.tensor(np.load('./data/binarymnist/train/xtrain_17.npy'))
    num_train = len(x_train)
    y_train = np.load('./data/binarymnist/train/ytrain_17.npy')
    y_train[y_train == 1] = 0
    y_train[y_train == 7] = 1
    # probs is the vector of probabilities that each training example will be labeled 0.
    probs = np.float64(np.empty(num_train))
    probs[y_train == 0] = 1 - args.flip_prob
    probs[y_train == 1] = args.flip_prob

    x_test = torch.tensor(np.load('./data/binarymnist/test/xtest_17.npy'))
    num_test = len(x_test)
    y_test = np.load('./data/binarymnist/test/ytest_17.npy')
    y_test[y_test == 1] = 0
    y_test[y_test == 7] = 1

    with open(f'rhoinv/rho_inv_q.{int(args.flip_prob*1000)}_d.{num_train}.txt', 'r') as f:
        rho_inv = np.array(list(map(mpf, f.readline().strip().split(',')[:-1])))

    with torch.no_grad():
        filename = f'flipprob{args.flip_prob}'
        if args.regularize:
            filename += '_regularized'
        data_file = open(f'logs/mnist/{filename}.txt_tmp', 'w')
        log_file = open(f'logs/mnist/{filename}_stats.txt_tmp', 'w')
        embedding = network.embed(x_train)
        embedding = torch.cat([embedding, torch.ones((num_train, 1))], dim=1)
        XTX = torch.mm(embedding.t(), embedding)
        if args.regularize:
            dim = embedding.shape[1]
            condition = np.linalg.cond(XTX)
            XTXinvXT, _ = torch.solve(embedding.t(), XTX)
            residual = np.linalg.norm(y_train - embedding.numpy() @ XTXinvXT.numpy() @ y_train)
            noise_estimate = residual ** 2 / (num_train - dim)
            reg_constant = (1 + args.flip_prob) / 2 * condition * dim * noise_estimate / num_train
            print(f'LAMBDA = {reg_constant}')
            print(f'LAMBDA = {reg_constant}', file=log_file, flush=True)
            XTX += torch.eye(dim) * reg_constant
        XTXinvXT, _ = torch.solve(embedding.t(), XTX)
        test_embed = network.embed(x_test)
        test_embed = torch.cat([test_embed, torch.ones((num_test, 1))], dim=1)
        hat_matrix = torch.mm(test_embed, XTXinvXT).numpy()

        # Noiseless least-squares accuracy
        # predictions = (hat_matrix @ y_train) >= 0.5
        # correct = (predictions == y_test).sum()
        # print(correct / len(predictions))

        print(f'ID\tp\tflips_discrete\tflips_kl\tcorrect\tempirical', file=data_file, flush=True)
        acc = 0
        if args.num_to_test == -1:
            args.num_to_test = num_test
        k = np.empty(args.num_to_test)
        z = []
        for sample in range(args.num_to_test):
            start = datetime.now()
            label = y_test[sample]
            # w is alpha.T in the paper
            b, c, w = probs, 1 - probs, hat_matrix[sample]
            min_t = newton(lambda t: deriv1(t, b, c, w), 1, tol=5e-3, disp=False, fprime=lambda t: deriv2(t, b, c, w))
            bound = chernoff_bound(min_t, b, c, hat_matrix[sample])
            # min_t, bound = newtons(probs, hat_matrix[sample])
            prediction = int(min_t >= 0)
            chernoff_lower_bound = fsub(mpf(1), bound)
            discrete_robust = compute_discrete_label_robustness(chernoff_lower_bound, rho_inv)
            kl_robust = 0 if chernoff_lower_bound <= 0.5 else compute_kl_label_robustness(chernoff_lower_bound)
            k[sample] = discrete_robust

            required_flips = 0
            #####################################################
            # Generate attack by flipping training points with highest contribution to correct label
            # If train label is 0, flipping will add. If train label is 1, flipping will subtract.
            flip_factor = 1 - 2 * y_train
            # w * flip_factor is now how much the inner product will change if we flip the label.
            # If label is 0, we want to add to attack (decreasing). If label is 1, we subtract to attack (increasing).
            inner_product_order = np.argsort(w * flip_factor)[::(2 * label - 1)]
            low, high = 0, num_train
            required_flips = num_train
            while low < high:
                flips_to_try = (high-low) // 2 + low
                # inds_to_flip = inner_product_order[:required_flips]
                attacked_probs, a = probs.copy()[inner_product_order], w.copy()[inner_product_order]
                # attacked_probs[inds_to_flip] = 1 - attacked_probs[inds_to_flip]
                attacked_probs[:flips_to_try] = 1 - attacked_probs[:flips_to_try]
                b, c = attacked_probs, 1 - attacked_probs
                min_t = newton(lambda t: deriv1(t, b, c, a), 1, tol=5e-3, disp=False, fprime=lambda t: deriv2(t, b, c, a))
                bound = chernoff_bound(min_t, attacked_probs, 1 - attacked_probs, a)
                attacked_pred = int(min_t >= 0)
                if attacked_pred == label:
                    low = flips_to_try + 1
                else:
                    high = flips_to_try
                    required_flips = min(required_flips, flips_to_try)
            #####################################################

            elapsed = datetime.now() - start
            print(f'{sample} Chernoff: {chernoff_lower_bound} k (KL): {kl_robust} k: (discrete): {discrete_robust} '
                  f'({elapsed.seconds}.{elapsed.microseconds // 1000:03d} s) (empirical flips to break: {required_flips})', file=log_file, flush=True)
            print(f'{sample}\t{chernoff_lower_bound}\t{discrete_robust}\t{kl_robust}\t{int(prediction==label)}\t{required_flips}',
                  file=data_file, flush=True)
            if prediction == label:
                acc += 1
                z.append(discrete_robust)
            else:
                z.append(0)
            print(sample, acc/(sample+1), required_flips, (np.array(z) >= 10).mean())
        print(f'Overall accuracy: {acc / args.num_to_test}')
        print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}')
        print(f'Overall accuracy: {acc / args.num_to_test}', file=log_file, flush=True)
        print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}', file=log_file, flush=True)
        data_file.close()
        log_file.close()