import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

from CubeDetector1 import CubeDetector1

class RightHalfGrayScaleImage():
  def __call__(self, pillow_image):
    width,height = pillow_image.size
    gray = pillow_image.convert(mode='L')
    cropped =  gray.crop((width//2, 0, width, height))
    return cropped

preprocess = transforms.Compose([
  RightHalfGrayScaleImage(),
  transforms.RandomCrop((240,160), padding=10, padding_mode='edge'),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.450], std=[0.230], inplace=True)
])

trainset = torchvision.datasets.folder.ImageFolder(
  root="./common_dataset",
  transform=preprocess
)

batchSize = 250
trainloader = torch.utils.data.DataLoader(dataset=trainset,
                                          batch_size=batchSize,
                                          shuffle=True)
reviewloader = torch.utils.data.DataLoader(dataset=trainset,
                                          batch_size=10,
                                          shuffle=True)

# For debugging...
pats = trainloader.__iter__().__next__()

device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')

in_dim = (160,240)  # right half of Cozmo camera image with 10 pixel crop margin

model = CubeDetector1(in_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
#momentum=0.9

nepochs = 250
epoch = 0

def train_model():
  global epoch, model
  model = model.to(device)
  npats = len(trainset.samples)
  print('Epoch Cum.Loss   Time       Correct  % Correct')
  for i in range(nepochs):
    epoch += 1
    runningLoss = 0.0
    now = time.time()
    correct = 0.0 
    for (images,labels) in trainloader:
      images = images.to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      runningLoss += loss.item() * batchSize
      correct += np.array(outputs[:,1].sign().cpu() == (labels*2-1).cpu(), dtype='int').sum()
    print('%03d  %9.4f  time=%5.4f %5d    %5.3f' %
          (epoch, runningLoss,  time.time()-now, correct, correct/npats*100))

def save_weights(filename="cube_weights.pt"):
  global model
  torch.save(model.to('cpu').state_dict(), filename)
  print('Weights saved to', filename)
  model = model.to(device)

def load_weights(filename="cube_weights.pt"):
  model.load_state_dict(torch.load(filename))
  print('Weights loaded from', filename)

def review_false_positives():
  review_wrong_answers(correct=1)

def review_false_negatives():
  review_wrong_answers(correct=0)

def review_wrong_answers(correct):
  global model
  error_type = 'Wrongly said %scube: ' % ['no ',''][correct]
  model = model.to(torch.device('cpu'))
  for (images,labels) in trainloader:
    images = images.to(torch.device('cpu'))
    labels = labels.to(torch.device('cpu'))
    outputs = model(images)
    outclasses = outputs.argmax(1)
    outputs = outputs.detach().numpy()
    for i in range(len(labels)):
      if (labels[i] == correct) and labels[i] != outclasses[i]:
        print(error_type, 'cube = %7.4f    nocube = %7.4f' % (outputs[i,0], outputs[i,1]))
        plt.clf()
        plt.imshow(images[i,0,:,:].cpu(), cmap='gray')
        plt.pause(0.5)
        input('Press the Enter key to proceed...')

def show_pattern(n=0):
  patterns = trainloader.__iter__().__next__()
  img = np.array(patterns[0][n].view(-1,160))
  plt.imshow(img, cmap='gray')
  plt.xlabel('Class: %s' % ['Cube','No Cube'][patterns[1][n]])
  plt.pause(0.5)

def show_kernel1(k=0):
  kernels1 = list(model.parameters())[0]
  kernel = kernels1.detach().to(torch.device('cpu'))[k,0,:,:]
  plt.imshow(kernel, cmap='gray')
  plt.pause(0.5)

def show_kernels1():
    kernels1 = list(model.parameters())[0].detach().to(torch.device('cpu'))
    minval = kernels1.min().item()
    maxval = kernels1.max().item()
    print('minval=', minval, '   maxval=', maxval)
    plt.clf()
    for i in range(min(36, kernels1.shape[0])):
      plt.subplot(6,6,i+1)
      plt.imshow(kernels1[i,0,:,:], cmap='gray', vmin=minval, vmax=maxval)
      plt.axis('off')
    plt.pause(0.5)


#load_weights()
#review_false_negatives()
