/************************************************************************/
/************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "nr.h"
#include "nrutil.h"

#include "opt-parameters.h"
#include "main.h"

/*
#define GRAPHICS
*/

/*****************************************************************************/
/*****************************************************************************/

static OPT_PARAM *params;
extern SIM sim;
static float all_time_low_cost = 1e20;
static int func_calls = 0;
static int n_parameters = 0;
static char output_file[10000];
static int debug_opt = 1;

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/
/* Numerical Recipes Simulated Annealing stuff */

static float **p; /* ndim+1 x ndim dimensional points in the simplex */
static float p_offsets[MAX_N_PARAMETERS+1];
static float y[MAX_N_PARAMETERS+2]; /* values for points in p */
static float pb[MAX_N_PARAMETERS+1]; /* best point seen so far */
static float yb = 1e30; /* best value seen so far (initialize to high value) */
static float xoff[MAX_N_PARAMETERS+1];
static float init_parameters[MAX_N_PARAMETERS+1];

static float temperature = 100000.0;
static int iter1 = 100; /* iterations at a fixed temperature */
/* static int iter2 = 50; */
static int iter2 = 20; /* number of temperatures tried. */
static int n_iterations = 0; /* total iterations so far */
static int iter = 0; /* current iteration count at this temperature */
static int n_temperatures = 0; /* How many temperatures done so far */
static float decay = 0.8; /* 0.9 */ /* T_new = decay*T */
static float cost_reject = 10000000.0; /* Individual simulation reject */
static float reject = 10000000.0; /* Value for crash */
#define FTOL 1.0e-6

long idum=(-61); /* random number seed? */

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

float func_core( float *x )
{
  int i;
  float cost = 0;
  extern float get_score( SIM *s );

  vector_to_sim( &(x[1]), n_parameters, params );
  reinit_sim( &sim );
#ifdef GRAPHICS
  reinit_display();
#endif

  for( ; sim.time < sim.trial_duration; )
    {
      controller( &sim );
      integrate_one_time_step( &sim );
#ifdef GRAPHICS
      redisplay_stuff();
#endif
    }

  cost = get_score( &sim );

  return cost;
}

/***************************************************************************/
/***************************************************************************/
/***************************************************************************/

float call_one_func(x)
float x[];
{
  float cost = 0;
  int i;

  sim.rand_seed = 1;
  cost += func_core( x );

  if ( cost < all_time_low_cost )
    {
      printf( "ALL TIME LOW COST: %g\n", cost );
      for ( i = 1; i <= n_parameters; i++ )
	printf( "init_parameters[%d] = %18.12f;\n", i-1, x[i] );
      write_param_file( output_file, params );
      all_time_low_cost = cost;
    }

  if ( debug_opt )
    {
      printf( "%d %d: cost %g\n", iter, func_calls, cost );
      for( i = 1; i <= n_parameters; i++ )
	printf( "%g ", x[i] );
      printf( "\n\n" );
    }

  return cost;
}

/***************************************************************************/

float call_many_func(x)
float x[];
{
  float cost = 0;
  float c;
  int i;

  printf( "call_many_func\n" );

  for( i = 0; i < sim.how_many_trials; i++ )
    {
      sim.rand_seed = i + 1;
      c = func_core( x );
      cost += c;
      /* printf( "%d: %g\n", i, c ); */
    }

  if ( cost < all_time_low_cost )
    {
      printf( "ALL TIME LOW COST: %g\n", cost );
      for ( i = 1; i <= n_parameters; i++ )
	printf( "init_parameters[%d] = %18.12f;\n", i-1, x[i] );
      write_param_file( output_file, params );
      all_time_low_cost = cost;
    }

  if ( debug_opt )
    {
      printf( "%d %d: cost %g\n", iter, func_calls, cost );
      for( i = 1; i <= n_parameters; i++ )
	printf( "%g ", x[i] );
      printf( "\n\n" );
    }

  return cost;
}

/************************************************************************/
/************************************************************************/
/* Main: initialize and then call simulated annealing optimizer */

int main(int argc, char **argv)
{
  int i, j;
  float last_y_best = 1e30;
  OPT_PARAM *read_opt_param_file();

  init_default_parameters( &sim );

  if ( argc < 2 )
    {
      fprintf( stderr, "Optimize using which parameter file?\n" );
      fprintf( stderr, "%s parameter-file\n", argv[0] );
      exit( -1 );
    }
  params = read_opt_param_file( argv[1] );
  n_parameters = process_parameters( params, &sim, 1 );
  if ( n_parameters > MAX_N_PARAMETERS )
    {
      fprintf( stderr, "Too many parameters %d > %d\n",
	       n_parameters, MAX_N_PARAMETERS );
      exit( -1 );
    }
  sprintf( output_file, "%s.new", argv[1] );

  init_sim( &sim );
#ifdef GRAPHICS
  init_my_graphics();
#endif

  parameters_to_vector( params, init_parameters );
  for( i = 0; i < n_parameters; i++ )
    {
      p_offsets[i] = 0.01*init_parameters[i];
      if ( p_offsets[i] == 0.0 )
	p_offsets[i] = 0.01;
    }

  debug_opt = 1;
  p = matrix( 1, n_parameters+1, 1, n_parameters );
  for( i = 1; i <= n_parameters+1; i++ )
    {
      for ( j = 1; j <= n_parameters; j++ )
	p[i][j] = init_parameters[j-1];
      if ( i >= 2 )
	p[i][i-1] += p_offsets[i-2];
      y[i] = call_many_func( p[i] );
      if ( i >= 2 && ( y[i] >= cost_reject || y[i] >= reject ) )
	{
	  printf( "Trying other direction: %d.\n", i );
	  p[i][i-1] -= 2*p_offsets[i-2];
	  y[i] = call_many_func( p[i] );
	}
      if ( y[i] < yb )
	{
	  yb = y[i];
	  for( j = 1; j <= n_parameters; j++ )
	    pb[j] = p[i][j];
	}
    }
  last_y_best = yb;
     
  for( n_temperatures = 0; n_temperatures < iter2; n_temperatures++ ) 
    {
      iter=iter1;
      amebsa(p,y,n_parameters,pb,&yb,FTOL,call_many_func,&iter,temperature);
      n_iterations += iter1-iter;
      if ( yb < last_y_best ) 
	{
	  last_y_best = yb;
	  printf("new best: %6d %g: ", n_iterations, temperature );
	  printf("%g\n",yb);
	  for (j=1;j<=n_parameters;j++)
	    printf("%9.5f ",pb[j]);
	  printf( "\n" );
	}
  printf( "Iterations: %d\n", n_iterations );
  printf( "Minimum found at: \n" );
  for ( i = 1; i <= n_parameters; i++ )
    printf( "  init_parameters[%d] = %18.12f;\n", i-1, pb[i] );
  printf( "\n\nMinimum function value = %12.6f\n", yb );
      if (iter > 0) break;
      temperature *= decay;
    }

  printf( "Iterations: %d\n", n_iterations );
  printf( "Minimum found at: \n" );
  for ( i = 1; i <= n_parameters; i++ )
    printf( "  init_parameters[%d] = %18.12f;\n", i-1, pb[i] );
  printf( "\n\nMinimum function value = %12.6f\n", yb );
}

/************************************************************************/
