#ifndef MPC_SOLVER_H
#define MPC_SOLVER_H

#ifdef MATLAB
#include "mex.h"
#endif

#include "mpc_model.h"
#include "config_file.h"
#include <iostream>
#include "eiquadprog.hpp"
#include <Eigen/Core>
#include <Eigen/Cholesky>

using namespace Eigen;

template<class MODEL>
class MPC_Solver{

public:

	MPC_Solver();
	MPC_Solver(ConfigurationData *cd);
	MPC_Solver(char *cdfile);


	void precompute(double t0, VectorXd &x0);
	void precompute(mpcInput *in);
	void precompute();
	void update(double t0, VectorXd &x0);
	void update(mpcInput *in);
	void update();
	double solve(double t0, VectorXd &x0);
	double solve(mpcInput *in);
	double solve();

	void set_params(ConfigurationData *cd);
	void set_params(char *cdfile);
	
	void set_input(double t0, VectorXd &x0);
	void set_input(mpcInput *in);

	mpcOutput get_output();

	double total_cost();
	
	void print_trajectory(bool);
	
	VectorXd get_output_trajectory(bool);
	void print_output_trajectory(bool);

	VectorXd getU(){ return model->U;};
	VectorXd getX(){ return model->X;};
	VectorXd getY(){ return model->Y;};

	mpcModel *get_model_ptr(){return model;};
	//MPC_Solver<MODEL> operator=(MPC_Solver<MODEL>);
	
private:

	bool is_model_set;
	bool is_precomputed;
	bool is_updated;
	bool is_solved;

	mpcModel *model;
	MODEL the_real_model;

	LLT<MatrixXd,Lower> chol;
	double Htrace;
			
	double cost;
	
};

//template<class MODEL>
//MPC_Solver<MODEL> MPC_Solver<MODEL>::operator=(MPC_Solver<MODEL> s){
//	*this = s;
//	this->model = &this->the_real_model;
//	return *this;
//}


template<class MODEL>
MPC_Solver<MODEL>::MPC_Solver(){
	is_precomputed=false;
	is_updated=false;
	is_solved=false;

	model = &the_real_model;
	is_model_set=true;
}

template<class MODEL>
MPC_Solver<MODEL>::MPC_Solver(ConfigurationData *cd){
	is_precomputed=false;
	is_updated=false;
	is_solved=false;

	model = &the_real_model;
	is_model_set=true;

	set_params(cd);
}

template<class MODEL>
MPC_Solver<MODEL>::MPC_Solver(char *cdfile){
	is_precomputed=false;
	is_updated=false;
	is_solved=false;

	model = &the_real_model;
	is_model_set=true;
	set_params(cdfile);
}

//template<class MODEL>
//MPC_Solver<MODEL>::MPC_Solver(MODEL *m){
//	MPC_Solver();
//	set_model(m);
//}
//
//template<class MODEL>
//void MPC_Solver<MODEL>::set_model(MODEL *m){
//	model = m;
//	is_model_set=true;
//}

template<class MODEL>
void MPC_Solver<MODEL>::set_params(ConfigurationData *cd){
	the_real_model.set_params(cd);
	model = &the_real_model;
}

template<class MODEL>
void MPC_Solver<MODEL>::set_params(char *cdfile){
	ConfigurationData *cd = load_conf_data(cdfile);
	set_params(cd);
	free_conf_data(cd);
}
	

template<class MODEL>
void MPC_Solver<MODEL>::precompute(double t0, VectorXd &x0){
	assert(is_model_set);
	
	precompute(&mpcInput(t0,x0));

}

template<class MODEL>
void MPC_Solver<MODEL>::precompute(mpcInput *in){
	assert(is_model_set);
		
	model->set_mpc_input(in);

	precompute();

}

template<class MODEL>
void MPC_Solver<MODEL>::precompute(){
	assert(is_model_set);
		
	model->precompute_dynamics();
	model->precompute_cost();
	model->precompute_constraints();

	if(model->is_constant_hessian){
		printf("MPC Precomputing Hessian Decomposition\n");
		chol.compute(model->PreLuu);
		Htrace = model->PreLuu.trace();
	}	

	is_updated = false;
	is_precomputed=true;
	is_solved = false;
	std::cout << "MPC Precomputed" << std::endl;
}


template<class MODEL>
void MPC_Solver<MODEL>::update(double t0, VectorXd &x0){
	if(!is_precomputed)	precompute(t0,x0);
		
	model->update_cost(t0,x0);
	model->update_constraints(t0,x0);

	is_updated = true;
	is_solved = false;

	std::cout << "MPC Updated t0 = " << t0 << " x0 = " << x0.transpose() << std::endl;
}

template<class MODEL>
void MPC_Solver<MODEL>::update(mpcInput *in){

	if(!is_precomputed)	precompute(in);
		
	model->set_mpc_input(in);

	update();
}

template<class MODEL>
void MPC_Solver<MODEL>::update(){

	if(!is_precomputed)	precompute();
		
	model->update_cost();
	model->update_constraints();
	
	is_updated = true;
	is_solved = false;

	std::cout << "MPC Updated" << std::endl;// t0 = " << t0 << " x0 = " << x0.transpose() << std::endl;
}

template<class MODEL>
double MPC_Solver<MODEL>::solve(double t0, VectorXd &x0){
	assert(is_model_set);
	double fval;

	if(!is_updated || (t0!=model->start_time()) || (x0!=model->start_state()))	update(t0,x0);

	if(model->is_constant_hessian){
		fval = solve_quadprog2(chol, Htrace, model->f, model->CE, model->ce0, model->CI, model->ci0, model->U);
	}else{
		fval = solve_quadprog(model->H, model->f, model->CE, model->ce0, model->CI, model->ci0, model->U);
	}
	
	//std::cout << "MPC Sovled - Value = " << fval << std::endl;
	
	is_solved = true;

	return fval;
}

template<class MODEL>
double MPC_Solver<MODEL>::solve(mpcInput *in){
	if(!is_model_set){
		printf("Error: MPC Model Not Set\n");
		return 0;
	}

	model->set_mpc_input(in);

	return solve();
}


template<class MODEL>
double MPC_Solver<MODEL>::solve(){
	if(!is_model_set){
		printf("Error: MPC Model Not Set\n");
		return 0;
	}

	//if(!is_updated)	update();
	update();

	if(model->is_unconstrained){
		if(model->is_constant_hessian){
			model->U = -chol.solve(model->f);
		}else{
			model->U = -model->H.llt().solve(model->f);
		}
	}else {
		if(model->is_constant_hessian){
			cost = solve_quadprog2(chol, Htrace, model->f, model->CE, model->ce0, model->CI, model->ci0, model->U);
		}else{
			cost = solve_quadprog(model->H, model->f, model->CE, model->ce0, model->CI, model->ci0, model->U);
		}
	}
	
	std::cout << "MPC Sovled - Value = " << cost << std::endl;
	
	is_solved = true;

	return cost;
}

template<class MODEL>
double MPC_Solver<MODEL>::total_cost(){
	assert(is_solved);
	return model->total_cost(model->start_time(),model->start_state(),model->U);
}

template<class MODEL>
void MPC_Solver<MODEL>::print_trajectory(bool run_sim){
	int i,j;
	assert(is_model_set);

	if(run_sim)		model->X = model->simulate(model->start_time(),model->start_state(),model->U);
	//if(run_sim)		model->X = model->simulate();
	
	printf("Continuous:\n");
	for(i=0;i<model->N_TIMESTEPS;i++){
		printf("%4.2f ",model->T(i));
		for(j=0;j<model->num_states;j++){
			printf("% 4.2f ",model->X(i*model->num_states+j));
		}
		for(j=0;j<model->num_actions_continuous;j++){
			printf("% 4.2f ",model->U(i*model->num_actions_continuous+j));
		}
		printf("\n");
	}
	if(model->num_actions_discrete>0){
		printf("Discrete:\n");
		for(i=0;i<model->num_actions_discrete;i++)	printf("% f ",model->U(model->N_TIMESTEPS*model->num_actions_continuous+i));
		printf("\n");
	}
}

template<class MODEL>
VectorXd MPC_Solver<MODEL>::get_output_trajectory(bool run_sim){	
	int i,j;
	
	assert(is_model_set);

	if(run_sim)		model->simulate(model->start_time(),model->start_state(),model->U);
	
	return model->Y;
}


template<class MODEL>
void MPC_Solver<MODEL>::print_output_trajectory(bool run_sim){
	int i,j;
	
	assert(is_model_set);

	if(run_sim)		model->simulate(model->start_time(),model->start_state(),model->U);
	//if(run_sim)		model->Y = model->output_trajectory();
	
	printf("Output:\n");
	for(i=0;i<model->N_TIMESTEPS;i++){
		printf("%4.2f ",model->T(i));
		for(j=0;j<model->num_outputs;j++){
			printf("% 4.2f ",model->Y(i*model->num_outputs+j));
		}
		printf("\n");
	}
}

template<class MODEL>
void MPC_Solver<MODEL>::set_input(double t0, VectorXd &x0){
	assert(is_model_set);
	model->set_mpc_input(t0,x0);	
	is_updated=false;
}

template<class MODEL>
void MPC_Solver<MODEL>::set_input(mpcInput *in){
	assert(is_model_set);
	model->set_mpc_input(in);
	is_updated=false;
}

template<class MODEL>
mpcOutput MPC_Solver<MODEL>::get_output(){
	assert(is_model_set);
	return model->get_output();
}

#endif