
/**********************************************************************
 * $Id: fem.c,v 1.2 92/11/30 11:55:36 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 **********************************************************************/

 /*********************************************************************
 *
 *  MFT/FEM/Boltzmann Modules written by     Evan W. Steeg
 *                                           Dept. of Computer Science
 *  August 1991                              Univ. of Toronto
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>

#include <xerion/useful.h>
#include <xerion/version.h>
#include <xerion/simulator.h>
#include "fem.h"
#include "fem-train.h"
#include "help.h"

/***********************************************************************
 *	Private functions
 ***********************************************************************/
static void	initNet    ARGS((Net	 net)) ;
static void	deinitNet  ARGS((Net	 net)) ;
static void	initGroup  ARGS((Group	group)) ;
static void	initLink   ARGS((Link	 link)) ;
static void	deinitLink ARGS((Link	 link)) ;

static void	netActivityUpdate  ARGS((Net	net)) ;
static void	unitActivityUpdate ARGS((Unit	unit)) ;
static void	unitGradientUpdate ARGS((Unit	unit)) ;

static void	clampOutput ARGS((Unit	unit, void	*data)) ;
static void	setOutput   ARGS((Unit	unit, void	*data)) ;
static void	setOutputProbActual ARGS((Unit	unit, void	*data)) ;
static void     unitUpdateEnergies  ARGS((Unit  unit, void      *data)) ;
static void	setIncomingProducts ARGS((Unit	unit, void	*data)) ;
static void	updateActivity      ARGS((Unit	unit, void	*data)) ;
/***********************************************************************/

struct TRACE    inRelaxation ;

/***********************************************************************
 *	Name:		main 
 *	Description:	the main function, used for the xerion simulator
 *	Parameters:	
 *		int	argc	- the number of input args
 *		char	**argv  - array of argument strings from command line
 *	Return Value:	
 *		int	main	- 0
 ***********************************************************************/
int main (argc, argv)
  int	argc ;
  char	**argv ;
{
  authors = "Evan Steeg" ;

  /* Insert any private initialization routines here */
  setCreateNetHook (initNet) ;
  setDestroyNetHook(deinitNet) ;

  setCreateGroupHook(initGroup) ;

  setCreateLinkHook (initLink) ;
  setDestroyLinkHook(deinitLink) ;

  /* Perform initialization of the simulator */
  IStandardInit(&argc, argv);

  /* Enter loop that reads commands and handles graphics */
  ICommandLoop(stdin, stdout, NULL);

  return 0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initNet 
 *	Description:	sets the procedures that calculate error for
 *			the net. As well, changes the activityUpdateProc,
 *			allocates memory for the extension and initializes
 *			some values in it
 *	Parameters:	
 *		Net	net - the net to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initNet(net)
  Net	net ;
{
  net->calculateErrorDerivProc = calculateNetErrorDeriv ;
  net->calculateErrorProc      = calculateNetError ;
  net->activityUpdateProc      = netActivityUpdate ;

  net->extension = (NetExtension)calloc(1, sizeof(NetExtensionRec)) ;

  MtMax(net)    = T_MAX ;
  MtMin(net)    = T_MIN ;
  MtDecay(net)  = T_DECAY ;
  Mrunning(net) = FALSE ;
}
/**********************************************************************/
static void	deinitNet(net)
  Net	net ;
{
  if (net->extension != NULL)
    free(net->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initGroup 
 *	Description:	sets the activity and weight updates for the units
 *			in a group.
 *	Parameters:	
 *		Group	group - the group to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initGroup(group)
  Group	group ;
{
  group->unitActivityUpdateProc = unitActivityUpdate ;
  group->unitGradientUpdateProc = unitGradientUpdate ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initLink 
 *	Description:	allocates the memory for the link extension record
 *	Parameters:	
 *		Link	link - the link to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initLink(link)
  Link	link ;
{
  link->extension = (LinkExtension)calloc(1, sizeof(LinkExtensionRec)) ;
}
/**********************************************************************/
static void	deinitLink(link)
  Link	link ;
{
  if (link->extension != NULL)
    free(link->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		dotProduct
 *	Description:	calculates the dot product of all incoming
 *			links for a unit and stores it in the totalinput
 *			field of the unit.
 *	Parameters:	
 *		const Unit	unit - the unit to calculate the dot 
 *					product for
 *	Return Value:	
 *		Real	dotProduct - the dot product
 ***********************************************************************/
static Real	dotProduct(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	totalInput ;
  int	idx ;

  totalInput = 0.0 ;
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link    = incoming[idx] ;

    totalInput += link->weight * link->preUnit->output ;
  }
  unit->totalInput = totalInput ;
  return totalInput ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netActivityUpdate
 *	Description:	activates the network assuming an example has
 *			been input already. it does the anealing for
 *			the example.
 *	Parameters:	
 *		Net	net - the net to activate.
 *	Return Value:	
 ***********************************************************************/
static void	netActivityUpdate(net)
  Net		net ;
{
  Real		tMin   = MtMin(net) ;
  Real		tMax   = MtMax(net) ;
  Real		tDecay = MtDecay(net) ;
  int		relaxSweepCount ;
  Boolean	stopAnnealing ;

  netForAllUnits(net, ALL, setIncomingProducts, (void *)CLEAR) ;
  Menergy(net) = Mentropy(net) = MfreeEnergy(net) = MprobActual(net) = 0.0;

  /* Positive phase */
  netForAllUnits(net, ~(INPUT | OUTPUT | BIAS), setOutput, NULL) ;
  netForAllUnits(net, OUTPUT, clampOutput, NULL) ;
  MmaxDelta(net)  = 0.0;
  relaxSweepCount = 0 ;
  for (stopAnnealing = FALSE, Mtemp(net) = tMax ; 
       stopAnnealing != TRUE ; Mtemp(net) *= tDecay) {
    relaxSweepCount++ ;
    netForAllUnitsRandom(net, ~(INPUT | OUTPUT | BIAS), updateActivity, NULL) ;

    if (!Manneal(net) || (Mtemp(net)<=tMin) || 
	(MmaxDelta(net) <= MrelaxTolerance(net)))  stopAnnealing = TRUE;

    if (Mrunning(net) != TRUE) {
      int	idx, delayCount = MdelayCount(net) ;

      IDoTrace(&inRelaxation) ;
      /* For viewing relaxation */
      for (idx = 0 ; idx < delayCount*1000 ; idx++)
	;
    }
  }
  MrelaxSweepCountAve(net) += relaxSweepCount ;
  netForAllUnits(net, ALL, setIncomingProducts, (void *)POSITIVE) ;

  /* Calculate Energy, Entropy, and Free Energy of net. From that,
   *  compute the actual probability net gives of example being a
   * "positive case". */
  netForAllUnits(net, ALL, unitUpdateEnergies, NULL);

  Menergy(net)     = 0.5 * (0.0 - Menergy(net));
  MfreeEnergy(net) = Menergy(net) - (Mtemp(net) * Mentropy(net));
  MprobActual(net) = 1.0 / (1 + exp(MfreeEnergy(net)));

  netForAllUnits(net, OUTPUT, setOutputProbActual, NULL);

  McaseError(net) = MprobActual(net) - MprobDesired(net);
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitActivityUpdate
 *	Description:	activates a single unit. Sum the inputs, 
 *			scale them by the temperature, and pass the
 *			result through a sigmoid
 * 			Essentially the same as in regular MFT
 *	Parameters:	
 *		Unit	unit - the unit to activate.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitActivityUpdate(unit)
  Unit		unit ;
{
  Real	delta;
  Net	net = unit->net ;

  unit->totalInput = dotProduct(unit) ;

  if (MunitFunction(net))
    delta = tanh(unit->totalInput/Mtemp(net)) - unit->output ;
  else 
    delta = sigmoid(unit->totalInput/Mtemp(net)) - unit->output ;

  unit->output += delta * MrelaxStepSize(net);

  MmaxDelta(net) = MAX(MmaxDelta(net), fabs(delta)) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitGradientUpdate
 *	Description:	updates the gradients of all the incoming links
 *			to a unit.
 *	Parameters:	
 *		Unit	unit - the unit to update the gradients of.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitGradientUpdate(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	caseError   = McaseError(unit->net) ;
  int	idx ;

  /* Gradient in MFT is driven by (probActual - probDesired)*s_i*s_j    */
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link  = incoming[idx] ;

    link->deriv += caseError * Mprod(link) ; 
  }
}
/**********************************************************************/


/**********************************************************************/
static void	clampOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
 MprobDesired(unit->net) = unit->target;
}
/**********************************************************************/
static void	setOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->output = 0.5 ;
}
/**********************************************************************/
static void	setOutputProbActual(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->output = MprobActual(unit->net);
}
/**********************************************************************/
static void	updateActivity(unit, data)
  Unit		unit ;
  void		*data ;
{
  MupdateUnitActivity(unit) ;
}
/**********************************************************************/
static void     setIncomingProducts(unit, data)
  Unit          unit ;
  void          *data ;
{
  int           mode = (int)data ;
  int           numIncoming = unit->numIncoming ;
  Link          *incoming   = unit->incomingLink ;
  int           idx ;

  if (mode == POSITIVE) {
    for (idx = 0 ; idx < numIncoming ; ++idx) {
      Link        link = incoming[idx] ;
      Mprod(link) += link->preUnit->output*link->postUnit->output ;
    }
  } else if (mode == CLEAR) {
    for (idx = 0 ; idx < numIncoming ; ++idx) {
      Mprod(incoming[idx]) =  0.0 ;
    }
  }
}
/**********************************************************************/
static void	unitUpdateEnergies(unit, data)
  Unit		unit ;
  void		*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Net	net         = unit->net ;
  int	idx ;

  /* Compute E,H,F by summing over units and links. */
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link  = incoming[idx] ;

    Menergy(net) += link->weight * Mprod(link);
  }

  Mentropy(net) += ENTROPY(unit->output) ;
}
/**********************************************************************/
