/* ======================================================================
   A simple implementation of a Backprop Net.
   Deals strictly with only one input, one hidden and one output layer.
   Inspired on the PDP programs.
   Unlike simple-bp.c, this file represents networks using matrices
   rather than structures.  This simplifies the code structure somewhat.
   ====================================================================== */
#include <stdio.h>
#include <math.h>
#include "sim-basics.h"

#define MAX(A,B)  ((A>B)?(A):(B))  /* If A>B then A, else B */
#define SCOM	  argv[0]          /* Argument string for the called command */

double compute_error();            /* Computes squared error (2*E_p) */
double show_pat();

#define MAX_UNITS 500                   /* max number of units */
int     INPunits=0;                     /* number of input units */
int     HIDunits=0;                     /* number of hidden units */
int     OUTunits=0;                     /* number of output units */
int     MAXunits=0;                     /* max of input, hidden and output */

/* macros to access start and ending of layers */
#define FirstINP        0
#define FirstHID        INPunits
#define FirstOUT        (INPunits+HIDunits)
#define LastINP         INPunits
#define LastHID         FirstOUT
#define LastOUT         (FirstOUT+OUTunits)
#define FirstTo(U)      ((U<FirstHID)?0:(U<FirstOUT)?FirstINP:FirstHID)
#define LastTo(U)       ((U<FirstHID)?0:(U<FirstOUT)?LastINP :LastHID )

/* unit definitions */
double  net_input[MAX_UNITS];           /* net input to unit */
double  activation[MAX_UNITS];          /* activation of unit */
double  error[MAX_UNITS];               /* error of unit */
double  delta[MAX_UNITS];               /* delta of unit */

double  bias[MAX_UNITS];                /* bias to unit */
double  dbias[MAX_UNITS];               /* delta bias to unit */
double  bed[MAX_UNITS];                 /* cummulative bias deltas */

/* connections */
double  weight[MAX_UNITS][MAX_UNITS];   /* weight[to][from] */
double  dweight[MAX_UNITS][MAX_UNITS];  /* delta weight */
double  wed[MAX_UNITS][MAX_UNITS];      /* cummulative weight deltas */

/* learning parameters */
double  momentum=0.9;                   /* momentum */
double  epsilon= 0.5;                   /* learning rate for weights */
double  bepsilon=0.5;                   /* learning rate for bias */

#define MAX_PATS 50                     /* maximum number of patterns stored */
int     Npats=0;                        /* number of training patterns */
double  pat[MAX_PATS][MAX_UNITS];       /* patterns pat[pat_no][values] */
#define START_OUT_PAT   INPunits        /* beginning of input patterns */
#define START_INP_PAT   0               /* beginning of output patterns */

/* derivative of logistic; d_act/d_net */
#define DERIV_log( A ) (A*(1.0 - A))

/* scales for showing different values */
#define ACT_SCL(A)  nint((A)*100.0)
#define ERR_SCL(E)  ACT_SCL(E)
#define DTA_SCL(D)  nint((D)*1000.0)
#define WGT_SCL(W)  nint((W)*100.0)


/*======================================================================*/
/*      This function computes the logistic, except for values at the
        extremes, which it just returns a close approximation to save
        time.                                                           */
/*======================================================================*/

double logistic (x)
     double x;
{
  double  exp();
  /* .99999988 is very close to the largest single precise value
     that is resolvably less than 1.0 -- jlm */
  if (x > 16.0)
    return(.99999988);
  else if (x < -16.0)
    return(.00000012);
  else
    return( 1.0 / (1.0 + exp( (-1.0) * x) ));
}

/*======================================================================*/
/*      This function computes the output of the specified network by
	propagating activation values forward.  The activation values
        are determined by using the logistic function.                  */
/*======================================================================*/

compute_output() 
{
  int 	 unit, from, last_to;
  double net;
  
  /* We skip the input units since their output = their input by
     convention.  For each unit, we find the net value by adding to
     the initial bias, the sum of weighted inputs into the unit.
     The output is found by taking the logistic of the net input.   */

  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    net = bias[unit];
    last_to = LastTo(unit);
    for( from=FirstTo(unit); from < last_to; ++from )
      net += weight[unit][from] * activation[from];
    net_input[unit] = net;
    activation[unit] = logistic( net );
  }
}

/*======================================================================*/
/*      This function computes the sum of the squared error over all
        the output units of the network.  This is equal to 2*E_p.  
        Then all the delta_j's (for both output and hidden units) are
        found for use in back-propagation.                              */
/*======================================================================*/     

double compute_error( PatNo )
     int PatNo;
{
  int    punit, unit, from;
  double err, sqrerr;

  /*  For each of the output units, we find the difference 
      between the target pattern value and the current
      activation level of the node.  We then square the error and
      accumulate it to determine the global error (2*E_p).  We also
      find delta_j to be used in back-propogation.                 */

  sqrerr = 0.0;
  for( unit=FirstOUT,punit=START_OUT_PAT; unit < LastOUT; ++unit, ++punit ) {
    error[unit] = pat[PatNo][punit] - activation[unit];
    sqrerr += error[unit] * error[unit];
    delta[unit] = error[unit] * DERIV_log( activation[unit] );
  }

  /*  Now we find the delta_j's for all the hidden units (which
      need the results above from the output units).             */

  for( unit=FirstHID; unit < LastHID; ++unit ) {
    for( from=FirstOUT, err=0.0; from < LastOUT; ++from )
      err += delta[from] * weight[from][unit];
    error[unit] = err;
    delta[unit] = error[unit] * DERIV_log( activation[unit] );
  }
  return sqrerr;
}

/*======================================================================*/
/*   This function is used to accumulate weight error derivatives over
     a series of presentations of different patterns.  This accumulated
     error is used to change the weights (later).
/*======================================================================*/     

add_weds()
{
  int	unit, from, last_to;

  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    last_to = LastTo(unit);
    for( from=FirstTo(unit); from < last_to; ++from )
      wed[unit][from] +=delta[unit] * activation[from];

    bed[unit] += delta[unit];
  }
}

/*======================================================================*/
/*   This function simply resets the "wed" values to 0 to initialize
     a series of pattern applications to the network.                   */
/*======================================================================*/     

clear_weds()  
{
  int 	unit, from, last_to;

  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    last_to = LastTo(unit);
    for( from=FirstTo(unit); from < last_to; ++from )
      wed[unit][from] = 0.0;

    bed[unit] = 0.0;
  }
}

/*======================================================================*/
/*    This function changes the connection weights.  It basically
      calculates:
 	  dweight(T+1) = epsilon * wed + momentum * dweight(T)
      where "wed" is the sum of the delta * activation terms
      accumulated in the add_weds function.                             */
/*======================================================================*/     

change_weights()
{
  int 	unit, from, last_to;

  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    last_to = LastTo(unit);
    for( from=FirstTo(unit); from < last_to; ++from ) {
      dweight[unit][from] = epsilon*wed[unit][from] +
			     momentum*dweight[unit][from];
      weight[unit][from] += dweight[unit][from];
      wed[unit][from] = 0.0;
    }
    dbias[unit] = bepsilon*bed[unit] + momentum*dbias[unit];
    bias[unit] += dbias[unit];
    bed[unit] = 0.0;
  }
}

/*======================================================================*/
#define BReadPats "MFILE\n\
  read the patterns in MFILE and create the specified network structures."
/*======================================================================*/     

#define MFILE argv[1]
ReadPats( argc, argv )
     int argc; char *argv[];
{
  int kf, i, j, unit, npat, Tij;
  FILE *file;

  /*  If there is no file which matches the specified name, abort. */

  if( NULL == (file=fopen( MFILE, "r" )))
     nerror2( return OK, "%s: Can't find file '%s'\n", SCOM, MFILE );

  /* If there is a consistency problem in the specification, abort. 
     A consistency problem exists if there are not exactly three
     numbers indicating the number of input, hidden, and output
     units on the first line.  If any of the numbers are negative,
     this is also an error.  If the total number of units exceeds
     the maximum allowed for storage, we also have a problem.  
     Finally, if the number of implied connections exceeds the
     maximum storage we also exit (stage left).  If an error is
     encountered, the file is closed and the network parameters
     are reset to their starting values.                            */
     
  if( 3 != fscanf( file," %d %d %d", &INPunits, &HIDunits, &OUTunits) 
      || INPunits<=0 
      || HIDunits<0 
      || OUTunits<=0 
      || INPunits+HIDunits+OUTunits >= MAX_UNITS ) {
    fclose(file);
    MAXunits=INPunits=HIDunits=OUTunits=Npats=0;
    nerror2( return OK, "%s: unexpected values. corrupt pattern file '%s'?\n", 
	     SCOM, MFILE );
  }

  /* At this point we presumably have a consistent set of data values. */

  printf("INPunits %d HIDunits %d OUTunits %d\n", INPunits, HIDunits,OUTunits);

  /* Let's read in the data specifying the desired I/O behavior for the
     network.  The file contains a set of patterns, each of which gives
     the input values and the desired output values, for a particular case. */

  printf( "reading patterns\n" );
  npat = unit = 0;
  while( EOF != ( kf = fscanf( file, 
                               " %lf ", 
                               &(pat[npat][unit]))) 
         && kf==1 ){
    printf( "%5d", ACT_SCL(pat[npat][unit]) );
    ++ unit;

    if( unit % (INPunits+OUTunits) == 0 ) { 
      printf( " : pat %d\n", npat );
      unit = 0; 
      ++npat;

      if( npat >= MAX_PATS ) 
	nerror1( break, "%s: Too many patterns. Rest ignored\n", argv[0] );
    }
    else if( unit % INPunits == 0 ) printf( " ==> " );
  }
  fclose( file );

  /* After reading in all the I/O specifications, it is found that some
     numbers are left over, then abort.                                  */

  if( unit % (INPunits+OUTunits) != 0 || kf != EOF ) {
    MAXunits=INPunits=HIDunits=OUTunits=Npats = 0;
    nerror2( return OK, "%s: corrupt pattern file '%s'?\n", SCOM, MFILE );
  }

  /* Set the network parameters for limiting values.                     */

  MAXunits = MAX( INPunits, HIDunits );
  MAXunits = MAX( MAXunits, OUTunits );
  Npats = npat;
  return OK;
}

/*======================================================================*/
#define BShowPat "PAT_NO\n\
  Show the network state vis-a-vis a given I/O pattern"
/*======================================================================*/     

#define SHOW_LGND " %6s%6s%6s%6s%6s%6s%6s%6s\n",\
" Iact"," Hact"," Hdta"," Oact"," Odta","Targt"," Oerr","err^2"
#define PAT_NO 1

ShowPat( argc, argv )
     int argc; char *argv[];
{
  int 	 PatNo;

  /* The given pattern must be in the range as specified in the input
     file.                                                            */

  SCAN_POSP_INT_ARGV(return OK, PAT_NO, PatNo );
  if( PatNo >= Npats )
    nerror2( return OK, "%s: '%d' invalid pattern no\n", SCOM, PatNo );
  
  /* Display a table with headers listing info about the network state
     in relation to the pattern number.                               */

  printf( SHOW_LGND );
  show_pat( PatNo );
  return OK;
}

/*======================================================================*/
/*   This function lists network parameters vis-a-vis the selected
     pattern.                                                           */
/*======================================================================*/     

#define ACT_FMT(A)  PRT_FMT(ACT_SCL(A))
#define PAT_FMT(A)  PRT_FMT(ACT_SCL(A))
#define ERR_FMT(E)  PRT_FMT(ERR_SCL(E))
#define DTA_FMT(D)  PRT_FMT(DTA_SCL(D))
#define PRT_FMT(P)  printf("%6d",P)
#define SPC_FMT     printf("%6s"," " )

double show_pat( PatNo )
     int PatNo;
{
  int	 unit, punit;
  double sqrerr=0;

  /* First, set the input units to the values specified in the selected
     I/O pattern.                                                       */

  for( unit=FirstINP, punit=START_INP_PAT; unit < LastINP; ++unit, ++punit) {
    activation[unit] = pat[PatNo][punit];
  }

  /* Find the output values corresponding to the given inputs:          */

  compute_output();

  /* Find the global error - i.e. the difference between the desired
     outputs and the actual outputs.                                    */

  sqrerr = compute_error( PatNo );

  /* Now a table is printed to give the details of activations, deltas
     and errors for each unit.  This enables a detailed analysis of the
     network state.                                                     */

  for( unit=0; unit<MAXunits; ++unit ){
    printf( "%c", unit==0 ? '-' : ' ' );
    if( unit< INPunits ) ACT_FMT( activation[unit] );
    else		 SPC_FMT;
    if( unit< HIDunits ) ACT_FMT( activation[unit+FirstHID] );
    else		 SPC_FMT;
    if( unit< HIDunits ) DTA_FMT( delta[unit+FirstHID] );
    else		 SPC_FMT;
    if( unit< OUTunits ) ACT_FMT( activation[unit+FirstOUT] );
    else		 SPC_FMT;
    if( unit< OUTunits ) DTA_FMT( delta[unit+FirstOUT] );
    else		 SPC_FMT;
    if( unit< OUTunits ) PAT_FMT( pat[PatNo][unit+START_OUT_PAT] );
    else		 SPC_FMT;
    if( unit< OUTunits ) ERR_FMT( error[unit+FirstOUT] );
    else		 SPC_FMT;
    if( unit == 0 )      ERR_FMT( sqrerr );
    printf( "\n" );
  }
  return sqrerr;
}

/*======================================================================*/
#define BShowPats "Show the network state for all given I/O patterns"
/*======================================================================*/     

ShowPats()
{
  int	pat;
  double ssqrerr;

  printf( SHOW_LGND );
  for( pat=0, ssqrerr=0.0; pat < Npats; ++pat ) {
    ssqrerr += show_pat( pat );
    printf( "\n" );
  }
  printf( "sum of sqr errs = %6d\n", ERR_SCL(ssqrerr) );
}

/*======================================================================*/
#define BReadWeights "MFILE\n\
        Read the connection weights from MFILE."

/* To understand the operation of this routine you need to know the 
   format of the connection weights file, which is:

   <I_1 H_1> <I_2 H_1> ... <I_In H_1> <Bias H_1>
   <I_1 H_2> <I_2 H_2> ... <I_In H_2> <Bias H_2>
   ...
   <I_1 H_Hn> <I_2 H_Hn> ... <I_In H_Hn> <Bias H_Hn>
   <H_1 O_1> <H_2 O_1> ... <H_Hn O_1> <Bias O_1>
   <H_1 O_2> <H_2 O_2> ... <H_Hn O_2> <Bias O_2>
   ...
   <H_1 O_On> <H_2 O_On> ... <H_Hn O_On> <Bias O_On>
   
   Where 
	<a b> indicates weight from unit a to unit b.
	<Bias a> is bias weight to unit a.
	I_k is the k'th input unit;	In number of input units.
	H_k is the k'th hidden unit	Hn number of hidden units.
	O_k is the k'th output unit	On number of output units.      */

/*======================================================================*/

#define MFILE argv[1]

ReadWeights( argc, argv )
     int argc; char *argv[];
{
  int kf, unit, npat, last_to, from;
  FILE *file;

  /*  If there is no file which matches the specified name, abort. */

  if( NULL == (file=fopen( MFILE, "r" )))
     nerror2( return OK, "%s: Can't find file '%s'\n", SCOM, MFILE );
     
  /* Get the connection weights.  The first set of weights are from
     inputs to hidden units and the second are from hidden to 
     output units.                                                  */

  printf( "reading weights ...\n" );
  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    printf( "to %2d: ", unit );
    last_to = LastTo(unit);
    for( from=FirstTo(unit); from < last_to; ++from ) {
      fscanf( file, " %lf ", &(weight[unit][from]));
      printf( "%6d", WGT_SCL(weight[unit][from]) );
    }
    fscanf( file, " %lf ", &(bias[unit]));
    printf( "  bias= %6d\n", WGT_SCL(bias[unit]) );
  }
  return OK;
}

/*======================================================================*/
#define BShowWeights "Show the connection weights of the network." 
/*======================================================================*/

ShowWeights( argc, argv )
     int argc; char *argv[];
{
  int unit, last_to, from;

  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    printf( "to %2d: ", unit );
    last_to = LastTo(unit);
    for( from=FirstTo(unit); from < last_to; ++from ) {
      printf( "%6d", WGT_SCL(weight[unit][from]) );
    }
    printf( "  bias= %6d\n", WGT_SCL(bias[unit]) );
  }
  return OK;
}

/*======================================================================*/
#define BStrain "S_TIMES\n\
    Train the network by performing the back-propagation cycle S_TIMES\n\
    [S_TIMES defaults to 30]"
/*======================================================================*/

#define S_TIMES 1

Strain( argc, argv )
     int argc; char *argv[];
{
  int PatNo, time, unit, punit, s_times;
  double sqrerr;

  s_times = 30;
  if( argc > S_TIMES )
    SCAN_POSP_INT_ARGV(return OK, S_TIMES, s_times );

  /* For each of s_times (default 30),
       Compute the values of the network outputs and find the resulting
       output error, for all training patterns.  Use the errors to back-
       propagate computing weight errors for hidden units.  Use this
       aggregated error information to change the connection weights.  
       Printout the squared error every five iterations.               */

  for( time=0; time < s_times; ++time ){
    for( PatNo=0, sqrerr=0.0; PatNo < Npats; ++PatNo ) {
      for( unit=FirstINP,punit=START_INP_PAT; unit<LastINP; ++unit, ++punit)
	activation[unit] = pat[PatNo][punit];
      compute_output();
      sqrerr += compute_error( PatNo );
      add_weds();
    }
    change_weights();
    if( time % 5 == 0 ) printf( "%d=%d", time, ERR_SCL(sqrerr));
    printf( "." );
  }
  printf( "\n" );
}

/* ======================================================================
   install the above defined commands (with their abreviations) and call
   the interpreter loop
   ======================================================================*/

main( )
{
  ba_install_cmd( "ReadPats", 	ReadPats,   BReadPats );
  ba_install_cmd( "read", 	ReadPats,   "Abbreviation for 'ReadPats'." );
  ba_install_cmd( "ShowPat", 	ShowPat,    BShowPat );
  ba_install_cmd( "sp", 	ShowPat,    "Abbreviation for 'ShowPat'." );
  ba_install_cmd( "ShowPats", 	ShowPats,   BShowPats );
  ba_install_cmd( "sps", 	ShowPats,   "Abbreviation for 'ShowPats'." );
  ba_install_cmd( "ReadWeights",ReadWeights,BReadWeights );
  ba_install_cmd( "readw",	ReadWeights,"Abbreviation for 'ReadWeights'.");
  ba_install_cmd( "ShowWeights",ShowWeights,BShowWeights );
  ba_install_cmd( "sw",		ShowWeights,"Abbreviation for 'ShowWeights'.");
  ba_install_cmd( "Strain", 	Strain,	    BStrain );
  ba_install_cmd( "st", 	Strain,	    "Abbreviation for 'Strain'.");

  ba_start_msg( "Simple BackProp Net Simulator." );
  ba_prompt( "SBp> " );
  ba_interpreter_loop( stdin );
}
