import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class ALVINNData(Dataset):
    def __init__(self, patterns, desired, transform=None):
        self.transform = transform
        self.data = list(zip(patterns,desired))

    def __len__(self):
        return len(self.data)

    def __getitem__(self,index):
        return self.data[index]

patterns = np.loadtxt('../data/Patterns.dat').astype('float32')
desired = np.loadtxt('../data/Desired.dat').astype('float32')
trainset = ALVINNData(patterns.reshape(-1,1,32,30), desired)
batchSize = 10
trainloader = DataLoader(dataset=trainset,  batch_size=batchSize, shuffle=True)

test_patterns = np.loadtxt('../data/CVPatterns.dat').astype('float32')
test_desired = np.loadtxt('../data/CVDesired.dat').astype('float32')
testset = ALVINNData(test_patterns, test_desired)

twolane_patterns = np.loadtxt('../data/TwoLanePatterns.dat').astype('float32')
twolane_desired = np.loadtxt('../data/TwoLaneDesired.dat').astype('float32')
twolaneset = ALVINNData(twolane_patterns, twolane_desired)

#---------------- Fully Connected One Hidden Layer ----------------

class OneConvLayer(nn.Module):
  def __init__(self, in_dim, out_dim, nkernels):
    super(OneConvLayer, self).__init__()
    self.network1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=nkernels, kernel_size=5, stride=1, padding=2),
      nn.BatchNorm2d(nkernels),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2),
    )
    self.network2 = nn.Linear(nkernels*15*16, out_dim)


  def forward(self, x):
    out = self.network1(x)
    out = out.view(out.size(0), -1)
    out = self.network2(out)
    return out

# Uncomment just one of the two lines below:
#device = torch.device(0)        # GPU board
device = torch.device('cpu')    # regular CPU

in_dim = 30 * 32
out_dim = 30
nkernels = 5
model = OneConvLayer(in_dim, out_dim, nkernels).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, weight_decay=0.05)

epoch = 0
nepochs = 100
outputs = None

def train_alvinn():
  global epoch, patterns, outputs
  pats = patterns.reshape(-1,1,32,30)
  for i in range(nepochs):
    epoch += 1
    runningLoss = 0.0
    for (images,labels) in trainloader:
      images = images.to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(outputs, labels)
      runningLoss += loss.item() * images.shape[0] * out_dim
      loss.backward()
      optimizer.step()
    # End of epoch
    print('epoch %3d:' % epoch, '  mean loss = %7.5f' % (runningLoss / pats.shape[0]))
    if epoch % 10 == 0:
        outputs = model(torch.tensor(pats)).detach()
        show_pattern(0, trainset, outputs)
        show_kernel(0)

def test_alvinn():
    """ THIS CODE NEEDS TO BE REWRITTEN FOR THE CNN ARCHITECTURE
    outputs = model(torch.tensor(test_patterns))
    labels = torch.tensor(test_desired)
    loss = criterion(outputs, labels).item() * out_dim
    print('Loss on test set is', loss)
    outputs = outputs.detach()
    for i in range(len(testset.data)):
        show_pattern(i, testset, outputs)
"""

#---------------- Display ----------------

def show_pattern(n, data=trainset.data, outs=None):
    global outputs
    if outs is None:
        outs = outputs
    pattern = trainset.data[n][0].reshape(32,30).T
    desired = trainset.data[n][1].reshape(1,30)
    plt.figure(1)
    plt.clf()
    ap = plt.axes((0, 0.05, 0.9, 0.8))
    plt.imshow(pattern, cmap='gray')
    plt.axis('off')
    ad = plt.axes((0.125, 0.85, 0.65, 0.1))
    plt.imshow(desired, cmap='gray')
    plt.axis('off')
    if outputs is not None:
        output = outputs[n].reshape(1,30)
        ao = plt.axes((0.125, 0.90, 0.65, 0.1))
        plt.imshow(output)
        plt.axis('off')
    plt.show(block=False)
    plt.pause(1.0)

def show_kernel(n):
    """Show the 5x5 kernel weights"""
    p = list(model.parameters())
    print('show_kernel not yet implemented')

def show_kernels():
    for i in range(nkernels):
        show_kernel(i)
