/* backprop.c
 * CMUnited-97 (soccer client for Robocup-97)
 * Peter Stone <pstone@cs.cmu.edu>
 * Computer Science Department
 * Carnegie Mellon University
 * Copyright (C) 1997 Peter Stone
 *
 * CMUnited-97 was created by Peter Stone and Manuela Veloso
 *
 * You may copy and distribute this program freely as long as you retain this notice.
 * If you make any changes or have any comments we would appreciate a message.
 */

/************************************
   Code for implementing the backpropagation algorithm for training
   a fully-connected MLP neural network with 1 layer of hidden units.
   Loosely translated from backprop.lisp by David Touretzky.

   Compile with the command:  cc -o backprop backprop.c -lm
   
   Justin Boyan, Oct 5 1993
   Modified by Peter Stone, 1996
*************************************/

#include "backprop.h"

#ifndef FOR_USE   
#include "main.h"   /* All the rest of the backprop stuff: not needed for use */
extern  int NN_n_hid;
extern  float ETA1;
extern  float ETA2;
#else
#include "global.h" /* We're in the clienttrial directory                     */
#endif



/*** Set up the network ***/
static int   n_pat;
static float **train_in;
static float **train_out;
/*********/

static float eta1, eta2;/* learning rates used in w1 and w2 weights */
static float alpha=0.9;           /* momentum coefficient */
static float beta=0.999;           /* weight decay coefficient */
static float randmax=0.5;         /* random weights are initialized in [-R,R) */

static long  epoch_counter=0;	  /* keeps track of # of training epochs */
static float epoch_error;	  /* keep track of total sum^2 error each epoch*/
static int   epoch_error_num;    /* keeps track of number of misclassifications*/

static float inp[NN_n_inp+1];     /* input unit activations -- +1 for bias */
static float *hid0,*hid;        /* hidden unit before & after activ. fn. */
static float out0[NN_n_out],out[NN_n_out];  /* output unit before & after activ. fn. */

static float target[NN_n_out];          /* target output values */
static float dout[NN_n_out],*dhid;/* delta vals used in backprop computation*/
static float **w1;                 /* input->hidden weight matrix */
static float **w2;     /* hidden->output weight matrix */
static float **dw1;    /* accumulates weight changes to w1 per-epoch */
static float **dw2;    /* accumulates weight changes to w2 per-epoch */
static float **prev_dw1;/* previous epoch's change to w1 */
static float **prev_dw2;/* previous epoch's change to w2 */

enum actfntype {sigm,line,gauss}
/*  h_actfn = sigm, o_actfn = sigm;  /* activation fn to use on each layer */
  h_actfn = sigm, o_actfn = line;  /* Fahlman suggests line for cont output*/

/*** Prototypes ***/

#define sqr(x) ((x)*(x))
#define rnd(lo,hi) ((lo) + ((hi)-(lo))*(rand()/2147483647.0))
#ifndef FOR_USE
void  test(char *source);
void  train(char *source, int max_epochs);
void  initialize();
void  train_one_epoch();
void  backward_pass();
void  update_weights();
void  write_weights(char *source);
void  initialize_from_file(char *source);
#endif
void  forward_pass();
void  load_weights(char *source);
/***/

float **allocate_2dim(int dim1, int dim2){
  float **new_array;

  new_array= new (float*)[dim1];
  for (int i=0; i<dim1; i++){
    new_array[i] = new float[dim2];
  }
  return new_array;
}

void allocate_arrays(int patterns){
  hid0 = new float[NN_n_hid];
  hid  = new float[NN_n_hid+1];
  dhid = new float[NN_n_hid];

  w1 =       allocate_2dim(NN_n_inp+1,NN_n_hid);
  dw1 =      allocate_2dim(NN_n_inp+1,NN_n_hid);
  prev_dw1 = allocate_2dim(NN_n_inp+1,NN_n_hid);

  w2 =       allocate_2dim(NN_n_hid+1,NN_n_out);
  dw2 =      allocate_2dim(NN_n_hid+1,NN_n_out);
  prev_dw2 = allocate_2dim(NN_n_hid+1,NN_n_out);

  train_in  = allocate_2dim(patterns,NN_n_inp);
  train_out = allocate_2dim(patterns,NN_n_out);
}

void NN_initialize_to_use(char *source)
{
  w1 = allocate_2dim(NN_n_inp+1,NN_n_hid);
  w2 = allocate_2dim(NN_n_hid+1,NN_n_out);

  hid0 = new float[NN_n_hid];
  hid  = new float[NN_n_hid+1];  

  /* initialize bias units */
  inp[NN_n_inp] = 1.0;
  hid[NN_n_hid] = 1.0;

  load_weights(source);
}

void NN_use(float *array)
{
  /* Assumes that the array has room for (max n_inp,n_out) */

  for (int i=0; i<NN_n_inp; i++){
    inp[i] = (array[i]-INP_BASES[i])/INP_RANGES[i];
  }

  forward_pass();

  for (int i=0; i<NN_n_out; i++){
    array[i] = (out[i]*OUT_RANGES[i]) + OUT_BASES[i];
  }
}

#ifndef FOR_USE   /** Only need forward_pass and load_weights if just the net **/

void GetStampedName( char *name ){
  char *outputName = "weights";
  char date[100],weekday[10],month[10],temp[10];
  int  day,hour,min,sec,year;
  FILE *dateFile;
  
/*  if ( strcmp(BehaviorName,"Dunno") )
      outputName = strdup(BehaviorName); */

  /* Main.c does this                                            */
  /* system("date > date.log");        /* Put the date in a file */

  dateFile = fopen("date.log","r");
  fscanf(dateFile,"%[^\n]",date);   /* Scan it in             */
  fclose(dateFile);
  
  sscanf(date,"%s %s %d %d:%d:%d %s %d",
	 weekday,month,&day,&hour,&min,&sec,temp,&year);
  sprintf(name,"%s-%s%d-%d:%d.dat",outputName,month,day,hour,min);
}


void test (char *source)
{
  int p,i,o;
  float error;
  char WeightFileName[200];

  if ( !strcmp(TEST_WEIGHTS,"none") )
    GetStampedName(WeightFileName);
  else
    strcpy(WeightFileName,TEST_WEIGHTS);

  w1 = allocate_2dim(NN_n_inp+1,NN_n_hid);
  w2 = allocate_2dim(NN_n_hid+1,NN_n_out);

  if ( LengthSourceFile(source, &n_pat) ) my_error("LengthSourceFile");
  train_in  = allocate_2dim(n_pat, NN_n_inp);
  train_out = allocate_2dim(n_pat, NN_n_out);

  if ( InputSourceFile(source, train_in, train_out) )
    my_error("InputSourceFile");
  NN_initialize_to_use(WeightFileName);

  error = 0;
  epoch_error = 0;
  epoch_error_num = 0;

  for (p=0; p<n_pat; p++) {
    for (i=0;i<NN_n_inp;i++) inp[i] = train_in[p][i]; /* set input vector */
    forward_pass();
    for (o=0;o<NN_n_out;o++) {
      target[o] = train_out[p][o]; /* set target vector */
      error = target[o]-out[o];
      epoch_error += sqr(error);
      if ( fabs(target[o]-out[o]) > .4 )
	epoch_error_num++;
    }
  }
  printf("total error = %7.4f, avg. error = %7.4f, %d errors out of %d.\n",
	     epoch_error, 
	     epoch_error/(n_pat*NN_n_out), epoch_error_num, n_pat);

  for (i=0; i<n_pat; i++){
    delete(train_in[i]);
    delete(train_out[i]);
  }
/*  delete(train_in);
    delete(train_out);*/
}

void train (char *source, int max_epochs, int argc, char **argv)
{
  int i;
  char WeightFileName[200];
  
  GetStampedName( WeightFileName );
  
  if ( LengthSourceFile(source, &n_pat) ) my_error("LengthSourceFile");
  allocate_arrays(n_pat);
  eta1 = ETA1;
  eta2 = ETA2;

  if ( InputSourceFile(source, train_in, train_out) )
    my_error("InputSourceFile");

  if ( strcmp(INITIALIZE_FILE,"none") )
    initialize_from_file(INITIALIZE_FILE);
  else 
    initialize();

  static float last_epoch_error = 0;
  do {
    train_one_epoch();
    epoch_counter ++;

    if (epoch_counter % DISPLAY_FREQ == 0) {
      printf("Epoch %d:  tot err = %7.4f, avg err = %7.4f, num wrong = %d\n",
	     epoch_counter, epoch_error, 
	     epoch_error/(n_pat*NN_n_out), epoch_error_num);
      if (epoch_counter % SAVE_WGTS_FREQ == 0) {
	write_weights(WeightFileName);
      }
      if ( fabs(epoch_error - last_epoch_error) < .001 ) break;
      last_epoch_error = epoch_error;
    }
  } while (epoch_counter < max_epochs && epoch_error >= 0.01);
	  

  printf("\nBackprop quit after epoch %d with error %7.4f (%d wrong)\n",
	 epoch_counter, epoch_error, epoch_error_num);

  FILE *result_compilation = fopen ("compilation.dat","a");
  fprintf(result_compilation,"--%s: %d hiddens, eta1 = %f, eta2 = %f\n",
	  WeightFileName, NN_n_hid, eta1, eta2);
  fprintf(result_compilation,"Epoch %d:  tot err = %7.4f, avg err = %7.4f, num wrong = %d\n\n",
	  epoch_counter, epoch_error, 
	  epoch_error/(n_pat*NN_n_out), epoch_error_num);
  fclose(result_compilation);

  for (i=0; i<n_pat; i++){
    delete(train_in[i]);
    delete(train_out[i]);
  }
/*  delete(train_in);
    delete(train_out);*/
}


void initialize()
{
  int i,h,o;
  printf("Initializing %d->%d->%d network:\n", NN_n_inp,NN_n_hid,NN_n_out);
  printf("\teta = (%f,%f), alpha = %f, beta = %f, randmax = %f\n",
	 eta1,eta2,alpha,beta,randmax);
  printf("\t%d training patterns\n", n_pat);

  /* initialize bias units */
  inp[NN_n_inp] = 1.0;
  hid[NN_n_hid] = 1.0;

  /* initialize input->hidden weights */
  for (i=0; i<NN_n_inp+1; i++) {
    for (h=0; h<NN_n_hid; h++) {
      w1[i][h] = rnd(-randmax,randmax);
      dw1[i][h] = 0.0;
    }
  }
  /* initialize hidden->output weights */
  for (h=0; h<NN_n_hid+1; h++) {
    for (o=0; o<NN_n_out; o++) {
      w2[h][o] = rnd(-randmax,randmax);
      dw2[h][0] = 0.0;
    }
  }
}

void train_one_epoch ()
{
  int i,h,o,p;
  /* clear all weight deltas */
  for (i=0; i<NN_n_inp+1; i++) for (h=0; h<NN_n_hid; h++) dw1[i][h]=0.0;
  for (h=0; h<NN_n_hid+1; h++) for (o=0; o<NN_n_out; o++) dw2[h][o]=0.0;

  epoch_error = 0.0;
  epoch_error_num = 0;
  for (p=0; p<n_pat; p++) {
    for (i=0;i<NN_n_inp;i++) inp[i] = train_in[p][i]; /* set input vector */
    forward_pass();
    for (o=0;o<NN_n_out;o++) target[o] = train_out[p][o]; /* set target vector */
    backward_pass();
  }
  update_weights ();
/*my_error("1epoch");*/
}

#endif  /* Need actfn for forward_pass which is needed for use*/

/* applies an activation function of type g to a value */
float actfn(enum actfntype g, float value)
{
  if (g==line) return value;
  if (g==sigm) return 1.0 / (1.0+exp(-value));
  if (g==gauss) return exp(-sqr(value));
}

/* computes the deriv. of an act.fn. g at a value, possibly using g(value) */
float actfnprime(enum actfntype g,float value,float g_value)
{
  if (g==line) return 1.0;
  if (g==sigm) return g_value*(1.0-g_value) + .1;  /* .1 is Fahlman's suggestion */
/*  if (g==sigm) return g_value*(1.0-g_value);*/
  if (g==gauss) return -2.0*value*g_value;
}

/* assume the inp[] array has already been set with the desired input */
void forward_pass ()
{
  int i,h,o;
  /* propagate inputs to hidden layer */
  for (h=0; h<NN_n_hid; h++) {
    hid0[h]=0.0;
    for (i=0; i<NN_n_inp+1; i++) {
      hid0[h] += inp[i] * w1[i][h];
    }
    hid[h] = actfn(h_actfn, hid0[h]);
  }
  /* propagate hidden layer to outputs */
  for (o=0; o<NN_n_out; o++) {
    out0[o]=0.0;
    for (h=0; h<NN_n_hid+1; h++) {
      out0[o] += hid[h] * w2[h][o];
    }
    out[o] = actfn(o_actfn, out0[o]);
  }
/*  my_error("forward");*/
}

#ifndef FOR_USE  

/* assume the inp[] array has been set, forward_pass() has already been called
   to propagate forward those inputs, and the target[] array is also set.   */
void backward_pass ()
{
  float error,delta;
  int i,h,o;
  /* compute error at outputs */
  for (o=0; o<NN_n_out; o++) {
    error = target[o]-out[o];
    epoch_error += sqr(error);
    if ( fabs(error) > .4 ) epoch_error_num++;
    dout[o] = error * actfnprime(o_actfn,out0[o],out[o]);
  }
  /* backpropagate error signal to the hidden layer */
  for (h=0; h<NN_n_hid; h++) {
    delta=0.0;
    for (o=0; o<NN_n_out; o++) {
      delta += dout[o] * w2[h][o];
    }
    dhid[h] = delta * actfnprime(h_actfn,hid0[h],hid[h]);
  }
  /* Now that we've got an error signal for each unit in the network,
     we can determine the weight changes & accumulate them in dw1 and dw2. */
  for (o=0; o<NN_n_out; o++) {
    for (h=0; h<NN_n_hid+1; h++) {
      dw2[h][o] += eta2 * dout[o] * hid[h];
    }
  }
  for (h=0; h<NN_n_hid; h++) {
    for (i=0; i<NN_n_inp+1; i++) {
      dw1[i][h] += eta1 * dhid[h] * inp[i];
    }
  }
/*  my_error("backward");*/
}

/* update the w1 and w2 weights using the accumulated changes in dw1 and dw2
   as well as a momentum term involving the last epoch's total weight change.*/
void update_weights()
{
  int i,h,o;
  for (i=0; i<NN_n_inp+1; i++) {
    for (h=0; h<NN_n_hid; h++) {
      w1[i][h] += ( prev_dw1[i][h] = dw1[i][h] + alpha*prev_dw1[i][h] );
      w1[i][h] *= beta;
    }
  }
  for (h=0; h<NN_n_hid+1; h++) {
    for (o=0; o<NN_n_out; o++) {
      w2[h][o] += ( prev_dw2[h][o] = dw2[h][o] + alpha*prev_dw2[h][o] );
      w2[h][o] *= beta;
    }
  }
}

/* Write the weights to a file  */
void write_weights(char *source)
{
  FILE *weightFile;
  int i,j;

  weightFile = fopen(source,"w");
  if ( !weightFile ) 
    my_error("couldn't open weightFile");

/* The weights */
  fprintf(weightFile,"{");
  for (i=0; i<NN_n_inp+1; i++){
    fprintf(weightFile,"{");
    for (j=0; j<NN_n_hid; j++)
      fprintf(weightFile,"%f ",w1[i][j]);
    fprintf(weightFile,"} ");
  }
  fprintf(weightFile,"}\n");

  fprintf(weightFile,"{");
  for (i=0; i<NN_n_hid+1; i++){
    fprintf(weightFile,"{");
    for (j=0; j<NN_n_out; j++)
      fprintf(weightFile,"%f ",w2[i][j]);
    fprintf(weightFile,"} ");
  }
  fprintf(weightFile,"}\n\n");

  fprintf(weightFile,"Last changes in weights:\n");
  fprintf(weightFile,"{");
  for (i=0; i<NN_n_inp+1; i++){
    fprintf(weightFile,"{");
    for (j=0; j<NN_n_hid; j++)
      fprintf(weightFile,"%f ",dw1[i][j]);
    fprintf(weightFile,"} ");
  }
  fprintf(weightFile,"}\n");

  fprintf(weightFile,"{");
  for (i=0; i<NN_n_hid+1; i++){
    fprintf(weightFile,"{");
    for (j=0; j<NN_n_out; j++)
      fprintf(weightFile,"%f ",dw2[i][j]);
    fprintf(weightFile,"} ");
  }
  fprintf(weightFile,"}\n\n");

  fprintf(weightFile,"The last forward pass, followed by target (epoch %d):\n",
	  epoch_counter);
  for (i=0; i<NN_n_inp+1; i++)
    fprintf(weightFile,"%f ",inp[i]);
  fprintf(weightFile,"\n");

  for (i=0; i<NN_n_hid; i++)
    fprintf(weightFile,"%f ",hid0[i]);
  fprintf(weightFile,"\n");

  for (i=0; i<NN_n_hid+1; i++)
    fprintf(weightFile,"%f ",hid[i]);
  fprintf(weightFile,"\n");

  for (i=0; i<NN_n_out; i++)
    fprintf(weightFile,"%f ",out0[i]);
  fprintf(weightFile,"\n");

  for (i=0; i<NN_n_out; i++)
    fprintf(weightFile,"%f ",out[i]);
  fprintf(weightFile,"\n");

  for (i=0; i<NN_n_out; i++)
    fprintf(weightFile,"%f ",target[i]);
  fprintf(weightFile,"\n\n");

  fprintf(weightFile,"dhid and dout from last pass:\n");
  for (i=0; i<NN_n_hid; i++)
    fprintf(weightFile,"%f ",dhid[i]);
  fprintf(weightFile,"\n");

  for (i=0; i<NN_n_out; i++)
    fprintf(weightFile,"%f ",dout[i]);
  fprintf(weightFile,"\n\n");

  fprintf(weightFile,"%d->%d->%d network:\n", NN_n_inp,NN_n_hid,NN_n_out);
  fprintf(weightFile,"\teta = (%f,%f), alpha = %f, beta = %f, randmax = %f\n",
	 eta1,eta2,alpha,beta,randmax);
  fprintf(weightFile,"\t%d training patterns from %s\n", n_pat, TRAIN_FILE);
  if ( strcmp(INITIALIZE_FILE,"none") )
    fprintf(weightFile,"Weights initialized from %s\n",INITIALIZE_FILE);

  fprintf(weightFile,"Epoch %d:  tot err = %7.4f, avg err = %7.4f, num wrong = %d\n",
	 epoch_counter, epoch_error, 
	 epoch_error/(n_pat*NN_n_out), epoch_error_num);

  fclose(weightFile);
}

#endif

/* Load the weights from a file  */
void load_weights(char *source)
{
  FILE *weightFile;
  int i,j;

  weightFile = fopen(source,"r");
  if ( !weightFile ) 
    my_error("couldn't open weightFile");
  
/* The weights */
  fscanf(weightFile,"{");
  for (i=0; i<NN_n_inp+1; i++){
    fscanf(weightFile,"{");
    for (j=0; j<NN_n_hid; j++)
      fscanf(weightFile,"%f ",&(w1[i][j]));
    fscanf(weightFile,"} ");
  }
  fscanf(weightFile,"}\n");
  
  fscanf(weightFile,"{");
  for (i=0; i<NN_n_hid+1; i++){
    fscanf(weightFile,"{");
    for (j=0; j<NN_n_out; j++)
      fscanf(weightFile,"%f ",&(w2[i][j]));
    fscanf(weightFile,"} ");
  }
  fscanf(weightFile,"}\n\n");

  fclose(weightFile);
/*  my_error("load_weights");*/
}

#ifndef FOR_USE

/* Load the weights from a file called */
void initialize_from_file(char *source)
{
  FILE *weightFile;
  int i,j;

  printf("Initializing %d->%d->%d network:\n", NN_n_inp,NN_n_hid,NN_n_out);
  printf("\teta = (%f,%f), alpha = %f\n",
	 eta1,eta2,alpha);
  printf("\t%d training patterns\n", n_pat);
  printf("Initializing weights from %s\n", source);

  weightFile = fopen(source,"r");
  if ( !weightFile ) 
    my_error("couldn't open weightFile");

  /* initialize bias units */
  inp[NN_n_inp] = 1.0;
  hid[NN_n_hid] = 1.0;
  
/* The weights */
  fscanf(weightFile,"{");
  for (i=0; i<NN_n_inp+1; i++){
    fscanf(weightFile,"{");
    for (j=0; j<NN_n_hid; j++)
      fscanf(weightFile,"%f ",&(w1[i][j]));
    fscanf(weightFile,"} ");
  }
  fscanf(weightFile,"}\n");
  
  fscanf(weightFile,"{");
  for (i=0; i<NN_n_hid+1; i++){
    fscanf(weightFile,"{");
    for (j=0; j<NN_n_out; j++)
      fscanf(weightFile,"%f ",&(w2[i][j]));
    fscanf(weightFile,"} ");
  }
  fscanf(weightFile,"}\n\n");

  fscanf(weightFile,"Last changes in weights:\n");
  fscanf(weightFile,"{");
  for (i=0; i<NN_n_inp+1; i++){
    fscanf(weightFile,"{");
    for (j=0; j<NN_n_hid; j++)
      fscanf(weightFile,"%f ",&(dw1[i][j]));
    fscanf(weightFile,"} ");
  }
  fscanf(weightFile,"}\n");

  fscanf(weightFile,"{");
  for (i=0; i<NN_n_hid+1; i++){
    fscanf(weightFile,"{");
    for (j=0; j<NN_n_out; j++)
      fscanf(weightFile,"%f ",&(dw2[i][j]));
    fscanf(weightFile,"} ");
  }
  fscanf(weightFile,"}\n\n");


  fclose(weightFile);
}

#endif
