import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
import time

NPatterns = 8

patterns = torch.tensor(np.eye(NPatterns)*2-1, dtype=torch.float)

class N2N_Encoder(nn.Module):
    def __init__(self, NPats):
        super(N2N_Encoder, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(NPats,2),
            nn.Tanh()
            )
        self.layer2 = nn.Sequential(
            nn.Linear(2,NPats),
            nn.Tanh()
            )

    def forward(self, x):
        self.out1 = self.layer1(x)
        self.out2 = self.layer2(self.out1)
        return self.out2

model = N2N_Encoder(NPatterns)

def train():
    model.__init__(NPatterns)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.5, momentum=0.5)
    criterion = nn.MSELoss()
    print('Press Enter to continue...')
    for epoch in range(5000):
        optimizer.zero_grad()
        outputs = model(patterns)
        loss = criterion(outputs, patterns)
        loss.backward()
        print(epoch, 'Loss=', loss.item())
        optimizer.step()
        if (epoch % 100) == 0:
            show_hidden_patterns()
    model.eval()

plt.ion()

def show_hidden_patterns():
    plt.cla()
    for i in range(NPatterns):
        pat = model.out1[i,:].detach().numpy()
        plt.plot(pat[0], pat[1], 'o')
    plt.xlabel('Hidden 1')
    plt.ylabel('Hidden 2')
    plt.show(block=False)
    input('continue?')
