import time
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

# Uncomment this to force THCudaCheck FAIL to throw an exception.
# torch.backends.cudnn.benchmark=True

trainset = torchvision.datasets.MNIST(root='./mnist_data',
                                      download=True,
                                      transform=transforms.ToTensor())

# trainset.train_data is 60,000 tuples of form ( <28x28 tensor>, <scalar tensor> )
print(trainset)

class TwoConvLayers(nn.Module):
  def __init__(self):
    super(TwoConvLayers, self).__init__()
    self.network1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.BatchNorm2d(16),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2),
    )
    self.network2 = nn.Sequential(
      nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.network3 = nn.Linear(32*7*7, 10)


  def forward(self, x):
    self.out1 = self.network1(x)
    self.out2 = self.network2(self.out1)
    self.out3 = self.network3(self.out2.view(self.out2.size(0),-1))
    return self.out3

# Uncomment just one of the two lines below:
device = torch.device(0)        # GPU board
#device = torch.device('cpu')    # regular CPU

in_dim = 28 * 28
out_dim = 10
nkernels = 32
batch_size = 100

model = TwoConvLayers().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

def train():
    epochs = 15
    print('Loading...')
    trainloader = torch.utils.data.DataLoader(dataset=trainset,  batch_size=batch_size, shuffle=True)
    print('Loading complete.')

    for epoch in range(epochs):
      runningLoss = 0.0
      now = time.time()
      correct = 0.0 
      for (images,labels) in trainloader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        predict = outputs.argmax(1)
        correct += (predict==labels).sum()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        runningLoss += loss.item() * batch_size
      print(epoch, runningLoss, 'time=', time.time()-now, end='')
      print('  num_correct=', correct.item(), '  %correct = ', correct.item()/60000*100, '%')
      test()
    torch.save(model.state_dict(), './mnist3-saved.pt')
    print('model saved')


def test(train=False):
    global out, pred, succ, p, d
    if train:
      model.train()
    else:
      model.eval()
    succ = 0.0
    for i in range(0, 60000, batch_size):
      p=trainset.train_data[i:i+batch_size].to(0).float().view(batch_size,1,28,28)
      d=trainset.train_labels[i:i+batch_size].to(0).float()
      out = model.forward(p)
      pred = out.argmax(1).to(0).float()
      succ += (pred==d).sum().item()
    print('num_correct=', succ, ' success rate:', succ/60000.0*100.0)

import matplotlib.pyplot as plt

def display():
    plt.figure()
    weights = model.network1.parameters().__next__().detach().cpu()
    for i in range(16):
      plt.imshow(weights[i].view(5,5))
      plt.show(block=False)
      plt.pause(1)

