import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

device = 'cuda'

"""
folder = datasets.ImageFolder(root='cube_cropped', transform=transforms.ToTensor())
images = [im[0].permute(1,2,0).numpy() for im in folder]

for im in images:
    for i in range(im.shape[0]):
        for j in range(im.shape[1]):
            if (im[i,j,0] >= 0.8) and (im[i,j,2] >= 0.8) and (im[i,j,1] <= 0.2):
                im[i,j] = [0.5, 0.5, 0.5]

onechan = [im[:,:,0] for im in images]
NPatterns = len(onechan)

traindata = torch.tensor(onechan).view(NPatterns, -1)

plt.imshow(traindata[0].view(240,320).numpy()
plt.show(block=False)
"""

NKernels = 16
ImageSize = 320*240
KernelSize = 8
#KernelSize = 16


def reorder(patterns):
    """
    Rearrange the pixels of a 320x240 training image into blocks
    of size K**2, where K is the kernel size.  This matches
    the output of the Conv2D layer.
    """
    rows = int(240/KernelSize)
    cols = int(320/KernelSize)
    num_pixels = KernelSize**2
    num_patterns = patterns.size()[0]
    result = torch.zeros(num_patterns, num_pixels, rows, cols).to(0)
    for p in range(num_patterns):
        for i in range(rows):
            for j in range(cols):
                ix = i * KernelSize
                jx = j * KernelSize
                subimage = patterns[p, 0, ix:ix+KernelSize, jx:jx+KernelSize].resize(KernelSize**2)
                result[p,:,i,j] = subimage
    return result

"""
patterns = torch.load('images.pt')
NPatterns = patterns.size()[0]
patterns = patterns.view(NPatterns, 1, 240, 320)

big_patterns = torch.zeros((NPatterns*16, 1, 240, 320))
cnt = 0
new_cnt = 0
for p in range(NPatterns):
    for i in (0,2,4,6):
        for j in (0,2,4,6):
            big_patterns[new_cnt, 0, :, :] = patterns[cnt, 0, :, :]
            big_patterns[new_cnt, 0, 0:240-i, 0:320-j] = patterns[cnt, 0, i:, j:]
            new_cnt += 1
    cnt += 1

patterns = big_patterns
"""

patterns = torch.load('patterns960.pt')
NPatterns = patterns.size()[0]
desired = reorder(patterns)

#================

class Encoder2(nn.Module):
    def __init__(self):
        super(Encoder2, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=NKernels, kernel_size=KernelSize, stride=KernelSize, padding=0),
            nn.ReLU()
            )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=NKernels, out_channels=KernelSize**2, kernel_size=1, stride=1, padding=0)
            )

    def forward(self, x):
        self.out1 = self.layer1(x)
        self.out2 = self.layer2(self.out1)
        return self.out2

m = torch.load('model960.pt')
model = Encoder2().to(device)
model.load_state_dict(m)

patterns = patterns.to(device)
desired = desired.to(device)

def train(lr=0.01, momentum=0.5, epochs=3000):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.MSELoss()
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(patterns)
        loss = criterion(outputs, desired)
        loss.backward()
        optimizer.step()
        if (epoch % 100) == 0:
            print(epoch, 'Loss=', loss.item())
            pass # show_hidden_patterns()
    model.eval()
    show_output(0)

def show_output(n):
    outputs = model(patterns)
    p0 = outputs[n].detach().cpu().numpy()
    num_pixels, rows, cols = p0.shape
    image = np.zeros((240,320))
    for i in range(rows):
        for j in range(cols):
            ix = i * KernelSize
            jx = j * KernelSize
            subimage = p0[:,i,j].copy()
            subimage.resize(KernelSize,KernelSize)
            image[ix:ix+KernelSize, jx:jx+KernelSize] = subimage    
    plt.clf()
    plt.imshow(image, cmap='gray')
    plt.show(block=False)

def show_pattern(n):
    p0 = patterns[n].detach().cpu().numpy()
    p0.resize(240,320)
    plt.clf()
    plt.imshow(p0, cmap='gray')
    plt.show(block=False)


def show_kernel(n):
    params = list(model.parameters())
    k = params[0].detach().cpu().numpy()[n,0]
    plt.clf()
    plt.imshow(k, cmap='RdBu')
    plt.colorbar()
    plt.show(block=False)


def show_encoding_slice(n, c=None):
    if c is None:
        c = model.out1.size()[3]//2
    values = model.out1.detach().cpu()[n,:,:,c].view(NKernels,-1).transpose(0,1)
    plt.imshow(values)
    plt.show(block=False)

def show_encoding(n):
    values = model.out1.detach()[n].cpu()
    (kernels,rows,cols) = values.size()
    s = int(kernels**0.5)
    image = np.zeros((s*rows, s*cols))
    for i in range(rows):
        for j in range(cols):
            image[i*s:i*s+s, j*s:j*s+s] = values[:,i,j].view(s,s)
    plt.imshow(image)
    plt.show(block=False)


trainset = torch.load('trainset.pt')
nimages = trainset['train_data'].shape[0]
images = torch.tensor(trainset['train_data']).float().view(nimages, 1, 240, 320).to(device)
labels = torch.tensor(trainset['train_labels']).to(device)

def test_traindata():
    model.forward(images)
