import numpy as np
import argparse
from datetime import datetime
import mpmath
from mpmath import mpf, fsub, fmul, fadd, fdiv, mp
mp.dps = 1500
from sklearn.decomposition import PCA, FastICA


def compute_kl_label_robustness(p):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        return len(rho_inv) + 1
    if p < 0.5:
        return 0
    num = fadd(fadd(mpmath.log(p), mpmath.log(fsub(mpf(1), p))), mpmath.log(mpf(4)))
    denom = fmul(mpf(2), fsub(mpf(1), fmul(mpf(2), mpf(args.flip_prob))))
    denom = fmul(denom, fsub(mpmath.log(mpf(args.flip_prob)), mpmath.log(fsub(mpf(1), mpf(args.flip_prob)))))
    return int(mpmath.floor(fdiv(num, denom)))

def compute_discrete_label_robustness(p, rho_inv):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        return len(rho_inv) + 1
    if p <= 0.5:
        return 0
    return (p >= rho_inv).sum() + 1

def mp_log_abc(a, b, c):
    return float(mpmath.log(a + b * mpmath.exp(c)))

vec_mp_log_abc = np.vectorize(mp_log_abc)
vec_log_abc = np.vectorize(np.logaddexp)

# Compute element-wise log of a + b * exp(c)
def log_abc(a, b, c):
    z = vec_log_abc(np.log(a), np.log(b) + c)
    if not np.all(np.isfinite(z)) or np.any(np.isnan(z)):
        z = vec_mp_log_abc(a, b, c)
    return z

def chernoff_bound(t, b, c, w):
    terms = log_abc(b, c, -t * w)
    sum = t / 2 + terms.sum()
    if not np.isfinite(sum):
        return 0.5
    return mpmath.exp(sum)

# def safe_chernoff_bound(t, b, c, w):
#     terms = vec_mp_log_abc(b, c, -t * w)
#     sum = t / 2 + terms.sum()
#     if not np.isfinite(sum):
#         return 0.5
#     return mpmath.exp(sum)

def deriv1(t, b, c, w):
    log_terms = np.log(c)
    log_terms -= log_abc(c, b, t * w)
    terms = np.exp(log_terms) * w
    return 0.5 - terms.sum()

def deriv2(t, b, c, w):
    log_terms = t * w + np.log(b) + np.log(c) + 2 * np.log(np.abs(w))
    log_terms -= 2 * log_abc(c, b, t * w)
    ret = np.exp(log_terms).sum()
    return 1e-8 if ret == 0 else ret

def newtons(probs, alpha, beta, eps):
    t = 1
    inv_probs = 1 - probs
    f_val = chernoff_bound(t, probs, inv_probs, alpha)
    f_prime = deriv1(t, probs, inv_probs, alpha)
    count = 0
    while np.abs(f_prime) > eps and count < args.max_newton:
        f_prime2 = deriv2(t, probs, inv_probs, alpha)
        if f_prime2 == 0:
            break
        direction = -f_prime / f_prime2
        step_size = 1
        f_next = chernoff_bound(t + step_size * direction, probs, inv_probs, alpha)
        f_prime_next = deriv1(t + step_size * direction, probs, inv_probs, alpha)
        c1, c2 = 1e-4, 0.9
        # Check Wolfe condition, but also whether we've reached a small enough derivative or value
        while f_next > f_val + c1 * step_size * f_prime * direction and np.abs(f_prime_next) > eps:
            term1 = -direction * f_prime_next
            term2 = -c2 * direction * f_prime
            # curvature condition
            if term1 > term2:
                step_size /= beta
                break
            step_size *= beta
            f_next = chernoff_bound(t + step_size * direction, probs, inv_probs, alpha)
            f_prime_next = deriv1(t + step_size * direction, probs, inv_probs, alpha)
        t += step_size * direction
        f_val = f_next
        f_prime = f_prime_next
    return t, f_val


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--flip-prob', type=float, default=0.4)
parser.add_argument('--num-to-test', type=int, default=-1)
parser.add_argument('--max-newton', type=int, default=100)
parser.add_argument('--regularize', action='store_true', default=False)
parser.add_argument('--skip', type=int, default=1)
args = parser.parse_args()

np.random.seed(args.seed)

X_train = np.load('data/aclImdb/train/x_train_word2vec.npy')
num_train = len(X_train)
pca = PCA(n_components=5, whiten=True).fit(X_train)
X_train = pca.transform(X_train)
X_train = np.hstack([X_train, np.ones((num_train, 1))])
Y_train = np.squeeze(np.load('data/aclImdb/train/y_train)_word2vec.npy'))
X_test = np.load('data/aclImdb/test/x_test_word2vec.npy')
num_test = len(X_test)
X_test = pca.transform(X_test)
X_test = np.hstack([X_test, np.ones((len(X_test), 1))])
Y_test = np.squeeze(np.load('data/aclImdb/test/y_test)_word2vec.npy'))

Y_train[Y_train == -1] = 0
Y_test[Y_test == -1] = 0

with open(f'rhoinv/rho_inv_q.{str(int(args.flip_prob * 1000))}_d.{num_train}.txt', 'r') as f:
    rho_inv = np.array(list(map(mpf, f.readline().strip().split(',')[:-1])))

filename = f'flipprob{args.flip_prob}'
if args.regularize:
    filename += '_regularized'
data_file = open(f'logs/imdb/{filename}.txt_tmp', 'w')
log_file = open(f'logs/imdb/{filename}_stats.txt_tmp', 'w')
print(f'ID\tp\tflips_discrete\tflips_kl\tcorrect', file=data_file, flush=True)

XTX = X_train.T @ X_train
if args.regularize:
    dim = X_train.shape[1]
    condition = np.linalg.cond(XTX)
    XTXinvXT = np.linalg.solve(XTX, X_train.T)
    residual = np.linalg.norm(Y_train - X_train @ XTXinvXT @ Y_train)
    noise_estimate = residual ** 2 / (num_train - dim)
    reg_constant = (1 + args.flip_prob) / 2 * condition * dim * noise_estimate / num_train
    print(f'LAMBDA = {reg_constant}')
    print(f'LAMBDA = {reg_constant}', file=log_file, flush=True)
    XTX += np.eye(dim) * reg_constant
XTXinvXT = np.linalg.solve(XTX, X_train.T)
hat_matrix = X_test @ XTXinvXT

# Noiseless least-squares accuracy
predictions = (hat_matrix @ Y_train) >= 0.5
correct = (predictions == Y_test).sum()
print(correct / len(predictions))


# probs is the vector of probabilities that each training example will be labeled 0.
probs = np.empty(num_train, dtype=float)
probs[Y_train == 0] = 1 - args.flip_prob
probs[Y_train == 1] = args.flip_prob

acc = 0
if args.num_to_test == -1: args.num_to_test = (num_test // args.skip)
k = np.zeros(args.num_to_test)
z = []
for sample in range(0, args.num_to_test * args.skip, args.skip):
    start = datetime.now()
    label = Y_test[sample]
    min_t, _ = newtons(probs, hat_matrix[sample], 0.5, 5e-3)
    bound = chernoff_bound(min_t, probs, 1-probs, hat_matrix[sample])
    prediction = int(min_t >= 0)
    acc += int(prediction == label)
    chernoff_lower_bound = fsub(mpf(1), bound)
    discrete_robust = compute_discrete_label_robustness(chernoff_lower_bound, rho_inv)
    print(sample, acc/(sample+1), discrete_robust)
    kl_robust = 0 if chernoff_lower_bound <= 0.5 else compute_kl_label_robustness(chernoff_lower_bound)
    k[sample // args.skip] = discrete_robust
    elapsed = datetime.now() - start
    print(f'{sample} Chernoff: {chernoff_lower_bound} k (KL): {kl_robust} k: (discrete): {discrete_robust} '
          f'({elapsed.seconds}.{elapsed.microseconds // 1000:03d} s)', file=log_file, flush=True)
    print(f'{sample}\t{chernoff_lower_bound}\t{discrete_robust}\t{kl_robust}\t{int(prediction == label)}',
          file=data_file, flush=True)
print(f'Overall accuracy: {acc / args.num_to_test}')
print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}')
print(f'Overall accuracy: {acc / args.num_to_test}', file=log_file, flush=True)
print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}', file=log_file, flush=True)
data_file.close()
log_file.close()
