import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

"""
folder = datasets.ImageFolder(root='cube_cropped', transform=transforms.ToTensor())
images = [im[0].permute(1,2,0).numpy() for im in folder]

for im in images:
    for i in range(im.shape[0]):
        for j in range(im.shape[1]):
            if (im[i,j,0] >= 0.8) and (im[i,j,2] >= 0.8) and (im[i,j,1] <= 0.2):
                im[i,j] = [0.5, 0.5, 0.5]

onechan = [im[:,:,0] for im in images]
NPatterns = len(onechan)

traindata = torch.tensor(onechan).view(NPatterns, -1)

plt.imshow(traindata[0].view(240,320).numpy()
plt.show(block=False)
"""

patterns = torch.load('images.pt')
NPatterns = patterns.size()[0]

#================

NHiddens = 100
ImageSize = 320*240

class Encoder1(nn.Module):
    def __init__(self):
        super(Encoder1, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(ImageSize,NHiddens),
            #nn.ReLU()
            nn.Tanh(),
            )
        self.layer2 = nn.Sequential(
            nn.Linear(NHiddens,ImageSize)
            )

    def forward(self, x):
        self.out1 = self.layer1(x)
        self.out2 = self.layer2(self.out1)
        return self.out2

class Encoder2(nn.Module):
    def __init__(self):
        super(Encoder2, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.Tanh()
            )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1)
            )

    def forward(self, x):
        out0 = self.layer1(x)
        self.out1 = out0.view(NPatterns, 16, 1, -1)
        self.out2 = self.layer2(self.out1)
        return self.out2

model = Encoder1().to('cuda')
#patterns = patterns.view(60,1,240,320)
#desired = patterns.view(60,240*320)
desired = patterns

patterns = patterns.to('cuda')
desired = desired.to('cuda')

def train(lr=0.1, momentum=0.9):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.MSELoss()
    print('Press Enter to continue...')
    for epoch in range(2000):
        optimizer.zero_grad()
        outputs = model(patterns).view(60,-1)
        loss = criterion(outputs, desired)
        loss.backward()
        optimizer.step()
        if (epoch % 100) == 0:
            print(epoch, 'Loss=', loss.item())
            pass # show_hidden_patterns()
    model.eval()
    show_output(0)

def show_output(n):
    outputs = model(patterns)
    p0 = outputs[n].detach().cpu().numpy()
    p0.resize(240,320)
    plt.imshow(p0, cmap=plt.gray())
    plt.show(block=False)

def show_pattern(n):
    p0 = patterns[n].detach().cpu().numpy()
    p0.resize(240,320)
    plt.imshow(p0, cmap=plt.gray())
    plt.show(block=False)
