
/**********************************************************************
 * $Id: fem-train.c,v 1.3 92/11/30 11:55:33 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 **********************************************************************/

 /*********************************************************************
 *
 *  MFT/FEM/Boltzmann Modules written by     Evan W. Steeg
 *                                           Dept. of Computer Science
 *  August 1991                              Univ. of Toronto
 *
 **********************************************************************/


#include <stdio.h>
#include <math.h>

#include <xerion/simulator.h>
#include "fem.h"
#include "fem-train.h"

/**********************************************************************/
static void	zeroLinks           ARGS((Unit	unit, void	*data)) ;
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetErrorDeriv
 *	Description:	procedure for updating error and associated
 *			derivatives for a Mean Field Net.
 *			It performs a positive annealed phase, a
 *			negative annealed phase, and then updates
 *			the gradients for each example.
 *	Parameters:	
 *		Net		net - the network to use
 *		ExampleSet	exampleSet - the example set to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void		calculateNetErrorDeriv(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  Mrunning(net) = TRUE ;

  Merror(net) = 0.0 ;
  netForAllUnits(net, ALL, zeroLinks, NULL) ;

  MrelaxSweepCountAve(net) = 0 ;
  for (numExamples = 0 ; numExamples < net->batchSize ; ++numExamples) {
    MgetNext(exampleSet) ;

    MupdateNetActivities(net) ;
    Merror(net) += square((double)McaseError(net))/2 ;

    MupdateNetGradients(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("train: no examples processed") ;

  MrelaxSweepCountAve(net) /= (Real)numExamples ;

  /* update the cost after everything else is done */
  MevaluateCostAndDerivs(net) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetError
 *	Description:	procedure for calculating the error on a set
 *			of examples for a mean field net. It calls
 *			the net's updateActivities procedure for each
 *			example.
 *	Parameters:	
 *		Net		net - the network to use
 *		ExampleSet	exampleSet - the example set to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void		calculateNetError(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  Mrunning(net) = TRUE ;
  net->error = 0.0 ;
  for (numExamples = 0 ; numExamples < net->batchSize ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
  }
  if (numExamples <= 0)
    IErrorAbort("train: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCost(net) ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		zeroLinks
 *	Description:	zeroes the deriv fields in the incoming links
 *			to a unit
 *	Parameters:
 *	  Unit		unit - the unit whose links are to be zeroed
 *	  void		*data - UNUSED
 *	Return Value:
 *	  static void	zeroLinks - NONE
 *********************************************************************/
static void	zeroLinks(unit, data)
  Unit		unit ;
  void		*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv = 0.0 ;
}
/**********************************************************************/
Real		square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/**********************************************************************/
