import time
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

trainset = torchvision.datasets.MNIST(root='./mnist_data',
                                      download=True,
                                      transform=transforms.ToTensor())

# trainset.train_data is 60,000 tuples of form ( <28x28 tensor>, <scalar tensor> )
print(trainset)

class MultiLogisticModel(nn.Module):
  def __init__(self, in_dim, out_dim):
    super(MultiLogisticModel, self).__init__()
    self.linear = nn.Linear(in_dim, out_dim)

  def forward(self, x):
    out = self.linear(x)
    return out

# Uncomment just one of the two lines below:
device = torch.device(0)        # GPU board
#device = torch.device('cpu')    # regular CPU

in_dim = 28 * 28
out_dim = 10
model = MultiLogisticModel(in_dim, out_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

batchSize = 100
epochs = 15
print('Loading...')
trainloader = torch.utils.data.DataLoader(dataset=trainset,  batch_size=batchSize, shuffle=True)
print('Loading complete.')

for epoch in range(epochs):
  runningLoss = 0.0
  now = time.time()
  correct = 0.0 
  for (images,labels) in trainloader:
    images = images.view(-1, 28*28).to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    runningLoss += loss.item() * batchSize

    # Calculate success rate for this batch
    #outputs = outputs.detach()
    o = outputs.detach().cpu().numpy()
    for i in range(batchSize):
      #predict = torch.argmax(outputs[i,:])
      predict = np.argmax(o[i,:])
      if predict == labels[i]:
        correct += 1
  print(epoch, runningLoss, 'time=', time.time()-now, end='')
  print('  correct = ', correct/60000*100, '%')


import matplotlib.pyplot as plt

plt.figure()

weights = model.linear.weight.detach().cpu()

for i in range(10):
  plt.imshow(weights[i,:].view(28,28))
  plt.show(block=False)
  input('Press Enter to proceed...')




