# this file is based on code publicly available at
#   https://github.com/bearpaw/pytorch-classification
# written by Wei Yang.

import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime
import numpy as np
from train_utils import AverageMeter, accuracy, init_logfile, log
from torchvision import datasets, transforms
from torchvision.utils import save_image

parser = argparse.ArgumentParser()
# parser.add_argument('outdir', type=str, help='folder to save model and training log)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
                    help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=30,
                    help='How often to decrease learning by gamma.')
parser.add_argument('--gamma', type=float, default=0.1,
                    help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
args = parser.parse_args()

device = torch.device('cuda') if torch.cuda.is_available() and args.cuda else torch.device('cpu')

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Upsampling
        self.down = nn.Sequential(nn.Conv2d(3, 64, 3, 2, 1), nn.ReLU())
        # Fully-connected layers
        self.down_size = 32 // 2
        down_dim = 64 * (32 // 2) ** 2
        self.fc1 = nn.Sequential(
            nn.Linear(down_dim, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.ReLU(inplace=True),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, down_dim),
            nn.BatchNorm1d(down_dim),
            nn.ReLU(inplace=True),
        )
        # Upsampling
        self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, 3, 3, 1, 1))

    def forward(self, img):
        out = self.down(img)
        embed = self.fc1(out.view(out.size(0), -1))
        out = self.fc2(embed)
        out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
        return out, embed

def custom_loss(inputs, outputs, embed):
    error = torch.abs(outputs - inputs).mean()
    # l1 = torch.abs(embed).sum()
    return error# + l1


def main():
    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    train_dataset = datasets.CIFAR10("../data/cifar-data", train=True, download=True,
                                     transform=transforms.Compose([transforms.ToTensor(),
                                                                   transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                                        (0.2471, 0.2435, 0.2616))])
                                     )
    test_dataset = datasets.CIFAR10("../data/cifar-data", train=False, download=True,
                                     transform=transforms.Compose([transforms.ToTensor(),
                                                                   transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                                        (0.2471, 0.2435, 0.2616))])
                                    )
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch)

    model = AutoEncoder().to(device)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttestloss")

    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

    for epoch in range(args.epochs):
        before = time.time()
        train_loss = train(train_loader, model, custom_loss, optimizer, epoch)
        test_loss = test(test_loader, model, custom_loss)
        scheduler.step(epoch)
        after = time.time()

        log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
            epoch, str(datetime.timedelta(seconds=(after - before))),
            scheduler.get_lr()[0], train_loss, test_loss))

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(args.outdir, 'checkpoint.pth.tar'))



def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer, epoch: int):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, _) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)
        inputs = inputs.to(device)

        # compute output
        outputs, embeds = model(inputs)
        loss = criterion(inputs, outputs, embeds)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        losses.update(loss.item(), inputs.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses))

            save_image(torch.cat([inputs[:8], outputs[:8]]), f'reconstruction{i}.png')

    return losses.avg


def test(loader: DataLoader, model: torch.nn.Module, criterion):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    # switch to eval mode
    model.eval()

    with torch.no_grad():
        for i, (inputs, _) in enumerate(loader):
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = inputs.to(device)

            # compute output
            outputs, embeds = model(inputs)
            loss = criterion(inputs, outputs, embeds)

            # measure elapsed time
            losses.update(loss.item(), inputs.size(0))
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                    i, len(loader), batch_time=batch_time,
                    data_time=data_time, loss=losses))

        return losses.avg


if __name__ == "__main__":
    main()
