/****************************************************************************
This program develops a controller for a one step optimization problem.
min L(x,u) with one step dynamics y = F(x,u)
approx: 
L(x,u) = L0 + Lx*dx + Lu*du + 0.5*dx'*Lxx*dx + du'*Lxu*dx + 0.5*du'*Luu*du 
F(x,u) = F0 + Fx*dx + Fu*du + 0.5*dx'*Fxx*dx + du'*Fxu*dx + 0.5*du'*Fuu*du 

/****************************************************************************
Fixes:
Z -> Q
A,B -> Fx, Fu.
n_states, n_controls -> n_x, n_u, n_y
symmetrize Fxx, Fuu, ...

/****************************************************************************/
/* INCLUDES */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "../useful/useful.h"
#include "../cholesky/ldlt.h"
#include "../dm/dm.h"
#include "ddp-one-step.h"

/****************************************************************************/
/* DEFINES */

/****************************************************************************/
/****************************************************************************/

DDP1 *ddp1_initialize( int nx, /* Dimensionality of input vector */
		       int nu, /* Dimensionality of controls */
		       int ny, /* Dimensionality of output vector */
		       int (*one_step_dynamics)(),
		       double (*cost_function)(),
		       void *argument /* argument to dynamics and cost functions */
		       )
{
  int i;
  DDP1 *r;

  r = (DDP1 *) malloc( sizeof( DDP1 ) );
  if ( r == NULL )
    {
      fprintf( stderr, "allocation failure 1 in DDP1()");
      exit( -1 );
    }

  r->n_x = nx;
  r->n_u = nu;
  r->n_y = ny;

  r->one_step_dynamics_function = one_step_dynamics;
  r->cost_function = cost_function;
  r->argument = argument;

  printf( "Initializing DDP1: %d %d %d 0x%x 0x%x 0x%x\n",
	  r->n_x, r->n_u, r->n_y,
	  r->one_step_dynamics_function, 
	  r->cost_function,
	  r->argument ); 

  r->terminal_cost_function = NULL;
  r->compute_linear_model_function = NULL;
  r->analytic_cost_derivatives = NULL;

  r->x = dv( r->n_x );
  dv_zero( r->x, r->n_x );
  r->u = dv( r->n_u );
  dv_zero( r->u, r->n_u );
  r->y = dv( r->n_y );
  dv_zero( r->y, r->n_y );

  r->delta_u = dv( r->n_u );
  dv_zero( r->delta_u, r->n_u );

  r->use_u_limits = FALSE;
  r->u_min = dv( r->n_u );
  r->u_max = dv( r->n_u );

  r->Fx = dm( r->n_y, r->n_x );
  dm_zero( r->Fx, r->n_y, r->n_x );
  r->Fu = dm( r->n_y, r->n_u );
  dm_zero( r->Fu, r->n_y, r->n_u );

  r->Fxx = d3( r->n_y, r->n_x, r->n_x );
  d3_zero( r->Fxx, r->n_y, r->n_x, r->n_x );
  r->Fux = d3( r->n_y, r->n_u, r->n_x );
  d3_zero( r->Fux, r->n_y, r->n_u, r->n_x );
  r->Fxu = d3( r->n_y, r->n_x, r->n_u );
  d3_zero( r->Fxu, r->n_y, r->n_x, r->n_u );
  r->Fuu = d3( r->n_y, r->n_u, r->n_u );
  d3_zero( r->Fuu, r->n_y, r->n_u, r->n_u );
  
  r->Lx = dv( r->n_x );
  dv_zero( r->Lx, r->n_x );
  r->Lu = dv( r->n_u );
  dv_zero( r->Lu, r->n_u );

  r->Lxx = dm( r->n_x, r->n_x );
  dm_zero( r->Lxx, r->n_x, r->n_x );
  r->Lxu = dm( r->n_x, r->n_u );
  dm_zero( r->Lxu, r->n_x, r->n_u );
  r->Lux = dm( r->n_u, r->n_x );
  dm_zero( r->Lux, r->n_u, r->n_x );
  r->Luu = dm( r->n_u, r->n_u );
  dm_zero( r->Luu, r->n_u, r->n_u );
 
  r->Ty = dv( r->n_y );
  dv_zero( r->Ty, r->n_y );
  r->Tyy = dm( r->n_y, r->n_y );
  dm_zero( r->Tyy, r->n_y, r->n_y );

  r->K = dm( r->n_u, r->n_x );
  dm_zero( r->K, r->n_u, r->n_x );

  r->Vx = dv( r->n_x );
  dv_zero( r->Vx, r->n_x );
  r->Vxx = dm( r->n_x, r->n_x );
  dm_zero( r->Vxx, r->n_x, r->n_x );

  r->Vy = dv( r->n_y );
  dv_zero( r->Vy, r->n_y );
  r->Vyy = dm( r->n_y, r->n_y );
  dm_zero( r->Vyy, r->n_y, r->n_y );
  r->Vy_terminal = dv( r->n_y );
  dv_zero( r->Vy_terminal, r->n_y );
  r->Vyy_terminal = dm( r->n_y, r->n_y );
  dm_zero( r->Vyy_terminal, r->n_y, r->n_y );
  r->y_terminal = dv( r->n_y );
  dv_zero( r->y_terminal, r->n_y );

  r->Qx = dv( r->n_x );
  dv_zero( r->Qx, r->n_x );
  r->Qu = dv( r->n_u );
  dv_zero( r->Qu, r->n_u );
  r->Qxx = dm( r->n_x, r->n_x );
  dm_zero( r->Qxx, r->n_x, r->n_x );
  r->Qxu = dm( r->n_x, r->n_u );
  dm_zero( r->Qxu, r->n_x, r->n_u );
  r->Qux = dm( r->n_u, r->n_x );
  dm_zero( r->Qux, r->n_u, r->n_x );
  r->Quu = dm( r->n_u, r->n_u );
  dm_zero( r->Quu, r->n_u, r->n_u );
  r->Quu_L = dm( r->n_u, r->n_u );
  dm_zero( r->Quu_L, r->n_u, r->n_u );
  r->Quu_D = dv( r->n_u );
  dv_zero( r->Quu_D, r->n_u );

  r->tempx1 = dv( r->n_x );
  dv_zero( r->tempx1, r->n_x );
  r->tempx2 = dv( r->n_x );
  dv_zero( r->tempx2, r->n_x );
  r->tempx3 = dv( r->n_x );
  dv_zero( r->tempx3, r->n_x );
  r->tempu1 = dv( r->n_u );
  dv_zero( r->tempu1, r->n_u );
  r->tempu2 = dv( r->n_u );
  dv_zero( r->tempu2, r->n_u );
  r->tempu3 = dv( r->n_u );
  dv_zero( r->tempu3, r->n_u );
  r->tempy1 = dv( r->n_y );
  dv_zero( r->tempy1, r->n_y );
  r->tempy2 = dv( r->n_y );
  dv_zero( r->tempy2, r->n_y );

  r->tempxx1 = dm( r->n_x, r->n_x );
  dm_zero( r->tempxx1, r->n_x, r->n_x );
  r->tempxx2 = dm( r->n_x, r->n_x );
  dm_zero( r->tempxx2, r->n_x, r->n_x );
  r->tempxx3 = dm( r->n_x, r->n_x );
  dm_zero( r->tempxx3, r->n_x, r->n_x );
  r->tempxu1 = dm( r->n_x, r->n_u );
  dm_zero( r->tempxu1, r->n_x, r->n_u );
  r->tempxu2 = dm( r->n_x, r->n_u );
  dm_zero( r->tempxu2, r->n_x, r->n_u );
  r->tempuu1 = dm( r->n_u, r->n_u );
  dm_zero( r->tempuu1, r->n_u, r->n_u );
  r->tempux1 = dm( r->n_u, r->n_x );
  dm_zero( r->tempux1, r->n_u, r->n_x );
  r->tempux2 = dm( r->n_u, r->n_x );
  dm_zero( r->tempux2, r->n_u, r->n_x );
  r->tempyx1 = dm( r->n_y, r->n_x );
  dm_zero( r->tempyx1, r->n_y, r->n_x );
  r->tempyx2 = dm( r->n_y, r->n_x );
  dm_zero( r->tempyx2, r->n_y, r->n_x );
  r->tempxy1 = dm( r->n_x, r->n_y );
  dm_zero( r->tempxy1, r->n_x, r->n_y );
  r->tempyu1 = dm( r->n_y, r->n_u );
  dm_zero( r->tempyu1, r->n_y, r->n_u );
  r->tempyu2 = dm( r->n_y, r->n_u );
  dm_zero( r->tempyu2, r->n_y, r->n_u );
  r->tempuy1 = dm( r->n_u, r->n_y );
  dm_zero( r->tempuy1, r->n_u, r->n_y );
  r->tempyy1 = dm( r->n_y, r->n_y );
  dm_zero( r->tempyy1, r->n_y, r->n_y );

  r->epsilon_gradient = 1e-3;
  r->epsilon_second_order = 1.0;
  r->delta = 1e-4;

  // printf( "Initialization done.\n" );

  return r;
}

/****************************************************************************/

int ddp1_compute_dynamics_first_derivatives( DDP1 *ddp1,
					double *x, 
					double *u,
					/* return values */
					double **Fx, 
					double **Fu )
{
  int i, j;
  double save;

  if ( ddp1->compute_linear_model_function != NULL )
    {
      return ddp1->compute_linear_model_function( ddp1->argument,
						  x, u, Fx, Fu );
    }

  /* Do it numerically */

  for( i = 0; i < ddp1->n_x; i++ )
    {
      save = x[i];
      x[i] += ddp1->delta;
      ddp1->one_step_dynamics_function( ddp1->argument, x, u, ddp1->tempy2 );
      x[i] = save - ddp1->delta;
      ddp1->one_step_dynamics_function( ddp1->argument, x, u, ddp1->tempy1 );
      x[i] = save;
      for( j = 0; j < ddp1->n_y; j++ )
	Fx[j][i] = (ddp1->tempy2[j] - ddp1->tempy1[j])/(2*ddp1->delta);
    }

  for( i = 0; i < ddp1->n_u; i++ )
    {
      save = u[i];
      u[i] += ddp1->delta;
      ddp1->one_step_dynamics_function( ddp1->argument, x, u, ddp1->tempy2 );
      u[i] = save - ddp1->delta;
      ddp1->one_step_dynamics_function( ddp1->argument, x, u, ddp1->tempy1 );
      u[i] = save;
      for( j = 0; j < ddp1->n_y; j++ )
	Fu[j][i] = (ddp1->tempy2[j] - ddp1->tempy1[j])/(2*ddp1->delta);
    }
  return 0;
}

/****************************************************************************/

int ddp1_compute_dynamics_second_derivatives( DDP1 *z, double *x, double *u,
					/* return values */
					 double ***Fxx, double ***Fxu,
					 double ***Fux, double ***Fuu )
{
  int i, j, k;
  double save;

  for( i = 0; i < z->n_x; i++ )
    {
      save = x[i];
      x[i] += z->delta;
      ddp1_compute_dynamics_first_derivatives( z, x, u, z->tempyx1, z->tempyu1 );
      x[i] = save - z->delta;
      ddp1_compute_dynamics_first_derivatives( z, x, u, z->tempyx2, z->tempyu2 );
      x[i] = save;
      for( j = 0; j < z->n_y; j++ )
	{
	  for( k = 0; k < z->n_x; k++ )
	    Fxx[j][k][i] = (z->tempyx1[j][k] - z->tempyx2[j][k])
	      /(2*z->delta);
	  for( k = 0; k < z->n_u; k++ )
	    Fux[j][k][i] = (z->tempyu1[j][k] - z->tempyu2[j][k])
	      /(2*z->delta);
	}
    }

  for( i = 0; i < z->n_u; i++ )
    {
      save = u[i];
      u[i] += z->delta;
      ddp1_compute_dynamics_first_derivatives( z, x, u, z->tempyx1, z->tempyu1 );
      u[i] = save - z->delta;
      ddp1_compute_dynamics_first_derivatives( z, x, u, z->tempyx2, z->tempyu2 );
      u[i] = save;
      for( j = 0; j < z->n_y; j++ )
	{
	  for( k = 0; k < z->n_x; k++ )
	    Fxu[j][k][i] = (z->tempyx1[j][k] - z->tempyx2[j][k])/
	      (2*z->delta);
	  for( k = 0; k < z->n_u; k++ )
	    Fuu[j][k][i] = (z->tempyu1[j][k] - z->tempyu2[j][k])/
	      (2*z->delta);
	}
    }

  /* Let's make sure Fxx[i] is symmetric. */
  for ( i = 0; i < z->n_y; i++ )
    {
      dm_transpose( Fxx[i], z->tempxx2, z->n_x, z->n_x );
      dm_acc( Fxx[i], z->tempxx2, z->n_x, z->n_x );
      dm_scale( z->tempxx2, 0.5, Fxx[i], z->n_x, z->n_x );
    }

  /* Let's make sure Fuu[i] is symmetric. */
  for ( i = 0; i < z->n_y; i++ )
    {
      dm_transpose( Fuu[i], z->tempuu1, z->n_u, z->n_u );
      dm_acc( Fuu[i], z->tempuu1, z->n_u, z->n_u );
      dm_scale( z->tempuu1, 0.5, Fuu[i], z->n_u, z->n_u );
    }

  /* Let's make sure Fxu[i] = Fux'[i]. */
#ifdef COMMENT
  for( i = 0; i < z->n_y; i++ )
    {
      for( j = 0; j < z->n_x; j++ )
	{
	  for( k = 0; k < z->n_u; k++ )
	    {
	      save = (Fxu[i][j][k] + Fux[i][k][j])/2;
	      Fxu[i][j][k] = Fux[i][k][j] = save;
	    }
	}
    }
#endif

  return 0;
}

/****************************************************************************/

int ddp1_compute_cost_first_derivatives( DDP1 *ddp1,
				    double *x, 
				    double *u,
				    /* return values */
				    double *Lx, 
				    double *Lu )
{
  int i, j;
  double c1, c2;
  double save;

  if ( ddp1->analytic_cost_derivatives != NULL )
    {
      return ddp1->analytic_cost_derivatives( ddp1->argument,
					      x, u, Lx, Lu );
    }

  /* Do it numerically */

  for( i = 0; i < ddp1->n_x; i++ )
    {
      save = x[i];
      x[i] += ddp1->delta;
      c1 = ddp1->cost_function( ddp1->argument, x, u );
      // printf( "1: %g %g %g -> %g\n", x[0], x[1], u[0], c1 );
      x[i] = save - ddp1->delta;
      c2 = ddp1->cost_function( ddp1->argument, x, u );
      // printf( "1: %g %g %g -> %g\n", x[0], x[1], u[0], c2 );
      x[i] = save;
      Lx[i] = (c1 - c2)/(2*ddp1->delta);
    }

  for( i = 0; i < ddp1->n_u; i++ )
    {
      save = u[i];
      u[i] += ddp1->delta;
      c1 = ddp1->cost_function( ddp1->argument, x, u );
      u[i] = save - ddp1->delta;
      c2 = ddp1->cost_function( ddp1->argument, x, u );
      u[i] = save;
      Lu[i] = (c1 - c2)/(2*ddp1->delta);
      // printf( "cost-u %d %g %g %g %g %g\n", i, u[i], c1, c2, c1 - c2, Lu[i] );
    }
  return 0;
}

/****************************************************************************/

int ddp1_compute_cost_second_derivatives( DDP1 *z,
				     double *x, 
				     double *u,
				     /* return values */
				     double **Lxx, 
				     double **Lxu, 
				     double **Lux, 
				     double **Luu )
{
  int i, j;
  double save;

  /* Do it numerically */

  for( i = 0; i < z->n_x; i++ )
    {
      save = x[i];
      x[i] += z->delta;
      ddp1_compute_cost_first_derivatives( z, x, u, z->tempx1, z->tempu1 );
      // printf( "2: %g %g %g -> %g %g %g\n", x[0], x[1], u[0],
      //      z->tempx1[0], z->tempx1[1], z->tempu1[0] );
      x[i] = save - z->delta;
      ddp1_compute_cost_first_derivatives( z, x, u, z->tempx2, z->tempu2 );
      // printf( "2: %g %g %g -> %g %g %g\n", x[0], x[1], u[0],
      //      z->tempx2[0], z->tempx2[1], z->tempu2[0] );
      x[i] = save;
      for( j = 0; j < z->n_x; j++ )
	Lxx[j][i] = (z->tempx1[j] - z->tempx2[j])/(2*z->delta);
      for( j = 0; j < z->n_u; j++ )
	Lux[j][i] = (z->tempu1[j] - z->tempu2[j])/(2*z->delta);
    }

  /* Let's make sure Lxx is symmetric. */
  dm_transpose( Lxx, z->tempxx2, z->n_x, z->n_x );
  dm_acc( Lxx, z->tempxx2, z->n_x, z->n_x );
  dm_scale( z->tempxx2, 0.5, Lxx, z->n_x, z->n_x );

  for( i = 0; i < z->n_u; i++ )
    {
      save = u[i];
      u[i] += z->delta;
      ddp1_compute_cost_first_derivatives( z, x, u, z->tempx1, z->tempu1 );
      u[i] = save - z->delta;
      ddp1_compute_cost_first_derivatives( z, x, u, z->tempx2, z->tempu2 );
      u[i] = save;
      for( j = 0; j < z->n_x; j++ )
	Lxu[j][i] = (z->tempx1[j] - z->tempx2[j])/(2*z->delta);
      for( j = 0; j < z->n_u; j++ )
	Luu[j][i] = (z->tempu1[j] - z->tempu2[j])/(2*z->delta);
    }

  /* Let's make sure Luu is symmetric. */
  dm_transpose( Luu, z->tempuu1, z->n_u, z->n_u );
  dm_acc( Luu, z->tempuu1, z->n_u, z->n_u );
  dm_scale( z->tempuu1, 0.5, Luu, z->n_u, z->n_u );

  /* Let's make sure Lxu = Lux'. */
  for( i = 0; i < z->n_x; i++ )
    {
      for( j = 0; j < z->n_u; j++ )
	{
	  save = (Lxu[i][j] + Lux[j][i])/2;
	  Lxu[i][j] = Lux[j][i] = save;
	}
    }

  return 0;
}

/****************************************************************************/
// Vy_terminal, Vyy_terminal, y_terminal, and y need to be set.

int ddp1_second_order_default_initialization( DDP1 *z )
{
  /* This is from page 71 of Dyer and McReynolds */
  // V(y) = VY0 + Vy*(y-y_terminal) + 0.5*(y-y_terminal)'*Vyy*(y-y_terminal)
  // Vy(y) = Vy_terminal + Vyy_terminal * (y - y_terminal)
  dv_subtract( z->y, z->y_terminal, z->tempy1, z->n_y ); 
  dmv_mult( z->Vyy_terminal, z->tempy1, z->Vy, z->n_y, z->n_y ); 
  dv_acc( z->Vy_terminal, z->Vy, z->n_y );

  // Vyy(y) = Vyy_terminal
  dm_copy( z->Vyy_terminal, z->Vyy, z->n_y, z->n_y );

  return 0;
}

/****************************************************************************/
/*
This is derived from page 71 of Dyer and McReynolds
Qx = Vy * Fx + Lx
Qu = Vy * Fu + Lu
This is from page 69 of Dyer and McReynolds
Qxx = Fx' * Vyy * Fx + Vy * Fxx + Lxx
Qux = Fu' * Vyy * Fx + Vy * Fux + Lux
Qxu = Fx' * Vyy * Fu + Vy * Fxu + Lxu
Quu = Fu' * Vyy * Fu + Vy * Fuu + Luu
*/

int ddp1_second_order_policy( DDP1 *z )
{
  int i, j, k, state, ok;

  ddp1_compute_dynamics_first_derivatives( z, z->x, z->u, z->Fx, z->Fu );
  ddp1_compute_dynamics_second_derivatives( z, z->x, z->u,
					    z->Fxx, z->Fxu, z->Fux, z->Fuu );
  ddp1_compute_cost_first_derivatives( z, z->x, z->u, z->Lx, z->Lu );
  ddp1_compute_cost_second_derivatives( z, z->x, z->u, 
					z->Lxx, z->Lxu, z->Lux, z->Luu );

  // Qx = Vy * Fx + Lx
  dvm_mult( z->Vy, z->Fx, z->tempx1, z->n_y, z->n_x );
  dv_add( z->tempx1, z->Lx, z->Qx, z->n_x );

  // Qu = Vy * Fu + Lu
  dvm_mult( z->Vy, z->Fu, z->tempu1, z->n_y, z->n_u );
  dv_add( z->tempu1, z->Lu, z->Qu, z->n_u );

  // Qxx = Fx' * Vyy * Fx + Vy * Fxx + Lxx
  // Fx' * Vyy * Fx term
  dm_mult( z->Vyy, z->Fx, z->tempyx1, z->n_y, z->n_y, z->n_x );
  dm_transpose( z->Fx, z->tempxy1, z->n_y, z->n_x );
  dm_mult( z->tempxy1, z->tempyx1, z->Qxx, z->n_x, z->n_y, z->n_x );
  // Vy*Fxx term
  for( state = 0; state < z->n_y; state++ )
    {
      dm_scale( z->Fxx[state], z->Vy[state], z->tempxx1, z->n_x, z->n_x );
      dm_acc( z->tempxx1, z->Qxx, z->n_x, z->n_x );
    }
  // Lxx term
  dm_acc( z->Lxx, z->Qxx, z->n_x, z->n_x );

  // Let's make sure Qxx is symmetric.
  dm_transpose( z->Qxx, z->tempxx2, z->n_x, z->n_x );
  dm_acc( z->Qxx, z->tempxx2, z->n_x, z->n_x );
  dm_scale( z->tempxx2, 0.5, z->Qxx, z->n_x, z->n_x );

  // Qux = Fu' * Vyy * Fx + Vy * Fux + Lux
  // Fu' * Vyy * Fx term: tempyx1 is still Vyy*Fx
  dm_transpose( z->Fu, z->tempuy1, z->n_y, z->n_u );
  dm_mult( z->tempuy1, z->tempyx1, z->Qux, z->n_u, z->n_y, z->n_x );

  // Vy*Fux term
  for( state = 0; state < z->n_y; state++ )
    {
      dm_scale( z->Fux[state], z->Vy[state], z->tempux1, z->n_u, z->n_x );
      dm_acc( z->tempux1, z->Qux, z->n_u, z->n_x );
    }
  // Lux term
  dm_acc( z->Lux, z->Qux, z->n_u, z->n_x );
  dm_transpose( z->Qux, z->Qxu, z->n_u, z->n_x );

#ifdef COMMENT
  /*
    Is Zux really the transpose of Zxu?
    compute it and check
  */
  /* Zxu = Fx' * Vxx * Fu + Vx * Fxu + Lxu */
  dm_mult( traj->next->Vxx, Fu, temp_xu1,
	   traj->next->n_x, traj->next->n_x, traj->n_u );
  dm_transpose( Fx, temp_xx1, traj->next->n_x, traj->n_x );
  dm_mult( temp_xx1, temp_xu1, Zxu, traj->n_x, traj->next->n_x, traj->n_u );
  /* Vx * Fxu term */
  for( state = 0; state < traj->next->n_x; state++ )
    {
      dm_scale( Fxu[state], traj->next->Vx[state], temp_xu1, traj->n_x, traj->n_u );
      dm_acc( temp_xu1, Zxu, traj->n_x, traj->n_u );
    }
  dm_scale( z->Zxu, z->discount, z->Zxu, z->n_states, z->n_u );
  dm_acc( Lxu, Zxu, traj->n_x, traj->n_u );
#endif      

  // Quu = Fu' * Vyy * Fu + Vy * Fuu + Luu
  dm_mult( z->Vyy, z->Fu, z->tempyu1, z->n_y, z->n_y, z->n_u );
  // tempuy1 is still Fu'
  dm_mult( z->tempuy1, z->tempyu1, z->Quu, z->n_u, z->n_y, z->n_u );
  // Vy*Fuu term
  for( state = 0; state < z->n_y; state++ )
    {
      dm_scale( z->Fuu[state], z->Vy[state], z->tempuu1, z->n_u, z->n_u );
      dm_acc( z->tempuu1, z->Quu, z->n_u, z->n_u );
    }
  // Luu term
  dm_acc( z->Luu, z->Quu, z->n_u, z->n_u );

  /* Let's make sure Quu is symmetric. */
  dm_transpose( z->Quu, z->tempuu1, z->n_u, z->n_u );
  dm_acc( z->Quu, z->tempuu1, z->n_u, z->n_u );
  dm_scale( z->tempuu1, 0.5, z->Quu, z->n_u, z->n_u );

#ifdef COMMENT
  /* Zuu_inv */
  dm_invert( z->Zuu, z->Zuu_inv, z->n_u );

  /* Check out if inverse is correct */
  dm_mult( z->Zuu_inv, z->Zuu, z->tempuu1,
	   z->n_u, z->n_u, z->n_u );
  // dm_print( stdout, "%g ", "\n", z->tempuu1, z->n_u, z->n_u );
  // printf( "\n" );
  for ( j = 0; j < z->n_u; j++ )
    for ( k = 0; k < z->n_u; k++ )
      {
	if ( j == k )
	  {
	    if ( fabs( z->tempuu1[j][k] - 1.0 ) > 1e-10 )
	      printf( "diagonal error: %g %g\n", z->tempuu1[j][k], z->tempuu1[j][k] - 1.0 );
	  }
	else
	  {
	    if ( fabs( z->tempuu1[j][k] ) > 1e-4 )
	      printf( "off diagonal error: %g\n", z->tempuu1[j][k] );
	  }
      }
#endif

#ifdef COMMENT
  /* delta_U_array(i,:) = Zuu_inv*Zu */
  dmv_mult( z->Zuu_inv, z->Zu, z->tempu1, z->n_u, z->n_u );
  
  /* K(t) = Zuu_inv*Zux */
  dm_mult( z->Zuu_inv, z->Zux, z->K,
	   z->n_u, z->n_u, z->n_states );
#endif

  /* Let's do an LDL' decomposition of Quu */
  ok = ldlt_decompose( z->Quu, z->Quu_L, z->Quu_D, z->tempu1, z->n_u );
  if ( !ok )
    {
      printf( "Luu:\n" );
      dm_print( stdout, "%g ", "\n", z->Luu, z->n_u, z->n_u );
      printf( "Quu:\n" );
      dm_print( stdout, "%g ", "\n", z->Quu, z->n_u, z->n_u );
      // exit( -1 );
      return -1;
    }

  printf( "z->Quu_D: " );
  for( i = 0; i < z->n_u; i++ )
    printf( "%g ", z->Quu_D[i] );
  printf( "\n" );

  /* delta_u = Quu_inv*Qu */
  ldlt_solve_vector( z->Quu_L, z->Quu_D, z->Qu, z->delta_u, z->tempu1, z->n_u );

  /* K = Quu_inv*Qux */
  ldlt_solve_matrix( z->Quu_L, z->Quu_D, z->Qux, z->K, z->tempu1, z->n_u, z->n_x );

  /* Vx = Qx - Qu*K */
  dvm_mult( z->Qu, z->K, z->tempx1, z->n_u, z->n_x );
  dv_subtract( z->Qx, z->tempx1, z->Vx, z->n_x );

  /* Vxx = Qxx - Qxu*K */
  dm_mult( z->Qxu, z->K, z->tempxx1, z->n_x, z->n_u, z->n_x );
  dm_subtract( z->Qxx, z->tempxx1, z->Vxx, z->n_x, z->n_x );

  /* Let's make sure Vxx is symmetric. */
  dm_transpose( z->Vxx, z->tempxx1, z->n_x, z->n_x );
  dm_acc( z->Vxx, z->tempxx1, z->n_x, z->n_x );
  dm_scale( z->tempxx1, 0.5, z->Vxx, z->n_x, z->n_x );

  return 0;
}

/****************************************************************************/
/****************************************************************************/

ddp1_set_u_limits( DDP1 *z, double *u_min, double *u_max )
{
  dv_copy ( u_min, z->u_min, z->n_u );
  dv_copy ( u_max, z->u_max, z->n_u );
  z->use_u_limits = TRUE;
}

/****************************************************************************/
/************************************************************************/

#ifdef COMMENT

/****************************************************************************/

double trajectory_cost( DDP1 *z, 
			double **x_array, double **u_array, int n_points )
{
  int i;
  double cost = 0;

  if ( z->terminal_cost_function )
    cost += z->terminal_cost_function( z->argument, x_array[n_points-1] );
  else
    cost += z->cost_function( z->argument, x_array[n_points-1], 
				u_array[n_points-1], STEADY_STATE_INDEX );

  z->V_array[n_points-1] = cost;

  for ( i = n_points - 2; i >= 0; i-- )
    {
      cost = z->cost_function( z->argument, x_array[i], u_array[i], i )
	+ z->discount*cost;
      z->V_array[i] = cost;
    }

  if ( !finite( cost ) )
    cost = DDP1_BAD_VALUE;

  z->V_array[0] = cost;

  return( cost );
}

/****************************************************************************/

double total_cost( DDP1 *z )
{
  return trajectory_cost( z, z->X, z->U, z->n_points );
}

/****************************************************************************/

int gradient( DDP1 *z )
{
  int i, j;

  /* This is from page 58 of Dyer and McReynolds */
  /* Vx = P_final * (X[n_points-1] - X_d_final) */
  dv_subtract ( z->X[z->n_points-1], z->X_d_final, z->tempx1, z->n_states ); 
  dmv_mult ( z->P_ss, z->tempx1, z->Vx, z->n_states, z->n_states ); 
  /*
  printf( "%g %g; %g %g\n", z->tempx1[0], z->tempx1[1], 
	  z->Vx[0], z->Vx[1] );
  */

  for ( i = z->n_points - 1; i >= 0; i-- )
    {
      compute_linear_model( z, z->X[i], z->U[i], i, z->A, z->B );
      compute_cost_first_derivatives( z, z->X[i], z->U[i], i, z->Lx, z->Lu );

#ifdef COMMENT
      /* Zx = Vx * A + (Lxx * (X[i] - X_d[i])) */
      dv_subtract( z->X[i], z->X_d[i], z->tempx1, z->n_states );
      dmv_mult( z->Lxx, z->tempx1, z->tempx2, z->n_states, z->n_states );
      dvm_mult( z->Vx, z->A, z->tempx1, z->n_states, z->n_states );
      dv_add( z->tempx1, z->tempx2, z->Zx, z->n_states );
#endif
      /* Zx = Vx * A + Lx */
      dvm_mult( z->Vx, z->A, z->tempx1, z->n_states, z->n_states );
      dv_scale( z->tempx1, z->discount, z->tempx1, z->n_states );
      dv_add( z->tempx1, z->Lx, z->Zx, z->n_states );

#ifdef COMMENT
      /* Zu = Vx * B + Luu * (U[i] - U_d[i]) */
      dvm_mult( z->Vx, z->B, z->tempu2, z->n_states, z->n_u );
      dv_subtract( z->U[i], z->U_d[i], z->tempu1, z->n_u );
      dmv_mult( z->Luu, z->tempu1, z->tempu3,
		z->n_u, z->n_u );
      dv_add( z->tempu3, z->tempu2, z->Zu, z->n_u );
#endif
      /* Zu = Vx * B + Lu */
      dvm_mult( z->Vx, z->B, z->tempu1, z->n_states, z->n_u );
      dv_scale( z->tempu1, z->discount, z->tempu1, z->n_u );
      dv_add( z->tempu1, z->Lu, z->Zu, z->n_u );

      /* delta_U_array(i,:) = Zu */
      dv_copy( z->Zu, z->delta_U_array[i], z->n_u );

      /* Vx = Zx */
      dv_copy( z->Zx, z->Vx, z->n_states );

      /*
	printf( "%d: %g %g %g\n", i, z->delta_U_array[i][0], 
	z->Vx[0], z->Vx[1] );
      */
    }
  return 0;
}

/****************************************************************************/
/****************************************************************************/
/****************************************************************************/

set_ddp1_dynamics( one_step_dynamics )
     int (*one_step_dynamics)();
{
  if ( !initialized() )
    return;

  one_step_dynamics_function = one_step_dynamics;
}

/****************************************************************************/

set_ddp1_derivatives( compute_linear_model )
     int (*compute_linear_model)();
{
  int compute_linear_model_numerically();

  if ( !initialized() )
    return;

  if ( compute_linear_model )
    compute_linear_model_function = compute_linear_model;
  else
    compute_linear_model_function = compute_linear_model_numerically;
}

/****************************************************************************/

set_delta( arg )
     double arg;
{
  if ( !initialized() )
    return;

  delta = arg;
}

/****************************************************************************/

set_epsilon_gradient( epsilon )
     double epsilon;
{
  if ( !initialized() )
    return;

  epsilon_gradient = epsilon;
}

/****************************************************************************/

set_epsilon_second_order( epsilon )
     double epsilon;
{
  if ( !initialized() )
    return;

  epsilon_second_order = epsilon;
}

/****************************************************************************/
/****************************************************************************/
/****************************************************************************/
/****************************************************************************/
#endif
