import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from math import pi

N = 100
xvals = np.linspace(-pi, pi, num=N)
yvals = 0.28 + (-0.85*np.cos(2*xvals) * xvals * np.exp(-(0.6*xvals-.4)**2)) * 0.75

plt.plot(xvals,yvals,'b')
plt.show(block=False)
plt.pause(0.0001)

x_train = torch.tensor(xvals, dtype=torch.float).reshape(-1,1)
y_train= torch.tensor(yvals, dtype=torch.float).reshape(-1,1)
 

class OneHiddenLayerModel(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super(OneHiddenLayerModel, self).__init__()
        self.in_to_hidden = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU()
        )
        self.hidden_to_out = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        hidden = self.in_to_hidden(x)
        out = self.hidden_to_out(hidden)
        return out

nhiddens = 20
model = OneHiddenLayerModel(1,1,nhiddens)

learn_rate = 0.2
momentum = 0.0
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate, momentum=momentum)
criterion = nn.MSELoss()

epochs = 50000

def plot_hiddens():
    params = list(model.in_to_hidden.parameters())
    for i in range(nhiddens):
        w = params[0][i][0]
        b = params[1][i]
        netin = np.array((x_train * w + b).view(-1).tolist())
        y = np.maximum(netin,0.0)
        plt.plot(xvals, y, 'g')

for epoch in range(epochs):
    out = model(x_train)
    optimizer.zero_grad()
    loss = criterion(out, y_train)
    loss.backward()
    optimizer.step()
    if (epoch % 100) == 0:
        print('Epoch', epoch, ' loss', loss)
        plt.cla()
        plt.axis([-3.2, 3.2, -0.1, 1.2])
        plt.plot(xvals,yvals,'bo')
        plt.plot(xvals, out.tolist(), 'r')
        plot_hiddens()
        plt.show(block=False)
        plt.pause(0.0001)
        if loss < 0.001:
            break


