import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt

from mnist3 import OneConvLayer

in_dim = 28 * 28
out_dim = 10
nkernels = 32
device = torch.device(0)        # GPU board

model = OneConvLayer(in_dim, out_dim, nkernels).to(device)

model.load_state_dict(torch.load('./mnist3-saved.pt'))
model.eval()

trainset = torchvision.datasets.MNIST(root='./mnist_data',
                                      download=True,
                                      transform=transforms.ToTensor())

d2 = trainset.train_data[2,:,:]
plt.imshow(d2.numpy(), cmap=plt.gray())
plt.show(block=False)

result = model.forward(d2.view(1,1,28,28).float().to(device))

print(result)

print('Best label is', result.argmax().item())
