import torch
import torch.utils.data.dataset
import torch.nn as nn
import torchvision
import torchvision.transforms
import numpy as np
import scipy.linalg
import torch.nn.functional as F

cuda = torch.cuda.is_available()
np.random.seed(1)
torch.manual_seed(1)
if cuda:
    torch.cuda.manual_seed(0)
num_features = 784
lr = 5e-3

class Classifier(nn.Module):
    def __init__(self, num_features):
        super(Classifier, self).__init__()
        self.classifier = nn.Linear(num_features, 10)

    def forward(self, x):
        return self.classifier(x)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

def mnist_net():
    model = nn.Sequential(
        nn.Conv2d(1, 16, 4, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 32, 4, stride=2, padding=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(32*7*7, num_features),
        nn.ReLU(),
    )
    return model

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('../../data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=512, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('../../data', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=512, shuffle=False)

model = mnist_net()
classifier = Classifier(num_features)
if cuda:
    model = model.cuda()
    classifer = classifier.cuda()
optimizer = torch.optim.Adam(classifier.parameters(), lr=5e-3)

classifier.train()
for epoch in range(5):
    train_loss = train_acc = train_n = 0.
    for X, y in train_loader:
        if cuda:
            X, y = X.cuda(), y.cuda()
        features = X.reshape(-1, 784)
        output = classifier(features)
        loss = F.cross_entropy(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_n += len(X)
        train_loss += loss.item() * len(X)
        train_acc += (output.argmax(1) == y).sum().item()
    print(f'Train loss: {train_loss / train_n}\tTrain acc: {train_acc / train_n}')
classifier.eval()
test_loss = test_acc = test_n = 0.
for X, y in test_loader:
    if cuda:
        X, y = X.cuda(), y.cuda()
    features = X.reshape(-1, 784)
    output = classifier(features)
    loss = F.cross_entropy(output, y)
    test_n += len(X)
    test_loss += loss.item() * len(X)
    test_acc += (output.argmax(1) == y).sum().item()
print(f'Test loss: {test_loss / test_n}\tTest acc: {test_acc / test_n}')



