import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt

from mnist3 import OneConvLayer

in_dim = 28 * 28
out_dim = 10
nkernels = 32
device = torch.device(0)        # GPU board

model = OneConvLayer(in_dim, out_dim, nkernels).to(device)

model.load_state_dict(torch.load('./mnist3.pt'))
model.eval()

trainset = torchvision.datasets.MNIST(root='./mnist_data',
                                      download=True,
                                      transform=transforms.ToTensor())

def test_mnist():
    num_correct = 0
    num_incorrect = 0
    num_total = 0

    for idx in range(60000):
            train_img = trainset.train_data[idx]
            pred = model.forward(train_img.view(1,1,28,28).float().to(device)).argmax()
            if pred == trainset.train_labels[idx].item():
                    num_correct += 1
            else:
                    num_incorrect += 1
            num_total += 1

    print("Num correct: {}/{}".format(num_correct, num_total))
    print("Accuracy: {}".format(num_correct/num_total))
