import torch
import numpy as np
import argparse
import pickle
import math
from sklearn.linear_model import LogisticRegression

def sigmoid(t):
    return 1. / (1 + np.exp(t))

def compute_grad_loss(coef, x, y):
    return -sigmoid(-y * np.dot(coef, x)) * y * x

def compute_hessian(coef, xtrain):
    dotprods = coef @ xtrain.T
    weights = sigmoid(dotprods) * sigmoid(-dotprods)
    hessian = (xtrain.T @ np.diag(weights[0]) @ xtrain) / len(xtrain)
    return hessian

def get_indices_to_flip(coef, HVPs, x, y):
    gradloss = compute_grad_loss(coef, x, y)
    uploss = gradloss @ HVPs
    return np.argsort(np.abs(uploss))[::-1]


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('influencedir')
    parser.add_argument('--no-compute', action='store_true', default=False)
    parser.add_argument('--mult', type=float, default=1.1)
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    np.random.seed(args.seed)
    train_embeddings = np.load('./data/binarymnist/train/xtrain_17_features.npy')
    test_embeddings = np.load('./data/binarymnist/test/xtest_17_features.npy')
    num_train = len(train_embeddings)
    num_test = len(test_embeddings)

    y_train = np.load('./data/binarymnist/train/ytrain_17.npy')
    y_train[y_train == 1] = 0
    y_train[y_train == 7] = 1

    y_test = np.load('./data/binarymnist/test/ytest_17.npy')
    y_test[y_test == 1] = 0
    y_test[y_test == 7] = 1

    try:
        labels_to_flip = pickle.load(open(f'{args.influencedir}/influence_successful_attacks.dict', 'rb'))
        print('loaded dictionary')
    except FileNotFoundError:
        labels_to_flip = {}
        print('failed to load, restarting')

    weight_decay = 1e-3
    C = 1.0 / (num_train * weight_decay)
    max_lbfgs_iter = 1000
    model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                               warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, y_train)
    coefs = model.coef_
    orig_predictions = model.predict(test_embeddings)
    print('Original model test accuracy:', (orig_predictions == y_test).sum() / num_test)
    orig_incorrect = np.argwhere(orig_predictions != y_test).flatten()
    print('Incorrect indices:', orig_incorrect)

    if not args.no_compute:
        print('Calculating gradients...')
        HVPs = np.empty_like(train_embeddings.T)
        hessian = compute_hessian(coefs, train_embeddings)
        for idx in range(num_train):
            grad_orig = compute_grad_loss(coefs, train_embeddings[idx], y_train[idx])
            grad_flipped = compute_grad_loss(coefs, train_embeddings[idx], int(y_train[idx] == 0))
            HVPs[:, idx] = np.linalg.solve(hessian, grad_orig - grad_flipped)

        print('Generating label poisoning attacks...')
        for test_idx in range(num_test):
            print('% less than 500:', (np.array([len(val) for val in labels_to_flip.values()]) < 500).sum() / 2163)
            print(f'Test Index {test_idx} (cur min: {len(labels_to_flip.get(test_idx, np.arange(num_train)))})')
            indices_to_flip = []
            # indices_to_skip = []
            prev_logit = 1
            all_inds_to_flip = get_indices_to_flip(coefs, HVPs, test_embeddings[test_idx], y_test[test_idx])
            corrupted_y_train = np.copy(y_train)
            already_changed = []
            while len(indices_to_flip)+1 < min(len(labels_to_flip.get(test_idx, np.arange(num_train))), 500):
                i = 0
                while all_inds_to_flip[i] in indices_to_flip:# or all_inds_to_flip[i] in indices_to_skip:
                    i += 1
                index_to_flip = all_inds_to_flip[i]
                indices_to_flip.append(index_to_flip)
                corrupted_y_train[index_to_flip] = (y_train[index_to_flip] == 0).astype(np.bool)
                corrupted_model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                                   warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, corrupted_y_train)
                new_logits = corrupted_model.predict_proba(test_embeddings)
                new_predictions = (new_logits[:, 0] < new_logits[:, 1]).astype(int)
                diff_predictions = np.argwhere(new_predictions != orig_predictions).flatten()
                for changed_prediction in np.setdiff1d(diff_predictions, already_changed):
                    already_changed.append(changed_prediction)
                    if len(indices_to_flip) < len(labels_to_flip.get(changed_prediction, np.arange(num_train))):
                        labels_to_flip[changed_prediction] = np.copy(indices_to_flip)
                        print(f'Broke test index {changed_prediction} in {len(indices_to_flip)} flips ({len(labels_to_flip.keys())} total)')
                        pickle.dump(labels_to_flip, open(f'{args.influencedir}/influence_successful_attacks.dict', 'wb'))
                if test_idx in diff_predictions:
                    break
                # if new_logits[test_idx][int(y_train[test_idx] == 0)] < prev_logit:
                #     prev_logit = new_logits[test_idx][int(y_train[test_idx] == 0)]
                # else:
                #     indices_to_skip.append(index_to_flip)
                #     indices_to_flip = indices_to_flip[:-1]
                #     corrupted_y_train[index_to_flip] = y_train[index_to_flip]

        sys.exit()


        # Double check that each result is actually correct
        for test_idx in range(num_test):
            print('Confirming Test Index', test_idx)
            if test_idx not in labels_to_flip:
                continue
            indices_to_flip = labels_to_flip[test_idx]
            corrupted_y_train = np.copy(y_train)
            corrupted_y_train[indices_to_flip] = (y_train[indices_to_flip] == 0).astype(int)
            corrupted_model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                               warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, corrupted_y_train)
            new_prediction = corrupted_model.predict(test_embeddings[test_idx:test_idx+1])
            assert new_prediction != orig_predictions[test_idx]
