import numpy as np
import argparse
from datetime import datetime
from scipy.optimize import minimize_scalar
import mpmath
from mpmath import mpf, fsub, fmul, fdiv, mp
mp.dps = 50

def compute_tv_robustness(p):
    return max((p-.5), 0)

def mistake_prob_bound_func(label, xis, probs):
    def objective(t):
        chernoff_t = np.float64(t) * label
        plus = np.exp(chernoff_t * np.float64(xis)) * (1 - probs)
        minus = np.exp(chernoff_t * np.float64(xis * -1)) * probs
        expected = plus + minus
        return np.prod(expected)
    return objective

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--fraction', type=float, default=0.2)
parser.add_argument('--num-to-test', type=int, default=-1)
parser.add_argument('--min-prob', type=float, default=0.7)
parser.add_argument('--delta', type=float, default=0.001)
parser.add_argument('--ensemble-size', type=int, default=10000)
parser.add_argument('--skip', type=int, default=25)
args = parser.parse_args()

X_train = np.load('data/aclImdb/train/x_train_word2vec.npy')
X_train = np.hstack([X_train, np.ones((25000, 1))])
Y_train = np.squeeze(np.load('data/aclImdb/train/y_train)_word2vec.npy'))
X_test = np.load('data/aclImdb/test/x_test_word2vec.npy')
X_test = np.hstack([X_test, np.ones((25000, 1))])
Y_test = np.squeeze(np.load('data/aclImdb/test/y_test)_word2vec.npy'))

XTX = X_train.T @ X_train
XTXinvXT = np.linalg.solve(XTX, X_train.T)
P = X_test @ XTXinvXT

num_train = len(Y_train)
num_test = len(Y_test)
acc = 0.
if args.num_to_test == -1: args.num_to_test = (num_test // args.skip)
k = np.zeros(args.num_to_test)

for sample in range(0, args.num_to_test * args.skip, args.skip):
    start = datetime.now()
    label = Y_test[sample]
    correct = 0
    num_to_subsample = int(num_train * args.fraction)
    for net in range(args.ensemble_size):
        random_indices = np.random.choice(num_train, num_to_subsample, replace=False)
        # noise = np.random.choice([1, -1], size=num_to_subsample, p=[1 - args.flip_prob, args.flip_prob])
        # noisy_ytrain = Y_train[random_indices] * noise
        score = np.dot(P[sample, random_indices], Y_train[random_indices])
        prediction = 1 if score > 0 else -1
        correct += (prediction == label)
    if correct / args.ensemble_size > 0.5:
        acc += 1
    margin = correct / args.ensemble_size
    hoeff_t = np.sqrt(np.log(args.delta) / args.ensemble_size / -2)
    hoeff_lower_bound = margin - hoeff_t
    tv_robust = compute_tv_robustness(hoeff_lower_bound)
    k[sample // args.skip] = tv_robust
    elapsed = datetime.now() - start
    print(f'{sample} accuracy: {margin:.3f} Hoeffding: {hoeff_lower_bound:.5f} k (TV): {tv_robust} ({elapsed.seconds}.{elapsed.microseconds // 1000:03d} s)')
print(f'Overall accuracy: {acc / args.num_to_test}')
print(f'mean: {k.mean()}, std: {k.std()}, median: {np.median(k)}')
