import numpy as np
import matplotlib

matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
from typing import *
import pandas as pd
import seaborn as sns

sns.set()
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']


class Accuracy(object):
    def __init__(self, data_file_path: str):
        self.data_file_path = data_file_path

    def all_empirical(self, flips):
        df = pd.read_csv(self.data_file_path, delimiter="\t")
        robust = df["empirical"]
        return np.array([(robust >= flip).mean() for flip in flips])

    def at_flips(self, flips: np.ndarray, kl: bool=False) -> np.ndarray:
        df = pd.read_csv(self.data_file_path, delimiter="\t")
        return np.array([self.at_flip(df, flip, kl) for flip in flips])

    def at_flip(self, df: pd.DataFrame, flips: int, kl: bool=False):
        robust = df["flips_kl"] if kl else df["flips_discrete"]
        return (df["correct"] & (robust >= flips)).mean()

    def accuracy(self):
        df = pd.read_csv(self.data_file_path, delimiter="\t")
        print(self.at_flip(df, 400))
        return self.at_flip(df, 0)

class Line(object):
    def __init__(self, quantity: Accuracy, legend: str, plot_fmt: str = ""):
        self.quantity = quantity
        self.legend = legend
        self.plot_fmt = plot_fmt


def plot_certified_accuracy(outfile: str, title: str, max_flips: int, num_train: int, loc: str, uncert: float, num_cls: int,
                            lines: List[Line], empirical: bool=False) -> None:
    flips = np.arange(max_flips+1)
    fig, ax1 = plt.subplots()
    for i, line in enumerate(lines):
        ax1.plot(flips, line.quantity.at_flips(flips), color=colors[i], linestyle=('--' if line.legend == 'undefended' else '-'))
        if empirical and line.legend != 'undefended':
            ax1.plot(flips, line.quantity.all_empirical(flips), color=colors[i], linestyle='--', label='_nolegend_')
        elif not (empirical and line.legend == 'undefended'):
            ax1.plot(flips, np.repeat(line.quantity.accuracy(), len(flips)), color=colors[i], linestyle='--', label='_nolegend_')
    if not empirical:
        ax1.plot(flips, np.repeat(uncert, max_flips+1), 'k-.')
    ax1.plot(flips, np.repeat(1/num_cls, max_flips+1), 'k:')
    addl_legends = ['$g$ constant']
    if not empirical:
        addl_legends = ['$q=0$, uncertified'] + addl_legends
    plt.legend([method.legend for method in lines] + addl_legends, loc=loc, fontsize=13)
    ax2 = ax1.twiny()

    plt.ylim((0, 1))
    ax1.set_xlim((0, max_flips))
    ax2.set_xlim((0, max_flips/num_train * 100))
    ax1.tick_params(labelsize=14)
    ax2.tick_params(labelsize=12)
    ax1.set_xlabel("Label Flips", fontsize=16)
    ax1.set_ylabel("Certified Accuracy", fontsize=16)
    ax2.set_xlabel("Percentage of Training Set", fontsize=16)
    ax2.grid(False)
    plt.tight_layout()
    plt.savefig(outfile + ".pdf", dpi=200)
    plt.title(title, fontsize=20)
    plt.tight_layout()
    plt.savefig(outfile + ".png", dpi=300)
    plt.close()

if __name__ == "__main__":
    plot_certified_accuracy(
        "plots/MNIST_vary_q", "MNIST 1/7, vary $q$", 2000, 13007, 'lower left', .9926, 2, [
            Line(Accuracy("logs/mnist/flipprob0.3_regularized.txt"), "$q = 0.3$"),
            Line(Accuracy("logs/mnist/flipprob0.4_regularized.txt"), "$q = 0.4$"),
            Line(Accuracy("logs/mnist/flipprob0.45_regularized.txt"), "$q = 0.45$"),
            Line(Accuracy("logs/mnist/flipprob0.475_regularized.txt"), "$q = 0.475$"),
            Line(Accuracy("logs/mnist/undefended.txt"), "undefended"),
        ])
    plot_certified_accuracy(
        "plots/MNIST_vary_q_attack", "MNIST 1/7, vary $q$", 2000, 13007, 'lower left', .9926, 2, [
            Line(Accuracy("logs/mnist/flipprob0.3_regularized.txt_tmp"), "$q = 0.3$"),
            Line(Accuracy("logs/mnist/flipprob0.4_regularized.txt_tmp"), "$q = 0.4$"),
            Line(Accuracy("logs/mnist/flipprob0.475_regularized.txt_tmp"), "$q = 0.475$"),
            Line(Accuracy("logs/mnist/undefended.txt"), "undefended"),
        ], empirical=True)
    plot_certified_accuracy(
        "plots/MNIST_reg_vs_no", "MNIST 1/7, Effects of $\ell_2$ Regularization", 1900, 13007, 'lower right', 0, 2, [
            Line(Accuracy("logs/mnist/flipprob0.3_regularized.txt"), "$q = 0.3, \lambda\\approx12291$"),
            Line(Accuracy("logs/mnist/flipprob0.3.txt"), "$q = 0.3, \lambda=0$"),
            Line(Accuracy("logs/mnist/flipprob0.4_regularized.txt"), "$q = 0.4, \lambda\\approx13237$"),
            Line(Accuracy("logs/mnist/flipprob0.4.txt"), "$q = 0.4, \lambda=0$"),
        ])
    plot_certified_accuracy(
        "plots/imdb_vary_q", "IMDB, vary q", 350, 25000, 'upper right', .8478, 2, [
            Line(Accuracy("logs/imdb/flipprob0.01_regularized.txt"), "$q = 0.01$"),
            Line(Accuracy("logs/imdb/flipprob0.025_regularized.txt"), "$q = 0.025$"),
            Line(Accuracy("logs/imdb/flipprob0.05_regularized.txt"), "$q = 0.05$"),
            Line(Accuracy("logs/imdb/flipprob0.1_regularized.txt"), "$q = 0.1$"),
        ])
    plot_certified_accuracy(s
        "plots/imdb_vary_q_ica", "IMDB, vary q", 500, 25000, 'upper right', .6495, 2, [
            Line(Accuracy("logs/imdb/flipprob0.01_regularized.txt_tmp"), "$q = 0.01$"),
            Line(Accuracy("logs/imdb/flipprob0.01_regularized.txt"), "$q = 0.01$ old"),
            Line(Accuracy("logs/imdb/flipprob0.025_regularized.txt_tmp"), "$q = 0.025$"),
            Line(Accuracy("logs/imdb/flipprob0.05_regularized.txt_tmp"), "$q = 0.05$"),
            Line(Accuracy("logs/imdb/flipprob0.1_regularized.txt"), "$q = 0.1$ old"),
        ])
    plot_certified_accuracy(
        "plots/dogfish_poisoning_results", "Dogfish Inception Poisoning Robustness", 150, 1800, 'upper right', .8983, 2, [
            Line(Accuracy("logs/dogfish/flipprob0.0001_regularized.txt"), "$q = 0.0001$"),
            Line(Accuracy("logs/dogfish/flipprob0.01_regularized.txt"), "$q = 0.01$"),
            Line(Accuracy("logs/dogfish/flipprob0.05_regularized.txt"), "$q = 0.05$"),
            Line(Accuracy("logs/dogfish/undefended.txt"), "undefended"),
        ]
    )
    plot_certified_accuracy(
        "plots/dogfish_poisoning_results_rica", "Dogfish Inception Poisoning Robustness (RICA)", 80, 1800, 'upper right', .6717, 2, [
            Line(Accuracy("logs/dogfish/flipprob0.0001_regularized.txt_rica"), "$q = 0.0001$"),
            Line(Accuracy("logs/dogfish/flipprob0.001_regularized.txt_rica"), "$q = 0.001$"),
            Line(Accuracy("logs/dogfish/flipprob0.01_regularized.txt_rica"), "$q = 0.01$"),
            Line(Accuracy("logs/dogfish/undefended.txt"), "undefended"),
        ]
    )
    plot_certified_accuracy(
        "plots/dogfish_vary_q_attack", "Dogfish, vary q", 150, 1800, 'upper right', .8983, 2, [
            Line(Accuracy("logs/dogfish/flipprob0.0001_regularized_empirical.txt"), "$q = 0.0001$"),
            Line(Accuracy("logs/dogfish/flipprob0.01_regularized_empirical.txt"), "$q = 0.01$"),
            Line(Accuracy("logs/dogfish/undefended.txt"), "undefended"),
        ], empirical=True
    )
    plot_certified_accuracy(
        "plots/multiMNIST_vary_q_ica_tmp", "Multiclass MNIST, vary $q$", 500, 60000, 'upper right', .8112, 10, [
            Line(Accuracy("logs/mnist/multiclass_flipprob0.0125.txt"), "$q = 0.0125$"),
            Line(Accuracy("logs/mnist/multiclass_flipprob0.025.txt"), "$q = 0.025$"),
            Line(Accuracy("logs/mnist/multiclass_flipprob0.05.txt"), "$q = 0.05$"),
        ])
    plot_certified_accuracy(
        "plots/CIFAR10_vary_q", "CIFAR10, vary $q$", 500, 50000, 'upper right', .9116, 10, [
            Line(Accuracy("logs/cifar/multiclass_flipprob0.012_regularized.txt"), "$q = 0.012$"),
            Line(Accuracy("logs/cifar/multiclass_flipprob0.025_regularized.txt_tmp"), "$q = 0.025$"),
            Line(Accuracy("logs/cifar/multiclass_flipprob0.1_regularized.txt_tmp"), "$q = 0.1$"),
        ])