import numpy as np
from scipy.optimize import leastsq
from Visualizer import TwoLinkArm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt


l1 = .1
l2 = .11	


def residual(thetas,target, func):
	x_guess = func(thetas)
	delta = x_guess-target
	res = np.abs(delta)
	return(res)



def Forward_Kinematics(thetas):
	# TODO fill out the arm forward kinematics
	# Input - thetas (numpy array shape (2)): joint angles of the arm
	# Output - point (numpy array shape (3)): xytheta point of the end of the arm
	x = np.zeros(2)


	return(point)

def compute_inverse_kinematics(target, forward_kin = Forward_Kinematics, guess=np.array([0,0])):
	thetas = guess
	obj = lambda thetas: residual(thetas,target, forward_kin)
	thetas = leastsq(obj, thetas)[0]
	thetas[0] = (thetas[0]+np.pi)%(np.pi*2)-np.pi
	thetas[1] = (thetas[1]+np.pi)%(np.pi*2)-np.pi
	return(thetas)

def generate_data(N):
	# TODO: Generate N samples of corresponding data 
	# Input - N (int): number of samples to generate
	# Output - theta_data (numpy array shape (N,2)): array of sampled joint angles where each row is
	# a joint angle sample
	# Output - point_data (numpy array shape (N,3)): array of xytheta points corresponding to the
	# sampled joint angles. Each row is a point
	# Hint: the function np.random.uniform might be helpfull
	theta_data = np.zeros((N,2))
	point_data = np.zeros((N,3))


	return(theta_data, point_data)

def K_Nearest_Neighbors(thetas, points, query_point, K=1):
	#TODO: Implement K-Nearest Neighbors
	# Input - theta_data (numpy array shape (N,2)): array of joint angles where each row is a 
	# joint angle sample.
	# Input - point_data (numpy array shape (N,3)): array of xytheta points corresponding to the
	# sampled joint angles. Each row is a point.
	# Input - query_point (numpy array shape (3)): query xytheta point
	# Input - K (int): number of newarest points to consider, to start, you can ignore this and
	# just use the closest point.
	# Output - result (numpy array shape (2)): the joint angles associated with the xytheta point
	# in the dataset closest to the query point.
	result = np.zeros(2)


	return(result)


class ArmNet(nn.Module):
	def __init__(self):
		super(ArmNet, self).__init__()
		# Example net with only two layers, when you add more layers you have to initialize them here.
		self.fc1 = nn.Linear(2,8)
		self.fc2 = nn.Linear(8,2)
		# TODO: experiment with more layers
		


	def forward(self,x):
		# Forward pass of the network, when you add more layers, you have to include them in the
		# forward pass too.
		x = F.relu(self.fc1(x))
		x = self.fc2(x)
		# TODO: Experiment with other activation functions or layers
		

		return(x)


class ArmData(Dataset):
	def __init__(self, thetas, points):
		self.thetas = thetas
		self.points = points
	def __len__(self):
		return self.thetas.shape[0]

	def __getitem__(self, idx):
		if torch.is_tensor(idx):
			idx = idx.tolist()

		sample = (self.points[idx], self.thetas[idx])

		return(sample)

def load_model(path = "armnet.pth"):
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	model = ArmNet()
	model.load_state_dict(torch.load(path))
	model.eval()

	model = model.to(device)
	return(model)

def train_network(thetas, points, batch_size=50, lr = 2e-3, epochs = 20, load = False, verbose=True):

	train_data = ArmData(thetas, points)
	train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	print(device)
	model = ArmNet()
	if load:
		model.load_state_dict(torch.load("armnet.pth"))
		model.eval()

	model = model.to(device)
	optimizer = optim.Adam(model.parameters(), lr=lr)
	loss_func = nn.MSELoss()

	test_thetas, test_points = generate_data(2000)

	print_freq = 10
	avg_losses = []
	test_errors = []
	train_errors = []
	for epoch in range(epochs):  # Loop over the dataset multiple times.


		pred_thetas = eval_model(model, points)
		diff = pred_thetas - thetas
		train_errors.append(np.mean(np.sqrt(np.sum(diff**2,axis=1))))

		running_loss = 0.0       # Initialize running loss.
		val_loss = 0.0
		for i, data in enumerate(train_loader, 0):

			# Get the inputs.

			inputs, labels = data

			inputs = inputs.float()
			labels = labels.float()
			
			# Move the inputs to the specified device.
			inputs, labels = inputs.to(device), labels.to(device)
			
			# Zero the parameter gradients.
			optimizer.zero_grad()

			# Forward step.
			outputs = model(inputs)
			loss = loss_func(outputs, labels)
			
			# Backward step.
			loss.backward()
			
			# Optimization step (update the parameters).
			optimizer.step()

			# Print statistics.
			running_loss += loss.item()
			if verbose and i % print_freq == print_freq - 1: # Print every several mini-batches.
				avg_loss = running_loss / print_freq
				print('[epoch: {}, i: {:5d}] avg mini-batch loss: {:.5f}'.format(
					epoch, i, avg_loss))
				avg_losses.append(avg_loss)
				running_loss = 0.0


		
		pred_thetas = eval_model(model, test_points)
		diff = pred_thetas - test_thetas
		test_errors.append(np.mean(np.sqrt(np.sum(diff**2,axis=1))))
		



	print('Finished Training.')
	torch.save(model.state_dict(), "armnet.pth")
	plt.plot(train_errors,c='b',label="Train Error")
	plt.plot(test_errors,c='r',label="Test Error")
	plt.xlabel("Training Epochs")
	plt.ylabel("Error")
	plt.title("Training and Test Error vs Epochs")
	plt.legend()
	plt.show()

	return(model)

def eval_model(model, query_point):
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	inputs = torch.from_numpy(query_point)
	inputs = inputs.float()
	inputs = inputs.to(device)

	x = model(inputs).detach().numpy()
	return(x)


def main():
	#PART 1: Forward Kinematics

	#Generate a random target
	target = np.array([0.07071068, 0.18071068,np.pi/2])
	print("Desired Target: ", target)

	#Compute Inverse Kinematics
	thetas = compute_inverse_kinematics(target)
	print("Calculated Thetas: ", thetas)

	#Visualize Inverse Kinematics solution
	vis = TwoLinkArm(l1,l2)
	vis.make_plot()
	vis.render_pos(thetas, target)

	# exit() #Comment this out when ready to move on

	#PART 2: K-Nearest Neighbors

	#Generate data
	theta_data, point_data = generate_data(100)

	#Evaluate Inverse Kinematics with K_Nearest_Neighbors
	KNN_theta = K_Nearest_Neighbors(theta_data, point_data, target)
	print("KNN Thetas: ", KNN_theta)
	print("KNN Error: ", np.linalg.norm(Forward_Kinematics(KNN_theta) - Forward_Kinematics(thetas)))
	
	#Visualize K-Nearest Neighbors Inverse Kinematics solution
	vis.make_plot()
	vis.render_pos(KNN_theta, target)

	# exit() #Comment this out when ready to move on

	#PART 3: Neural Networks

	#Train model with data
	model = train_network(theta_data, point_data, batch_size=50, lr = 2e-3, epochs = 20, verbose=True)
	#Load previously trained model
	model = load_model()
	#Once you are happy with your model and just want to evaluate it, you can comment out the training
	# line and just load the previously trained model

	#Evaluate Inverse Kinematics with trained network
	Neural_Net_theta = eval_model(model,target)
	print("Neural Network thetas: ", Neural_Net_theta)
	print("Neural Network Error: ", np.linalg.norm(Forward_Kinematics(Neural_Net_theta) - Forward_Kinematics(thetas)))

	#Visualize trained network  Inverse Kinematics Solution
	vis.make_plot()
	vis.render_pos(Neural_Net_theta, target)







if __name__ == '__main__':
	main()
	


