import torch
import numpy as np
import argparse
import pickle
import math
from sklearn.linear_model import LogisticRegression

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('influencedir')
    parser.add_argument('--logdir', default=None)
    parser.add_argument('--no-compute', action='store_true', default=False)
    parser.add_argument('--mult', type=float, default=1.1)
    args = parser.parse_args()

    train_embeddings = torch.tensor(np.load('data/dogfish/dogfish_xtrain_embeddings.npy'))
    num_train = len(train_embeddings)
    y_train = np.load('data/dogfish/dogfish_ytrain.npy')

    test_embeddings = torch.tensor(np.load('data/dogfish/dogfish_xtest_embeddings.npy'))
    num_test = len(test_embeddings)
    y_test = np.load('data/dogfish/dogfish_ytest.npy')

    try:
        feature_influences = np.load(f'{args.influencedir}/feature_influences.npy')
    except FileNotFoundError:
        feature_influences = np.empty((num_test, num_train))
        for test_idx, test_point in enumerate(test_embeddings):
            feature_influences[test_idx] = np.linalg.norm(train_embeddings - test_point, axis=1)
        np.save(f'{args.influencedir}/feature_influences.npy', feature_influences)

    try:
        labels_to_flip = pickle.load(open(f'{args.influencedir}/successful_attacks.dict', 'rb'))
    except FileNotFoundError:
        labels_to_flip = {}

    weight_decay = 1e-3
    C = 1.0 / (num_train * weight_decay)
    max_lbfgs_iter = 1000
    model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                               warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, y_train)
    orig_predictions = model.predict(test_embeddings)
    print('Original model test accuracy:', (orig_predictions == y_test).sum() / num_test)
    orig_incorrect = np.argwhere(orig_predictions != y_test).flatten()
    print('Incorrect indices:', orig_incorrect)

    if not args.no_compute:
        with torch.no_grad():
            print('Generating label poisoning attacks...')

            for test_idx in range(num_test):
                print('Average:', np.array([len(x) for x in labels_to_flip.values()]).mean())
                print('Test Index', test_idx)
                indices_to_flip = np.argsort(feature_influences[test_idx])  # Sort by smallest l2 distance
                corrupted_y_train = np.copy(y_train)
                num_to_flip = 11
                prev_flipped_count = 0
                while num_to_flip < 500:
                    flips = indices_to_flip[:num_to_flip]
                    corrupted_y_train[flips] = (y_train[flips] == 0).astype(np.bool)
                    corrupted_model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                                       warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, corrupted_y_train)
                    new_predictions = corrupted_model.predict(test_embeddings)
                    diff_predictions = np.argwhere(new_predictions != orig_predictions).flatten()
                    if len(diff_predictions) == prev_flipped_count:
                        num_to_flip = math.ceil(num_to_flip * args.mult)
                        continue
                    for changed_prediction in diff_predictions:
                        if num_to_flip < len(labels_to_flip.get(changed_prediction, np.arange(num_train))):
                            labels_to_flip[changed_prediction] = indices_to_flip[:num_to_flip]
                            print(f'Broke test index {changed_prediction} in {num_to_flip} flips ({len(labels_to_flip.keys())} total)')
                            pickle.dump(labels_to_flip, open(f'{args.influencedir}/successful_attacks.dict', 'wb'))
                    prev_flipped_count = len(diff_predictions)
                    num_to_flip = math.ceil(num_to_flip * args.mult)

            # Go through again, but instead of skipping just retry each learned set of flips, decreasing by inc each time
            # Would be faster to do a binary search but whatever
            decs = [100, 20, 5, 1]
            for i, dec in enumerate(decs):
                print('decrementing by', dec)
                for test_idx in range(num_test):
                    print('Test Index', test_idx)
                    if test_idx not in labels_to_flip:
                        continue
                    indices_to_flip = labels_to_flip[test_idx]
                    num_to_flip = len(indices_to_flip)-dec
                    prev_flipped_count = 0
                    while num_to_flip > 0:
                        if i != 0 and num_to_flip < len(indices_to_flip) - 1 - decs[i-1]:
                            break
                        corrupted_y_train = np.copy(y_train)
                        flips = indices_to_flip[:num_to_flip]
                        corrupted_y_train[flips] = (y_train[flips] == 0).astype(np.bool)
                        corrupted_model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                                           warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, corrupted_y_train)
                        new_predictions = corrupted_model.predict(test_embeddings)
                        diff_predictions = np.argwhere(new_predictions != orig_predictions).flatten()
                        if test_idx not in diff_predictions:
                            break
                        for changed_prediction in diff_predictions:
                            if num_to_flip < len(labels_to_flip.get(changed_prediction, np.arange(num_train))):
                                labels_to_flip[changed_prediction] = indices_to_flip[:num_to_flip]
                                print(f'Broke test index {changed_prediction} in {num_to_flip} flips ({len(labels_to_flip.keys())} total)')
                                pickle.dump(labels_to_flip, open(f'{args.influencedir}/successful_attacks.dict', 'wb'))
                        prev_flipped_count = len(diff_predictions)
                        num_to_flip -= dec

            # Double check that each result is actually correct
            for test_idx in range(num_test):
                print('Confirming Test Index', test_idx)
                if test_idx not in labels_to_flip:
                    continue
                indices_to_flip = labels_to_flip[test_idx]
                corrupted_y_train = np.copy(y_train)
                corrupted_y_train[indices_to_flip] = (y_train[indices_to_flip] == 0).astype(np.bool)
                corrupted_model = LogisticRegression(C=C, tol=1e-8, fit_intercept=False, solver='lbfgs',
                                   warm_start=True, max_iter=max_lbfgs_iter, random_state=0).fit(train_embeddings, corrupted_y_train)
                new_prediction = corrupted_model.predict(test_embeddings[test_idx:test_idx+1])
                assert new_prediction != orig_predictions[test_idx]


    if args.logdir is not None:
        logfile = open(f'{args.logdir}/undefended.txt', 'w')
        print('ID\tp\tflips_discrete\tflips_kl\tcorrect', file=logfile, flush=True)
        for idx in range(num_test):
            if idx in labels_to_flip:
                print(f'{idx}\t \t{len(labels_to_flip[idx])-1}\t \t{int(idx not in orig_incorrect)}',
                      file=logfile, flush=True)
            else:
                print(f'{idx}\t \t{num_train}\t \t{int(idx not in orig_incorrect)}',
                      file=logfile, flush=True)
