/************************************************************************/
/*
dynamics for a 4 link pendulum
*/
/************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "dynamics4.h"

/************************************************************************/

/* State vector indices */
#define S_P1 0
#define S_P2 1
#define S_P3 2
#define S_P4 3
#define S_V1 4
#define S_V2 5
#define S_V3 6
#define S_V4 7

/* Action vector indices */
#define A_T1 0
#define A_T2 1
#define A_T3 2
#define A_T4 3

/* Parameters */
#define LENGTH (1.0/4.0)	// length of link
#define WIDTH (0.1f)    // width of link
#define MASS (1.0/4.0)     // mass of link

#define GRAVITY 9.81
// #define VISCOUS_FRICTION 0.1 // viscous friction at each joint
#define VISCOUS_FRICTION 0.0 // viscous friction at each joint

#define TIMESTEP 0.01 // Timestep of integrator

/* Score parameters */
#define STATE_PENALTY 0.001

/* Handy macros */
#define SQ(x) ((x)*(x))

/**********************************************************************/
/* Globals */

float the_desired_state[N_STATE_DIMENSIONS] = { M_PI, 0, 0, 0, 0, 0, 0, 0 };

/**********************************************************************/
/* Don't need to do any initialization */

int init_dynamics()
{
  return 1;
}

/**********************************************************************/

Dynamics *create_dynamics()
{
  Dynamics *result;
  int i;

  result = (Dynamics *) malloc( sizeof( Dynamics ) );
  if ( result == NULL )
    {
      fprintf( stderr, "Can't allocate dynamics.\n" );
      exit( -1 );
    }

  for ( i = 0; i < N_STATE_DIMENSIONS; i++ )
    result->desired_state[i] = the_desired_state[i];
  set_state( result, 0.0, the_desired_state );
  result->seed = 1; /* Set random number generator seed */

  return result;
}

/**********************************************************************/

int set_state( Dynamics *d, float time, float *state )
{
  int i; 

  d->time = time;
  for ( i = 0; i < N_STATE_DIMENSIONS; i++ )
    d->state[i] = state[i];
  return 1;
}

/**********************************************************************/

void dynamics( float a1, float a2, float a3, float a4,
	       float a1d, float a2d, float a3d, float a4d,
	       float tau1, float tau2, float tau3, float tau4,
	       float *a1dd, float *a2dd, float *a3dd, float *a4dd )
{
  /* Slightly faster to have these as variables than defines. Go figure */
  float m1 = MASS;
  float m2 = MASS;
  float m3 = MASS;
  float m4 = MASS;
  float l1cm = (LENGTH/2);
  float l2cm = (LENGTH/2);
  float l3cm = (LENGTH/2);
  float l4cm = (LENGTH/2);
  float l1 = LENGTH;
  float l2 = LENGTH;
  float l3 = LENGTH;
  float G = GRAVITY;
  float I1 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  float I2 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  float I3 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  float I4 = (MASS*(LENGTH*LENGTH + WIDTH*WIDTH)/12); /* Icom */
  float s1, c1, s2, c2, s3, c3, s4, c4;
  float a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t;
  float determinant;

  s1 = sinf( a1 );
  c1 = cosf( a1 );
  s2 = sinf( a2 );
  c2 = cosf( a2 );
  s3 = sinf( a3 );
  c3 = cosf( a3 );
  s4 = sinf( a4 );
  c4 = cosf( a4 );
  /*
  s23 = s2*c3 + c2*s3;
  c23 = c2*c3 - s2*s3;
  */

  /* Fourth row */
  p = I4 + SQ(l4cm)*m4;

  o = I4 + SQ(l4cm)*m4 + l3*l4cm*m4*cosf(a4);

  n = I4 + SQ(l4cm)*m4 + l3*l4cm*m4*cosf(a4) + l2*l4cm*m4*cosf(a3 + a4);

  m = I4 + SQ(l4cm)*m4 + l3*l4cm*m4*cosf(a4) + l2*l4cm*m4*cosf(a3 + a4) + 
    l1*l4cm*m4*cosf(a2 + a3 + a4);

  t = tau4 - VISCOUS_FRICTION*a4d
    -(l4cm*m4*(SQ(a1d + a2d + a3d)*l3*sinf(a4) + 
	       SQ(a1d + a2d)*l2*sinf(a3 + a4) + 
	       SQ(a1d)*l1*sinf(a2 + a3 + a4) + G*sinf(a1 + a2 + a3 + a4)));

  // printf( "%g %g\n", t, tau4 );

  /* Third row */
  l = I4 + SQ(l4cm)*m4 + l3*l4cm*m4*cosf(a4);

  k = I3 + I4 + SQ(l3cm)*m3 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    2*l3*l4cm*m4*cosf(a4);

  j = I3 + I4 + SQ(l3cm)*m3 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    l2*(l3cm*m3 + l3*m4)*cosf(a3) + 2*l3*l4cm*m4*cosf(a4) + 
    l2*l4cm*m4*cosf(a3 + a4);

  i = I3 + I4 + SQ(l3cm)*m3 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    l2*(l3cm*m3 + l3*m4)*cosf(a3) + l1*(l3cm*m3 + l3*m4)*cosf(a2 + a3) + 
    2*l3*l4cm*m4*cosf(a4) + l2*l4cm*m4*cosf(a3 + a4) + 
    l1*l4cm*m4*cosf(a2 + a3 + a4);

  s = tau3 - VISCOUS_FRICTION*a3d
    -(SQ(a1d + a2d)*l2*(l3cm*m3 + l3*m4)*sinf(a3) + 
      SQ(a1d)*l1*(l3cm*m3 + l3*m4)*sinf(a2 + a3) + 
      G*l3cm*m3*sinf(a1 + a2 + a3) + G*l3*m4*sinf(a1 + a2 + a3) - 
      2*a1d*a4d*l3*l4cm*m4*sinf(a4) - 2*a2d*a4d*l3*l4cm*m4*sinf(a4) - 
      2*a3d*a4d*l3*l4cm*m4*sinf(a4) - SQ(a4d)*l3*l4cm*m4*sinf(a4) + 
      SQ(a1d)*l2*l4cm*m4*sinf(a3 + a4) + 
      2*a1d*a2d*l2*l4cm*m4*sinf(a3 + a4) + 
      SQ(a2d)*l2*l4cm*m4*sinf(a3 + a4) + 
      SQ(a1d)*l1*l4cm*m4*sinf(a2 + a3 + a4) + 
      G*l4cm*m4*sinf(a1 + a2 + a3 + a4)
);
  // printf( "%g %g\n", s, tau3 );

  /* Second row */
  h = I4 + SQ(l4cm)*m4 + l3*l4cm*m4*cosf(a4) + l2*l4cm*m4*cosf(a3 + a4);

  g = I3 + I4 + SQ(l3cm)*m3 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    l2*(l3cm*m3 + l3*m4)*cosf(a3) + 2*l3*l4cm*m4*cosf(a4) + 
    l2*l4cm*m4*cosf(a3 + a4);

  f = I2 + I3 + I4 + SQ(l2cm)*m2 + SQ(l2)*m3 + SQ(l3cm)*m3 + 
    SQ(l2)*m4 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    2*l2*(l3cm*m3 + l3*m4)*cosf(a3) + 2*l3*l4cm*m4*cosf(a4) + 
    2*l2*l4cm*m4*cosf(a3 + a4);

  e = I2 + I3 + I4 + SQ(l2cm)*m2 + SQ(l2)*m3 + SQ(l3cm)*m3 + 
    SQ(l2)*m4 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    l1*(l2cm*m2 + l2*(m3 + m4))*cosf(a2) + 2*l2*(l3cm*m3 + l3*m4)*cosf(a3) + 
    l1*l3cm*m3*cosf(a2 + a3) + l1*l3*m4*cosf(a2 + a3) + 2*l3*l4cm*m4*cosf(a4) + 
    2*l2*l4cm*m4*cosf(a3 + a4) + l1*l4cm*m4*cosf(a2 + a3 + a4);

  r = tau2 - VISCOUS_FRICTION*a2d
    - (
       SQ(a1d)*l1*(l2cm*m2 + l2*(m3 + m4))*sinf(a2) + 
       G*(l2cm*m2 + l2*(m3 + m4))*sinf(a1 + a2) - 2*a1d*a3d*l2*l3cm*m3*sinf(a3) - 
       2*a2d*a3d*l2*l3cm*m3*sinf(a3) - SQ(a3d)*l2*l3cm*m3*sinf(a3) - 
       2*a1d*a3d*l2*l3*m4*sinf(a3) - 2*a2d*a3d*l2*l3*m4*sinf(a3) - 
       SQ(a3d)*l2*l3*m4*sinf(a3) + SQ(a1d)*l1*l3cm*m3*sinf(a2 + a3) + 
       SQ(a1d)*l1*l3*m4*sinf(a2 + a3) + G*l3cm*m3*sinf(a1 + a2 + a3) + 
       G*l3*m4*sinf(a1 + a2 + a3) - 2*a1d*a4d*l3*l4cm*m4*sinf(a4) - 
       2*a2d*a4d*l3*l4cm*m4*sinf(a4) - 2*a3d*a4d*l3*l4cm*m4*sinf(a4) - 
       SQ(a4d)*l3*l4cm*m4*sinf(a4) - 2*a1d*a3d*l2*l4cm*m4*sinf(a3 + a4) - 
       2*a2d*a3d*l2*l4cm*m4*sinf(a3 + a4) - 
       SQ(a3d)*l2*l4cm*m4*sinf(a3 + a4) - 
       2*a1d*a4d*l2*l4cm*m4*sinf(a3 + a4) - 2*a2d*a4d*l2*l4cm*m4*sinf(a3 + a4) - 
       2*a3d*a4d*l2*l4cm*m4*sinf(a3 + a4) - 
       SQ(a4d)*l2*l4cm*m4*sinf(a3 + a4) + 
       SQ(a1d)*l1*l4cm*m4*sinf(a2 + a3 + a4) + 
       G*l4cm*m4*sinf(a1 + a2 + a3 + a4)
);

  /* First row */
  d = I4 + SQ(l4cm)*m4 + l3*l4cm*m4*cosf(a4) + l2*l4cm*m4*cosf(a3 + a4) + 
    l1*l4cm*m4*cosf(a2 + a3 + a4);

  c = I3 + I4 + SQ(l3cm)*m3 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    l2*(l3cm*m3 + l3*m4)*cosf(a3) + l1*(l3cm*m3 + l3*m4)*cosf(a2 + a3) + 
    2*l3*l4cm*m4*cosf(a4) + l2*l4cm*m4*cosf(a3 + a4) + 
    l1*l4cm*m4*cosf(a2 + a3 + a4);

  b = I2 + I3 + I4 + SQ(l2cm)*m2 + SQ(l2)*m3 + SQ(l3cm)*m3 + 
    SQ(l2)*m4 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    l1*(l2cm*m2 + l2*(m3 + m4))*cosf(a2) + 2*l2*(l3cm*m3 + l3*m4)*cosf(a3) + 
    l1*l3cm*m3*cosf(a2 + a3) + l1*l3*m4*cosf(a2 + a3) + 2*l3*l4cm*m4*cosf(a4) + 
    2*l2*l4cm*m4*cosf(a3 + a4) + l1*l4cm*m4*cosf(a2 + a3 + a4);

  a = I1 + I2 + I3 + I4 + SQ(l1cm)*m1 + SQ(l1)*m2 + SQ(l2cm)*m2 + 
    SQ(l1)*m3 + SQ(l2)*m3 + SQ(l3cm)*m3 + SQ(l1)*m4 + 
    SQ(l2)*m4 + SQ(l3)*m4 + SQ(l4cm)*m4 + 
    2*l1*(l2cm*m2 + l2*(m3 + m4))*cosf(a2) + 2*l2*(l3cm*m3 + l3*m4)*cosf(a3) + 
    2*l1*l3cm*m3*cosf(a2 + a3) + 2*l1*l3*m4*cosf(a2 + a3) + 
    2*l3*l4cm*m4*cosf(a4) + 2*l2*l4cm*m4*cosf(a3 + a4) + 
    2*l1*l4cm*m4*cosf(a2 + a3 + a4);

  q = tau1 - VISCOUS_FRICTION*a1d
    - (
    G*(l1cm*m1 + l1*(m2 + m3 + m4))*sinf(a1) -
    a2d*(2*a1d + a2d)*l1*(l2cm*m2 + l2*(m3 + m4))*sinf(a2) +
    G*l2cm*m2*sinf(a1 + a2) + G*l2*m3*sinf(a1 + a2) + G*l2*m4*sinf(a1 + a2) -
    2*a1d*a3d*l2*l3cm*m3*sinf(a3) - 2*a2d*a3d*l2*l3cm*m3*sinf(a3) -
    SQ(a3d)*l2*l3cm*m3*sinf(a3) - 2*a1d*a3d*l2*l3*m4*sinf(a3) -
    2*a2d*a3d*l2*l3*m4*sinf(a3) - SQ(a3d)*l2*l3*m4*sinf(a3) -
    2*a1d*a2d*l1*l3cm*m3*sinf(a2 + a3) -
    SQ(a2d)*l1*l3cm*m3*sinf(a2 + a3) -
    2*a1d*a3d*l1*l3cm*m3*sinf(a2 + a3) - 2*a2d*a3d*l1*l3cm*m3*sinf(a2 + a3) -
    SQ(a3d)*l1*l3cm*m3*sinf(a2 + a3) - 2*a1d*a2d*l1*l3*m4*sinf(a2 + a3) -
    SQ(a2d)*l1*l3*m4*sinf(a2 + a3) - 2*a1d*a3d*l1*l3*m4*sinf(a2 + a3) -
    2*a2d*a3d*l1*l3*m4*sinf(a2 + a3) - SQ(a3d)*l1*l3*m4*sinf(a2 + a3) +
    G*l3cm*m3*sinf(a1 + a2 + a3) + G*l3*m4*sinf(a1 + a2 + a3) -
    2*a1d*a4d*l3*l4cm*m4*sinf(a4) - 2*a2d*a4d*l3*l4cm*m4*sinf(a4) -
    2*a3d*a4d*l3*l4cm*m4*sinf(a4) - SQ(a4d)*l3*l4cm*m4*sinf(a4) -
    2*a1d*a3d*l2*l4cm*m4*sinf(a3 + a4) - 2*a2d*a3d*l2*l4cm*m4*sinf(a3 + a4) -
    SQ(a3d)*l2*l4cm*m4*sinf(a3 + a4) -
    2*a1d*a4d*l2*l4cm*m4*sinf(a3 + a4) - 2*a2d*a4d*l2*l4cm*m4*sinf(a3 + a4) -
    2*a3d*a4d*l2*l4cm*m4*sinf(a3 + a4) -
    SQ(a4d)*l2*l4cm*m4*sinf(a3 + a4) -
    2*a1d*a2d*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    SQ(a2d)*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    2*a1d*a3d*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    2*a2d*a3d*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    SQ(a3d)*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    2*a1d*a4d*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    2*a2d*a4d*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    2*a3d*a4d*l1*l4cm*m4*sinf(a2 + a3 + a4) -
    SQ(a4d)*l1*l4cm*m4*sinf(a2 + a3 + a4) +
    G*l4cm*m4*sinf(a1 + a2 + a3 + a4)
);

  /*
  printf( "abcd: %g %g %g %g\n", a, b, c, d );
  printf( "efgh: %g %g %g %g\n", e, f, g, h );
  printf( "ijkl: %g %g %g %g\n", i, j, k, l );
  printf( "mnop: %g %g %g %g\n", m, n, o, p );
  printf( "qrst: %g %g %g %g\n", q, r, s, t );
  */

  determinant =
    (d*g*j*m - c*h*j*m - d*f*k*m + b*h*k*m + c*f*l*m - b*g*l*m - d*g*i*n +
     c*h*i*n + d*e*k*n - a*h*k*n - c*e*l*n + a*g*l*n + d*f*i*o - b*h*i*o -
     d*e*j*o + a*h*j*o + b*e*l*o - a*f*l*o - c*f*i*p + b*g*i*p + c*e*j*p -
     a*g*j*p - b*e*k*p + a*f*k*p);
  *a1dd = q*(-(h*k*n) + g*l*n + h*j*o - f*l*o - g*j*p + f*k*p)
    + r*(d*k*n - c*l*n - d*j*o + b*l*o + c*j*p - b*k*p)
    + s*(-(d*g*n) + c*h*n + d*f*o - b*h*o - c*f*p + b*g*p)
    + t*(d*g*j - c*h*j - d*f*k + b*h*k + c*f*l - b*g*l);
  *a2dd = q*(h*k*m - g*l*m - h*i*o + e*l*o + g*i*p - e*k*p)
    + r*(-(d*k*m) + c*l*m + d*i*o - a*l*o - c*i*p + a*k*p)
    + s*(d*g*m - c*h*m - d*e*o + a*h*o + c*e*p - a*g*p)
    + t*(-(d*g*i) + c*h*i + d*e*k - a*h*k - c*e*l + a*g*l);
  *a3dd = q*(-(h*j*m) + f*l*m + h*i*n - e*l*n - f*i*p + e*j*p)
    + r*(d*j*m - b*l*m - d*i*n + a*l*n + b*i*p - a*j*p)
    + s*(-(d*f*m) + b*h*m + d*e*n - a*h*n - b*e*p + a*f*p)
    + t*(d*f*i - b*h*i - d*e*j + a*h*j + b*e*l - a*f*l);
  *a4dd = q*(g*j*m - f*k*m - g*i*n + e*k*n + f*i*o - e*j*o)
    + r*(-(c*j*m) + b*k*m + c*i*n - a*k*n - b*i*o + a*j*o)
    + s*(c*f*m - b*g*m - c*e*n + a*g*n + b*e*o - a*f*o)
    + t*(-(c*f*i) + b*g*i + c*e*j - a*g*j - b*e*k + a*f*k);
  *a1dd = *a1dd/determinant;
  *a2dd = *a2dd/determinant;
  *a3dd = *a3dd/determinant;
  *a4dd = *a4dd/determinant;
}

/**********************************************************************/

int integrate_one_step( Dynamics *d, float *action, float *next_state )
{
  int i;
  float a1 = 0;
  float a2 = 0;
  float a3 = 0;
  float a4 = 0;
  float a1d = 0;
  float a2d = 0;
  float a3d = 0;
  float a4d = 0;
  float torque1 = 0;
  float torque2 = 0;
  float torque3 = 0;
  float torque4 = 0;
  float a1dd = 0;
  float a2dd = 0;
  float a3dd = 0;
  float a4dd = 0;
  float new_a1d, new_a2d, new_a3d, new_a4d;

  a1 = d->state[S_P1];
  a2 = d->state[S_P2];
  a3 = d->state[S_P3];
  a4 = d->state[S_P4];
  a1d = d->state[S_V1];
  a2d = d->state[S_V2];
  a3d = d->state[S_V3];
  a4d = d->state[S_V4];
  torque1 = action[A_T1];
  torque2 = action[A_T2];
  torque3 = action[A_T3];
  torque4 = action[A_T4];

  dynamics( a1, a2, a3, a4, a1d, a2d, a3d, a4d,
	    torque1, torque2, torque3, torque4,
	    &a1dd, &a2dd, &a3dd, &a4dd );

  new_a1d = a1d + a1dd*TIMESTEP;
  a1 += (new_a1d + a1d)*TIMESTEP/2;
  a1d = new_a1d;
  new_a2d = a2d + a2dd*TIMESTEP;
  a2 += (new_a2d + a2d)*TIMESTEP/2;
  a2d = new_a2d;
  new_a3d = a3d + a3dd*TIMESTEP;
  a3 += (new_a3d + a3d)*TIMESTEP/2;
  a3d = new_a3d;
  new_a4d = a4d + a4dd*TIMESTEP;
  a4 += (new_a4d + a4d)*TIMESTEP/2;
  a4d = new_a4d;

  d->time += TIMESTEP;
  d->state[S_P1] = a1;
  d->state[S_P2] = a2;
  d->state[S_P3] = a3;
  d->state[S_P4] = a4;
  d->state[S_V1] = a1d;
  d->state[S_V2] = a2d;
  d->state[S_V3] = a3d;
  d->state[S_V4] = a4d;
  if ( next_state != NULL )
    {
      for( i = 0; i < N_STATE_DIMENSIONS; i++ )
	next_state[i] = d->state[i];
    }
}

/**********************************************************************/

float one_step_cost( Dynamics *d, float *state, float *action )
{
  float s1, s2, s3, s4, a1, a2, a3, a4;

  s1 = state[S_P1] - d->desired_state[S_P1];
  s2 = state[S_P2] - d->desired_state[S_P2];
  s3 = state[S_P3] - d->desired_state[S_P3];
  s3 = state[S_P4] - d->desired_state[S_P4];
  a1 = action[A_T1];
  a2 = action[A_T2];
  a3 = action[A_T3];
  a4 = action[A_T4];
  return TIMESTEP*(a1*a1 + a2*a2 + a3*a3 + a4*a4)
    + STATE_PENALTY*TIMESTEP*(s1*s1 + s2*s2 + s3*s3 + s4*s4);
}

/**********************************************************************/


