
/**********************************************************************
 * $Id: hcl.c,v 1.3 92/11/30 11:56:53 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>

#include <xerion/useful.h>
#include <xerion/version.h>
#include <xerion/simulator.h>
#include <xerion/minimize.h>
#include "hcl.h"
#include "help.h"

static void	initNet                ARGS((Net	net)) ;
static void	deinitNet              ARGS((Net	net)) ;
static void	calculateNetErrorDeriv ARGS((Net net,ExampleSet exampleSet));
static void	calculateNetError      ARGS((Net net,ExampleSet exampleSet));

static void	initGroup              ARGS((Group	group)) ;
static Proc     winnerTakeAllGroupActivityUpdate ARGS((Group group)) ;

static void	initUnit               ARGS((Unit  unit)) ;
static Real	unitError              ARGS((const Unit  unit)) ;
static Proc	gradUpdate             ARGS((Unit  unit)) ;
static void	zeroLinks              ARGS((Unit  unit, void *data)) ;

static Real	square    ARGS((double  x)) ;

/***********************************************************************
 *	Name:		main 
 *	Description:	the main function, used for the xerion simulator
 *	Parameters:	
 *		int	argc	- the number of input args
 *		char	**argv  - array of argument strings from command 
 *				  line
 *	Return Value:	
 *		int	main	- 0
 ***********************************************************************/
int main (argc, argv)
  int	argc ;
  char	**argv ;
{
  extern void	createScatterDisplay() ;
  extern int	addUserDisplay() ;

  addUserDisplay("scatter", createScatterDisplay) ;

  authors = "Sue Becker" ;

  /* Perform initialization of the simulator */
  IStandardInit(&argc, argv);

  /* Insert any private initialization routines here */
  setCreateNetHook  (initNet) ;
  setDestroyNetHook (deinitNet) ;
  setCreateGroupHook(initGroup) ;
  setCreateUnitHook(initUnit) ;

  /* Enter loop that reads commands and handles graphics */
  ICommandLoop(stdin, stdout, NULL);

  return 0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initNet 
 *	Description:	allocates the memory for the net extension record
 *			and initializes some net parameters.
 *	Parameters:	
 *		Net	net - the net to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initNet (net)
  Net	net ;
{
  net->calculateErrorDerivProc = calculateNetErrorDeriv ;
  net->calculateErrorProc      = calculateNetError ;

  net->extension = (NetExtension)calloc(1, sizeof(NetExtensionRec)) ;

  Mepsilon(net) = 0.0001 ;
  Mmomentum(net) = 0.0 ;
  Mradius(net) = 0.05;
  MdirectionMethod(net) = MZSTEEPEST;
  MstepMethod(net) = MZFIXEDSTEP;
}

/**********************************************************************/
static void	deinitNet (net)
  Net	net ;
{
  if (net->extension != NULL)
    free(net->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initGroup 
 *	Description:	sets the update procedures for the units in
 *			a group. 
 *	Parameters:	
 *		Group	group - the group to set the procedures for
 *	Return Value:	NONE
 ***********************************************************************/
static void	initGroup (group)
  Group	group ;
{
  group->groupActivityUpdateProc = winnerTakeAllGroupActivityUpdate;
  group->unitActivityUpdateProc = NULL ;
  group->unitGradientUpdateProc = gradUpdate ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		initUnit
 *	Description:	allocates space for unit extension fields.
 *	Parameters:	
 *		Unit	unit - the unit to allocate the space for
 *	Return Value:	NONE
 ***********************************************************************/
static void	initUnit (unit)
  Unit	unit ;
{
  unit->extension = (UnitExtension)calloc(1, sizeof(UnitExtensionRec)) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		winnerTakeAllGroupActivityUpdate
 *	Description:	sets each unit's totalInput to the squared difference
 *			between its weight vector and input vector
 *                      The unit with the SMALLEST total input wins the 
 *                      competition and outputs a 1. The rest output 0.
 *                      Returns if the proc is NULL
 *      Parameters:
 *		Group	group - the group object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static Proc	winnerTakeAllGroupActivityUpdate (group)
  Group		group ;
{
  Unit *unit;
  int unitIndex;
  int winner = -1;
  Real totalInput, minTotalInput = HUGE_VAL;

  if (group->type & INPUT) return;
  for (unitIndex = 0, unit = group->unit ;
       unitIndex < group->numUnits ; ++unitIndex, ++unit) {
     int	numIncoming = (*unit)->numIncoming ;
     Link	*incoming   = (*unit)->incomingLink ;
     int	linkIndex ;
 
     (*unit)->output = 0.0;
     totalInput = 0.0;
     for (linkIndex = 0 ; linkIndex < numIncoming ; ++linkIndex) {
        Link	link = incoming[linkIndex] ;
        totalInput += square(link->weight - (link->preUnit->output));
      }
     (*unit)->totalInput = totalInput;
     if (totalInput < minTotalInput) {
       minTotalInput = totalInput;
       winner = unitIndex;
     }
   }
   if (winner < 0)     
    IErrorAbort("winnerTakeAllGroupActivityUpdate: no units in group") ;
   else {
     unit = (group->unit)+winner;
     (*unit)->output = 1.0;
     (*unit)->net->error += unitError(*unit);
   }
}

/**********************************************************************/


/***********************************************************************
 *	Name:		unitError
 *	Description:	calculates the error value of a unit 
 *                      Only the winning unit may have a non-zero error.
 *	Parameters:	
 *		const Unit	unit - the unit to calclate the error of
 *	Return Value:	
 *		Real		unitError - the error of the unit
 ***********************************************************************/
static Real	unitError(unit)
  const Unit	unit ;
{
  Real		error ;

  if (unit->output < 1.0) error = 0.0;
  else  error = unit->totalInput ;

  return error ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		gradUpdate
 *	Description:	sets the deriv field on incoming links of unit
 *	Parameters:	
 *		Unit	unit - the unit to set the grads for
 *	Return Value:	NONE
 ***********************************************************************/
static Proc	gradUpdate(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;
  
  /* If the unit was the winner, update the grads of all the 
   * incoming links to the unit in the direction of minimizing the 
   * squared distance from the weight vector to the input vector.
   */
  if (unit->output >= 1.0) {
    for (idx = 0 ; idx < numIncoming ; ++idx) {
      Link	link  = incoming[idx] ;

      link->deriv -= (link->preUnit->output - link->weight);
    }
  }
}
/**********************************************************************/

/***********************************************************************
 *	Name:		calculateNetErrorDeriv
 *	Description:	gradient calculation procedure for backprop net
 *			It processes 'MbatchSize(net)' examples
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateNetErrorDeriv(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  /* zero the net error */
  net->error = 0.0 ;

  netForAllUnits(net, ~(INPUT | BIAS), zeroLinks, NULL) ;

  /* For each example	- do a forward pass updating the activities
   *			- update the derivatives
   */

  for (numExamples = 0 ; numExamples < MbatchSize(net) ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
    MupdateNetGradients(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("calculateNetErrorDeriv: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCostAndDerivs(net) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetError
 *	Description:	error calculation procedure for backprop net
 *			It processes 'MbatchSize(net)' examples
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateNetError(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  net->error = 0.0 ;

  for (numExamples = 0 ; numExamples < MbatchSize(net) ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
  }
  if (numExamples <= 0)
    IErrorAbort("calculateNetError: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCost(net) ;
}
/**********************************************************************/

/**********************************************************************/
static void	zeroLinks(unit, data)
  Unit	unit ;
  void	*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv  = 0.0;
}

/**********************************************************************/

static Real	square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/**********************************************************************/

