import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from math import pi

class OneHiddenLayerModel(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super(OneHiddenLayerModel, self).__init__()
        self.in_to_hidden = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.Tanh()
        )
        self.hidden_to_out = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        self.hidden = self.in_to_hidden(x)
        self.out = self.hidden_to_out(self.hidden)
        return self.out

def encoder(patterns=5):
    global NPATS, one_hot, model
    NPATS = patterns
    labels = np.array(range(NPATS))
    one_hot = 2 * torch.tensor([labels==i for i in range(NPATS)], dtype=torch.float) - 1
    nhiddens = 2
    model = OneHiddenLayerModel(NPATS,NPATS,nhiddens)

    learn_rate = 0.2
    momentum = 0.5
    model.optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate, momentum=momentum)
    model.criterion = nn.MSELoss()
    train()

def train(epochs=100001):
    global one_hot, model
    for epoch in range(epochs):
        out = model(one_hot)
        model.optimizer.zero_grad()
        loss = model.criterion(out, one_hot)
        loss.backward()
        model.optimizer.step()
        if (epoch % 50) == 0:
            plot_hiddens()
            print('Epoch %4d' % epoch, ' loss = ', loss.item())
            if loss.item() < 0.01:
                break

def plot_hiddens():
    global NPATS, model
    colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'black', 'tab:orange', 'tab:brown', 'tab:purple', '#a0ffa0']
    plt.cla()
    plt.axis([-1.5, 1.5, -1.5, 1.5])
    plt.xlabel('Hidden Unit 1')
    plt.ylabel('Hidden Unit 2')
    hidden_patterns = model.hidden.detach().numpy()
    for i in range(model.hidden.shape[0]):
        plt.plot(hidden_patterns[i][0], hidden_patterns[i][1], 'o', color=colors[i])

    params = list(model.hidden_to_out.parameters())
    weights = params[0].detach().numpy()
    biases = params[1].detach().numpy()
    xrange = [-1.5, 1.5]
    for i in range(NPATS):
        w1 = weights[i][0]
        w2 = weights[i][1]
        b = biases[i]
        p0 = -(b + w1*xrange[0])/w2
        p1 = -(b + w1*xrange[1])/w2
        plt.plot(xrange, [p0, p1], color=colors[i])
    plt.pause(0.001)


encoder(5)
