/************************************************************************/
/*
dynamics for a 4 link pendulum
*/
/************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "dynamics4.h"

/************************************************************************/

/* State vector indices */
#define S_P1 0
#define S_P2 1
#define S_P3 2
#define S_P4 3
#define S_V1 4
#define S_V2 5
#define S_V3 6
#define S_V4 7

/* Action vector indices */
#define A_T1 0
#define A_T2 1
#define A_T3 2
#define A_T4 3

/* Parameters */
#define LENGTH (1.0/4.0)	// length of link
#define WIDTH (0.1)    // width of link
#define MASS (1.0/4.0)     // mass of link

#define GRAVITY 9.81
// #define VISCOUS_FRICTION 0.1 // viscous friction at each joint
#define VISCOUS_FRICTION 0.0 // viscous friction at each joint

#define TIMESTEP 0.01 // Timestep of integrator

/* Score parameters */
#define STATE_PENALTY 0.001

/* Handy macros */
#define SQ(x) ((x)*(x))

/**********************************************************************/
/* Globals */

double the_desired_state[N_STATE_DIMENSIONS] = { M_PI, 0, 0, 0, 0, 0, 0, 0 };

/**********************************************************************/
/* Don't need to do any initialization */

int init_dynamics()
{
  return 1;
}

/**********************************************************************/

Dynamics *create_dynamics()
{
  Dynamics *result;
  int i;

  result = (Dynamics *) malloc( sizeof( Dynamics ) );
  if ( result == NULL )
    {
      fprintf( stderr, "Can't allocate dynamics.\n" );
      exit( -1 );
    }

  for ( i = 0; i < N_STATE_DIMENSIONS; i++ )
    result->desired_state[i] = the_desired_state[i];
  set_state( result, 0.0, the_desired_state );
  result->seed = 1; /* Set random number generator seed */

  return result;
}

/**********************************************************************/

int set_state( Dynamics *d, double time, double *state )
{
  int i; 

  d->time = time;
  for ( i = 0; i < N_STATE_DIMENSIONS; i++ )
    d->state[i] = state[i];
  return 1;
}

/**********************************************************************/

void dynamics( double a1, double a2, double a3, double a4,
	       double a1d, double a2d, double a3d, double a4d,
	       double tau1, double tau2, double tau3, double tau4,
	       double *a1dd, double *a2dd, double *a3dd, double *a4dd )
{
  /* Slightly faster to have these as variables than defines. Go figure */
  double m1 = MASS;
  double m2 = MASS;
  double m3 = MASS;
  double m4 = MASS;
  double l1cm = (LENGTH/2);
  double l2cm = (LENGTH/2);
  double l3cm = (LENGTH/2);
  double l4cm = (LENGTH/2);
  double l1 = LENGTH;
  double l2 = LENGTH;
  double l3 = LENGTH;
  double G = GRAVITY;
  double I1 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  double I2 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  double I3 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  double I4 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  double s1, c1, s2, c2, s3, c3, s4, c4;
  double a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t;
  double determinant;
  double s12, c12, s23, c23, s34, c34, s1234, s123, s234, c234;
  double a1d_a1d, a2d_a2d, a3d_a3d, a4d_a4d;
  double a1d_p_a2d_2, l4cm_m4, l3_m4, l3cm_m3, l2cm_m2, l3cm_m3_l3_m4;
  double l2cm_m2_p_l2_m3_p_m4;
  double l3_l4cm_m4, l2_l4cm_m4, l1_l4cm_m4;
  double l2_l3cm_m3_l3_m4, l1_l3cm_m3_l3_m4, l2_l4cm_m4_c34;
  double expr1, expr2, expr3, expr4, expr5, expr6, expr7, expr8;
  double expr4a, expr4b, expr5a, expr9a, expr9;
  double a123d, l1_l3cm_m3_l3_m4_s23, l2_l4cm_m4_s34;

  /*
  s1 = sinf( a1 );
  c1 = cosf( a1 );
  s2 = sinf( a2 );
  c2 = cosf( a2 );
  s3 = sinf( a3 );
  c3 = cosf( a3 );
  s4 = sinf( a4 );
  c4 = cosf( a4 );
  */
  s1 = sin( a1 );
  c1 = cos( a1 );
  s2 = sin( a2 );
  c2 = cos( a2 );
  s3 = sin( a3 );
  c3 = cos( a3 );
  s4 = sin( a4 );
  c4 = cos( a4 );
  s12 = s1*c2 + c1*s2;
  c12 = c1*c2 - s1*s2;
  s23 = s2*c3 + c2*s3;
  c23 = c2*c3 - s2*s3;
  s34 = s3*c4 + c3*s4;
  c34 = c3*c4 - s3*s4;
  s1234 = s12*c34 + c12*s34;
  s123 = s12*c3 + c12*s3;
  s234 = s2*c34 + c2*s34;
  c234 = c2*c34 - s2*s34;

  a1d_a1d = a1d*a1d;
  a2d_a2d = a2d*a2d;
  a3d_a3d = a3d*a3d;
  a4d_a4d = a4d*a4d;
  a1d_p_a2d_2 = (a1d + a2d)*(a1d + a2d);

  l4cm_m4 = l4cm*m4;
  l3_l4cm_m4 = l3*l4cm_m4;
  l2_l4cm_m4 = l2*l4cm_m4;
  l2_l4cm_m4_c34 = l2_l4cm_m4*c34;
  l1_l4cm_m4 = l1*l4cm_m4;
  l3_m4 = l3*m4;
  l3cm_m3 = l3cm*m3;
  l3cm_m3_l3_m4 = l3cm_m3 + l3_m4;
  l2cm_m2 = l2cm*m2;
  l2cm_m2_p_l2_m3_p_m4 = l2cm_m2 + l2*(m3 + m4);
  l2_l3cm_m3_l3_m4 = l2*l3cm_m3_l3_m4;
  l1_l3cm_m3_l3_m4 = l1*l3cm_m3_l3_m4;
  a123d = a1d + a2d + a3d;
  l1_l3cm_m3_l3_m4_s23 = l1_l3cm_m3_l3_m4*s23;
  l2_l4cm_m4_s34 = l2_l4cm_m4*s34;

  expr1 = G*(s123*l3cm_m3_l3_m4 + s1234*l4cm_m4);
  expr2 = (2*a123d + a4d)*a4d*l3_l4cm_m4*s4;
  expr3 = G*l2cm_m2_p_l2_m3_p_m4*s12;
  expr4a = 2*a1d*a4d + 2*a2d*a4d + 2*a3d*a4d + a4d_a4d;
  expr4b = 2*a1d*a3d + 2*a2d*a3d + a3d_a3d;
  expr4 = (expr4b + expr4a)*l2_l4cm_m4_s34;
  expr5a = a1d_a1d*l1*s234;
  expr5 = l4cm_m4*expr5a;
  expr6 = expr4b*l2_l3cm_m3_l3_m4*s3;
  expr7 = l1*l2cm_m2_p_l2_m3_p_m4;
  expr8 = l1*(m2+m3+m4);
  expr9a = 2*a1d*a2d + a2d_a2d;
  expr9 = (expr9a + expr4b);

  /* Fourth row */
  p = I4 + l4cm*l4cm_m4;

  o = p + l3_l4cm_m4*c4;

  n = o + l2_l4cm_m4_c34;

  m = n + l1_l4cm_m4*c234;

  t = tau4 - VISCOUS_FRICTION*a4d
    -(l4cm_m4*(a123d*a123d*l3*s4 + 
	       a1d_p_a2d_2*l2*s34 + 
	       expr5a + G*s1234));

  /* Third row */
  l = o;

  k = I3 + o + l3cm*l3cm_m3 + l3*l3_m4 + l3_l4cm_m4*c4;

  j = k + l2_l3cm_m3_l3_m4*c3 + l2_l4cm_m4_c34;

  i = j + l1_l3cm_m3_l3_m4*c23
    + l1_l4cm_m4*c234;

  s = tau3 - VISCOUS_FRICTION*a3d
    -((a1d_p_a2d_2*l2_l3cm_m3_l3_m4*s3 + a1d_a1d*l1_l3cm_m3_l3_m4_s23) + 
      - expr2 
      + a1d_p_a2d_2*l2_l4cm_m4_s34
      + expr5
      + expr1
);

  /* Second row */
  h = n;

  g = j;

  f = j + I2 + l2cm*l2cm_m2  + SQ(l2)*(m3 + m4) 
    + l2_l3cm_m3_l3_m4*c3 + l2_l4cm_m4_c34;

  e = f + i - j + expr7*c2;

  r = tau2 - VISCOUS_FRICTION*a2d
    - (
       a1d_a1d*expr7*s2
       - expr6
       + a1d_a1d*l1_l3cm_m3_l3_m4_s23
       - expr2
       - expr4
       + expr5
       + expr3
       + expr1
);

  /* First row */
  d = m;

  c = i;

  b = e;

  a = 2*e + I1 - f + SQ(l1cm)*m1 + l1*expr8;

  q = tau1 - VISCOUS_FRICTION*a1d
    - ( -expr9a*expr7*s2
	- expr6
	- expr9*l1_l3cm_m3_l3_m4_s23
	- expr2
	- expr4
	- (expr9 + expr4a)*l1_l4cm_m4*s234
	+ expr3
	+ G*(l1cm*m1 + expr8)*s1
	+ expr1
	);

  /*
  printf( "abcd: %g %g %g %g\n", a, b, c, d );
  printf( "efgh: %g %g %g %g\n", e, f, g, h );
  printf( "ijkl: %g %g %g %g\n", i, j, k, l );
  printf( "mnop: %g %g %g %g\n", m, n, o, p );
  printf( "qrst: %g %g %g %g\n", q, r, s, t );
  */

  determinant =
    (d*g*j*m - c*h*j*m - d*f*k*m + b*h*k*m + c*f*l*m - b*g*l*m - d*g*i*n +
     c*h*i*n + d*e*k*n - a*h*k*n - c*e*l*n + a*g*l*n + d*f*i*o - b*h*i*o -
     d*e*j*o + a*h*j*o + b*e*l*o - a*f*l*o - c*f*i*p + b*g*i*p + c*e*j*p -
     a*g*j*p - b*e*k*p + a*f*k*p);
  *a1dd = q*(-(h*k*n) + g*l*n + h*j*o - f*l*o - g*j*p + f*k*p)
    + r*(d*k*n - c*l*n - d*j*o + b*l*o + c*j*p - b*k*p)
    + s*(-(d*g*n) + c*h*n + d*f*o - b*h*o - c*f*p + b*g*p)
    + t*(d*g*j - c*h*j - d*f*k + b*h*k + c*f*l - b*g*l);
  *a2dd = q*(h*k*m - g*l*m - h*i*o + e*l*o + g*i*p - e*k*p)
    + r*(-(d*k*m) + c*l*m + d*i*o - a*l*o - c*i*p + a*k*p)
    + s*(d*g*m - c*h*m - d*e*o + a*h*o + c*e*p - a*g*p)
    + t*(-(d*g*i) + c*h*i + d*e*k - a*h*k - c*e*l + a*g*l);
  *a3dd = q*(-(h*j*m) + f*l*m + h*i*n - e*l*n - f*i*p + e*j*p)
    + r*(d*j*m - b*l*m - d*i*n + a*l*n + b*i*p - a*j*p)
    + s*(-(d*f*m) + b*h*m + d*e*n - a*h*n - b*e*p + a*f*p)
    + t*(d*f*i - b*h*i - d*e*j + a*h*j + b*e*l - a*f*l);
  *a4dd = q*(g*j*m - f*k*m - g*i*n + e*k*n + f*i*o - e*j*o)
    + r*(-(c*j*m) + b*k*m + c*i*n - a*k*n - b*i*o + a*j*o)
    + s*(c*f*m - b*g*m - c*e*n + a*g*n + b*e*o - a*f*o)
    + t*(-(c*f*i) + b*g*i + c*e*j - a*g*j - b*e*k + a*f*k);
  *a1dd = *a1dd/determinant;
  *a2dd = *a2dd/determinant;
  *a3dd = *a3dd/determinant;
  *a4dd = *a4dd/determinant;
}

/**********************************************************************/

int integrate_one_step( Dynamics *d, double *action, double *next_state )
{
  int i;
  double a1 = 0;
  double a2 = 0;
  double a3 = 0;
  double a4 = 0;
  double a1d = 0;
  double a2d = 0;
  double a3d = 0;
  double a4d = 0;
  double torque1 = 0;
  double torque2 = 0;
  double torque3 = 0;
  double torque4 = 0;
  double a1dd = 0;
  double a2dd = 0;
  double a3dd = 0;
  double a4dd = 0;
  double new_a1d, new_a2d, new_a3d, new_a4d;

  a1 = d->state[S_P1];
  a2 = d->state[S_P2];
  a3 = d->state[S_P3];
  a4 = d->state[S_P4];
  a1d = d->state[S_V1];
  a2d = d->state[S_V2];
  a3d = d->state[S_V3];
  a4d = d->state[S_V4];
  torque1 = action[A_T1];
  torque2 = action[A_T2];
  torque3 = action[A_T3];
  torque4 = action[A_T4];

  dynamics( a1, a2, a3, a4, a1d, a2d, a3d, a4d,
	    torque1, torque2, torque3, torque4,
	    &a1dd, &a2dd, &a3dd, &a4dd );

  new_a1d = a1d + a1dd*TIMESTEP;
  a1 += (new_a1d + a1d)*TIMESTEP/2;
  a1d = new_a1d;
  new_a2d = a2d + a2dd*TIMESTEP;
  a2 += (new_a2d + a2d)*TIMESTEP/2;
  a2d = new_a2d;
  new_a3d = a3d + a3dd*TIMESTEP;
  a3 += (new_a3d + a3d)*TIMESTEP/2;
  a3d = new_a3d;
  new_a4d = a4d + a4dd*TIMESTEP;
  a4 += (new_a4d + a4d)*TIMESTEP/2;
  a4d = new_a4d;

  d->time += TIMESTEP;
  d->state[S_P1] = a1;
  d->state[S_P2] = a2;
  d->state[S_P3] = a3;
  d->state[S_P4] = a4;
  d->state[S_V1] = a1d;
  d->state[S_V2] = a2d;
  d->state[S_V3] = a3d;
  d->state[S_V4] = a4d;
  if ( next_state != NULL )
    {
      for( i = 0; i < N_STATE_DIMENSIONS; i++ )
	next_state[i] = d->state[i];
    }
}

/**********************************************************************/

double one_step_cost( Dynamics *d, double *state, double *action )
{
  double s1, s2, s3, s4, a1, a2, a3, a4;

  s1 = state[S_P1] - d->desired_state[S_P1];
  s2 = state[S_P2] - d->desired_state[S_P2];
  s3 = state[S_P3] - d->desired_state[S_P3];
  s3 = state[S_P4] - d->desired_state[S_P4];
  a1 = action[A_T1];
  a2 = action[A_T2];
  a3 = action[A_T3];
  a4 = action[A_T4];
  return TIMESTEP*(a1*a1 + a2*a2 + a3*a3 + a4*a4)
    + STATE_PENALTY*TIMESTEP*(s1*s1 + s2*s2 + s3*s3 + s4*s4);
}

/**********************************************************************/


