import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class ALVINNData(Dataset):
    def __init__(self, patterns, desired, transform=None):
        self.transform = transform
        self.data = list(zip(patterns,desired))

    def __len__(self):
        return len(self.data)

    def __getitem__(self,index):
        return self.data[index]

patterns = np.loadtxt('../data/Patterns.dat').astype('float32')
desired = np.loadtxt('../data/Desired.dat').astype('float32')
trainset = ALVINNData(patterns,desired)
batchSize = 10
trainloader = DataLoader(dataset=trainset,  batch_size=batchSize, shuffle=True)

test_patterns = np.loadtxt('../data/CVPatterns.dat').astype('float32')
test_desired = np.loadtxt('../data/CVDesired.dat').astype('float32')
testset = ALVINNData(test_patterns, test_desired)

twolane_patterns = np.loadtxt('../data/TwoLanePatterns.dat').astype('float32')
twolane_desired = np.loadtxt('../data/TwoLaneDesired.dat').astype('float32')
twolaneset = ALVINNData(twolane_patterns, twolane_desired)

gaussians = np.loadtxt('../data/Gaussians.dat')

#---------------- Fully Connected One Hidden Layer ----------------

class OneHiddenLayer(nn.Module):
  def __init__(self, in_dim, out_dim, nhiddens):
    super(OneHiddenLayer, self).__init__()
    self.network = nn.Sequential(
      nn.Linear(in_dim, nhiddens),
      nn.ReLU(),
      nn.Linear(nhiddens, out_dim))

  def forward(self, x):
    out = self.network(x)
    return out

# Uncomment just one of the two lines below:
#device = torch.device(0)        # GPU board
device = torch.device('cpu')    # regular CPU

in_dim = 30 * 32
out_dim = 30
nhiddens = 5
model = OneHiddenLayer(in_dim, out_dim, nhiddens).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.05)

epoch = 0
nepochs = 100
outputs = None

def train_alvinn():
  global epoch, patterns, outputs
  #patterns = np.array([item[0] for item in trainset.data])
  for i in range(nepochs):
    epoch += 1
    runningLoss = 0.0
    for (images,labels) in trainloader:
      images = images.to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(outputs, labels)
      runningLoss += loss.item() * images.shape[0] * out_dim
      loss.backward()
      optimizer.step()
    # End of epoch
    print('epoch %3d:' % epoch, '  mean loss = %7.5f' % (runningLoss / patterns.shape[0]))
    if epoch % 10 == 0:
        outputs = model(torch.tensor(patterns)).detach()
        show_pattern(0, trainset, outputs)
        show_hiddens()

def test_alvinn():
    outputs = model(torch.tensor(test_patterns))
    labels = torch.tensor(test_desired)
    loss = criterion(outputs, labels).item() * out_dim
    print('Loss on', test_patterns.shape[0], 'test patterns is', loss)
    outputs = outputs.detach()
    print ('Pattern  OARE')
    for i in range(len(testset.data)):
        _,oare,_ = closest_gaussian(outputs[i].numpy())
        print('%3d  %f' % (i, oare))
        show_pattern(i, testset, outputs)

def closest_gaussian(output):
    ngauss = gaussians.shape[1]
    outs = np.zeros((30,ngauss))
    outs[:,:] = output.reshape(30,1)
    sse = ((gaussians-outs)**2).sum(0) # sum-squared error
    closest_index = sse.argmin()
    return gaussians[:,closest_index], sse[closest_index], closest_index

#---------------- Display ----------------

def show_pattern(n, data=trainset.data, outs=None):
    global outputs
    if outs is None:
        outs = outputs
    if isinstance(outs, torch.Tensor):
        outs = outs.detach().numpy()
    pattern = data[n][0].reshape(32,30).T
    desired = data[n][1].reshape(1,30)
    plt.figure(1)
    plt.clf()
    ap = plt.axes((0, 0.05, 0.9, 0.8))
    plt.imshow(pattern, cmap='gray')
    plt.axis('off')
    ad = plt.axes((0.125, 0.85, 0.65, 0.1))
    plt.imshow(desired, cmap='gray')
    plt.axis('off')
    if outs is not None:
        output = outs[n].reshape(1,30)
        sum_sq_error = ((desired-outs[n])**2).sum()
        ao = plt.axes((0.125, 0.90, 0.65, 0.1))
        plt.imshow(output)
        plt.axis('off')
        plt.show(block=False)
        plt.draw()
        plt.pause(0.001)
        plt.figure(3)
        plt.clf()
        plt.plot(outputs[n])
        p1,p2,p3 = closest_gaussian(outs[n])
        plt.plot(p1)
        plt.legend(('actual','ideal'))
        plt.xlabel('Pattern %d:  error=%f  OARE error=%f' % (n, sum_sq_error, p2))
    plt.draw()
    plt.pause(0.001)

def show_hidden(n, ax=None):
    p = list(model.parameters())
    imin = p[0].min()
    imax = p[0].max()
    in_weights = p[0][n].view(32,30).T.detach().numpy()
    omin = p[2].min()
    omax = p[2].max()
    out_weights = p[2][:,n].view(1,30).detach().numpy()
    if ax is None:
        plt.figure(2)
        plt.clf()
        ax_in = plt.axes((0, 0.05, 0.9, 0.8))
        plt.ylabel('Epoch %d   Hidden %d' % (epoch,n))
        ax_out = plt.axes((0.125, 0.85, 0.65, 0.1))
    else:
        ax_in = ax
        ax_in.axis('off')
        bbox = ax_in.get_position()
        ax_out = plt.axes((bbox.xmin, bbox.ymin, bbox.xmax-bbox.xmin, 0.7))
    ax_in.imshow(in_weights, vmin=imin, vmax=imax, aspect='equal')
    ax_out.imshow(out_weights, vmin=omin, vmax=omax)
    ax_out.axis('off')
    if ax is None:
        plt.draw()
        plt.pause(0.001)

def show_hiddens():
    plt.figure(2)
    plt.clf()
    for i in range(nhiddens):
        ax = plt.subplot(2,3,i+1)
        show_hidden(i,ax)
    ax=plt.subplot(2,3,nhiddens+1)
    ax.axis('off')
    ax.text(0.2, 0.5, 'Epoch %d' % epoch)
    plt.draw()
    plt.pause(0.001)

def show_hiddens_seq():
    for i in range(nhiddens):
        show_hidden(i)

train_alvinn()

