import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import argparse
import torch.utils.data

class Embedder(nn.Module):
    def __init__(self):
        super(Embedder, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 8)

    def forward(self, x):
        return self.fc3(self.embed(x))

    def embed(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        return self.fc2(x)

digit_map = {
    0: 0,
    2: 1,
    3: 2,
    4: 3,
    5: 4,
    6: 5,
    8: 6,
    9: 7
}

def do_map(vals):
    return torch.tensor([digit_map[val.item()] for val in vals])


class Dataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.len = len(self.x)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return self.x[index], self.y[index]


def train_pretrained(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.cross_entropy(output, do_map(target))
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    torch.save(network.state_dict(), f'{args.outdir}/embedder.pth')
    torch.save(optimizer.state_dict(), f'{args.outdir}/optimizer.pth')

def test_pretrained():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.cross_entropy(output, do_map(target)).item()
            pred = output.argmax(1)
            correct += (pred == do_map(target)).sum()
    print('\nTest set: loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def finetune(epoch, classes):
    network.train()
    for batch_idx, (data, target) in enumerate(finetune_train_loader):
        keep = torch.tensor(np.isin(target, classes))
        data, target = data[keep], target[keep]
        target[target == classes[0]] = 0
        target[target == classes[1]] = 1
        noise = torch.tensor(np.random.choice(2, len(target), p=[0.51, 0.49]))
        target = np.abs(target-noise)
        optimizer.zero_grad()
        embedding = network.embed(data)
        A = torch.mm(embedding.t(), embedding)
        xtx_inv_xt, _ = torch.solve(embedding.t(), A)
        beta = torch.mm(xtx_inv_xt, torch.unsqueeze(target, dim=1).float())

        loss = 0
        for test_data, test_target in finetune_test_loader:
            test_keep = torch.tensor(np.isin(test_target, classes))
            test_data, test_target = test_data[test_keep], test_target[test_keep]
            test_target[test_target == classes[0]] = 0
            test_target[test_target == classes[1]] = 1
            scores = torch.squeeze(torch.sigmoid(torch.mm(network.embed(test_data), beta)))
            loss += F.binary_cross_entropy(scores, test_target.float())
        loss.backward()
        optimizer.step()
        print('Finetune Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(finetune_train_loader.dataset),
            100. * batch_idx / len(finetune_train_loader), loss.item()))
        torch.save(network.state_dict(), f'{args.outdir}/finetuned_embedder.pth')

def finetune_test(classes):
    network.eval()
    with torch.no_grad():
        for data, target in finetune_train_loader:
            keep = torch.tensor(np.isin(target, classes))
            data, target = data[keep], target[keep]
            target[target == classes[0]] = 0
            target[target == classes[1]] = 1
            noise = torch.tensor(np.random.choice(2, len(target), p=[0.51, 0.49]))
            target = np.abs(target-noise)

            embedding = network.embed(data)
            A = torch.mm(embedding.t(), embedding)
            xtx_inv_xt, _ = torch.solve(embedding.t(), A)
            beta = torch.mm(xtx_inv_xt, torch.unsqueeze(target, dim=1).float())

            test_loss = 0
            correct = 0
            total = 0
            for test_data, test_target in finetune_test_loader:
                test_keep = torch.tensor(np.isin(test_target, classes))
                test_data, test_target = test_data[test_keep], test_target[test_keep]
                total += len(test_target)
                test_target[test_target == classes[0]] = 0
                test_target[test_target == classes[1]] = 1
                scores = torch.squeeze(torch.sigmoid(torch.mm(network.embed(test_data), beta)))
                correct += ((scores > 0.5) == test_target.byte()).sum()
                test_loss += F.binary_cross_entropy(scores, test_target.float())
            print('\nTest set: loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, total, 100. * correct / total))
            break


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('outdir', type=str)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log-interval', type=int, default=10)
    parser.add_argument('--resume', action='store_true', default=False)
    parser.add_argument('--finetune-epochs', type=int, default=10)
    args = parser.parse_args()

    batch_size_test = 1000
    torch.backends.cudnn.enabled = False
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    x_train = torch.tensor(np.load('../../data/binarymnist/train/xtrain_no17.npy'))
    y_train = np.load('../../data/binarymnist/train/ytrain_no17.npy')

    x_test = torch.tensor(np.load('../../data/binarymnist/test/xtest_no17.npy'))
    y_test = np.load('../../data/binarymnist/test/ytest_no17.npy')

    train_loader = torch.utils.data.DataLoader(Dataset(x_train, y_train), batch_size=args.batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(Dataset(x_test, y_test), batch_size=batch_size_test, shuffle=False)

    network = Embedder()
    optimizer = optim.SGD(network.parameters(), lr=args.lr, momentum=0.9)

    if args.resume:
        network.load_state_dict(torch.load(f'{args.outdir}/embedder.pth'))
        optimizer.load_state_dict(torch.load(f'{args.outdir}/optimizer.pth'))
    test_pretrained()
    for epoch in range(1, args.epochs + 1):
        train_pretrained(epoch)
        test_pretrained()

    optimizer = optim.Adam(network.parameters(), lr=args.lr / 10)
    finetune_train_loader = torch.utils.data.DataLoader(Dataset(x_train, y_train), batch_size=args.batch_size*50, shuffle=True)
    finetune_test_loader = torch.utils.data.DataLoader(Dataset(x_test, y_test), batch_size=batch_size_test*5, shuffle=False)

    for i in range(0):
        classes = np.random.choice(list(digit_map.keys()), 2, replace=False)
        print(classes)
        for epoch in range(1, args.finetune_epochs + 1):
            finetune_test(classes)
            finetune(epoch, classes)
            finetune_test(classes)
            test_pretrained()