import numpy as np
import argparse
from datetime import datetime
import mpmath
from mpmath import mpf, fsub, fmul, fadd, fdiv, mp, power
from sklearn.decomposition import PCA
from scipy.optimize import newton, minimize_scalar
precision = 800
mp.dps = precision
max_prob = fsub(mpf(1), power(mpf(10), mpf(-precision)))


def compute_kl_label_robustness(p):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        p = max_prob
    if p < 0.5:
        return 0
    num = fadd(fadd(mpmath.log(p), mpmath.log(fsub(mpf(1), p))), mpmath.log(mpf(4)))
    denom = fmul(mpf(2), fsub(mpf(1), fmul(mpf(2), mpf(args.flip_prob))))
    denom = fmul(denom, fsub(mpmath.log(mpf(args.flip_prob)), mpmath.log(fsub(mpf(1), mpf(args.flip_prob)))))
    return int(mpmath.floor(fdiv(num, denom)))

def compute_discrete_label_robustness(p, rho_inv):
    # We've hit the limit of what we can certify for the given precision. This could be increased.
    if p == 1:
        p = max_prob
    if p <= 0.5:
        return 0
    return (p >= rho_inv).sum()

def mp_log_abc(a, b, c):
    return float(mpmath.log(a + b * mpmath.exp(c)))

vec_mp_log_abc = np.vectorize(mp_log_abc)
vec_log_abc = np.vectorize(np.logaddexp)

# Compute element-wise log of a + b * exp(c)
def log_abc(a, b, c):
    z = vec_log_abc(np.log(a), np.log(b) + c)
    if not np.all(np.isfinite(z)) or np.any(np.isnan(z)):
        z = vec_mp_log_abc(a, b, c)
    return z

def chernoff_bound(t, b, c, w):
    terms = log_abc(b, c, -t * w)
    sum = t / 2 + terms.sum()
    if not np.isfinite(sum):
        return 0.5
    return mpmath.exp(sum)

def safe_chernoff_bound(t, b, c, w):
    terms = vec_mp_log_abc(b, c, -t * w)
    sum = t / 2 + terms.sum()
    if not np.isfinite(sum):
        return 0.5
    return mpmath.exp(sum)

def deriv1(t, b, c, w):
    log_terms = np.log(c)
    log_terms -= log_abc(c, b, t * w)
    terms = np.exp(log_terms) * w
    return 0.5 - terms.sum()

def deriv2(t, b, c, w):
    log_terms = t * w + np.log(b) + np.log(c) + 2 * np.log(np.abs(w))
    log_terms -= 2 * log_abc(c, b, t * w)
    ret = np.exp(log_terms).sum()
    return 1e-8 if ret == 0 else ret

def newtons(probs, alpha, beta, eps):
    t = 1
    inv_probs = 1 - probs
    f_val = chernoff_bound(t, probs, inv_probs, alpha)
    f_prime = deriv1(t, probs, inv_probs, alpha)
    count = 0
    while np.abs(f_prime) > eps and count < args.max_newton:
        f_prime2 = deriv2(t, probs, inv_probs, alpha)
        if f_prime2 == 0:
            break
        direction = -f_prime / f_prime2
        step_size = 1
        f_next = chernoff_bound(t + step_size * direction, probs, inv_probs, alpha)
        f_prime_next = deriv1(t + step_size * direction, probs, inv_probs, alpha)
        c1, c2 = 1e-4, 0.9
        # Check Wolfe condition, but also whether we've reached a small enough derivative or value
        while f_next > f_val + c1 * step_size * f_prime * direction and np.abs(f_prime_next) > eps:
            term1 = -direction * f_prime_next
            term2 = -c2 * direction * f_prime
            # curvature condition
            if term1 > term2:
                step_size /= beta
                break
            step_size *= beta
            f_next = chernoff_bound(t + step_size * direction, probs, inv_probs, alpha)
            f_prime_next = deriv1(t + step_size * direction, probs, inv_probs, alpha)
        t += step_size * direction
        f_val = f_next
        f_prime = f_prime_next
    return t, f_val


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--flip-prob', type=float, default=0.1)
parser.add_argument('--num-to-test', type=int, default=-1)
parser.add_argument('--max-newton', type=int, default=10)
parser.add_argument('--regularize', action='store_true', default=False)
args = parser.parse_args()

np.random.seed(args.seed)

X_train = np.load('data/dogfish/dogfish_xtrain_embeddings_224.npy')
pca = PCA(n_components=5)
X_train = pca.fit_transform(X_train)
num_train = len(X_train)
X_train = np.hstack([X_train, np.ones((num_train, 1))])
Y_train = np.squeeze(np.load('data/dogfish/dogfish_ytrain.npy').astype(np.int))
X_test = np.load('data/dogfish/dogfish_xtest_embeddings_224.npy')
num_test = len(X_test)
X_test = pca.transform(X_test)
X_test = np.hstack([X_test, np.ones((num_test, 1))])
Y_test = np.squeeze(np.load('data/dogfish/dogfish_ytest.npy').astype(np.int))

with open(f'rhoinv/rho_inv_q.{str(int(args.flip_prob * 10000))}_d.{25000}-10000.txt', 'r') as f:
    rho_inv = np.array(list(map(mpf, f.readline().strip().split(',')[:-1])))

print(f'max cert with precision {precision}: {(max_prob >= rho_inv).sum()}')

filename = f'flipprob{args.flip_prob}'
if args.regularize:
    filename += '_regularized'
data_file = open(f'logs/dogfish/{filename}_empirical.txt', 'w')
log_file = open(f'logs/dogfish/{filename}_empirical_stats.txt', 'w')
print(f'ID\tp\tflips_discrete\tflips_kl\tcorrect\tempirical', file=data_file, flush=True)

XTX = X_train.T @ X_train
if args.regularize:
    dim = X_train.shape[1]
    condition = np.linalg.cond(XTX)
    XTXinvXT = np.linalg.solve(XTX, X_train.T)
    residual = np.linalg.norm(Y_train - X_train @ XTXinvXT @ Y_train)
    noise_estimate = residual ** 2 / (num_train - dim)
    reg_constant = (1 + args.flip_prob) / 2 * condition * dim * noise_estimate / num_train
    print(f'LAMBDA = {reg_constant}')
    print(f'LAMBDA = {reg_constant}', file=log_file, flush=True)
    XTX += np.eye(dim) * reg_constant
XTXinvXT = np.linalg.solve(XTX, X_train.T)
hat_matrix = X_test @ XTXinvXT

# Noiseless least-squares accuracy
predictions = (hat_matrix @ Y_train) >= 0.5
correct = (predictions == Y_test).sum()
print(correct / len(predictions))


# probs is the vector of probabilities that each training example will be labeled 0.
probs = np.empty(num_train, dtype=float)
probs[Y_train == 0] = 1 - args.flip_prob
probs[Y_train == 1] = args.flip_prob

acc = 0
if args.num_to_test == -1: args.num_to_test = num_test
k = np.zeros(args.num_to_test)
z = []

for sample in range(args.num_to_test):
    start = datetime.now()
    label = Y_test[sample]
    b, c, w = probs, 1-probs, hat_matrix[sample]
    # min_t = newton(lambda t: deriv1(t, b, c, w), 1, tol=5e-3, disp=False, fprime=lambda t: deriv2(t, b, c, w))
    # min_t = np.clip(min_t, -10000, 10000)
    # bound = safe_chernoff_bound(min_t, probs, 1-probs, hat_matrix[sample])
    min_t, _ = newtons(probs, hat_matrix[sample], 0.9, 5e-3)
    bound = safe_chernoff_bound(min_t, probs, 1-probs, w)
    prediction = int(min_t >= 0)
    acc += int(prediction == label)
    chernoff_lower_bound = fsub(mpf(1), bound)
    discrete_robust = compute_discrete_label_robustness(chernoff_lower_bound, rho_inv)
    kl_robust = 0 if chernoff_lower_bound <= 0.5 else compute_kl_label_robustness(chernoff_lower_bound)
    k[sample] = discrete_robust

    required_flips = 0
    # #####################################################
    # Generate attack by flipping training points with highest contribution to correct label
    # If train label is 0, flipping will add. If train label is 1, flipping will subtract.
    flip_factor = 1 - 2 * Y_train
    # w * flip_factor is now how much the inner product will change if we flip the label.
    # If label is 0, we want to add to attack (decreasing). If label is 1, we subtract to attack (increasing).
    inner_product_order = np.argsort(w * flip_factor)[::(2 * label - 1)]
    low, high = 0, num_train
    required_flips = num_train
    while low < high:
        flips_to_try = (high - low) // 2 + low
        # inds_to_flip = inner_product_order[:required_flips]
        attacked_probs, a = probs.copy()[inner_product_order], w.copy()[inner_product_order]
        # attacked_probs[inds_to_flip] = 1 - attacked_probs[inds_to_flip]
        attacked_probs[:flips_to_try] = 1 - attacked_probs[:flips_to_try]
        b, c = attacked_probs, 1 - attacked_probs
        min_t, _ = newtons(b, a, 0.9, 5e-3)
        bound = chernoff_bound(min_t, b, c, a)
        attacked_pred = int(min_t >= 0)
        if attacked_pred == label:
            low = flips_to_try + 1
        else:
            high = flips_to_try
            required_flips = min(required_flips, flips_to_try)
    # #####################################################

    elapsed = datetime.now() - start
    count = 0
    if prediction == label:
        z.append(discrete_robust)
    else:
        z.append(0)
    print(sample, acc/(sample+1), (np.array(z) >= 1).mean(), discrete_robust, required_flips)
    print(f'{sample} ({label}) Chernoff: {chernoff_lower_bound} k (KL): {kl_robust} k: (discrete): {discrete_robust} '
          f'({elapsed.seconds}.{elapsed.microseconds // 1000:03d} s) (empirical flips to break: {required_flips})', file=log_file, flush=True)
    print(f'{sample}\t{chernoff_lower_bound}\t{discrete_robust}\t{kl_robust}\t{int(prediction == label)}\t{required_flips}',
          file=data_file, flush=True)
print(f'Overall accuracy: {acc / args.num_to_test}', file=log_file, flush=True)
print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}', file=log_file, flush=True)
print(f'Overall accuracy: {acc / args.num_to_test}')
print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}')
data_file.close()
log_file.close()
