import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

N = 50
xmax = 1.7
xvals = np.random.random(N) * xmax
noise = np.random.randn(N)
yvals = 10 * np.sin(xvals) + noise + 20

plt.plot(xvals,yvals,'bo')
plt.show(block=False)

x_train = torch.tensor(xvals, dtype=torch.float).reshape(-1,1)
y_train= torch.tensor(yvals, dtype=torch.float).reshape(-1,1)


class LinearModel(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        out = self.linear(x)
        return out

model = LinearModel(1,1)

learn_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
criterion = nn.MSELoss()

epochs = 12

for epoch in range(epochs):
    out = model(x_train)
    optimizer.zero_grad()
    loss = criterion(out, y_train)
    loss.backward()
    optimizer.step()
    print('Epoch', epoch, ' loss', loss)
    params = list(model.parameters())
    m = params[0].tolist()[0][0]
    b = params[1].tolist()[0]
    if epoch < epochs-1:
        color = 'r--'
        linewidth = 0.5
    else:
        color = 'g'
        linewidth = 4.0
    plt.plot([0,xmax],[b,m*xmax+b],color, linewidth=linewidth)
    plt.show(block=False)
    if epoch < epochs-1:
        input('Hit <Enter> to continue...')
