/**********************************************************************
 * $Id: mft.c,v 1.2 92/11/30 12:00:10 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 **********************************************************************/

 /*********************************************************************
 *
 *  MFT/FEM/Boltzmann Modules written by     Evan W. Steeg
 *                                           Dept. of Computer Science
 *  August 1991                              Univ. of Toronto
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>

#include <xerion/useful.h>
#include <xerion/version.h>
#include <xerion/simulator.h>

#include "mft.h"
#include "train.h"
#include "help.h"

static void	initNet     ARGS((Net	 net)) ;
static void	deinitNet   ARGS((Net	 net)) ;
static void	initGroup   ARGS((Group	group)) ;
static void	initUnit    ARGS((Unit	 unit)) ;
static void	deinitUnit  ARGS((Unit	 unit)) ;
static void	initLink    ARGS((Link	 link)) ;
static void	deinitLink  ARGS((Link	 link)) ;

static void	netActivityUpdate  ARGS((Net	net)) ;
static void	unitActivityUpdate ARGS((Unit	unit)) ;
static void	unitGradientUpdate ARGS((Unit	unit)) ;
static void	updateNetError     ARGS((Unit	unit, void	*data)) ;

struct TRACE    inRelaxation ;


/***********************************************************************
 *	Name:		main 
 *	Description:	the main function, used for the xerion simulator
 *	Parameters:	
 *		int	argc	- the number of input args
 *		char	**argv  - array of argument strings from command line
 *	Return Value:	
 *		int	main	- 0
 ***********************************************************************/
int main (argc, argv)
  int	argc ;
  char	**argv ;
{
  authors = "Evan Steeg" ;

  /* Insert any private initialization routines here */
  setCreateNetHook  (initNet) ;
  setDestroyNetHook (deinitNet) ;

  setCreateGroupHook(initGroup) ;

  setCreateUnitHook  (initUnit) ;
  setDestroyUnitHook (deinitUnit) ;

  setCreateLinkHook  (initLink) ;
  setDestroyLinkHook (deinitLink) ;

  /* Perform initialization of the simulator */
  IStandardInit(&argc, argv);

  /* Enter loop that reads commands and handles graphics */
  ICommandLoop(stdin, stdout, NULL);

  return 0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initNet 
 *	Description:	sets the error calculation  procedures for
 *			the net. As well, changes the activityUpdateProc,
 *			allocates memory for the extension and initializes
 *			some values in it
 *	Parameters:	
 *		Net	net - the net to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initNet(net)
  Net	net ;
{
  net->calculateErrorDerivProc = calculateNetErrorDeriv ;
  net->calculateErrorProc      = calculateNetError ;
  net->activityUpdateProc      = netActivityUpdate ;

  net->extension = (NetExtension)calloc(1, sizeof(NetExtensionRec)) ;

  MtMax(net)    = T_MAX ;
  MtMin(net)    = T_MIN ;
  MtDecay(net)  = T_DECAY ;
  Mrunning(net) = FALSE ;
}
/**********************************************************************/
static void	deinitNet(net)
  Net	net ;
{
  if (net->extension != NULL)
    free(net->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initGroup 
 *	Description:	sets the activity and weight updates for the units
 *			in a group.
 *	Parameters:	
 *		Group	group - the group to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initGroup(group)
  Group	group ;
{
  group->unitActivityUpdateProc = unitActivityUpdate ;
  group->unitGradientUpdateProc = unitGradientUpdate ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initUnit 
 *	Description:	allocates the memory for the unit extension record
 *	Parameters:	
 *		Unit	unit - the unit to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initUnit(unit)
  Unit	unit ;
{
  unit->extension = (UnitExtension)calloc(1, sizeof(UnitExtensionRec)) ;
}
/**********************************************************************/
static void	deinitUnit(unit)
  Unit	unit ;
{
  if (unit->extension != NULL)
    free(unit->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initLink 
 *	Description:	allocates the memory for the link extension record
 *	Parameters:	
 *		Link	link - the link to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initLink(link)
  Link	link ;
{
  link->extension = (LinkExtension)calloc(1, sizeof(LinkExtensionRec)) ;
}
/**********************************************************************/
static void	deinitLink(link)
  Link	link ;
{
  if (link->extension != NULL)
    free(link->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitError
 *	Description:	calculates the error value of a unit assuming
 *			a zeroErrorRadius, and assuming an info-theoretic
 *			global net error measure.
 *	Parameters:	
 *		const Unit	unit - the unit to calclate the error of
 *	Return Value:	
 *		Real	unitError - the error of the unit
 ***********************************************************************/
static Real	unitError(unit)
  const Unit	unit ;
{
  Real		error ;

  if (MzeroErrorRadius(unit->net) <= 0.0) {
    error = info(unit->output, unit->target) ;
  } else {
    Real	upper = unit->target + MzeroErrorRadius(unit->net) ;
    Real	lower = unit->target - MzeroErrorRadius(unit->net) ;
    upper = MIN(1.0, upper) ;
    lower = MAX(0.0, lower) ;
    if (unit->output > upper)
      error = info(unit->output, upper) ;
    else if (unit->output < lower)
      error = info(unit->output, lower) ;
    else
      error = 0.0 ;
  }

  return error ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		dotProduct
 *	Description:	calculates the dot product of all incoming
 *			links for a unit and stores it in the totalinput
 *			field of the unit.
 *	Parameters:	
 *		const Unit	unit - the unit to calculate the dot 
 *					product for
 *	Return Value:	
 *		Real	dotProduct - the dot product
 ***********************************************************************/
static Real	dotProduct(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	totalInput ;
  int	idx ;

  totalInput = 0.0 ;
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link    = incoming[idx] ;

    if (MsynchronousUpdate(unit->net))
      totalInput += link->weight * Mold(link->preUnit) ;
    else
      totalInput += link->weight * link->preUnit->output ;
  }
  unit->totalInput = totalInput ;
  return totalInput ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netActivityUpdate
 *	Description:	activates the network assuming an example has
 *			been input already. it does the anealing for
 *			the example.
 *	Parameters:	
 *		Net	net - the net to activate.
 *	Return Value:	
 ***********************************************************************/
static void	updateNetError(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->net->error += unitError(unit) ;
}
/**********************************************************************/
static void	netActivityUpdate(net)
  Net		net ;
{
  Real		tMin   = MtMin(net) ;
  Real		tMax   = MtMax(net) ;
  Real		tDecay = MtDecay(net) ;

  /* In annealing, can do updates asynch (and random order) or synch
     (and faster inorder traversal).  For synch, have to store old
     activation values for one sweep and use them in next sweep.
     Anneal until (user specified no annealing) or (finished annealing
     schedule) or (reached equilibrium).                               */

  netForAllUnits(net, ~(INPUT | BIAS), setOutput, NULL) ;

  if (MsynchronousUpdate(net)) 
    netForAllUnits(net, ALL, setOldOutput, NULL) ;

  MmaxDelta(net) = 0.0;
  for (Mtemp(net) = tMax ; TRUE ; Mtemp(net) *= tDecay) {
    if (MsynchronousUpdate(net))
      netForAllUnits(net, ~(INPUT | BIAS), updateActivity, NULL) ;
    else
      netForAllUnitsRandom(net, ~(INPUT | BIAS), updateActivity, NULL) ;

    if (Mrunning(net) != TRUE) {
      int	idx, delayCount = MdelayCount(net) ;

      IDoTrace(&inRelaxation) ;
      /* For viewing relaxation */
      for (idx = 0 ; idx < delayCount*1000 ; idx++)
	;
    }
    if (STOP_ANNEALING(net))
      break ;
  }

  netForAllUnits(net, OUTPUT, updateNetError, NULL) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitActivityUpdate
 *	Description:	activates a single unit. Sum the inputs, 
 *			scale them by the temperature, and pass the
 *			result through a sigmoid
 *	Parameters:	
 *		Unit	unit - the unit to activate.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitActivityUpdate(unit)
  Unit		unit ;
{
  Real		delta;
  Net		net = unit->net ;

  unit->totalInput = dotProduct(unit) ;

  if (MunitFunction(net))
    delta = tanh(unit->totalInput/Mtemp(net)) - unit->output ;
  else 
    delta = sigmoid(unit->totalInput/Mtemp(net)) - unit->output ;

  unit->output += delta * MrelaxStepSize(net) ;

  MmaxDelta(net) = MAX(MmaxDelta(net), fabs(delta)) ;

  if (MsynchronousUpdate(net)) 
    Mold(unit) = unit->output;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitGradientUpdate
 *	Description:	updates the gradients of all the incoming links
 *			to a unit.
 *	Parameters:	
 *		Unit	unit - the unit to update the gradients of.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitGradientUpdate(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link  = incoming[idx] ;
    link->deriv += MnegProd(link) - MposProd(link) ;
  }
}
/**********************************************************************/
