import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import RandomSampler, BatchSampler

import matplotlib.pyplot as plt

"""
To run this code, first run generate.py to generate the training set.
Then run test4.py and do train() until all patterns are correctly classified.
"""

device = 'cuda'

KernelSize1 = (8,8)
NKernels1 = 16

KernelSize2 = (3,4)
NKernels2 = 8

class CubeDetector(nn.Module):
    def __init__(self):
        super(CubeDetector, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=NKernels1, kernel_size=KernelSize1, stride=KernelSize1, padding=0),
            nn.BatchNorm2d(NKernels1),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=NKernels1, out_channels=NKernels2,
                      kernel_size=KernelSize2, stride=KernelSize2, padding=0),
            nn.BatchNorm2d(NKernels2),
            nn.ReLU()
            )
        self.layer3 = nn.Sequential(
            nn.Linear(10*10*NKernels2, 2)
        )
    
    def forward(self, x):
        npats = x.shape[0]
        with torch.no_grad():
            self.out1 = self.layer1(x)
        self.out2 = self.layer2(self.out1)
        self.out3 = self.layer3(self.out2.view(npats,-1))
        return self.out3

model = CubeDetector()

encoder_model_state = torch.load('model960.pt')
encoder_keys = ['layer1.0.weight', 'layer1.0.bias']

new_model_state = model.state_dict()
for k in encoder_keys:
    new_model_state[k] = encoder_model_state[k]
model.load_state_dict(new_model_state)
model = model.to(device)

trainset = torch.load('trainset.pt')
nimages = trainset['train_data'].shape[0]
images = torch.tensor(trainset['train_data']).float().view(nimages, 1, 240, 320).to(device)
labels = torch.tensor(trainset['train_labels']).to(device)
batch_size = 100

def train(lr=0.1, momentum=0.0, epochs=100):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    sampler = BatchSampler(RandomSampler(labels), batch_size=batch_size, drop_last=False)
    for epoch in range(epochs):
        running_loss = 0.0
        num_correct = 0
        for indices in sampler:
            im = images[indices]
            lab = labels[indices]
            optimizer.zero_grad()
            outputs = model.forward(im)
            predict = outputs.argmax(1)
            num_correct += (predict==lab).sum().item()
            loss = criterion(outputs, lab)
            running_loss += loss.item() * im.shape[0] 
            loss.backward()
            optimizer.step()
        print(epoch, 'loss =', running_loss, 'correct =', num_correct)
    print(num_correct, 'correct out of', images.shape[0])


def show_output(n):
    model.eval()
    outputs = model.forward(images[n].view(1,1,240,320))
    pred = outputs.argmax(1).item()
    label = 'cube' if pred==1 else 'no cube'
    plt.imshow(images[n].view(240,320).to('cpu'), cmap='gray')
    plt.show(block=False)
    print(label+':', outputs)

