
/**********************************************************************
 * $Id: net.c,v 1.12 93/01/26 11:20:50 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>

#include <xerion/simulator.h>
#include <xerion/minimize.h>
#include "simUtils.h"
#include "net.h"
#include "traverse.h"
#include "access.h"
#include "gaussian.h"
#include "costModel.h"

#ifndef MIN
#define MIN(x,y)	((x) < (y) ? (x) : (y))
#endif

/***********************************************************************
 *	Private functions
 ***********************************************************************/
static void	initialize ARGS((Net	net)) ;
static void	finalize   ARGS((Net	net)) ;

static int	netGetNVars  ARGS((Minimize)) ;
static int	netGetValues ARGS((Minimize, int n, Real *x)) ;
static int	netSetValues ARGS((Minimize, int n, Real *x)) ;
static Real	netFEval     ARGS((Minimize, int n, Real *x)) ;
static Real	netFGEval    ARGS((Minimize, int n, Real *x, Real *g)) ;
static String	netValueName ARGS((Minimize, int i)) ;
static void	netIncrementEpoch ARGS((Minimize)) ;

static void	updateErrorDerivs ARGS((Net	net, ExampleSet	exampleSet)) ;
static void	updateError       ARGS((Net	net, ExampleSet	exampleSet)) ;

static void	netActivityUpdate ARGS((Net	net)) ;
static void	netGradientUpdate ARGS((Net	net)) ;

static void	groupActivityUpdate ARGS((Group	group, void	*data)) ;
static void	groupGradientUpdate ARGS((Group	group, void	*data)) ;

static int	checkArraySize ARGS((Net net)) ;
static Link	linkFromVarIdx ARGS((Net net, int  varIdx, int *linkIdx)) ;
static Link	linkFromPtr    ARGS((Net net, Link link,   int *linkIdx)) ;
static Real	*rotateRealArray ARGS((Real *base, 
				       int numElements, int step)) ;
static Link	*rotateLinkArray ARGS((Link *base, 
				       int numElements, int step)) ;
static Link	*findConstrainedLinks ARGS((Net net, int 
					    varIdx, int *numConstrained)) ;
static void	deleteVariable ARGS((Net net, int varIdx)) ;

static void	cost		ARGS((Net)) ;
static void	costAndDerivs	ARGS((Net)) ;
static void	updateCost	    ARGS((Group, void *)) ;
static void	updateCostAndDerivs ARGS((Group, void *)) ;

/***********************************************************************
 *	Private variables
 ***********************************************************************/
static NetProc	createNetHook = NULL ;
static NetProc	destroyNetHook = NULL ;
/**********************************************************************/


/***********************************************************************
 *	Name:		setCreateNetHook  (setDestroyNetHook)
 *	Description:	sets the user hook to be called after (before)
 *			a net is created (destroyed).
 *	Parameters:	
 *		NetProc	hook - the procedure to be called:
 *				void hook (Net	net)
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void	setCreateNetHook(hook)
  NetProc	hook ;
{
  createNetHook = hook ;
}
/**********************************************************************/
void	setDestroyNetHook(hook)
  NetProc	hook ;
{
  destroyNetHook = hook ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		createNet 
 *	Description:	creates a net and initializes all of its data,
 *			and calls createNetHook if it is not NULL.
 *	Parameters:	
 *		char	*name - the name for the net
 *		Mask	mask  - a mask identifying the type of net
 *		int	timeSlices - the number of time slices for the net.
 *	Return Value:	
 *		Net	createNet - the new net
 ***********************************************************************/
Net	createNet (name, mask, timeSlices)
  char	*name ;
  Mask	mask ;
  int	timeSlices ;
{
  Net	net ;
  net = (Net)malloc(sizeof(NetRec)) ;

  if (net == NULL) {
    simError(SIM_ERR_MALLOC, "createNet \"%s\"", name) ;
    return net ;
  }
  net->name = strdup(name) ;

  if (!(mask & ALL))
    mask = UNKNOWN ;
  net->type = mask ;
  net->timeSlices = timeSlices ;

  initialize(net) ;

  if (createNetHook != NULL)
    (*createNetHook)(net) ;

  return net ;
}
/***********************************************************************/


/***********************************************************************
 *	Name:		initialize
 *	Description:	init proc for the Net class of objects
 *		Net	net - the net object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	initialize (net)
  Net		net ;
{
  net->error        = 0.0 ;
  net->cost         = 0.0 ;
  net->currentEpoch = 0 ;
  net->batchSize    = 0 ;
  
  net->minorWidth  = 0 ;
  net->minorHeight = 0 ;
  net->minorMargin = 0 ;
  net->majorWidth  = 0 ;
  net->majorHeight = 0 ;
  net->majorMargin = 0 ;
  net->activityWidth  = 0 ;
  net->activityHeight = 0 ;
  net->activityMargin = 0 ;
  net->linkLayoutMode = SHOW_INCOMING ;

  net->calculateErrorDerivProc	= updateErrorDerivs ;
  net->calculateErrorProc	= updateError ;
  net->activityUpdateProc = netActivityUpdate ;
  net->gradientUpdateProc = netGradientUpdate ;

  net->weightCost	= 0.0 ;
  net->costProc		= cost ;
  net->costAndDerivsProc= costAndDerivs ;

  net->group  	      = NULL ;
  net->numGroups      = 0 ;
  net->groupArraySize = 0 ;
  
  net->variables          = NULL ;
  net->gradient           = NULL ;
  net->numVariables       = 0 ;
  net->numFrozenVariables = 0 ;
  net->variableArraySize  = 0 ;
  
  net->links          = NULL ;
  net->numLinks       = 0 ;
  net->numFrozenLinks = 0 ;
  net->linkArraySize  = 0 ;
  
  net->unitGains      = TRUE ;
  
  net->currentTime = 0 ;

  net->mz		    = NULL ;
  net->trainingExampleSet   = NULL ;
  net->testingExampleSet    = NULL ;
  net->validationExampleSet = NULL ;
  net->costModel	    = NULL ;
  net->extension	    = NULL ;

  net->mz = initMinimize(net->mz, 0) ;

  setMinimizeUserData(net->mz, (void *)net) ;

  setMinimizeMethod(net->mz, MZGETNVARS,  (UnknownProc)netGetNVars) ;
  setMinimizeMethod(net->mz, MZGETVALUES, (UnknownProc)netGetValues) ;
  setMinimizeMethod(net->mz, MZSETVALUES, (UnknownProc)netSetValues) ;
  setMinimizeMethod(net->mz, MZFEVAL,	  (UnknownProc)netFEval) ;
  setMinimizeMethod(net->mz, MZGEVAL,	  (UnknownProc)NULL) ;
  setMinimizeMethod(net->mz, MZFGEVAL,	  (UnknownProc)netFGEval) ;
  setMinimizeMethod(net->mz, MZVALUENAME, (UnknownProc)netValueName) ;
  setMinimizeMethod(net->mz, MZINCITER,	  (UnknownProc)netIncrementEpoch) ;

  net->trainingExampleSet   = createExampleSet("Training",   TRAINING,   net) ;
  net->testingExampleSet    = createExampleSet("Testing",    TESTING,    net) ;
  net->validationExampleSet = createExampleSet("Validation", VALIDATION, net) ;

  McostModel(net) = createSumSquareCostModel(net->name, net) ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		procedures for interfacing the network functions
 *			with the minimize functions
 *	Description:	See individual procedures
 *	Parameters:	ditto
 *	Return Value:	ditto
 *********************************************************************/
static int	netGetNVars(mz)
  Minimize	mz ;
{
  return getNumberOfVars((Net)mz->userData) ;
}
/***********************************************************************/
static int	netGetValues(mz, n, x)
  Minimize	mz ;
  int		n ;
  Real		*x ;
{
  return getCurrentValues((Net)mz->userData, n, x) ;
}
/***********************************************************************/
static int	netSetValues(mz, n, x)
  Minimize	mz ;
  int		n ;
  Real		*x ;
{
  return setCurrentValues((Net)mz->userData, n, x) ;
}
/***********************************************************************/
static Real	netFEval(mz, n, x)
  Minimize	mz ;
  int		n ;
  Real		*x ;
{
  return evaluate((Net)mz->userData, n, x) ;
}
/***********************************************************************/
static Real	netFGEval(mz, n, x, g)
  Minimize	mz ;
  int		n ;
  Real		*x, *g ;
{
  return calculateGradient((Net)mz->userData, n, x, g) ;
}
/***********************************************************************/
static String	netValueName(mz, i)
  Minimize	mz ;
  int		i ;
{
  return getValueName((Net)mz->userData, i) ;
}
/***********************************************************************/
static void	netIncrementEpoch(mz)
  Minimize	mz ;
{
  Net	net = (Net)mz->userData ;

  ++net->currentEpoch ;
  setDelta(net, mz->n, mz->start, mz->x) ;
}
/***********************************************************************/


/***********************************************************************
 *	Name:		destroyNet 
 *	Description:	destroys a net
 *	Parameters:	
 *		Net	net - the net to destroy
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void	destroyNet (net)
  Net		net ;
{
  if (destroyNetHook != NULL)
    (*destroyNetHook)(net) ;

  /* HACK!!! used to avoid traversing all the links to keep indices
   * correct when we delete each link */
  net->numVariables = 0 ;

  finalize(net) ;

  if (McostModel(net))
    MCMdestroy(McostModel(net)) ;

  free (net->name) ;
  free (net) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		finalize
 *	Description:	finalize proc for the Net class of objects.
 *		Net	net - the net object ;
 *		void	*data - client data (unused)
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	finalize (net)
  Net		net ;
{
  while (net->numGroups > 0)
    destroyGroup(net->group[net->numGroups-1]) ;

  if (net->trainingExampleSet != NULL)
    destroyExampleSet(net->trainingExampleSet) ;
  if (net->testingExampleSet != NULL)
    destroyExampleSet(net->testingExampleSet) ;
  if (net->validationExampleSet != NULL)
    destroyExampleSet(net->validationExampleSet) ;
  
  if (net->group != NULL)
    free ((void *)net->group) ;
  
  if (net->variables != NULL)
    free ((void *)net->variables) ;
  if (net->gradient != NULL)
    free ((void *)net->gradient) ;
  if (net->links != NULL)
    free ((void *)net->links) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		updateErrorDerivs
 *	Description:	proc for calculating error and derivs for
 *			the net class of objects 
 *		Net	net - the net object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	updateErrorDerivs (net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  fprintf(stderr, 
	  "No derivative calculation procedure registered for net \"%s\"\n",
 	  net->name) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		updateError
 *	Description:	proc for the calculating the error for the
 *			net class of objects 
 *		Net	net - the net object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	updateError (net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  fprintf(stderr, "No error calculation procedure registered for net \"%s\"\n",
 	  net->name) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		cost
 *	Description:	calculates cost of a network using the costModel
 *			in it and its groups
 *		Net	net - the net object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	cost (net)
  Net		net ;
{
  net->cost = 0.0 ;

  if (McostModel(net) && MweightCost(net))
    net->cost += MCMevaluateCost(McostModel(net), MweightCost(net)) ;

  netForAllGroups(net, ALL, updateCost, NULL) ;
}
/**********************************************************************/
static void	updateCost(group, data)
  Group		group ;
  void		*data ;
{
  Net	net = group->net ;

  if (McostModel(group) && MweightCost(net))
    net->cost += MCMevaluateCost(McostModel(group), MweightCost(net)) ;
}
/********************************************************************/


/***********************************************************************
 *	Name:		costAndDerivs
 *	Description:	calculates cost of a network and associated 
 *			derivatives, using the costModel in it and
 *			its groups.
 *		Net	net - the net object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	costAndDerivs (net)
  Net		net ;
{
  net->cost = 0.0 ;

  if (McostModel(net) && MweightCost(net))
    net->cost += MCMevaluateCostAndDerivs(McostModel(net), MweightCost(net)) ;

  netForAllGroups(net, ALL, updateCostAndDerivs, NULL) ;
}
/**********************************************************************/
static void	updateCostAndDerivs(group, data)
  Group		group ;
  void		*data ;
{
  Net	net = group->net ;

  if (McostModel(group) && MweightCost(net))
    net->cost += MCMevaluateCostAndDerivs(McostModel(group), 
					  MweightCost(net)) ;
}
/********************************************************************/


/***********************************************************************
 *	Name:		netActivityUpdate
 *	Description:	calls the groupActivityUpdate for all groups
 *			if not NULL
 *		Group	group - the group object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	netActivityUpdate (net)
  Net		net ;
{
  netForAllOtherGroups (net, INPUT | BIAS, groupActivityUpdate, NULL) ;
}
/**********************************************************************/
static void	groupActivityUpdate (group, data)
  Group		group ;
  void		*data ;
{
  if (group->groupActivityUpdateProc != NULL)
    MupdateGroupActivities(group) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		groupGradientUpdate
 *	Description:	calls the groupGradientUpdate for all groups
 *			if not NULL
 *		Group	group - the group object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	netGradientUpdate (net)
  Net		net ;
{
  netForAllGroupsBack(net, ALL, groupGradientUpdate, NULL) ;
}
/**********************************************************************/
static void	groupGradientUpdate (group, data)
  Group		group ;
  void		*data ;
{
  if (group->groupGradientUpdateProc != NULL)
    MupdateGroupGradients(group) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netAddGroup 
 *	Description:	adds a group to net.
 *	Parameters:	
 *		Net	net   - the net to add the groups to
 *		Group	group - the group to add
 *	Return Value:	
 *		NONE
 ***********************************************************************/
#ifndef GRANULARITY
#define GRANULARITY 32
#define DEFINED_GRANULARITY
#endif

int	netAddGroup (net, group)
  Net	net ;
  Group	group ;
{
  Group	**arrayPtr ;
  int	*numPtr ;

  if (net->numGroups >= net->groupArraySize) {
    arrayPtr = &net->group ;
    numPtr   = &net->groupArraySize ;
      
    if (*numPtr == 0) {
      /* set to some default initial size */
      *numPtr = net->numGroups + GRANULARITY ;
      *arrayPtr = (Group *)malloc ((*numPtr)*sizeof(Group)) ;
    } else {
      /* double the size of the array */
      *numPtr = 2 * net->numGroups ;
      *arrayPtr = (Group *)realloc (*arrayPtr, (*numPtr)*sizeof(Group)) ;
    }
    if (*arrayPtr == NULL) {
      simError(SIM_ERR_MALLOC, "netAddGroup net: \"%s\", group: \"%s\"", 
	       net->name, group->name) ;
      return 0 ;
    }
  }

  net->group[net->numGroups] = group ;
  ++(net->numGroups) ;
  return 1 ;
}
#ifdef DEFINED_GRANULARITY
#undef DEFINED_GRANULARITY
#undef GRANULARITY
#endif
/**********************************************************************/


/***********************************************************************
 *	Name:		netDeleteGroup 
 *	Description:	deletes a group from a network
 *	Parameters:	
 *		Net	net   - the net to delete from
 *		Group	group - the group to delete
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void	netDeleteGroup (net, group)
  Net	net ;
  Group	group ;
{
  int	idx ;

  for (idx = 0 ; idx < net->numGroups && net->group[idx] != group ; ++idx)
    ;

  if (net->group != NULL && net->group[idx] == group) {
    for ( ; idx < net->numGroups ; ++idx)
      net->group[idx] = net->group[idx + 1] ;
    --(net->numGroups) ;
  }
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netGetExampleSet
 *	Description:	returns the requested example set for the net
 *	Parameters:	
 *		Net	net  - the network to get the example set from
 *		Mask	mask - a mask saying which set to get.
 *	Return Value:	
 *		ExampleSet - the requested example set, NULL on error
 ***********************************************************************/
ExampleSet	netGetExampleSet(net, mask)
  Net	net ;
  Mask	mask ;
{
  ExampleSet	exampleSet ;
  switch (mask) {
  case TRAINING:
    exampleSet = net->trainingExampleSet ;
    break ;
  case TESTING:
    exampleSet =  net->testingExampleSet ;
    break ;
  case VALIDATION:
    exampleSet =  net->validationExampleSet ;
    break ;
  default:
    exampleSet =  NULL ;
    break ;
  }
  return exampleSet ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		rotateRealArray
 *	Description:	rotates elements in an array of Reals by a given
 *			step size
 *	Parameters:
 *	  Real		*base - the array of Reals
 *	  int		numElements - the number of elements in the array
 *	  int		step - the step to shift ( if < 0 then rotate
 *				elements back, otherwise up)
 *	Return Value:
 *	  static Real	*rotateRealArray - NULL on error, otherwise the array
 *********************************************************************/
static Real	*rotateRealArray(base, numElements, step)
  Real		*base ;
  int		numElements ;
  int		step ;
{
  Real		*tmpArray, tmpValue ;
  int		stepSize = abs(step) ;
  int		idx, offset ;

  if (base == NULL || numElements == 0 || step == 0)
    return base ;

  if (stepSize > numElements)
    return NULL ;

  numElements -= stepSize ;
  offset       = (step < 0) ? 0 : numElements ;
  if (stepSize > 1) {
    tmpArray = (Real *)calloc(stepSize, sizeof(*base)) ;
    memcpy(tmpArray, base + offset, stepSize*sizeof(*base)) ;
  } else {
    tmpValue = base[offset] ;
  }

  if (step < 0) {
    for (idx = 0 ; idx < numElements ; ++idx) 
      base[idx] = base[idx+stepSize] ;
  } else {
    for (idx = numElements - 1 ; idx >= 0 ; --idx) 
      base[idx+stepSize] = base[idx] ;
  }

  offset = (step < 0) ? numElements : 0 ;
  if (stepSize > 1) {
    memcpy(base + offset, tmpArray, stepSize*sizeof(*base)) ;
    free(tmpArray) ;
  } else {
    base[offset] = tmpValue ;
  }

  return base ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		rotateLinkArray
 *	Description:	rotates elements in an array of Links by a given
 *			step size
 *	Parameters:
 *	  Link		*base - the array of Links
 *	  int		numElements - the number of elements in the array
 *	  int		step - the step to shift ( if < 0 then rotate
 *				elements back, otherwise up)
 *	Return Value:
 *	  static Link	*rotateLinkArray - NULL on error, otherwise the array
 *********************************************************************/
static Link	*rotateLinkArray(base, numElements, step)
  Link		*base ;
  int		numElements ;
  int		step ;
{
  Link		*tmpArray, tmpValue ;
  int		stepSize = abs(step) ;
  int		idx, offset ;

  if (base == NULL || numElements == 0 || step == 0)
    return base ;

  if (stepSize > numElements)
    return NULL ;

  numElements -= stepSize ;
  offset       = (step < 0) ? 0 : numElements ;
  if (stepSize > 1) {
    tmpArray = (Link *)calloc(stepSize, sizeof(*base)) ;
    memcpy(tmpArray, base + offset, stepSize*sizeof(*base)) ;
  } else {
    tmpValue = base[offset] ;
  }

  if (step < 0) {
    for (idx = 0 ; idx < numElements ; ++idx) 
      base[idx] = base[idx+stepSize] ;
  } else {
    for (idx = numElements - 1 ; idx >= 0 ; --idx) 
      base[idx+stepSize] = base[idx] ;
  }

  offset = (step < 0) ? numElements : 0 ;
  if (stepSize > 1) {
    memcpy(base + offset, tmpArray, stepSize*sizeof(*base)) ;
    free(tmpArray) ;
  } else {
    base[offset] = tmpValue ;
  }

  return base ;
}
/********************************************************************/


/**********************************************************************/
int	compareLinks(p1, p2)
  const void	*p1 ;
  const void	*p2 ;
{
  int	returnVal = (*(Link *)p1)->variableIdx - (*(Link *)p2)->variableIdx ;

  if (returnVal) return returnVal ;
  else		 return *(Link *)p1 - *(Link *)p2 ;
}
/**********************************************************************/
int	compareVarIdx(p1, p2)
  const void	*p1 ;
  const void	*p2 ;
{
  return (*(Link *)p1)->variableIdx - (*(Link *)p2)->variableIdx ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		linkFromPtr
 *	Description:	returns a link and its index in the link array
 *	Parameters:
 *	  Net		net - the net to search in
 *	  Link		link - the link to look for
 *	  int		*idx - the index of the link (returned)
 *	Return Value:
 *	  static Link	linkFromPtr - the link, NULL, if not found.
 *********************************************************************/
static Link	linkFromPtr(net, link, idx)
  Net		net ;
  Link		link ;
  int		*idx ;
{
  int	numLinks ;
  Link	*linkPtr, *links ;

  if (link->type & FROZEN) {
    links    = net->links + net->numLinks - net->numFrozenLinks ;
    numLinks = net->numFrozenLinks ;
  } else {
    links    = net->links ;
    numLinks = net->numLinks - net->numFrozenLinks ;
  }

  linkPtr = (Link *)bsearch(&link, links, numLinks, sizeof(Link),compareLinks);

  if (linkPtr == NULL || *linkPtr != link)
    return NULL ;
  *idx = linkPtr - net->links ;

  return *linkPtr ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		linkFromVarIdx
 *	Description:	returns a link and its idx in the link array
 *	Parameters:
 *	  Net		net - the net to search in
 *	  int		varIdx - the variable idx to use as key
 *	  int		*linkIdx - the index of the link (returned)
 *	Return Value:
 *	  static Link	linkFromVarIdx - the link, NULL, if not found.
 *********************************************************************/
static Link	linkFromVarIdx(net, varIdx, linkIdx)
  Net		net ;
  int		varIdx ;
  int		*linkIdx ;
{
  int	idx, numLinks ;
  Link	*linkPtr, *links ;
  static Link	link = NULL ;

  if (link == NULL)
    link = createLink("tmp", NULL, NULL, UNKNOWN) ;

  if (varIdx >= net->numVariables - net->numFrozenVariables) {
    links    = net->links + net->numLinks - net->numFrozenLinks ;
    numLinks = net->numFrozenLinks ;
  } else {
    links    = net->links ;
    numLinks = net->numLinks - net->numFrozenLinks ;
  }

  link->variableIdx = varIdx ;
  linkPtr = (Link *)bsearch(&link,links,numLinks,sizeof(Link),compareVarIdx);

  if (linkPtr == NULL)
    return NULL ;

  /* search backwards for the first link with this variableIdx */
  for (idx = linkPtr - links ; 
       idx >= 0 && links[idx]->variableIdx == varIdx ; --idx)
    ;
  ++idx ;
  *linkIdx = links - net->links + idx ;

  return links[idx] ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		findConstrainedLinks
 *	Description:	returns a pointer to an array of links constrained
 *			with the one requested. DO NOT ALTER ARRAY
 *	Parameters:
 *	  Net		net - the net to search
 *	  int		varIdx - the shared variableIdx
 *	  int		*numConstrained - number of links constrained together
 *				(including the one passed in)
 *	Return Value:
 *	  static Link	*findConstrainedLinks - the array of constrained
 *				links
 *********************************************************************/
static Link	*findConstrainedLinks(net, varIdx, numConstrained)
  Net		net ;
  int		varIdx ;
  int		*numConstrained ;
{
  int	idx, maxIdx, linkIdx ;
  Link	*links, link ;

  *numConstrained = 0 ;
  link = linkFromVarIdx(net, varIdx, &linkIdx) ;
  if (link == NULL)
    return NULL ;

  if (link->type & FROZEN)
    maxIdx = net->numLinks ;
  else
    maxIdx = net->numLinks - net->numFrozenLinks ;

  links = net->links ;
  for (idx = linkIdx ; 
       idx < maxIdx && links[idx]->variableIdx == varIdx ; ++idx)
    ;
  *numConstrained = idx - linkIdx ;

  return links + linkIdx ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		deleteVariable
 *	Description:	deletes a variable (frozen or otherwise from
 *			a network
 *	Parameters:
 *	  Net		net - the net to delete the variable from
 *	  int		varIdx - the index  of the variable
 *	Return Value:
 *	  static void	deleteVariable - NONE
 *********************************************************************/
static void	deleteVariable(net, varIdx) 
  Net		net ;
  int		varIdx ;
{
  int		numVars = net->numVariables ;
  Real		*vars   = net->variables ;
  Real		*grad   = net->gradient ;
  Link		*links  = net->links ;
  Link		link ;
  int		linkIdx ;

  link = linkFromVarIdx(net, varIdx, &linkIdx) ;
  if (link == NULL)
    return ;

  rotateRealArray(vars + varIdx, numVars - varIdx, -1) ;
  rotateRealArray(grad + varIdx, numVars - varIdx, -1) ;
  --(net->numVariables) ;
  if (link->type & FROZEN)
    --(net->numFrozenVariables) ;

  for (++linkIdx ; linkIdx < net->numLinks ; ++linkIdx)
    --(links[linkIdx]->variableIdx) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		checkArraySize
 *	Description:	ensures that the arrays of variables and links
 *			in a net are big enough to add another link
 *	Parameters:
 *	  Net		net - the net to work on
 *	Return Value:
 *	  static int	checkArraySize - 0 if fails, 1 if success
 *********************************************************************/
#ifndef GRANULARITY
#define GRANULARITY 1024
#define DEFINED_GRANULARITY
#endif
static int	checkArraySize(net)
  Net		net ;
{
  if (net->numVariables >= net->variableArraySize) {
    Real	*vars 	  = net->variables ;
    Real	*grad     = net->gradient ;
    int		arraySize = net->variableArraySize ;

    if (arraySize == 0) {
      arraySize = net->numVariables + GRANULARITY ;
      vars = (Real *)malloc (arraySize*sizeof(Real)) ;
      grad = (Real *)malloc (arraySize*sizeof(Real)) ;
    } else {
      arraySize = 2 * net->numVariables ;
      vars = (Real *)realloc(vars, arraySize*sizeof(Real)) ;
      grad = (Real *)realloc(grad, arraySize*sizeof(Real)) ;
    }
    if (vars == NULL || grad == NULL) {
      return 0 ;
    } else {
      net->variableArraySize = arraySize ;
      net->variables    = vars ;
      net->gradient     = grad ;
    }
  }

  if (net->numLinks >= net->linkArraySize) {
    Link	*links    = net->links ;
    int		arraySize = net->linkArraySize ;

    if (arraySize == 0) {
      arraySize = net->numLinks + GRANULARITY ;
      links = (Link *)malloc (arraySize*sizeof(Link)) ;
    } else {
      arraySize = 2 * net->numLinks ;
      links = (Link *)realloc(links, arraySize*sizeof(Link));
    }
    if (links == NULL) {
      return 0 ;
    } else {
      net->linkArraySize = arraySize ;
      net->links         = links ;
    }
  }
  return 1 ;
}
#ifdef DEFINED_GRANULARITY
#undef DEFINED_GRANULARITY
#undef GRANULARITY
#endif
/********************************************************************/


/***********************************************************************
 *	Name:		netAddLink 
 *	Description:	adds a link to net.
 *	Parameters:	
 *		Net	net   - the net to add the links to
 *		Link	link - the link to add
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	addToCostModel ARGS((Net, Link)) ;
/********************************************************************/
int	netAddLink (net, link)
  Net	net ;
  Link	link ;
{
  if (checkArraySize(net) == 0) {
    simError(SIM_ERR_MALLOC, "netAddLink net: \"%s\", link: \"%s\"", 
	     net->name, link->name) ;
    return 0 ;
  }

  addToCostModel(net, link) ;

  /* add the link to the proper arrays */
  net->links[net->numLinks]         = link ;
  net->variables[net->numVariables] = 0.0 ;
  net->gradient[net->numVariables]  = 0.0 ;
  ++(net->numLinks) ;
  ++(net->numVariables) ;

  /* if the link is FROZEN, stick it at the end of the array */
  if (link->type & FROZEN) {
    link->variableIdx = net->numVariables - 1 ;
    ++(net->numFrozenVariables) ;
    ++(net->numFrozenLinks) ;
  /* if there are no FROZEN links, stick it at the end of the array */
  } else if (net->numFrozenLinks == 0) {
    link->variableIdx = net->numVariables - 1 ;
  /* if there are FROZEN links, stick it before them in the array */
  } else {
    /* NOTE: add/subtract 1 from offsets because we've already 
     * incremented numVariables and numLinks */
    int	varOffset  = net->numVariables - net->numFrozenVariables - 1 ;
    int linkOffset = net->numLinks - net->numFrozenLinks - 1 ;
    int	numLinks   = net->numFrozenLinks + 1 ;
    int	numVars    = net->numFrozenVariables + 1 ;
    int	idx ;

    rotateLinkArray(net->links + linkOffset,    numLinks, 1) ;
    rotateRealArray(net->variables + varOffset, numVars,  1) ;
    rotateRealArray(net->gradient  + varOffset, numVars,  1) ;
    link->variableIdx = varOffset ;

    /* increment all frozen links' variableIdx's */
    for (idx = 1 ; idx < numLinks ; ++idx)
      ++(net->links[linkOffset + idx]->variableIdx) ;
  }	

  return 1 ;
}
/**********************************************************************/
static void	addToCostModel(net, link)
  Net		net ;
  Link		link ;
{
  CostModel 	costModel ;
  Group		group ;

  if (link->preUnit)
    group = link->preUnit->group ;
  else if (link->postUnit)
    group = link->postUnit->group ;
  else
    group = NULL ;

  if (group && McostModel(group))
    MCMaddLink(McostModel(group), link) ;
  else if (net && McostModel(net))
    MCMaddLink(McostModel(net), link) ;
}
/********************************************************************/


/***********************************************************************
 *	Name:		netDeleteLink 
 *	Description:	deletes a link from a network
 *	Parameters:	
 *		Net	net  - the net to delete from
 *		Link	link - the link to delete
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	removeFromCostModel ARGS((Net, Link)) ;
/**********************************************************************/
void	netDeleteLink (net, link)
  Net	net ;
  Link	link ;
{
  int	linkIdx, setSize ;
  Link	*linkSet ;

  link = linkFromPtr(net, link, &linkIdx) ;
  if (link == NULL)
    return ;

  removeFromCostModel(net, link) ;

  /* If this is the only link with this variableIdx */
  linkSet = findConstrainedLinks(net, link->variableIdx, &setSize) ;
  if (setSize == 1)
    deleteVariable(net, link->variableIdx) ;

  rotateLinkArray(net->links + linkIdx, net->numLinks - linkIdx, -1) ;
  --(net->numLinks) ;
  if (link->type & FROZEN)
    --(net->numFrozenLinks) ;
}
/**********************************************************************/
static void	removeFromCostModel(net, link)
  Net		net ;
  Link		link ;
{
  CostModel 	costModel ;
  Group		group ;

  if (link->preUnit)
    group = link->preUnit->group ;
  else if (link->postUnit)
    group = link->postUnit->group ;
  else
    group = NULL ;

  if (group && McostModel(group))
    MCMremoveLink(McostModel(group), link) ;
  else if (net && McostModel(net))
    MCMremoveLink(McostModel(net), link) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		netConstrainLink
 *	Description:	constrains one link to be a linear multiple of
 *			another. If the first link is already constrained,
 *			that constraint is broken. If the second link is 
 *			(un)frozen, the first is automatically (un)frozen
 *			too.
 *	Parameters:
 *	  Net	net - the net containing the links
 *	  Link	link - the link to constrain
 *	  Link	toLink - the link to constrain TO
 *	  double scale - the scale between weights. (i.e. 
 *			link->weight == scale*toLinkweight)
 *	Return Value:
 *	  int	netConstrainLink - 1 on success, 0 on failure
 *********************************************************************/
int	netConstrainLink(net, link, toLink, scale)
  Net	net ;
  Link	link ;
  Link	toLink ;
  double scale ;
{
  int	linkIdx, toLinkIdx, setSize ;
  int	arrayOffset, numItems ;
  Link	*linkSet ;

  /* set the scale and weight no matter what */
  link->scaleFactor = scale*toLink->scaleFactor ;
  linkSetWeight(link, scale*toLink->weight) ;

  /* return if they are already constrained together */
  if (link->variableIdx == toLink->variableIdx)
    return 1 ;

  link = linkFromPtr(net, link, &linkIdx) ;
  if (link == NULL)
    return 0 ;
  toLink = linkFromVarIdx(net, toLink->variableIdx, &toLinkIdx) ;
  if (toLink == NULL)
    return 0 ;

  linkSet = findConstrainedLinks(net, link->variableIdx, &setSize) ;
  if (linkSet == NULL)
    return 0 ;

  /* if the link is the only one associated with a variable, 
   * delete the variable (updating necessary fields) */
  if (setSize == 1)
    deleteVariable(net, link->variableIdx) ;

  /* set the variableIdx in the link to its future value */
  link->variableIdx = toLink->variableIdx ;

  /* rotate the link array */
  arrayOffset = MIN(linkIdx, toLinkIdx) ;
  numItems    = abs(linkIdx - toLinkIdx) + 1 ;
  rotateLinkArray(net->links + arrayOffset, numItems,
		  linkIdx < toLinkIdx ? -1 : 1) ;

  /* now sort the links that are constrained together */
  linkSet = findConstrainedLinks(net, link->variableIdx, &setSize) ;
  qsort(linkSet, setSize, sizeof(Link), compareLinks) ;

  /* set the number of frozen links properly */
  if ((link->type & FROZEN) && !(toLink->type & FROZEN)) {
    link->type &= ~FROZEN ;
    --(net->numFrozenLinks) ;
  } else if (!(link->type & FROZEN) && (toLink->type & FROZEN)) {
    link->type |= FROZEN ;
    ++(net->numFrozenLinks) ;
  }
  if (!(link->type & FROZEN) && link->scaleFactor != 1.0)
    net->unitGains = FALSE ;

  return 1 ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		netUnconstrainLink
 *	Description:	unconstrains a link (i.e. gives it its own
 *			degree of freedom)
 *	Parameters:
 *	  Net	net - the net containing the link
 *	  Link	link - the link
 *	Return Value:
 *	  void	netUnconstrainLink - NONE
 *********************************************************************/
void	netUnconstrainLink(net, link)
  Net	net ;
  Link	link ;
{
  netDeleteLink(net, link) ;
  netAddLink(net, link) ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		netFreezeLink
 *	Description:	freezes a link (and any others constrained with
 *			it) by moving it to the end of the arrays, and
 *			incrementing the counters by the proper value.
 *	Parameters:
 *	  Net	net - the net holding the link
 *	  Link	link - the link to freeze.
 *	Return Value:
 *	  void	netFreezeLink - NONE
 *********************************************************************/
void	netFreezeLink(net, link)
  Net	net ;
  Link	link ;
{
  int	linkIdx, arrayOffset, varIdx, numItems, setSize ;
  Link	*links, *linkSet ;

  linkSet = findConstrainedLinks(net, link->variableIdx, &setSize) ;
  if (linkSet == NULL || linkSet[0]->type & FROZEN)
    return ;
  link = linkSet[0] ;

  /* rotate all variables between link's idx and first frozen */
  arrayOffset = link->variableIdx ;
  numItems    = net->numVariables - net->numFrozenVariables - arrayOffset ;
  rotateRealArray(net->variables + arrayOffset, numItems,  -1) ;
  rotateRealArray(net->gradient  + arrayOffset, numItems,  -1) ;

  /* set the fields in the links to their future values */
  varIdx = net->numVariables - net->numFrozenVariables - 1 ;
  for (linkIdx = 0 ; linkIdx < setSize ; ++linkIdx) {
    linkSet[linkIdx]->type  |= FROZEN ;
    linkSet[linkIdx]->variableIdx = varIdx ;
  }

  /* rotate link to be first frozen */
  arrayOffset = linkSet - net->links ;
  numItems    = net->numLinks - net->numFrozenLinks - arrayOffset ;
  rotateLinkArray(net->links + arrayOffset, numItems, -setSize) ;

  /* decrement all links' variableIdx' between old and new position */
  links     = net->links + arrayOffset ;
  numItems -= setSize ;
  for (linkIdx = 0 ; linkIdx < numItems ; ++linkIdx)
    --(links[linkIdx]->variableIdx) ;
  
  /* set the number of frozen properly */
  ++(net->numFrozenVariables) ;
  net->numFrozenLinks += setSize ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		netUnfreezeLink
 *	Description:	unfreezes a link (and any others constrained
 *			with it), by moving it to the head of the
 *			frozenLinks in the arrays, and decrementing
 *			the counters by the proper amount
 *	Parameters:
 *	  Net	net - the net containing the link
 *	  Link	link - the link to unfreeze
 *	Return Value:
 *	  void	netUnfreezeLink - NONE
 *********************************************************************/
void	netUnfreezeLink(net, link)
  Net	net ;
  Link	link ;
{
  int	linkIdx, arrayOffset, varIdx, numItems, setSize ;
  Link	*links, *linkSet ;

  linkSet = findConstrainedLinks(net, link->variableIdx, &setSize) ;
  if (linkSet == NULL || !(linkSet[0]->type & FROZEN))
    return ;
  link = linkSet[0] ;

  /* rotate all variables between link's idx and first frozen */
  arrayOffset = net->numVariables - net->numFrozenVariables ;
  numItems    = link->variableIdx - arrayOffset + 1 ;
  rotateRealArray(net->variables + arrayOffset, numItems, 1) ;
  rotateRealArray(net->gradient  + arrayOffset, numItems, 1) ;

  /* set the fields in the links to their future values */
  varIdx = net->numVariables - net->numFrozenVariables ;
  for (linkIdx = 0 ; linkIdx < setSize ; ++linkIdx) {
    linkSet[linkIdx]->type  &= ~FROZEN ;
    linkSet[linkIdx]->variableIdx = varIdx ;
  }

  /* rotate link to be first frozen */
  arrayOffset = net->numLinks - net->numFrozenLinks ;
  numItems    = linkSet - net->links - arrayOffset + setSize;
  rotateLinkArray(net->links + arrayOffset, numItems, setSize) ;

  /* increment all links' variableIdx' between old and new position */
  links     = net->links + arrayOffset + setSize ;
  numItems -= setSize ;
  for (linkIdx = 0 ; linkIdx < numItems ; ++linkIdx)
    ++(links[linkIdx]->variableIdx) ;
  
  /* set the number of frozen properly */
  --(net->numFrozenVariables) ;
  net->numFrozenLinks -= setSize ;
}
/********************************************************************/
