
/**********************************************************************
 * $Id: unit.c,v 1.12 93/04/14 12:19:49 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>
#include <values.h>


#include <xerion/useful.h>
#include <xerion/simulator.h>
#include <xerion/sigmoid.h>

#include "unit.h"
#include "bp.h"

#ifdef TANH
#undef TANH
#endif

/* combin procedure pairs */
static void	multiplyCombIn 		ARGS((Unit	unit)) ;
static void	multiplyCombBack 	ARGS((Unit	unit)) ;
static void	dotProductCombIn 	ARGS((Unit	unit)) ;
static void	dotProductCombBack	ARGS((Unit	unit)) ;
static void	distanceCombIn 		ARGS((Unit	unit)) ;
static void	distanceCombBack	ARGS((Unit	unit)) ;

/* transfer function pairs */
static void	logisticTransfer	ARGS((Unit	unit)) ;
static void	logisticTransferBack	ARGS((Unit	unit)) ;
static void	tanhTransfer		ARGS((Unit	unit)) ;
static void	tanhTransferBack	ARGS((Unit	unit)) ;
static void	linearTransfer		ARGS((Unit	unit)) ;
static void	linearTransferBack	ARGS((Unit	unit)) ;
static void	exponentialTransfer	ARGS((Unit	unit)) ;
static void	exponentialTransferBack	ARGS((Unit	unit)) ;
static void	negExponentialTransfer	ARGS((Unit	unit)) ;
static void	negExponentialTransferBack ARGS((Unit	unit)) ;
static void	softMaxForward		ARGS((Group	group)) ;
static void	softMaxBackward		ARGS((Group	group)) ;

/* Error contribution function pairs */
static void	sumSquareError		ARGS((Unit	unit)) ;
static void	sumSquareErrorDeriv	ARGS((Unit	unit)) ;
static void	adaptiveSumSquareError	ARGS((Unit	unit)) ;
static void	adaptiveSumSquareErrorDeriv ARGS((Unit	unit)) ;
static void	crossEntropyError	ARGS((Unit	unit)) ;
static void	crossEntropyErrorDeriv	ARGS((Unit	unit)) ;

/* forward and backward procedures */
static void	standardForward		ARGS((Group group)) ;
static void	standardBackward	ARGS((Group group)) ;
static void	standardUnitForwardFAU	ARGS((Unit	unit, void *)) ;
static void	standardUnitBackwardFAU	ARGS((Unit	unit, void *)) ;

static void	gatedForward	ARGS((Group	group)) ;
static void	gatedBackward	ARGS((Group	group)) ;

/* miscellaneous */
static void	addErrorToNet	ARGS((Unit, void *)) ;
static Real	square		ARGS((double  x)) ;
static Real	adjustedTarget	ARGS((const Unit	unit)) ;

/*********************************************************************
 *	Name:		setDotProduct
 *	Description:	sets the combIn and combBack procedures for a
 *			group of units to use dot products
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setDotProduct - NONE
 *********************************************************************/
void	setDotProduct(group)
  Group	group ;
{
  McombInProc(group)   = dotProductCombIn ;
  McombBackProc(group) = dotProductCombBack ;
}
/***********************************************************************
 *	Name:		dotProductCombIn
 *	Description:	calculates the dot product of all incoming
 *			links for a unit and stores it in the totalInput
 *			field of the unit.
 *	Parameters:	
 *		Unit	unit - the unit to calculate the dot 
 *				product for
 *	Return Value:	
 *		static void	dotProductCombIn - NONE
 ***********************************************************************/
static void	dotProductCombIn(unit)
  Unit		unit ;
{
  Link		*linkPtr  = unit->incomingLink ;
  Link		*finalPtr = linkPtr + unit->numIncoming ;
  Real		totalInput ;

  totalInput = 0.0 ;
  while (linkPtr != finalPtr) {
    Link	link = *(linkPtr++) ;

    totalInput += link->weight * link->preUnit->output ;
  }
  unit->totalInput = totalInput ;
}
/***********************************************************************
 *	Name:		dotProductCombBack
 *	Description:	back propagates derivates through the incoming
 *			links of a unit and adds them to the pre units'
 *			outputDerivs and the links derivs.
 *	Parameters:	
 *		Unit	unit - the unit to back propagate from
 *	Return Value:	NONE
 ***********************************************************************/
static void	dotProductCombBack(unit)
  Unit	unit ;
{
  Real		inputDeriv = unit->inputDeriv ;
  Link		*linkPtr   = unit->incomingLink ;
  Link		*finalPtr  = linkPtr + unit->numIncoming ;

  /* immediately post input, no need to backprop the preUnit->outputDeriv */
  if (unit->group->type & IPI) {
    while (linkPtr != finalPtr) {
      Link	link = *(linkPtr++) ;
      link->deriv   += inputDeriv*link->preUnit->output ;
    }
  } else {
    while (linkPtr != finalPtr) {
      Link	link    = *(linkPtr++) ;
      Unit	preUnit = link->preUnit ;
  
      preUnit->outputDeriv += inputDeriv*link->weight ;
      link->deriv          += inputDeriv*preUnit->output ;
    }
  }
}
/**********************************************************************/


/*********************************************************************
 *	Name:		setMultiply
 *	Description:	sets the combIn and combBack procedures for a
 *			group of units to use multipliers
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setMultiply - NONE
 *********************************************************************/
void	setMultiply(group)
  Group	group ;
{
  McombInProc(group)   = multiplyCombIn ;
  McombBackProc(group) = multiplyCombBack ;
}
/***********************************************************************
 *	Name:		multiplyCombIn
 *	Description:	calculates the product of all preUnit outputs
 *			and the incoming links' weight for a unit and
 *			stores it in the totalInput field of the unit.
 *			N.B. - This assumes that all incoming links
 *			       are constrained together.
 *	Parameters:	
 *		Unit	unit - the unit to act on
 *	Return Value:	
 *		static void	multiplyCombIn - NONE
 ***********************************************************************/
static void	multiplyCombIn(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	product = 0.0 ;
  int	idx ;

  if (numIncoming)
    product = incoming[0]->weight ;
  for (idx = 0 ; idx < numIncoming ; ++idx)
    product *= incoming[idx]->preUnit->output ;

  unit->totalInput = product ;
}
/***********************************************************************
 *	Name:		multiplyCombBack
 *	Description:	back propagates derivates through the incoming
 *			links of a unit and adds them to the pre units'
 *			outputDerivs and the links derivs.
 *			N.B. - This assumes that all incoming links
 *			       are constrained together.
 *	Parameters:	
 *		Unit	unit - the unit to back propagate from
 *	Return Value:	NONE
 ***********************************************************************/
static void	multiplyCombBack(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	product ;
  int	idx ;

  /* calculate the product of all outputs and the inputDeriv */
  product = unit->inputDeriv ;
  for (idx = 0 ; idx < numIncoming ; ++idx)
    product *= incoming[idx]->preUnit->output ;

  /************************************************************
   * Set all links to have same deriv (the product from above). 
   * We have to average the product, since derivs are summed for 
   * constrained weights into a single derivative.
   */
  if (numIncoming)
    product /= numIncoming ;
  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv += product ;
  if (numIncoming)
    product *= numIncoming ;

  /* now backprop all the error derivs */
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link	= incoming[idx] ;
    Unit	preUnit = link->preUnit ;
    Real	subProduct ;

    if (preUnit->output) {	/* just divide by the output */
      subProduct = product / preUnit->output ;

    } else {			/* can't divide by zero, so recalculate */
      int	subIdx ;

      subProduct = unit->inputDeriv ;
      for (subIdx = 0 ; subIdx < numIncoming ; ++subIdx) {
	if (subIdx != idx)
	  subProduct *= incoming[subIdx]->preUnit->output ;
      } 
    }
    preUnit->outputDeriv += link->weight * subProduct ;
  }
}
/**********************************************************************/



/*********************************************************************
 *	Name:		setDistance
 *	Description:	sets the combin and back procedures to use
 *			a distance measure
 *	Parameters:
 *	  Group	group - the group of units to set them for
 *	Return Value:
 *	  void	setDistance - NONE
 *********************************************************************/
void	setDistance(group)
  Group	group ;
{
  McombInProc(group)   = distanceCombIn ;
  McombBackProc(group) = distanceCombBack ;
}
/*********************************************************************
 *	Name:		distanceCombIn
 *	Description:	sets total input to the sum of the squares of 
 *			the difference of weight and preUnit output.
 *	Parameters:
 *	  Unit	unit - the unit to work on
 *	Return Value:
 *	  static void	setDistance - NONE
 *********************************************************************/
static void	distanceCombIn(unit)
  Unit		unit ;
{
  Link		*incoming   = unit->incomingLink ;
  unsigned int	numIncoming = unit->numIncoming ;
  Real		totalInput  = 0 ;
  unsigned int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link = incoming[idx] ;

    totalInput += square(link->weight - link->preUnit->output) ;
  }
  unit->totalInput = totalInput ;
}
/*********************************************************************
 *	Name:		distanceCombBack
 *	Description:	reciprocal of distanceCombIn
 *	Parameters:
 *	  Unit		unit - the unit to act on
 *	Return Value:
 *	  static void	distanceCombBack - NOTHING
 *********************************************************************/
static void	distanceCombBack(unit)
  Unit		unit ;
{
  Link		*incoming   = unit->incomingLink ;
  unsigned int	numIncoming = unit->numIncoming ;
  Real		inputDeriv  = unit->inputDeriv ;
  Real		totalInput  = 0 ;
  unsigned int	idx ;

  /* immediately post input, no need to backprop the preUnit->outputDeriv */
  if (unit->group->type & IPI) {
    for (idx = 0 ; idx < numIncoming ; ++idx) {
      Link	link = incoming[idx] ;
      link->deriv   += 2*inputDeriv*(link->weight - link->preUnit->output) ;
    }
  } else {
    for (idx = 0 ; idx < numIncoming ; ++idx) {
      Link	link	= incoming[idx] ;
      Unit	preUnit = link->preUnit ;
      Real	delta	= 2*inputDeriv*(link->weight - preUnit->output) ;

      preUnit->outputDeriv -= delta ;
      link->deriv          += delta ;
    }
  }
}
/********************************************************************/


/*********************************************************************
 *	Name:		setLinear
 *	Description:	sets the forward and backward procedures as well
 *			as the transfer function to use linear function
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setLinear - NONE
 *********************************************************************/
void	setLinear(group)
  Group	group ;
{
  setStandardForwardBackward(group) ;
  MtransferProc(group)	   = linearTransfer ;
  MtransferBackProc(group) = linearTransferBack ;
}
/*********************************************************************
 *	Name:		linearTransfer/linearTransferBack
 *	Description:	linearTransfer passes the unit's totalInput
 *			through a linear transformation and sets
 *			the output.
 *			linearTransferBack sets the inputDeriv using
 *			the inverse transformation.
 *	Parameters:
 *	  Unit		unit - the unit
 *	Return Value:
 *	  static void	linearTransfer - NONE
 *********************************************************************/
static void	linearTransfer(unit)
  Unit		unit ;
{
  unit->output = unit->totalInput ;
}
/********************************************************************/
static void	linearTransferBack(unit)
  Unit		unit ;
{
  unit->inputDeriv = unit->outputDeriv ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setLogistic
 *	Description:	sets the forward and backward procedures as well
 *			as the transfer function to use logistic function
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setLogistic - NONE
 *********************************************************************/
void	setLogistic(group)
  Group	group ;
{
  setStandardForwardBackward(group) ;
  MtransferProc(group)	   = logisticTransfer ;
  MtransferBackProc(group) = logisticTransferBack ;
}
/*********************************************************************
 *	Name:		logisticTransfer/logisticTransferBack
 *	Description:	logisticTransfer passes the unit's totalInput
 *			through a logistic (sigmoid) transformation 
 *			and sets the output
 *			logisticTransferBack sets the inputDeriv using
 *			the inverse transformation.
 *	Parameters:
 *	  Unit		unit - the unit
 *	Return Value:
 *	  static void	logisticTransfer - the transformed value
 *********************************************************************/
static Sigmoid	sigmoid = NULL ;
static void	logisticTransfer(unit)
  Unit		unit ;
{
  if (sigmoid == NULL)
    sigmoid = createSigmoid(0.0, 0.0, 1.0, 1.0) ;

  unit->output     = MSsigmoid(sigmoid, unit->totalInput) ;
  unit->inputDeriv = unit->output*(1.0 - unit->output) ;
}
/********************************************************************/
static void	logisticTransferBack(unit)
  Unit		unit ;
{
  unit->inputDeriv *= unit->outputDeriv ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setTanh
 *	Description:	sets the forward and backward procedures as well
 *			as the transfer function to use tanh function
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setTanh - NONE
 *********************************************************************/
void	setTanh(group)
  Group	group ;
{
  setStandardForwardBackward(group) ;
  MtransferProc(group)	   = tanhTransfer ;
  MtransferBackProc(group) = tanhTransferBack ;
}
/*********************************************************************
 *	Name:		tanhTransfer/tanhTransferBack
 *	Description:	tanhTransfer passes the unit's totalInput
 *			through a tanh (tanh) transformation 
 *			and sets the output.
 *			tanhTransferBack sets the inputDeriv using
 *			the inverse transformation.
 *	Parameters:
 *	  Unit		unit - the unit
 *	Return Value:
 *	  static void	tanhTransfer - NONE
 *********************************************************************/
#define SIMPLE_TANH(x)	(tanh((double)(x)))
#define TANH(x)		((x) > 30 ? (1.0) : \
			 ((x) < -30 ? (-1.0) : SIMPLE_TANH(x)))
/********************************************************************/
#define SIMPLE_TANH_DERIV(x)	(1.0/square(cosh((double)(x))))
#define TANH_DERIV(x)		((x) > 30 ? (0.0) : \
				 ((x) < -30 ? (0.0) : SIMPLE_TANH_DERIV(x)))
/********************************************************************/
static void	tanhTransfer(unit)
  Unit		unit ;
{
  unit->output = TANH(unit->totalInput) ;
}
/********************************************************************/
static void	tanhTransferBack(unit)
  Unit		unit ;
{
  unit->inputDeriv = unit->outputDeriv * TANH_DERIV(unit->totalInput) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setExponential
 *	Description:	sets the forward and backward procedures as well
 *			as the transfer function to use exp function
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setExponential - NONE
 *********************************************************************/
void	setExponential(group)
  Group	group ;
{
  setStandardForwardBackward(group) ;
  MtransferProc(group)	   = exponentialTransfer ;
  MtransferBackProc(group) = exponentialTransferBack ;
}
/*********************************************************************/
void	setNegExponential(group)
  Group	group ;
{
  setStandardForwardBackward(group) ;
  MtransferProc(group)	   = negExponentialTransfer ;
  MtransferBackProc(group) = negExponentialTransferBack ;
}
/*********************************************************************
 *	Name:		exponentialTransfer/exponentialTransferBack
 *	Description:	exponentialTransfer passes the unit's totalInput
 *			through an exponential transformation and sets
 *			the output.
 *			exponentialTransferBack sets the inputDeriv using
 *			the inverse transformation.
 *	Parameters:
 *	  Unit		unit - the unit
 *	Return Value:
 *	  static void	exponentialTransfer - NONE
 *********************************************************************/
static void	exponentialTransfer(unit)
  Unit		unit ;
{
  unit->output = exp(unit->totalInput) ;
}
/********************************************************************/
static void	exponentialTransferBack(unit)
  Unit		unit ;
{
  unit->inputDeriv = unit->outputDeriv * unit->output ;
}
/********************************************************************/
static void	negExponentialTransfer(unit)
  Unit		unit ;
{
  unit->output = exp(-unit->totalInput) ;
}
/********************************************************************/
static void	negExponentialTransferBack(unit)
  Unit		unit ;
{
  unit->inputDeriv = - unit->outputDeriv * unit->output ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setSumSquareError
 *	Description:	sets the error  procedures for a
 *			group of units to use sum square error
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setSumSquareError - NONE
 *********************************************************************/
void	setSumSquareError(group)
  Group	group ;
{
  McontributeErrorProc(group)	  = sumSquareError ;
  McontributeErrorBackProc(group) = sumSquareErrorDeriv ;
}
/*********************************************************************
 *	Name:		sumSquareError/sumSquareErrorDeriv
 *	Description:	calculates the sum square error of a unit
 *			and adds it to the net.
 *			sumSquareErrorDeriv adds the derivative to
 *			the units outputDeriv.
 *	Parameters:
 * 	  Unit		unit -  the unit
 *	Return Value:
 *	  static void	sumSquareError -  NONE
 *********************************************************************/
static void	sumSquareError(unit)
  Unit		unit ;
{
  unit->net->error += square(unit->output - adjustedTarget(unit)) ;
}
/********************************************************************/
static void	sumSquareErrorDeriv(unit)
  Unit		unit ;
{
  unit->outputDeriv += 2.0*(unit->output - adjustedTarget(unit)) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setAdaptiveSumSquareError
 *	Description:	sets the error  procedures for a
 *			group of units to use sum square error
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setAdaptiveSumSquareError - NONE
 *********************************************************************/
void	setAdaptiveSumSquareError(group)
  Group	group ;
{
  McontributeErrorProc(group)	  = adaptiveSumSquareError ;
  McontributeErrorBackProc(group) = adaptiveSumSquareErrorDeriv ;

  MerrorStdDevVar(group)
    = createVariable("StdDeviation", group->net, LOG_TRANSFORM) ;
  variableSetValue(MerrorStdDevVar(group), 1.0) ;
}
/*********************************************************************
 *	Name:		adaptiveSumSquareError/adaptiveSumSquareErrorDeriv
 *	Description:	calculates the sum square error of a unit
 *			with adaptive variance and adds it to the net.
 *			adaptiveSumSquareErrorDeriv adds the derivative to
 *			the units outputDeriv.
 *	Parameters:
 *	  Unit		unit -  the unit
 *	Return Value:
 *	  static void	adaptiveSumSquareError - NONE
 *********************************************************************/
#define K	0.91893853	/* ln(sqrt(2*PI)) */
/********************************************************************/
static void	adaptiveSumSquareError(unit)
  Unit		unit ;
{
  Real	stdDev = variableGetValue(MerrorStdDevVar(unit->group)) ;

  unit->net->error 
    += K*log(stdDev) + square((unit->output - adjustedTarget(unit))/stdDev)/2 ;
}
/********************************************************************/
static void	adaptiveSumSquareErrorDeriv(unit)
  Unit		unit ;
{
  Real	stdDev	 = variableGetValue(MerrorStdDevVar(unit->group)) ;
  Real	variance = square(stdDev) ;

  variableAddToDeriv(MerrorStdDevVar(unit->group),
		     K - square((unit->output - adjustedTarget(unit))/stdDev));

  unit->outputDeriv += (unit->output - adjustedTarget(unit))/variance ;
}
/********************************************************************/
#undef K

/*********************************************************************
 *	Name:		setGated
 *	Description:	sets the error  procedures for a
 *			group of units to use sum square error
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setGated - NONE
 *********************************************************************/
typedef struct _GateErrorRec {
  Group	output ;
  int	idx ;
  Real	sum ;
} GateErrorRec, *GateError ;
#define gateErrorCurrentUnit(g)		((g)->output->unit[(g)->idx])
#define gateErrorSum(g)			((g)->sum)
#define gateErrorAddToSum(g, x)		(gateErrorSum(g) += (x))
#define gateErrorIncrementUnit(g)	(++((g)->idx))
#define gateErrorInit(g, o)		((g)->output = (o), \
					 (g)->idx = 0, (g)->sum = 0.0)
/********************************************************************/
static void	setGaussianFA		ARGS((Unit, void *)) ;
static void	sumErrorsFA		ARGS((Unit, void *)) ;

static void	expertErrorDerivFA	ARGS((Unit, void *)) ;
static void	gateErrorDerivFA	ARGS((Unit, void *)) ;
/********************************************************************/
void	setGated(group)
  Group	group ;
{
  group->groupActivityUpdateProc = gatedForward ;
  group->groupGradientUpdateProc = gatedBackward ;
}
/********************************************************************/
static void	gatedForward (group)
  Group		group ;
{
  Group		gate = Mgate(group) ;

  if (gate == NULL || Mmixture(gate) == NULL)
    return ;

  groupForAllUnits(gate, setGaussianFA, NULL) ;
  gate->net->error += MMcomplexity(Mmixture(gate), 0.0) ;
}
/********************************************************************/
static void	gatedBackward (group)
  Group		group ;
{
  Group		gate = Mgate(group) ;

  if (gate == NULL)
    return ;

  groupForAllUnitsBack(gate, gateErrorDerivFA,	NULL) ;
}
/********************************************************************/


/********************************************************************/
static void	setGaussianFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  Gaussian	gaussian = Mgaussian(unit) ;
  GateErrorRec	gateError ;

  gateErrorInit(&gateError, Mgated(unit->group)) ;
  groupForAllUnits(Mexpert(unit), sumErrorsFA, &gateError) ;

  MGsetMean(gaussian, sqrt(gateErrorSum(&gateError))) ;
  MGsetProportion(gaussian, unit->output) ;
}
/********************************************************************/
static void	sumErrorsFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  GateError	gateError = (GateError)data ;
  Unit		output    = gateErrorCurrentUnit(gateError) ;

  unit->target = output->target ;
  gateErrorAddToSum(gateError, square(unit->output - adjustedTarget(unit))) ;
  gateErrorIncrementUnit(gateError) ;
}
/********************************************************************/


/********************************************************************/
static void	gateErrorDerivFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  Gaussian	gaussian       = Mgaussian(unit) ;
  Real		responsibility = MGresponsibility(gaussian, 0.0) ;
  Real		errorDeriv ;

  groupForAllUnits(Mexpert(unit), expertErrorDerivFA, &responsibility) ;

  errorDeriv	    = -responsibility/MGproportion(gaussian) ;
  unit->outputDeriv += errorDeriv ;
} 
/********************************************************************/
static void	expertErrorDerivFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  Real		responsibility = *(Real *)data ;
  unit->outputDeriv = responsibility * (unit->output - adjustedTarget(unit)) ;
}
/********************************************************************/


/***********************************************************************
 *	Name:		adjustedTarget
 *	Description:	calculates the target value of a unit assuming
 *			a zeroErrorRadius
 *	Parameters:	
 *		const Unit	unit - the unit to calculate the error of
 *	Return Value:	
 *		Real		adjustedTarget - the target of the unit
 ***********************************************************************/
static Real	adjustedTarget(unit)
  const Unit	unit ;
{
  Real	radius = MzeroErrorRadius(unit->net) ;

  if (radius > 0.0) {
    Real	target = unit->target ;
    Real	output = unit->output ;
    Real	upper, lower ;

    if (     output > (upper = target + radius))
      return upper ;
    else if (output < (lower = target - radius))
      return lower ;
    else
      return output ;
  } else {
    return unit->target ;
  }
}
/**********************************************************************/


/*********************************************************************
 *	Name:		setCrossEntropyError
 *	Description:	sets the error  procedures for a
 *			group of units to use sum square error
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setCrossEntropyError - NONE
 *********************************************************************/
void	setCrossEntropyError(group)
  Group	group ;
{
  McontributeErrorProc(group)	  = crossEntropyError ;
  McontributeErrorBackProc(group) = crossEntropyErrorDeriv ;
}
/*********************************************************************
 *	Name:		crossEntropyError/crossEntropyErrorDeriv
 *	Description:	
 *	Parameters:
 *	  Unit		unit -  the unit
 *	Return Value:
 *	  static void	crossEntropyError - NONE
 *********************************************************************/
#define DELTA				(1.0e-6)
#define SIMPLE_CROSS_ENTROPY(y, d)   (-(d)*log(DELTA+y) - (1.0-(d))*log(DELTA+1.0-(y)))
#define CROSS_ENTROPY_ZERO_TARGET(y) ((y) == 0.0 ? 0.0 : -log(DELTA+1.0 - (y)))
#define CROSS_ENTROPY_ONE_TARGET(y)  ((y) == 1.0 ? 0.0 : -log(DELTA+y))
#define CROSS_ENTROPY(y, d)		\
  (((d) == 0.0) ? CROSS_ENTROPY_ZERO_TARGET(y) : 	\
   (((d) == 1.0) ? CROSS_ENTROPY_ONE_TARGET(y) : SIMPLE_CROSS_ENTROPY(y, d)))
/**********************************************************************/
#define CROSS_ENTROPY_DERIV(y, d)		\
  ((d) == 0.0 ?   1.0/(1.0-(y)) : 		\
   ((d) == 1.0 ? -1.0/(y)      : ((y)-(d))/((y)*(1.0-(y)))))
/**********************************************************************/
#define DIVERGENCE(y, d)		((d) == 0.0 ? 0.0 : -(d)*log(DELTA+y))
#define DIVERGENCE_DERIV(y, d)		((d) == 0.0 ? 0.0 : -(d)/(y))
/**********************************************************************/
static void	crossEntropyError(unit)
  Unit		unit ;
{
  Real	output = unit->output ;
  Real	target = unit->target ;
  if (output >= 0.0 && output <= 1.0 && target >= 0.0 && target <= 1.0)
    unit->net->error
      += (unit->group->groupActivityUpdateProc == softMaxForward ?
	  DIVERGENCE(output, target) : CROSS_ENTROPY(output, target)) ;
  else
    unit->net->error = MAXREAL ;
}
/********************************************************************/
static void	crossEntropyErrorDeriv(unit)
  Unit		unit ;
{
  Real	output = unit->output ;
  Real	target = unit->target ;
  if (output >= 0.0 && output <= 1.0 && target >= 0.0 && target <= 1.0)
    unit->outputDeriv 
      += (unit->group->groupGradientUpdateProc == softMaxBackward ?
	  DIVERGENCE_DERIV(output, target) : CROSS_ENTROPY_DERIV(output, target)) ;
  else
    unit->outputDeriv = MAXREAL ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setSoftMax
 *	Description:	sets the forward and backward procedures as well
 *			as the transfer function to use softMax function
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  void	setSoftMax - NONE
 *********************************************************************/
void	setSoftMax(group)
  Group	group ;
{
  group->groupActivityUpdateProc = softMaxForward ;
  group->groupGradientUpdateProc = softMaxBackward ;

  MtransferProc(group)	   = exponentialTransfer ;
  MtransferBackProc(group) = exponentialTransferBack ;
}
/*********************************************************************
 *	Name:		softMaxForward
 *	Description:	does the forward pass for a group of units.
 *			first does combin for all units, then calculates
 *			the sum of all e^inputs, then does transfer and
 *			error contribution
 *	Parameters:
 *	  Group		group - the group to act on
 *	Return Value:
 *	  static void	softMaxGroupForward  - NONE
 *********************************************************************/
static void	combInFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  McombIn(unit) ;
}
/********************************************************************/
static void	combBackFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  McombBack(unit) ;
}
/********************************************************************/
static void	transferFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  Mtransfer(unit) ;
}
/********************************************************************/
static void	softMaxTransferBackFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  Real	sum = *(Real *)data ;
  unit->inputDeriv += unit->output*(unit->outputDeriv - sum) ;
}
/********************************************************************/
static void	sumDerivOutputProdFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  Real	*sum = (Real *)data ;
  *sum += unit->output * unit->outputDeriv ;
}
/********************************************************************/


/********************************************************************/
static void	addErrorToNetFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  McontributeError(unit) ;
}
/********************************************************************/
static void	softMaxErrorBackFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  unit->inputDeriv += unit->output - unit->target ;
}
/********************************************************************/
static void	sumOutputFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  *(Real *)data += unit->output ;
}
/********************************************************************/
static void	divideOutputFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  unit->output /= *(Real *)data ;
}
/********************************************************************/
static void	getMaxInputFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  if (unit->totalInput > *(Real *)data)
    *(Real *)data = unit->totalInput ;
}
/********************************************************************/
static void	addToInputFAU(unit, data)
  Unit	unit ;
  void	*data ;
{
  unit->totalInput += *(Real *)data ;
}
/********************************************************************/
static void	softMaxForward (group)
  Group		group ;
{
  Real		maxInput = -MAXREAL ;
  Real		shift    = 0.0 ;
  UnitProc	proc ;

  if (MtransferProc(group) != exponentialTransfer)
    IErrorAbort("Group \"%s\" is softmax, but doesn't have an \
exponential activation function", group->name) ;

  /* do the combin */
  groupForAllUnits(group, combInFAU,	  NULL) ;

  /* transform input to avoid overflow (ok because we're gonna
   * normalize anyways) */
  groupForAllUnits(group, getMaxInputFAU, &maxInput) ;
  shift = log(MAXREAL/group->numUnits) - maxInput - 2 ;
  
  groupForAllUnits(group, addToInputFAU, &shift) ;

  /* do the trasfer */
  groupForAllUnits(group, transferFAU, NULL) ;

  /* shift the input back where they belong */
  shift *= -1 ;
  groupForAllUnits(group, addToInputFAU, &shift) ;

  /* normalize output */
  MoutputSum(group) = 0.0 ;
  groupForAllUnits(group, sumOutputFAU,		&MoutputSum(group)) ;
  groupForAllUnits(group, divideOutputFAU,	&MoutputSum(group)) ;

  if (group->type & OUTPUT)
    groupForAllUnits(group, addErrorToNetFAU, NULL) ;
}
/********************************************************************/
static void	softMaxBackward (group)
  Group		group ;
{
  if (MtransferBackProc(group) != exponentialTransferBack)
    IErrorAbort("Group \"%s\" is softmax, but doesn't have exponential \
activation function.", group->name) ;
  if ((group->type & OUTPUT)
      && McontributeErrorBackProc(group) != crossEntropyErrorDeriv)
    IErrorAbort("Group \"%s\" is softmax, but doesn't have cross entropy \
error function.", group->name) ;

  if (group->type & OUTPUT)
    groupForAllUnitsBack(group, softMaxErrorBackFAU,  NULL) ;

  MoutputSum(group) = 0.0 ;
  groupForAllUnitsBack(group, sumDerivOutputProdFAU,  &MoutputSum(group)) ;
  groupForAllUnitsBack(group, softMaxTransferBackFAU, &MoutputSum(group)) ;

  groupForAllUnitsBack(group, combBackFAU,	  NULL) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		setStandardForwardBackward
 *	Description:	sets the forward and backward procedures as well
 *			as the transfer function to use standard function
 *	Parameters:
 *	  Group	group - the group
 *	Return Value:
 *	  static void	setStandardForwardBackward - NONE
 *********************************************************************/
void		setStandardForwardBackward(group)
  Group		group ;
{
  group->groupActivityUpdateProc = standardForward ;
  group->groupGradientUpdateProc = standardBackward ;
}
/*********************************************************************
 *	Name:		standardForward
 *	Description:	The standard forward procedure for a unit.
 *			Does combIn, transfer, then error contribution
 *	Parameters:
 *	  Unit		unit - the unit
 *	Return Value:
 *	  static void	NONE
 *********************************************************************/
static void	standardForward (group)
  Group		group ;
{
  groupForAllUnits(group, standardUnitForwardFAU, NULL) ;
}
/********************************************************************/
static void	standardUnitForwardFAU(unit, data)
  Unit		unit ;
  void		*data ;
{
  Group		group = unit->group ;
  UnitProc	proc ;

  if (proc = McombInProc(group))
    proc(unit) ;

  if (proc = MtransferProc(group))
    proc(unit) ;
  
  if (proc = McontributeErrorProc(group))
    proc(unit) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		standardBackward
 *	Description:	The standard backward procedure for a unit.
 *			Does error contribution, transfer, then combBack
 *	Parameters:
 *	  Unit		unit - the unit
 *	Return Value:
 *	  static void	backward - 
 *********************************************************************/
static void	standardBackward (group)
  Group		group ;
{
  groupForAllUnitsBack(group, standardUnitBackwardFAU, NULL) ;
}
/********************************************************************/
static void	standardUnitBackwardFAU(unit, data)
  Unit		unit ;
  void		*data ;
{
  Group		group = unit->group ;
  UnitProc	proc ;

  if (proc = McontributeErrorBackProc(group))
    proc(unit) ;

  if (proc = MtransferBackProc(group))
    proc(unit) ;

  if (proc = McombBackProc(group))
    proc(unit) ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		square
 *	Description:	squares a real valued number
 *	Parameters:
 *	  double	x - the number to square
 *	Return Value:
 *	  static Real	square - x^2
 *********************************************************************/
static Real	square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/********************************************************************/
