import time
import numpy as np
import matplotlib.pyplot as plt
import math
import argparse

import torch
import torchvision
import torchvision.transforms as transforms

from CubeClassifierNet import *

parser = argparse.ArgumentParser()
parser.add_argument("--save", type=int, default=1)
parser.add_argument('--data_path', type=str, default='/afs/cs.cmu.edu/academic/class/15494-s20/projects/cubes/common_dataset')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--epoches', type=int, default=200)
args = parser.parse_args()


class RightHalfGrayScaleImage():
  def __call__(self, pillow_image):
    width,height = pillow_image.size
    gray = pillow_image.convert(mode='L')
    cropped =  gray.crop((width//2, 0, width, height))
    return cropped


def GetData(path, batch_size, save):
    transform = transforms.Compose([
      RightHalfGrayScaleImage(),
      transforms.RandomCrop((240,160), padding=10, padding_mode='edge'),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.450], std=[0.230], inplace=True)
    ])

    augmentation_transform = transforms.Compose([
        RightHalfGrayScaleImage(),
        transforms.RandomCrop((240,160), padding=10, padding_mode='edge'),
        transforms.RandomAffine(degrees=30, translate=None, scale=None, shear=None),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.450], std=[0.230], inplace=True)
    ])

    data = torchvision.datasets.folder.ImageFolder(root=path, transform=transform)
    aug_data = torchvision.datasets.folder.ImageFolder(root=path, transform=augmentation_transform)


    # Train on whole dataset
    if save:
        train_data = torch.utils.data.ConcatDataset([data,aug_data])
        train_num = len(train_data)
        trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
        return trainloader, train_num

    else:
        train_num = int(0.7 * len(data))
        test_num = len(data) - train_num
        train_data, valid_data = torch.utils.data.random_split(data, [train_num, test_num])

        # Concate augmentation for training
        train_data = torch.utils.data.ConcatDataset([train_data,aug_data])
        train_num = len(train_data)

        trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
        testloader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=True)

        print('train_num:', train_num, ' test_num:', test_num)
        return trainloader, testloader, train_num, test_num


def train(net, trainloader, train_num, epoches):
    # Training
    net.train()
    for epoch in range(epoches):
        total_loss = 0.0
        correct = 0.0 
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            inputs = inputs.type(dtype=torch.float)
            labels = labels.type(dtype=torch.float).view(-1,1)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            predict = (outputs>0.5).view(-1).float().cpu()
            labels = labels.view(-1)
            correct += np.array( predict == labels.cpu(), dtype='int').sum()

        print('%d total loss: %9.4f  correct %5.3f'%(epoch+1, total_loss/math.ceil(train_num/args.batch_size), correct/train_num*100))
    print('Finished Training')


def evaluate(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)

            predicted = (outputs>0.5).view(-1).float().cpu()
            labels = labels.view(-1)

            total += labels.size(0)
            correct += (predicted == labels.cpu()).sum().item()

    print('Accuracy of the network on the %d test images: %5.3f %%' % ( total, 100 * correct / total))
    print('correct:', correct, 'total:', total)


if __name__ == "__main__":
    device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')
    print('device:', device)

    # Model
    net = CubeClassifierNet().to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-3)

    # Train and evaluation
    if args.save:
        trainloader, train_num = GetData(args.data_path, args.batch_size, args.save)
        train(net, trainloader, train_num, args.epoches)
        net = net.to('cpu')
        torch.save(net.state_dict(), "cubeDetection_weight.h5")
    else:
        trainloader, testloader, train_num, test_num = GetData(args.data_path, args.batch_size, args.save)
        train(net, trainloader, train_num, args.epoches)
        evaluate(net, testloader)


