import time
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

import matplotlib.pyplot as plt

trainset = torchvision.datasets.MNIST(root='./mnist_data',
                                      download=True,
                                      transform=transforms.ToTensor())

# trainset.train_data is 60,000 tuples of form ( <28x28 tensor>, <scalar tensor> )
print(trainset)

class OneHiddenLayer(nn.Module):
  def __init__(self, in_dim, out_dim, nhiddens):
    super(OneHiddenLayer, self).__init__()
    self.network = nn.Sequential(
      nn.Linear(in_dim, nhiddens),
      nn.BatchNorm1d(nhiddens),
      nn.ReLU(),
      nn.Linear(nhiddens, out_dim)
      )

  def forward(self, x):
    out = self.network(x)
    return out

# Uncomment just one of the two lines below:
device = torch.device(0)        # GPU board
#device = torch.device('cpu')    # regular CPU
print('device is', device)

in_dim = 28 * 28
out_dim = 10
nhiddens = 20
model = OneHiddenLayer(in_dim, out_dim, nhiddens).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

batchSize = 100
epochs = 20
print('Loading...')
trainloader = torch.utils.data.DataLoader(dataset=trainset,  batch_size=batchSize, shuffle=True)
print('Loading complete.')

def train_model():
  for epoch in range(epochs):
    runningLoss = 0.0
    now = time.time()
    correct = 0.0 
    for (images,labels) in trainloader:
      images = images.view(-1, 28*28).to(device)
      labels = labels.to(device)
      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      runningLoss += loss.item() * labels.shape[0]

      # Calculate success rate for this batch
      outputs = outputs.detach()
      #o = outputs.detach().cpu().numpy()
      for i in range(batchSize):
        predict = torch.argmax(outputs[i,:])
        #predict = np.argmax(o[i,:])
        if predict == labels[i]:
          correct += 1
    # end of epoch
    print('{:2}  loss = {:8.2f}   time = {:.2f} s   correct = {:.2f} %'.format(
          epoch, runningLoss, time.time()-now, correct/60000*100))

def show_hidden_weights():
  params = list(model.network.parameters())
  weights = params[0].detach().cpu()
  for i in range(20):
    plt.clf()
    plt.imshow(weights[i,:].view(28,28))
    plt.xlabel(f'Hidden Unit {i}')
    plt.pause(0.001)
    input('Press Enter to proceed...')

def show_output_weights():
  params = list(model.network.parameters())
  weights = params[-2].detach().cpu()
  plt.clf()
  for i in range(10):
    plt.subplot(1,21,1+2*i)
    plt.imshow(weights[i,:].view(-1,1))
    plt.axis('off')
  for i in range(10):
    plt.text(-23.5+i*2.6, 21, f'{i}')
  plt.pause(0.01)


train_model()
plt.figure()
show_output_weights()
