
/**********************************************************************
 * $Id: train.c,v 1.3 92/11/30 12:00:12 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

/*********************************************************************
 *
 *  MFT/FEM/Boltzmann Modules written by     Evan W. Steeg
 *                                           Dept. of Computer Science
 *  August 1991                              Univ. of Toronto
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>

#include <xerion/simulator.h>

#include "mft.h"
#include "train.h"

#define CLEAR		0
#define POSITIVE	1
#define NEGATIVE	2

/**********************************************************************/
static void	clampOutput         ARGS((Unit	unit, void	*data)) ;
static void	setIncomingProducts ARGS((Unit	unit, void	*data)) ;
static void	zeroLinks           ARGS((Unit	unit, void	*data)) ;
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetErrorDeriv
 *	Description:	procedure for calculating network error and 
 *			associated derivatives in a Mean Field Net.
 *			It performs a positive annealed phase, a
 *			negative annealed phase, and then updates
 *			the gradients for each example.
 *	Parameters:	
 *		Net		net - the network to use
 *		ExampleSet	exampleSet - the example set to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void		calculateNetErrorDeriv(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  Real		tMin  = MtMin(net) ;
  Real		tMax  = MtMax(net) ;
  Real		tDecay = MtDecay(net) ;
  int		numExamples ;

  Mrunning(net) = TRUE ;
  net->error = 0.0 ;
  netForAllUnits(net, ALL, zeroLinks, NULL) ;

  MrelaxSweepCountAve(net) = 0.0 ;
  for (numExamples = 0 ; numExamples < net->batchSize ; ++numExamples) {
    int	relaxSweepCount = 0 ; 
    MgetNext(exampleSet) ;
    netForAllUnits(net, ALL, setIncomingProducts, (void *)CLEAR) ;

    /* In annealing, can do updates asynch (and random order) or synch
       (and faster inorder traversal).  For synch, have to store old
       activation values for one sweep and use them in next sweep.
       Anneal until (user specified no annealing) or (finished annealing
       schedule) or (reached equilibrium).                               */

    /* Positive phase */
    netForAllUnits(net, ~(INPUT | OUTPUT | BIAS), setOutput, NULL) ;
    netForAllUnits(net, OUTPUT, clampOutput, NULL) ;
    if (MsynchronousUpdate(net)) 
      netForAllUnits(net, ALL, setOldOutput, NULL) ;

    MmaxDelta(net) = 0.0;
    for (Mtemp(net) = tMax ; TRUE ; Mtemp(net) *= tDecay) {
      if (MsynchronousUpdate(net)) 
        netForAllUnits(net, ~(INPUT | OUTPUT | BIAS),updateActivity, NULL) ;
      else 
        netForAllUnitsRandom(net, ~(INPUT | OUTPUT | BIAS),
			     updateActivity, NULL) ;
      relaxSweepCount++ ;
      if (STOP_ANNEALING(net))
	break ;
    }
    netForAllUnits(net, ALL, setIncomingProducts, (void *)POSITIVE) ;

    /* Negative phase */
    netForAllUnits(net, ~(INPUT | BIAS), setOutput, NULL) ;
    if (MsynchronousUpdate(net)) 
      netForAllUnits(net, ALL, setOldOutput, NULL) ;

    MmaxDelta(net) = 0.0;
    for (Mtemp(net) = tMax ; TRUE ; Mtemp(net) *= tDecay) {
      if (MsynchronousUpdate(net)) 
       netForAllUnits(net, ~(INPUT | BIAS),updateActivity, NULL) ;
      else 
	netForAllUnitsRandom(net, ~(INPUT | BIAS),updateActivity, NULL) ;
      relaxSweepCount++ ;
      if (STOP_ANNEALING(net))
	break ;
    }
    netForAllUnits(net, ALL, setIncomingProducts, (void *)NEGATIVE) ;

    MrelaxSweepCountAve(net) += relaxSweepCount ;
    /* gradient update */
    MupdateNetGradients(net) ;

    /* update the error for the net */
    MupdateNetActivities(net) ;
  }
  if (numExamples <= 0)
    IErrorAbort("calculateNetErrorDeriv: no examples processed") ;

  MrelaxSweepCountAve(net) /= (Real)numExamples ;

  /* update the cost after everything else is done */
  MevaluateCostAndDerivs(net) ;

  Mrunning(net) = FALSE ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetError
 *	Description:	procedure for calculating the error in a mean
 *			field net for a set of examples. 
 *	Parameters:	
 *		Net		net - the network to use
 *		ExampleSet	exampleSet - the example set to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void		calculateNetError(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  Mrunning(net) = TRUE ;

  net->error = 0.0 ;
  for (numExamples = 0 ; numExamples < net->batchSize ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("calculateNetError: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCost(net) ;

  Mrunning(net) = FALSE ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		updateCost
 *	Description:	updates the net cost and deriv wrt cost
 *	Parameters:	
 *		Unit	unit - the unit to update the weights of.
 *	Return Value:	NONE
 ***********************************************************************/
static void	updateCost(unit, data)
  Unit	unit ;
  void	*data ;
{
  Net	net         = (Net)data ;
  Real	weightCost  = MweightCost(net) ;
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  /* Same as for MFT and Boltzmann */
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link  = incoming[idx] ;

    net->cost   += weightCost*square(link->weight) ;
    link->deriv += 2.0*weightCost*link->weight ;
  }
}
/**********************************************************************/


/**********************************************************************/
static void	clampOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->output = unit->target ;
}
/**********************************************************************/
void		updateActivity(unit, data)
  Unit		unit ;
  void		*data ;
{
  MupdateUnitActivity(unit) ;
}
/**********************************************************************/
static void	setIncomingProducts(unit, data)
  Unit		unit ;
  void		*data ;
{
  int		mode = (int)data ;
  int		numIncoming = unit->numIncoming ;
  Link		*incoming   = unit->incomingLink ;
  int		idx ;

  /* Calculate the (s_i*s_j)plus and (s_i*s_j)minus products for use
     in gradient update.                                               */

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link = incoming[idx] ;
    if (mode == POSITIVE)
      MposProd(link) += link->preUnit->output*link->postUnit->output ;
    else if (mode == NEGATIVE)
      MnegProd(link) += link->preUnit->output*link->postUnit->output ;
    else if (mode == CLEAR)
      MposProd(link) = MnegProd(link) = 0.0 ;
  }
}
/**********************************************************************/


/*********************************************************************
 *	Name:		zeroLinks
 *	Description:	zeroes the deriv fields in the incoming links
 *			to a unit
 *	Parameters:
 *	  Unit		unit - the unit whose links are to be zeroed
 *	  void		*data - UNUSED
 *	Return Value:
 *	  static void	zeroLinks - NONE
 *********************************************************************/
static void	zeroLinks(unit, data)
  Unit		unit ;
  void		*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv = 0.0 ;
}
/********************************************************************/
Real		square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/**********************************************************************/
void		setOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->output = 0.5 ;
}
/**********************************************************************/
void		setOldOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->extension->old = unit->output ;
}
/**********************************************************************/
