
/**********************************************************************
 *		   Copyright 1992 by Drew van Camp,
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for any  purpose  is hereby  granted without
 * fee, provided that the above copyright notice appears in all copies
 * and  that  both  the  copyright  notice and  this permission notice
 * appear in supporting  documentation, and that the  name of Drew van
 * Camp  not  be  used  in  advertising  or  publicity  pertaining  to
 * distribution  of  the  software  without  specific,  written  prior
 * permission.  Drew  van  Camp makes  no  representations  about  the
 * suitability  of this software for any  purpose.  It is provided "as
 * is" without express or implied warranty.
 *
 * DREW  VAN  CAMP  DISCLAIMS  ALL  WARRANTIES  WITH  REGARD  TO  THIS
 * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF  MERCHANTABILITY  AND
 * FITNESS, IN NO EVENT SHALL DREW VAN CAMP BE LIABLE FOR ANY SPECIAL,
 * INDIRECT  OR   CONSEQUENTIAL  DAMAGES  OR  ANY  DAMAGES  WHATSOEVER
 * RESULTING FROM LOSS OF  USE, DATA  OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR
 * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 **********************************************************************/

/********************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "sim.h"
/********************************************************************/

/********************************************************************/
static int	loop      (Net	*) ;
static void	initialize(Net  *) ;
static void	setFile   (Net  *, char	*) ;
static void	runNet    (Net  *, unsigned int, int) ;
static void	dump      (Net) ;
/********************************************************************/

/*********************************************************************
 *	Name:		main
 *	Description:	Main routine for a simple backprop simulator.
 *			It checks the args, initializes the random 
 *			number generator and the network, and (if
 *			specified) it sets the pattern file. It then
 *			goes into the command reading loop.
 *	Parameters:
 * 	  int	argc	- the number of command line arguments.
 *	  char	*argv[]	- the command line arguments.
 *	Return Value:
 *	  int	main - 0 (always)
 *********************************************************************/
int	main(int	argc,
	     char	*argv[]) {
  Net	net ;

  if (argc > 2) {
    fprintf(stderr, "usage: %s <pattern-file>\n", argv[0]) ;
    exit(1) ;
  }

  srand(0) ;
  initialize(&net) ;
  if (argc == 2)
    setFile(&net, argv[1]) ;

  loop(&net) ;

  return 0 ;
}
/********************************************************************/
/*********************************************************************
 *	Name:		loop
 *	Description:	Command line loop (very primitive). It 
 *			understands the commands:
 *			 init	     - (re)initialize the network.
 *			 file <file> - set the pattern file
 *			 train <n>   - train on <n> patterns
 *			 test <n>    - test on <n> patterns
 *			 dump	     - dump network data
 *			 quit	     - quit the loop
 *			Note that this is a *very* primitive loop.
 *			It does no real error checking.
 *	Parameters:
 *	  Net	*netPtr - pointer to the network we're gonna act on
 *	Return Value:
 *	  static int	loop - 0
 *********************************************************************/
#define prompt(fp)	fputs("sim-> ", fp)
/*********************************************************************/
static int	loop(Net	*netPtr) {
  char		line[128], word[128] ;
  unsigned int	numPatterns ;


  for (prompt(stdout) ; gets(line) ; prompt(stdout)) {
    sscanf(line, "%s", word) ;

    if (strcmp(word, "init") == 0) {
      initialize(netPtr) ;

    } else if (strcmp(word, "file") == 0) {
      sscanf(line, "%*s%s", word) ;
      setFile(netPtr, word) ;

    } else if (strcmp(word, "train") == 0) {
      sscanf(line, "%*s%d", &numPatterns) ;
      runNet(netPtr, numPatterns, 1) ;

    } else if (strcmp(word, "test") == 0) {
      sscanf(line, "%*s%d", &numPatterns) ;
      runNet(netPtr, numPatterns, 0) ;

    } else if (strcmp(word, "dump") == 0) {
      dump(*netPtr) ;

    } else if (strcmp(word, "quit") == 0) {
      break ;

    } else {
      fprintf(stderr, "Options:\n") ;
      fprintf(stderr, "\tinit\n") ;
      fprintf(stderr, "\tfile <pattern-file>\n") ;
      fprintf(stderr, "\ttest <num-patterns>\n") ;
      fprintf(stderr, "\ttrain <num-patterns>\n") ;
      fprintf(stderr, "\tdump\n") ;
      fprintf(stderr, "\tquit\n") ;
    }
  }
  puts("") ;
  return 0 ;
}
/********************************************************************/
#undef prompt
/********************************************************************/
/*********************************************************************
 *	Name:		initialize (and utilities)
 *	Description:	(Re)initialize a network by zeroing all unit 
 *			fields, and randomizing the connection weights.
 *	Parameters:
 *	  Net	*netPtr - the network to act on
 *	Return Value:
 *	  static void	initialize - NONE.
 *********************************************************************/
static void	randomizeWeights (int  numWeights, double *weight) ;
static void	zeroUnit (Unit *unitPtr) ;
/********************************************************************/
static void	initialize(Net	*netPtr) {
  int		idx ;

  netPtr->error = 0.0 ;

  for (idx = 0 ; idx < NUM_INPUT ; ++idx)
    zeroUnit(&netPtr->input[idx]) ;

  for (idx = 0 ; idx < NUM_HIDDEN ; ++idx) {
    zeroUnit(&netPtr->hidden[idx]) ;
    randomizeWeights(NUM_INPUT, netPtr->i2h[idx]) ;
    randomizeWeights(NUM_BIAS,  netPtr->b2h[idx]) ;
  }

  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx) {
    zeroUnit(&netPtr->output[idx]) ;
    randomizeWeights(NUM_HIDDEN, netPtr->h2o[idx]) ;
    randomizeWeights(NUM_BIAS,   netPtr->b2o[idx]) ;
  }

  for (idx = 0 ; idx < NUM_BIAS ; ++idx) {
    zeroUnit(&netPtr->bias[idx]) ;
    netPtr->bias[idx].output = 1.0 ;
  }
}
/********************************************************************/
static void	randomizeWeights(int	numWeights,
				 double	*weight) {
  int		idx ;
  for (idx = 0 ; idx < numWeights ; ++idx)
    weight[idx] = random() ;
}
/********************************************************************/
static void	zeroUnit(Unit	*unitPtr) {
  unitPtr->input  = 0.0 ;
  unitPtr->output = 0.0 ;
  unitPtr->target = 0.0 ;
  unitPtr->beta   = 0.0 ;
}
/********************************************************************/
/*********************************************************************
 *	Name:		setFile
 *	Description:	sets and opens the file where the patterns are.
 *			(also closes the old file if it exists).
 *	Parameters:
 *	  Net	*netPtr - the network to use.
 *	  char	*name   - the name of the pattern file to open
 *	Return Value:
 *	  static void	setFile - NONE
 *********************************************************************/
static void	setFile(Net	*netPtr,
			char	*name) {

  if (netPtr->fpPatterns != NULL)
    fclose(netPtr->fpPatterns) ;

  netPtr->fpPatterns = fopen(name, "r") ;

  if (netPtr->fpPatterns == NULL)
    fprintf(stderr, "Unable to open pattern file \"%s\"\n", name) ;
}
/********************************************************************/
/*********************************************************************
 *	Name:		runNet (and utilities)
 *	Description:	Train/Test a network on a given number of 
 *			patterns. Stops running and returns if it 
 *			can't get a pattern. The only differnce between
 *			training and testing is that training does a 
 *			backward pass through the net.
 *	Parameters:
 *	  Net		*netPtr - the network to train/test.
 *	  unsigned int	maxPatterns - the number of patterns to run
 *				through the network
 *	  int		updateWeights - if non-zero, then the weights
 *				should be updated after each example
 *				is presented (i.e train the net).
 *	Return Value:
 *	  static void	train - NONE
 *********************************************************************/
static int	getPattern(Net  *netPtr) ;
static void	forward   (Net	*netPtr) ;
static void	backward  (Net	*netPtr) ;
/********************************************************************/
static void	runNet(Net		*netPtr,
		       unsigned int	maxPatterns,
		       int		updateWeights) {
  int		idx ;

  netPtr->error = 0.0 ;
  for (idx = 0 ; idx < maxPatterns ; ++idx) {
    if (getPattern(netPtr) < 0)
      break ;
    
    forward(netPtr) ;
    if (updateWeights)
      backward(netPtr) ;
  }
}
/********************************************************************/
/*********************************************************************
 *	Name:		getPattern
 *	Description:	get a pattern from the pattern file, and clamp
 *			the input units' output values, and set the
 *			target units' targets. If we are at the end
 *			of the file, wrap to the beginning and start
 *			over. NO ERROR CHECKING.
 *	Parameters:
 *	  Net	*netPtr - the network to set the pattern for
 *	Return Value:
 *	  static int	getPattern - -1 if we can't get a pattern,
 *				      1 if we can.
 *********************************************************************/
static int	getPattern(Net	*netPtr) {
  int		idx ;
  double	value ;

  if (netPtr->fpPatterns == NULL)
    return -1 ;

  if (fscanf(netPtr->fpPatterns, "%lg", &value) == EOF) {
    rewind(netPtr->fpPatterns) ;
    fscanf(netPtr->fpPatterns, "%lg", &netPtr->input[0].output) ;
  } else {
    netPtr->input[0].output = value ;
  }

  for (idx = 1 ; idx < NUM_INPUT ; ++idx)
    fscanf(netPtr->fpPatterns, "%lg", &netPtr->input[idx].output) ;

  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx)
    fscanf(netPtr->fpPatterns, "%lg", &netPtr->output[idx].target) ;

  return 1 ;
}
/********************************************************************/
/*********************************************************************
 *	Name:		forward (and fanIn)
 *	Description:	Do a forward propagation of a pattern through 
 *			a net, updating unit activations, and network
 *			error. To activate a unit, do a fan in of all
 *			units in the previous layer (weight * output),
 *			then pass the result through a sigmoid.	
 *			Note: pattern must already be presented.
 *	Parameters:
 *	  Net	*netPtr - the network to do the forward pass through
 *	Return Value:
 *	  static void	forward - NONE
 *********************************************************************/
static void	fanIn  (Unit	*unitPtr, int, double *, Unit *) ;
/********************************************************************/
static void	forward(Net	*netPtr) {
  int		idx ;
  Unit		*unit ;

  unit = netPtr->hidden ;
  for (idx = 0 ; idx < NUM_HIDDEN ; ++idx) {
    unit[idx].input = 0.0 ;
    fanIn(&unit[idx], NUM_INPUT, netPtr->i2h[idx], netPtr->input) ;
    fanIn(&unit[idx], NUM_BIAS,  netPtr->b2h[idx], netPtr->bias) ;
    unit[idx].output = sigmoid(unit[idx].input) ;
  }

  unit = netPtr->output ;
  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx) {
    unit[idx].input = 0.0 ;
    fanIn(&unit[idx], NUM_HIDDEN, netPtr->h2o[idx], netPtr->hidden) ;
    fanIn(&unit[idx], NUM_BIAS,   netPtr->b2o[idx], netPtr->bias) ;
    unit[idx].output = sigmoid(unit[idx].input) ;

    netPtr->error += square(unit[idx].output - unit[idx].target) ;
  }
  netPtr->error /= 2.0 ;
}
/********************************************************************/
static void	fanIn(Unit	*unitPtr,
		      int	numIncoming,
		      double	*incomingWeight,
		      Unit	*incomingUnit) {
  int	idx ;
  for (idx = 0 ; idx < numIncoming ; ++idx)
    unitPtr->input += (incomingUnit[idx].output * incomingWeight[idx]) ;
}
/********************************************************************/
/*********************************************************************
 *	Name:		backward (and fanBack, and zeroBetas)
 *	Description:	Do a backward propagation of a pattern through 
 *			a net, updating unit betas, and 
 *			weights. To do the backward pass for a unit,
 *			pass the derivative of the error back through
 *			the sigmoid derivative, then fan back the error
 *			through the incoming connections (update the
 *			weights after all are done).
 *			Note: pattern must already be presented, and
 *			      forward pass done (i.e. outputs calculated.
 *	Parameters:
 *	  Net	*netPtr - the network to do the backward pass through
 *	Return Value:
 *	  static void	backward - NONE
 *********************************************************************/
static void	fanBack      (Unit	*unitPtr, int, double *, Unit *) ;
static void	updateWeights(Unit	*unitPtr, int, double *, Unit *) ;
static void	zeroBetas(Net	*netPtr) ;
/********************************************************************/
static void	backward(Net	*netPtr) {
  int		idx ;
  Unit		*unit ;

  zeroBetas(netPtr) ;

  /* calculate beta for output units */
  unit = netPtr->output ;
  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx)
    unit[idx].beta = unit[idx].output - unit[idx].target ;

  /* backpropagate beta for hidden units */
  unit = netPtr->output ;
  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx)
    fanBack(&unit[idx], NUM_HIDDEN, netPtr->h2o[idx], netPtr->hidden) ;

  /* don't need to backpropagate beta for input or bias units */

  /* update the weights hidden->output */
  unit = netPtr->output ;
  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx) {
    updateWeights(&unit[idx], NUM_HIDDEN, netPtr->h2o[idx], netPtr->hidden) ;
    updateWeights(&unit[idx], NUM_BIAS,   netPtr->b2o[idx], netPtr->bias) ;
  }

  /* update the weights input->hidden */
  unit = netPtr->hidden ;
  for (idx = 0 ; idx < NUM_HIDDEN ; ++idx) {
    updateWeights(&unit[idx], NUM_INPUT, netPtr->i2h[idx], netPtr->input) ;
    updateWeights(&unit[idx], NUM_BIAS,  netPtr->b2h[idx], netPtr->bias) ;
  }
}
/********************************************************************/
/********************************************************************/
static void	fanBack(Unit	*unitPtr,
			int	numIncoming,
			double	*incomingWeight,
			Unit	*incomingUnit) {
  int		idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incomingUnit[idx].beta += (incomingWeight[idx]
			       * sigmoidDerivative(unitPtr->output)
			       * unitPtr->beta) ;
}
/********************************************************************/
static void	updateWeights(Unit	*unitPtr,
			      int	numIncoming,
			      double	*incomingWeight,
			      Unit	*incomingUnit) {
  int		idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    incomingWeight[idx] -= STEP_SIZE * (incomingUnit[idx].output
					* sigmoidDerivative(unitPtr->output)
					* unitPtr->beta) ;
#ifdef COST
    incomingWeight[idx] -= COST(incomingWeight[idx]) ;
#endif
  }
}
/********************************************************************/
static void	zeroBetas(Net	*netPtr) {
  int		idx ;

  for (idx = 0 ; idx < NUM_HIDDEN ; ++idx)
    netPtr->hidden[idx].beta = 0.0 ;

  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx)
    netPtr->output[idx].beta = 0.0 ;
}
/********************************************************************/
/*********************************************************************
 *	Name:		dump
 *	Description:	dumps the values in a network to stdout
 *	Parameters:
 * 	  Net	net - the network to dump.
 *	Return Value:
 *	  static void	dump - NONE
 *********************************************************************/
static void	dump(Net	net) 
{
  int		idx, subIdx ;

  fprintf(stdout, "Error: %-8g\n", net.error) ;

  fprintf(stdout, "Input Units:\nOutput\n") ;
  for (idx = 0 ; idx < NUM_INPUT ; ++idx)
    fprintf(stdout, "%-8g\n", net.input[idx].output) ;
    
  fprintf(stdout, "Hidden Units:\nOutput\t\tIncoming Weights\n") ;
  for (idx = 0 ; idx < NUM_HIDDEN ; ++idx) {
    fprintf(stdout, "%-8g,", net.hidden[idx].output) ;
    for (subIdx = 0 ; subIdx < NUM_INPUT ; ++subIdx)
      fprintf(stdout, "\t%-8g", net.i2h[idx][subIdx]) ;
    fprintf(stdout, "\n") ;
  }
    
  fprintf(stdout, "Output Units:\nOutput\t\tIncoming Weights\n") ;
  for (idx = 0 ; idx < NUM_OUTPUT ; ++idx) {
    fprintf(stdout, "%-8g,", net.output[idx].output) ;
    for (subIdx = 0 ; subIdx < NUM_HIDDEN ; ++subIdx)
      fprintf(stdout, "\t%-8g", net.h2o[idx][subIdx]) ;
    fprintf(stdout, "\n") ;
  }
  return ;
}
/********************************************************************/
