
/**********************************************************************
 * $Id: access.c,v 1.5 92/12/01 12:51:22 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>
#include <xerion/simulator.h>
#include "simUtils.h"
#include "access.h"

#define WEIGHT	(int)(1<<1)
#define DERIV	(int)(1<<2)

static void	sumLinkDerivs  ARGS((Net	net)) ;
static void	setLinkWeights ARGS((Net	net)) ;

/***********************************************************************
 *	Name:		getNumberOfVars
 *	Description:	returns the number of degrees of freedom in
 *			a network. This being the number of unconstrained
 *			weights plus the number of constrained sets of
 *			weights
 *	Parameters:	
 *		Net	net - the net to evaluate
 *	Return Value:	
 *		int	getNumberOfVars - the number as above -1 on error
 ***********************************************************************/
int	getNumberOfVars (net)
  Net	net ;
{
  return net->numVariables - net->numFrozenVariables ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		getValueName
 *	Description:	returns the name associated with an index in
 *			the value array
 *	Parameters:	
 *		Net	net - the net to search through
 *		int	n - the index of the value in the array
 *	Return Value:	
 *		char	*getValueName - a STATIC array with the name in
 *				it. Overwritten at each call.
 ***********************************************************************/
char	*getValueName (net, n)
  Net	net ;
  int	n ;
{
  static char	valueName[100];

  Link	*link    = net->links ;
  int	numLinks = net->numLinks - net->numFrozenLinks ;
  int	idx ;

  for (idx = 0 ; idx < numLinks && link[idx]->variableIdx != n ; ++idx)
    ;

  if (idx < numLinks) {
    if (link[idx]->name)
      sprintf(valueName, "%s", link[idx]->name) ;
    else
      sprintf(valueName, "link #%d", idx) ;
  } else {
    strcpy(valueName, "") ;
  }
    
  return valueName ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		getCurrentValues
 *	Description:	returns the current values that represent the
 *			weights in a network.
 *	Parameters:	
 *		Net	net	- the net to get the weights from
 *		int	n       - the size of the array
 *		Real	value[]	- the array of values returned (size from 
 *				  getNumberOfVars above)
 *	Return Value:	
 *		1
 ***********************************************************************/
int	getCurrentValues(net, n, value)
  Net	net ;
  int	n ;
  Real	value[] ;
{
  if (n != net->numVariables - net->numFrozenVariables) {
    simError(SIM_ERR_BADARRAYSIZE,
	     "getCurrentValues: improper array size - need %d, have %d",
	     net->numVariables - net->numFrozenVariables, n) ;
    return 0 ;
  }

  memcpy(value, net->variables, n*sizeof(Real)) ;
  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		setCurrentValues
 *	Description:	sets the weights in a network given a vector
 *			of values. The values do not necessarily
 *			directly correspond to the weights
 *	Parameters:	
 *		Net	net	- the net to set the weights in
 *		int	n       - the size of the array
 *		Real	value[]	- the array of values (size from 
 *				  getNumberOfVars above)
 *	Return Value:	
 *		1
 ***********************************************************************/
int	setCurrentValues(net, n, value)
  Net	net ;
  int	n ;
  Real	value[] ;
{
  if (n != net->numVariables - net->numFrozenVariables) {
    simError(SIM_ERR_BADARRAYSIZE,
	     "setCurrentValues: improper array size - need %d, have %d",
	     net->numVariables - net->numFrozenVariables, n) ;
    return 0 ;
  }

  memcpy(net->variables, value, n*sizeof(Real)) ;

  setLinkWeights(net) ;
  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateGradient
 *	Description:	returns the gradient vector for a net, for the
 *			current values (passed in).
 *			It does this by running the net on the training 
 *			examples and calling all the update actions
 *			with the given weights.
 *	Parameters:	
 *		Net	net	- the net to set the weights in
 *		int	n       - the size of the arrays
 *		Real	value[]	- the array of values (size from 
 *				  getNumberOfVars above)
 *		Real	grad[]	- the array of derivatives to return
 *				  (size from getNumberOfVars above)
 *	Return Value:	
 *		Real	calculateGradient - the network error using the
 *				  given values
 ***********************************************************************/
Real	calculateGradient (net, n, value, grad)
  Net	net ;
  int  	n ;
  Real 	*value ;
  Real	*grad ;
{
  if (n != net->numVariables - net->numFrozenVariables) {
    simError(SIM_ERR_BADARRAYSIZE,
	     "calculateGradient: improper array size - need %d, have %d",
	     net->numVariables - net->numFrozenVariables, n) ;
    return 0.0 ;
  }

  setCurrentValues(net, n, value) ;

  if (net->trainingExampleSet != NULL) {
    int	oldBatchSize = net->batchSize ;
    if (oldBatchSize <= 0)
      net->batchSize = net->trainingExampleSet->numExamples ;
    
    MupdateNetErrorDerivs(net, net->trainingExampleSet) ;

    net->batchSize = oldBatchSize ;
  }

  memset(net->gradient, (int)0, n*sizeof(Real)) ;

  sumLinkDerivs(net) ;

  memcpy(grad, net->gradient, n*sizeof(Real)) ;

  return (Real)(net->error + net->cost) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		evaluate
 *	Description:	evaluates the network using a the current weights.
 *			They must be passed in.
 *			It does this by running the net on the training 
 *			examples and returning the net error.
 *	Parameters:	
 *		Net	net	- the net to set the weights in
 *		int	n       - the size of the array
 *		Real	value[]	- the array of values (size from 
 *				  getNumberOfVars above)
 *	Return Value:	
 *		Real	evaluate - the network error using the
 *				  given values
 ***********************************************************************/
Real	evaluate(net, n, value)
  Net	net ; 
  int	n ;
  Real	value[] ;
{
  if (n != net->numVariables - net->numFrozenVariables) {
    simError(SIM_ERR_BADARRAYSIZE,
	     "evaluate: improper array size - need %d, have %d",
	     net->numVariables - net->numFrozenVariables, n) ;
    return 0.0 ;
  }

  setCurrentValues(net, n, value) ;

  if (net->trainingExampleSet != NULL) {
    int	oldBatchSize = net->batchSize ;
    if (oldBatchSize <= 0)
      net->batchSize = net->trainingExampleSet->numExamples ;
    
    MupdateNetError(net, net->trainingExampleSet) ;

    net->batchSize = oldBatchSize ;
  }

  return (Real)(net->error + net->cost) ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		syncValues
 *	Description:	synchronizes the representations of the network
 *			weights and variables
 *	Parameters:
 *	  Net		net - the network to synchronize
 *	  SyncDirection direction - the direction to synchronize in
 *			VectorFromWeight, WeightFromVector
 *	Return Value:
 *	  void		syncValues - NONE
 *********************************************************************/
void		syncValues(net, direction)
  Net		net ;
  SyncDirection direction ;
{
  Link	*linkArray = net->links ;
  Real	*variable  = net->variables ;
  int	numLinks   = net->numLinks ;
  int	idx ;
  
  switch(direction) {
  case VectorFromWeights:
    memset(variable, (int)0, net->numVariables*sizeof(Real)) ;
    for (idx = 0 ; idx < numLinks ; ++idx) {
      Link        link  = linkArray[idx] ;
      switch (link->type & LOG_TRANSFORM) {
      case LOG_TRANSFORM:
	variable[link->variableIdx]  = log(link->weight/link->scaleFactor) ;
	break ;
      default:
	variable[link->variableIdx]  = link->weight/link->scaleFactor ;
	break ;
      }
    }
    sumLinkDerivs(net) ;
    break ;
  case WeightsFromVector:
    setLinkWeights(net) ;
  default:
    break ;
  }
}
/********************************************************************/


/**********************************************************************/
int	setDelta(net, n, start, end)
  Net  net ;
  int  n ;
  Real start[] ;
  Real end[] ;
{
  Link	*links   = net->links ;
  int	numLinks = net->numLinks ;
  int	idx ;

  if (n != net->numVariables - net->numFrozenVariables) {
    simError(SIM_ERR_BADARRAYSIZE,
	     "setDelta: improper array size - need %d, have %d",
	     net->numVariables - net->numFrozenVariables, n) ;
    return 0 ;
  }

  for (idx = 0 ; idx < numLinks ; ++idx) {
    Link	link   = links[idx] ;
    int		varIdx = link->variableIdx  ;
    switch (link->type & LOG_TRANSFORM) {
    case LOG_TRANSFORM:
      link->deltaWeight =  link->scaleFactor * (exp(end[varIdx]) 
						- exp(start[varIdx])) ;
      break ;
    default:
      link->deltaWeight =  link->scaleFactor * (end[varIdx] - start[varIdx]) ;
      break ;
    }
  }
  return 1 ;
}
/**********************************************************************/
static void	setLinkWeights(net)
  Net		net ;
{
  Real	*array     = net->variables ;
  Link	*linkArray = net->links ;
  int	numLinks   = net->numLinks ;
  int	idx ;

  if (net->unitGains) {
    for (idx = 0 ; idx < numLinks ; ++idx) {
      Link	link = linkArray[idx] ;
      switch (link->type & LOG_TRANSFORM) {
      case LOG_TRANSFORM:
	link->weight =  exp(array[link->variableIdx]) ;
	break ;
      default:
	link->weight =  array[link->variableIdx] ;
	break ;
      }
    }
  } else {
    for (idx = 0 ; idx < numLinks ; ++idx) {
      Link	link = linkArray[idx] ;
      switch (link->type & LOG_TRANSFORM) {
      case LOG_TRANSFORM:
	link->weight =  link->scaleFactor*exp(array[link->variableIdx]) ;
	break ;
      default:
	link->weight =  link->scaleFactor*array[link->variableIdx] ;
	break ;
      }
    }
  }
}
/**********************************************************************/
static void	sumLinkDerivs(net)
  Net		net ;
{
  Real	*array     = net->gradient ;
  Link	*linkArray = net->links ;
  int	numLinks   = net->numLinks ;
  int	idx ;

  if (net->unitGains) {
    for (idx = 0 ; idx < numLinks ; ++idx) {
      Link	link = linkArray[idx] ;
      switch (link->type & LOG_TRANSFORM) {
      case LOG_TRANSFORM:
	array[link->variableIdx] += link->deriv*link->weight ;
	break ;
      default:
	array[link->variableIdx] += link->deriv ;
	break ;
      }
    }
  } else {
    for (idx = 0 ; idx < numLinks ; ++idx) {
      Link	link = linkArray[idx] ;
      switch (link->type & LOG_TRANSFORM) {
      case LOG_TRANSFORM:
	array[link->variableIdx] += link->deriv*link->weight ;
	break ;
      default:
	array[link->variableIdx] += link->deriv/link->scaleFactor ;
	break ;
      }
    }
  }
}
/**********************************************************************/
