/* ======================================================================
   A simple implementation of a Backprop Net.
   Deals strictly with only one input, one hidden and one output layer.
   Inspired on the PDP programs.
   ====================================================================== */
#include <stdio.h>
#include <math.h>
#include "sim-basics.h"

#define MAX(A,B)  ((A>B)?(A):(B))  /* If A>B then A, else B */
#define SCOM	  argv[0]          /* Argument string for the called command */

char   *malloc();
double compute_error();            /* Computes squared error (2*E_p) */
double show_pat();

/* The structure definitions for connections and units               */

typedef struct A_CONN {
  struct A_UNIT* from_unit;	    /* pointer to unit from which connection */
				    /*  comes from */
  double	weight;             /* weight of connection */
  double	dweight;            /* delta weight */
  double	wed;                /* cummulative weight deltas */
  struct A_CONN* next;		    /* pointer to next connection (NULL terminated) */
} a_conn;

typedef struct A_UNIT {
  double 	net;                 /* net input to unit */
  double	act;                 /* activation of unit */
  double	err;                 /* error of unit */
  double	delta;               /* delta of unit */
  double	bias;                /* bias to unit */
  double	dbias;               /* delta bias to unit */
  double	bed;                 /* cummulative delta bias */
  a_conn*	conns;       	     /* pointer to first connection */
} a_unit;

a_unit 	*unit_list;		     /* pointer to array of units */
a_conn  *conn_list;		     /* pointer to array of connections */

int     INPunits=0;                     /* number of input units */
int     HIDunits=0;                     /* number of hidden units */
int     OUTunits=0;                     /* number of output units */
int     MAXunits=0;                     /* max of input, hidden and output */

/* macros to indicate start and ending of layers */
#define FirstINP	(unit_list)
#define FirstHID	(unit_list+INPunits)
#define FirstOUT	(unit_list+INPunits+HIDunits)
#define LastINP		FirstHID
#define LastHID		FirstOUT
#define LastOUT		(unit_list+INPunits+HIDunits+OUTunits)

/* learning parameters */
double  momentum=0.9;                   /* momentum */
double  epsilon= 0.5;                   /* learning rate for weights */
double  bepsilon=0.5;                   /* learning rate for bias */

int	MaxPats;			/* Maximum number of pattern allocated so far */
int     Npats=0;                        /* number of training patterns */
double  **pat;			        /* pointer to input/output pattern matrix */
#define START_OUT_PAT   INPunits        /* beginning of input patterns */
#define START_INP_PAT   0               /* beginning of output patterns */

/* scales for showing different values */
#define ACT_SCL(A)  nint((A)*100.0)
#define ERR_SCL(E)  ACT_SCL(E)
#define DTA_SCL(D)  nint((D)*1000.0)
#define WGT_SCL(W)  nint((W)*100.0)

/* derivative of logistic; d_act/d_net */
#define DERIV_log( A ) (A*(1.0 - A))

/*======================================================================*/
/*      This function computes the logistic, except for values at the
        extremes, which it just returns a close approximation to save
        time.                                                           */
/*======================================================================*/

double logistic (x)
     double x;
{
  double  exp();
  /* .99999988 is very close to the largest single precision value
     that is resolvably less than 1.0 -- jlm */
  if (x > 16.0)
    return(.99999988);
  else if (x < -16.0)
    return(.00000012);
  else
    return( 1.0 / (1.0 + exp( (-1.0) * x) ));
}

/*======================================================================*/
/*      This function computes the output of the specified network by
	propagating activation values forward.  The activation values
        are determined by using the logistic function.                  */
/*======================================================================*/

compute_output() 
{
  a_unit	*unit;
  a_conn	*conn;

  /* We skip the input units since their output = their input by
     convention.  For each unit, we find the net value by adding to
     the initial bias, the sum of weighted inputs into the unit.
     The output is found by taking the logistic of the net input.   */   


  for( unit=FirstHID; unit < LastOUT; ++unit ) {
    unit->net = unit->bias;
    for( conn=unit->conns; conn != NULL; conn=conn->next )
      unit->net += conn->weight * conn->from_unit->act;
    unit->act = logistic( unit->net );
  }
}
/*======================================================================*/
/*      This function computes the sum of the squared error over all
        the output units of the network.  This is equal to 2*E_p.  
        Then all the delta_j's (for both output and hidden units) are
        found for use in back-propagation.                              */
/*======================================================================*/     

double compute_error( PatNo )
     int PatNo;
{
  a_unit *unit, *to_unit;
  a_conn *to_conn;
  int    punit;
  double err, sqrerr;

  /*  For each of the output units, we find the difference 
      between the target pattern value and the current
      activation level of the node.  We then square the error and
      accumulate it to determine the global error (2*E_p).  We also
      find delta_j to be used in back-propogation.                 */

  sqrerr = 0.0;
  for( unit=FirstOUT,punit=START_OUT_PAT; 
       unit < LastOUT; 
       ++unit, ++punit ) {
    unit->err = pat[PatNo][punit] - unit->act;
    sqrerr += unit->err * unit->err;
    unit->delta = unit->err * DERIV_log( unit->act );
  }

  /*  Now we find the delta_j's for all the hidden units (which
      need the results above from the output units).             */

  for( unit=FirstHID; unit < LastHID; ++unit ) unit->err=0.0;

  for( to_unit=FirstOUT; to_unit < LastOUT; ++to_unit )
    for( to_conn=to_unit->conns; to_conn != NULL; to_conn=to_conn->next ) 
      to_conn->from_unit->err += to_unit->delta * to_conn->weight;

  for( unit=FirstHID; unit < LastHID; ++unit ) 
    unit->delta = unit->err * DERIV_log( unit->act );

  return sqrerr;
}

/*======================================================================*/
/*   This function is used to accumulate weight error derivatives over
     a series of presentations of different patterns.  This accumulated
     error is used to change the weights (later).
/*======================================================================*/     

add_weds()
{
  a_unit *unit;
  a_conn *conn;

  for( unit = FirstHID; unit < LastOUT; ++unit ) {
    for( conn=unit->conns; conn != NULL; conn=conn->next )
      conn->wed += unit->delta * conn->from_unit->act;
    unit->bed += unit->delta;
  }
}

/*======================================================================*/
/*    This function changes the connection weights.  It basically
      calculates:
 	  dweight(T+1) = epsilon * wed + momentum * dweight(T)
      where "wed" is the sum of the delta * activation terms
      accumulated in the add_weds function.                             */
/*======================================================================*/     

change_weights()
{
  a_unit *unit;
  a_conn *conn;

  for( unit = FirstHID; unit < LastOUT; ++unit ) {
    for( conn=unit->conns; conn != NULL; conn=conn->next ) {
      conn->dweight = epsilon * conn->wed + momentum * conn->dweight;
      conn->weight += conn->dweight;
      conn->wed = 0.0;			/* Reset wed's for next series */
    }
    unit->dbias = bepsilon*unit->bed + momentum*unit->dbias;
    unit->bias += unit->dbias;
    unit->bed = 0.0;
  }
}

/*======================================================================*/
#define BMakeNet "MFILE\n\
  create the specified network structures and read the patterns in MFILE"
/*======================================================================*/     

#define MFILE argv[1]

MakeNet( argc, argv )
     int argc; char *argv[];
{
  a_unit  *unit, *funit;
  a_conn  *conn;
  int     kf, unitc, npat, i;
  double  val, **oldpat;
  FILE    *file;

  /*  If there is no file which matches the specified name, abort. */

  if( NULL == (file=fopen( MFILE, "r" )))
     nerror2( return OK, "%s: Can't find file '%s'\n", SCOM, MFILE );
     
  /* If there is a consistency problem in the specification, abort. 
     A consistency problem exists if there are not exactly three
     numbers indicating the number of input, hidden, and output
     units on the first line.  If any of the numbers are negative,
     this is also an error.  If an error is
     encountered, the file is closed and the network parameters
     are reset to their starting values.                            */

  if( 3 != fscanf( file," %d %d %d", &INPunits, &HIDunits, &OUTunits) 
      || INPunits <= 0 
      || HIDunits <  0
      || OUTunits <= 0 ) {
    fclose(file);
    MAXunits=INPunits=HIDunits=OUTunits=Npats=0;
    nerror2( return OK, "%s: unexpected values. corrupt pattern file '%s'?\n", 
	     SCOM, MFILE );
  }
 
  /* At this point we presumably have a consistent set of data values. */

  printf("INPunits %d HIDunits %d OUTunits %d\n", INPunits, HIDunits,OUTunits);

  /* Let's read in the data specifying the desired I/O behavior for the
     network.  The file contains a set of patterns, each of which gives
     the input values and the desired output values, for a particular case. */

  printf( "reading patterns\n" );
  /* initial allocation for one array of pointers for i/o patterns */
  MaxPats = (INPunits+OUTunits)*4;
  pat = (double **) malloc((unsigned)(sizeof(double *) * MaxPats ));
  if( pat==NULL )
    nerror1( return OK, "%s: problem allocating space for patterns (1)\n", SCOM );

  npat = -1; 
  unitc = INPunits+OUTunits;
  while( EOF != (kf = fscanf( file, " %lf ", &val )) && kf==1 ){
    if( unitc == INPunits ) printf( " ==> " );
    printf( "%5d", ACT_SCL(val));
    if( unitc+1 == INPunits+OUTunits ) printf( " : pat %d\n", npat );
    if( unitc >= INPunits+OUTunits ){   		/* start a new io pattern */
      ++npat;
      unitc = 0; 
      if( npat >= MaxPats ) {  				/* need to allocated more ? */
	oldpat = pat;
	/* twice  as much as before */
	pat = (double **) malloc((unsigned)(sizeof(double *) * 2*MaxPats )); 
	/* move old patterns into new area */
	for( i=0; i < MaxPats; ++i ) pat[i] = oldpat[i]; 
	MaxPats *= 2;
      }
      /* allocate space for a new i/o pattern */
      pat[npat] = (double *) malloc((unsigned) ((sizeof(double) * (INPunits+OUTunits) )));
      /* check for allocation problems */
      if( pat==NULL || pat[npat]==NULL )
	nerror1( return OK, "%s: problem allocating space for patterns (2)\n", SCOM );
    }
    pat[npat][unitc] = val;
    ++unitc;
  }
  fclose( file );

  /* If after reading in all the I/O specifications, it is found that some
     numbers are left over, then abort.                                  */

  if( unitc % (INPunits+OUTunits) != 0 || kf != EOF ) {
    MAXunits=INPunits=HIDunits=OUTunits=Npats = 0;
    nerror2( return OK, "%s: corrupt pattern file '%s'?\n", SCOM, MFILE );
  }

  /* Set the network parameters for limiting values.                     */

  MAXunits = MAX( INPunits, HIDunits );
  MAXunits = MAX( MAXunits, OUTunits );
  Npats = npat+1;

  /* allocate enough space for the units and the connections */
  unit_list = (a_unit *) malloc((unsigned)(sizeof(a_unit)*(INPunits+HIDunits+OUTunits )));
  conn_list = (a_conn *) malloc((unsigned)(sizeof(a_conn)*
				(HIDunits*INPunits + OUTunits*HIDunits )));
  if( unit_list == NULL || conn_list == NULL ) 
    nerror1( return OK, "%s: problem allocating space for network\n", SCOM );

  /* Set up the unit and connection pointers */
  for( unit=FirstINP; unit < LastINP; ++unit ) 
    unit->conns = NULL;				/* no connections to input units */

  /* connections from INPut units to HIDden units */
  conn = conn_list;
  for( unit=FirstHID; unit < LastHID; ++unit ) {
    unit->conns = conn;
    for( funit=FirstINP; funit < LastINP; ++funit, ++conn ) {
      conn->next = (conn+1);
      conn->from_unit = funit;
    }
    (conn-1)->next = NULL;
  }

  /* connections from HIDden units to OUTput units */
  for( unit=FirstOUT; unit < LastOUT; ++unit ) {
    unit->conns = conn;
    for( funit=FirstHID; funit < LastHID; ++funit, ++conn ) {
      conn->next = (conn+1);
      conn->from_unit = funit;
    }
    (conn-1)->next = NULL;
  }
  return OK;
}

/*======================================================================*/
#define BShowPat "PAT_NO\n\
  Show the network state vis-a-vis a given I/O pattern"
/*======================================================================*/     

#define SHOW_LGND " %6s%6s%6s%6s%6s%6s%6s%6s\n",\
" Iact"," Hact"," Hdta"," Oact"," Odta","Targt"," Oerr","err^2"
#define PAT_NO 1

ShowPat( argc, argv )
     int argc; char *argv[];
{
  int 	 PatNo;

  /* The given pattern must be in the range as specified in the input
     file.                                                            */

  SCAN_POSP_INT_ARGV(return OK, PAT_NO, PatNo );
  if( PatNo >= Npats )
    nerror2( return OK, "%s: '%d' invalid pattern no\n", SCOM, PatNo );
  
  /* Display a table with headers listing info about the network state
     in relation to the pattern number.                               */

  printf( SHOW_LGND );
  show_pat( PatNo );
  return OK;
}

/*======================================================================*/
/*   This function runs a pattern through the network and shows the 
     state of units in the network; i.e., the activation, the target
     and the error values.
*/
/*======================================================================*/     

#define ACT_FMT(A)  PRT_FMT(ACT_SCL(A))
#define PAT_FMT(A)  PRT_FMT(ACT_SCL(A))
#define ERR_FMT(E)  PRT_FMT(ERR_SCL(E))
#define DTA_FMT(D)  PRT_FMT(DTA_SCL(D))
#define PRT_FMT(P)  printf("%6d",P)
#define SPC_FMT     printf("%6s"," " )

double show_pat( PatNo )
     int PatNo;
{
  a_unit *unit;
  int	 unitc, punit;
  double sqrerr=0;

  /* First, set the input units to the values specified in the selected
     I/O pattern.                                                       */

  for( unit=FirstINP, punit=START_INP_PAT; 
       unit < LastINP; 
       ++unit, ++punit) {
    unit->act = pat[PatNo][punit];
  }

  /* Find the output values corresponding to the given inputs:          */

  compute_output();

  /* Find the global error - i.e. the difference between the desired
     outputs and the actual outputs.                                    */

  sqrerr = compute_error( PatNo );

  /* Now a table is printed to give the details of activations, deltas
     and errors for each unit.  This enables a detailed analysis of the
     network state.                                                     */

  for( unitc=0, unit=FirstINP; 
       unitc<MAXunits; 
       ++unitc, ++unit ){
    printf( "%c", unitc==0 ? '-' : ' ' );
    if( unitc<INPunits ) ACT_FMT( unit->act );
    else		 SPC_FMT;
    if( unitc<HIDunits ) ACT_FMT( (unit+INPunits)->act );
    else		 SPC_FMT;
    if( unitc<HIDunits ) DTA_FMT( (unit+INPunits)->delta );
    else		 SPC_FMT;
    if( unitc<OUTunits ) ACT_FMT( (unit+HIDunits+INPunits)->act );
    else		 SPC_FMT;
    if( unitc<OUTunits ) DTA_FMT( (unit+HIDunits+INPunits)->delta );
    else		 SPC_FMT;
    if( unitc<OUTunits ) PAT_FMT( pat[PatNo][unitc+START_OUT_PAT] );
    else		 SPC_FMT;
    if( unitc<OUTunits ) ERR_FMT( (unit+HIDunits+INPunits)->err );
    else		 SPC_FMT;
    if( unitc == 0 )     ERR_FMT( sqrerr );
    printf( "\n" );
  }
  return sqrerr;
}

/*======================================================================*/
#define BShowPats "Show the network state for all defined I/O patterns"
/*======================================================================*/     

ShowPats()
{
  int	pat;
  double ssqrerr;

  printf( SHOW_LGND );
  for( pat=0, ssqrerr=0.0; pat < Npats; ++pat ) {
    ssqrerr += show_pat( pat );
    printf( "\n" );
  }
  printf( "sum of sqr errs = %6d\n", ERR_SCL(ssqrerr) );
}

/*======================================================================*/
#define BReadWeights "MFILE\n\
        Read the connection weights from MFILE."

/* To understand the operation of this routine you need to know the 
   format of the connection weights file, which is:

   <I_1 H_1> <I_2 H_1> ... <I_In H_1> <Bias H_1>
   <I_1 H_2> <I_2 H_2> ... <I_In H_2> <Bias H_2>
   ...
   <I_1 H_Hn> <I_2 H_Hn> ... <I_In H_Hn> <Bias H_Hn>
   <H_1 O_1> <H_2 O_1> ... <H_Hn O_1> <Bias O_1>
   <H_1 O_2> <H_2 O_2> ... <H_Hn O_2> <Bias O_2>
   ...
   <H_1 O_On> <H_2 O_On> ... <H_Hn O_On> <Bias O_On>
   
   Where 
	<a b> indicates weight from unit a to unit b.
	<Bias a> is bias weight to unit a.
	I_k is the k'th input unit;	In number of input units.
	H_k is the k'th hidden unit	Hn number of hidden units.
	O_k is the k'th output unit	On number of output units.      */

/*======================================================================*/

#define MFILE argv[1]
ReadWeights( argc, argv )
     int argc; char *argv[];
{
  a_unit *unit;
  a_conn *conn;

  int kf;
  FILE *file;

  /*  If there is no file which matches the specified name, abort. */

  if( NULL == (file=fopen( MFILE, "r" )))
     nerror2( return OK, "%s: Can't find file '%s'\n", SCOM, MFILE );
     
  /* Get the connection weights.  The first set of weights are from
     inputs to hidden units and the second are from hidden to 
     output units.                                                  */

  printf( "reading weights ...\n" );
  for( unit = FirstHID; unit < LastOUT; ++unit ) {
    printf( "to %d: ", unit-unit_list );
    for( conn=unit->conns; conn!=NULL; conn=conn->next ) {
      if( 1 != fscanf( file, " %lf ", &(conn->weight) )) break;
      printf( "%6d", WGT_SCL(conn->weight) );
    }
    if( 1 != fscanf( file, " %lf ", &(unit->bias) )) break;
    printf( "  bias= %6d\n", WGT_SCL(unit->bias) );
  }
  if( conn != NULL || unit < LastOUT ||  EOF!=fscanf(file, " %d ", kf ) )
    nerror2( ;, "%s: problem reading weights '%s'?\n", SCOM, MFILE );
  fclose( file );
  return OK;
}

/*======================================================================*/
#define BShowWeights "Show the connection weights of the network." 
/*======================================================================*/

ShowWeights( argc, argv )
     int argc; char *argv[];
{
  a_unit *unit;
  a_conn *conn;

  /* input to hidden unit connections */
  printf( "  from:" );					
  for( conn = FirstHID->conns; conn!=NULL; conn=conn->next ) 
    printf( "%6d", (conn->from_unit - unit_list ) );
  printf( "\n" );
  for( unit = FirstHID; unit < LastHID; ++unit ) {
    printf( "to %2d: ", (unit-unit_list) );
    for( conn=unit->conns; conn!=NULL; conn=conn->next ) {
      printf( "%6d", WGT_SCL(conn->weight) );
    }
    printf( "  bias= %6d\n", WGT_SCL(unit->bias) );
  }
  /* hidden to output unit connections */
  printf( "\n  from:" );
  for( conn = FirstOUT->conns; conn!=NULL; conn=conn->next ) 
    printf( "%6d", (conn->from_unit - unit_list ) );
  printf( "\n" );
  for( unit = FirstOUT; unit < LastOUT; ++unit ) {
    printf( "to %2d: ", (unit-unit_list) );
    for( conn=unit->conns; conn!=NULL; conn=conn->next ) {
      printf( "%6d", WGT_SCL(conn->weight) );
    }
    printf( "  bias= %6d\n", WGT_SCL(unit->bias) );
  }

  return OK;
}

/*======================================================================*/
#define BStrain "S_TIMES\n\
    Train the network by performing the back-propagation cycle S_TIMES\n\
    [S_TIMES defaults to 30]"
/*======================================================================*/

#define S_TIMES 1

Strain( argc, argv )
  int argc; char *argv[];
 {
  a_unit *unit;
  int PatNo, time, punit, s_times;
  double sqrerr;

  s_times = 30;
  if( argc > S_TIMES )
    SCAN_POSP_INT_ARGV(return OK, S_TIMES, s_times );

  /* For each of s_times (default 30),
       Compute the values of the network outputs and find the resulting
       output error, for all training patterns.  Use the errors to back-
       propagate computing weight errors for hidden units.  Use this
       aggregated error information to change the connection weights.  
       Printout the squared error every five iterations.               */

  for( time=0; time < s_times; ++time ){
    for( PatNo=0, sqrerr=0.0; PatNo < Npats; ++PatNo ) {
      /* set the io pattern */
      for( unit=FirstINP,punit=START_INP_PAT; unit<LastINP; ++unit, ++punit)
	unit->act = pat[PatNo][punit];
      compute_output();
      sqrerr += compute_error( PatNo );
      add_weds();
    }
    change_weights();
    if( time % 5 == 0 ) printf( "%d=%d", time, ERR_SCL(sqrerr));
    printf( "." );
  }
  printf( "\n" );
  return OK;
}

/* ======================================================================
   install the above defined commands (with their abreviations) and call
   the interpreter loop
   ======================================================================*/

main( )
{
  ba_install_cmd( "MakeNet", 	MakeNet,    BMakeNet );
  ba_install_cmd( "net", 	MakeNet,   "Abbreviation for 'MakeNet'." );
  ba_install_cmd( "ShowPat", 	ShowPat,    BShowPat );
  ba_install_cmd( "sp", 	ShowPat,    "Abbreviation for 'ShowPat'." );
  ba_install_cmd( "ShowPats", 	ShowPats,   BShowPats );
  ba_install_cmd( "sps", 	ShowPats,   "Abbreviation for 'ShowPats'." );
  ba_install_cmd( "ReadWeights",ReadWeights,BReadWeights );
  ba_install_cmd( "readw",	ReadWeights,"Abbreviation for 'ReadWeights'.");
  ba_install_cmd( "ShowWeights",ShowWeights,BShowWeights );
  ba_install_cmd( "sw",		ShowWeights,"Abbreviation for 'ShowWeights'.");
  ba_install_cmd( "Strain", 	Strain,	    BStrain );
  ba_install_cmd( "st", 	Strain,	    "Abbreviation for 'Strain'.");

  ba_start_msg( "Simple BackProp Net Simulator." );
  ba_prompt( "SBp> " );
  ba_interpreter_loop( stdin );
}
