import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

from CubeDetector2 import CubeDetector2

class RightHalfGrayScaleImage():
  def __call__(self, pillow_image):
    width,height = pillow_image.size
    gray = pillow_image.convert(mode='L')
    cropped =  gray.crop((width//2, 0, width, height))
    return cropped

preprocess = transforms.Compose([
  RightHalfGrayScaleImage(),
  transforms.RandomCrop((240,160), padding=10, padding_mode='edge'),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.450], std=[0.230], inplace=True)
])

trainset = torchvision.datasets.folder.ImageFolder(
  root="./common_dataset",
  transform=preprocess
)

batchSize = 250
trainloader = torch.utils.data.DataLoader(dataset=trainset,
                                          batch_size=batchSize,
                                          shuffle=True)

# For debugging...
pats = trainloader.__iter__().__next__()

device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')

in_dim = (160,240)  # right half of Cozmo camera image with 10 pixel crop margin

model = CubeDetector2(in_dim).to(device)
criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

nepochs = 250
epoch = 0
print(len(trainset.samples))
def train_model():
  global epoch
  npats = len(trainset.samples)
  print('Epoch Cum.Loss   Time       Correct  % Correct')
  for i in range(nepochs):
    epoch += 1
    runningLoss = 0.0
    now = time.time()
    correct = 0.0 
    for (images,labels) in trainloader:
      images = images.to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(images)
      #print(outputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      runningLoss += loss.item() * batchSize
      output2 = torch.argmax(outputs, dim = 1)
      correct += (output2 == labels).sum().item()
      #print(output2)
      #correct += np.array(outputs[:,1].sign().cpu() == (labels*2-1).cpu(), dtype='int').sum()
    print('%03d  %9.4f  time=%5.4f %5d    %5.3f' %
          (epoch, runningLoss,  time.time()-now, correct, correct/npats*100))

def show_pattern(n=0):
  patterns = trainloader.__iter__().__next__()
  img = np.array(patterns[0][n].view(-1,160))
  plt.imshow(img, cmap='gray')
  plt.xlabel('Class: %s' % ['Cube','No Cube'][patterns[1][n]])
  plt.pause(0.5)

def show_kernel1(k=0):
  kernels1 = list(model.parameters())[0]
  kernel = kernels1.detach().to(torch.device('cpu'))[k,0,:,:]
  plt.imshow(kernel, cmap='gray')
  plt.pause(0.5)

def show_kernels1():
    kernels1 = list(model.parameters())[0].detach().to(torch.device('cpu'))
    minval = kernels1.min().item()
    maxval = kernels1.max().item()
    print('minval=', minval, '   maxval=', maxval)
    plt.clf()
    for i in range(min(36, kernels1.shape[0])):
      plt.subplot(6,6,i+1)
      plt.imshow(kernels1[i,0,:,:], cmap='gray', vmin=minval, vmax=maxval)
      plt.axis('off')
    plt.pause(0.5)

def save_model():
  PATH = "mymodel.pth"
  torch.save(model.state_dict(), PATH)