
/**********************************************************************
 * $Id: linkCom.c,v 1.16 93/04/14 09:41:17 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 **********************************************************************/

#include <string.h>
#include <errno.h>
#include <xerion/commands.h>

static Real	randReal	ARGS((double centre, double field)) ;
static void	resetFile	ARGS((char *name, FILE *stream)) ;

typedef enum _RangeType { 
  UnspecifiedRange, PositiveRange, FreeRange } RangeType ;

/*******************************************************************
 *	Name:		freezeLink
 *	Description:	
 *	Parameters:
 *	  int   tokc - 
 *	  char *tokv
 *	Return Value:
 *	  int command_freezeLink - 
 *******************************************************************/
int	command_setLink(tokc,tokv)
  int   tokc;
  char *tokv[];
{
  String	name ;
  Link		*link ;
  Real		weight, *variable ;
  int		idx, numLinks ;
  RangeType	range = UnspecifiedRange ;
  Boolean	weightSet = FALSE ;

  IUsage("[-logTransform] [-free] [-weight <w>] <link1> ...") ;
  if (GiveHelp(tokc)) {
    ISynopsis("set the weight of a link") ;
    IHelp
      (IHelpArgs,
       "\"setLink\"  sets  the weight  of  a link  to a specified value.  1Link",
       "names may contain regular expressions.",
       "",
       "The \"-logTransform\" option causes the training algorithm to optimize",
       "the log of the weight.  This has the effect of forcing the weight to",
       "be positive.  The \"-free\" option turns off the logTransform.",
       "",
       "If the link is constrained with other links,  these  other links are",
       "also set to the proper value.",
       "EXAMPLES",
       "To  set  link  \"Hidden.0 -> Output.0\" to value 0.5 use the command:",
       "",
       "xerion-> setLink -weight 0.5 \"Hidden.0 -> Output.0\"",
       "SEE ALSO",
       "freezeLink, constrainLink",
       NULL);
    return 1;
  }
  
  name = *tokv ;
  for (++tokv, --tokc ; tokc > 1 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-logTransform", strlen(*tokv)) == 0) {
      if (range != UnspecifiedRange)
	IErrorAbort("Can only specify one range") ;
      range = PositiveRange ;
    } else if (strncmp(*tokv, "-free", strlen(*tokv)) == 0) {
      if (range != UnspecifiedRange)
	IErrorAbort("Can only specify one range") ;
      range = FreeRange ;
    } else if (strncmp(*tokv, "-weight", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      if (!IIsNumber(*tokv))
	IErrorAbort(IPrintUsage(name, usage)) ;
      weight = atof(*tokv) ;
      weightSet = TRUE ;
    } else if (*tokv[0] == '-') {
      IErrorAbort("Invalid option: %s", *tokv) ;
    } else {
      break ;
    }
  }

  if (tokc < 1)
    IErrorAbort(IPrintUsage(name, usage)) ;

  if (currentNet == NULL)
    IErrorAbort("There is no current net.") ;

  while (tokc > 0) {
    link = linksFromRegex(currentNet, *tokv) ;
    if (link == NULL || link[0] == NULL)
      IErrorAbort("Unknown link: \"%s\".", *tokv) ;
    for (idx = 0 ; link[idx] != NULL ; ++idx) {
      if (range != UnspecifiedRange) {
	switch (range) {
	case PositiveRange:
	  link[idx]->type |=  LOG_TRANSFORM ;
	  break ;
	default:
	  link[idx]->type &= ~LOG_TRANSFORM ;
	  break ;
	}
      }
      linkSetWeight(link[idx], weightSet ? weight : link[idx]->weight) ;
    }
    ++tokv, --tokc ;
  }

  markToRebuildDisplay(CONNECTION_DISPLAY) ;
  return 1 ;
}
/******************************************************************/


/***********************************************************************
 *	Name:		command_randomize
 *	Description:	command for randomizing network weights, also
 *			resets the currentEpoch
 *	Parameters:	
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_randomize - 0 on failure, 1 on success
 ***********************************************************************/
int command_randomize (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  Real	radius, mean ;
  int	idx, numVars ;

  IUsage("[<rad>] [<mean>]");
  if (GiveHelp(tokc)) {
    ISynopsis("Randomize all weights on links in current net");
    IHelp
      (IHelpArgs,
       "Set all the weights to random numbers with  a radius  of <rad> and a",
       "mean of <mean>,  i.e.   on  the range [mean-rad, mean+rad].  Default",
       "values are  mean = 0.0,   rad =  0.3.  This command  also clears the",
       "deltaWeight field of the links and resets the net's currentEpoch.",
       "",
       "This command  does not change the weights  in links which have  been",
       "frozen, and it obeys all constraints placed on links.",
       "EXAMPLE",
       "\txerion-> randomize 1.0",
       "SEE ALSO",
       "seed, freezeLink, constrainLink",
       NULL);
    return 1;
  }

  mean   = 0.0 ;
  radius = 0.3 ;

  if (tokc > 1)
    radius = atof(tokv[1]) ;
  if (tokc > 2)
    mean   = atof(tokv[2]) ;
  if (tokc > 3) {
    IErrorAbort(IPrintUsage(tokv[0], usage));
    return 0 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No current network") ;

  numVars = currentNet->numVariables - currentNet->numFrozenVariables ;
  for (idx = 0 ; idx < numVars ; ++idx) {
    currentNet->variables[idx] = randReal(mean, radius) ; 
    currentNet->gradient[idx]  = 0.0 ;
  }

  syncValues(currentNet, WeightsFromVector) ;
  currentNet->currentEpoch = 0 ;

  markToRebuildDisplay(CONNECTION_DISPLAY) ;
  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_seed
 *	Description:	command for seeding th random number generator
 *	Parameters:	
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_seed - 0 on failure, 1 on success
 ***********************************************************************/
int command_seed (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  int seed ;

  IUsage("[ <seed> ]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("Seed the random number generator") ;
    IHelp
      (IHelpArgs,
       "Seed the random number generator with a  specific value (default: 0)",
       "allowing identical reruns of a training set.",
       "EXAMPLE ",
       "\txerion-> seed 743",
       "SEE ALSO",
       "randomize",
       NULL) ;
    return 1 ;
  }

  if (tokc == 1)
    seed = 0 ;
  else if (tokc == 2)
    seed = atoi(tokv[1]) ;
  else {
    IErrorAbort(IPrintUsage(tokv[0], usage)) ;
    return 0 ;
  }

  ISeed(seed) ;

  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		saveWeights
 *	Description:	saves the weights in the current network to file
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_saveWeights - 0 on failure, 1 on success
 ***********************************************************************/
int 	command_saveWeights (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  struct ARG_DESC *argd;
  FILE *stream = dout;
  char *format = "%.6g";
  int named = 1;
  int binarySingle = 0;
  int binaryDouble = 0;
  int ascii = 1;
  int saveFrozen = 1;
  Link *link;
  int i, j, k, n;
  int status, closeStatus=0, errnum;
  double saveValueD;
  float saveValueF;
  char *filename = NULL;
  argd = StartArgs(tokv[0]);
  Args(argd, "[-Pbinary%P%?]", &binarySingle,
       "write variables in binary single precision format (default ascii)");
  Args(argd, "[-Pdouble%P%?]", &binaryDouble,
       "write variables in binary double precision format (default ascii)");
  Args(argd, "[-Snamed%S%?]", &named,
       "write link names in the file (default on)");
  Args(argd, "[-Sfrozen%S%?]", &saveFrozen,
       "save the frozen variables (default on)");
  Args(argd, "[<format>%F%?]", &format,
       "printf format for writing weights (default %.6g)");
  Args(argd, "[<file>%s]", &filename);
  EndArgs(argd);
  if (GiveHelp(tokc)) {
    ISynopsis("save the variables in the network");
    IHelp(tokc, tokv[0], NULL, synopsis,
	  "DESCRIPTION",
	  "Save the variables in the network.  Output is sent to the file",
	  "specified, or to stdout.",
	  "Variables are in a one-to-many relationship with link weights.",
	  "When one variable corresponds to multiple links, it is only",
	  "saved once.  The link name used is the first one found,",
	  "with preference for a link that has a scaleFactor equal to 1.",
	  "The link weight may be a multiple (link->scaleFactor) of",
	  "its corresponding variable.  If the -named option is given the",
	  "variable * link->scaleFactor is saved, rather than the plain",
	  "variable value.",
	  "AUTHOR",
	  "Tony Plate (tap@cs.toronto.edu)",
          NULL);
    return 1;
  }
  (void) ParseArgs(argd, tokc, tokv, 0);

  if (currentNet==NULL)
    IErrorAbort("current net is NULL");
  syncValues(currentNet, VectorFromWeights) ;

  link = currentNet->links;
  if (filename!=NULL && strcmp(filename, "-")==0)
    filename = NULL;
  if (filename!=NULL)
    stream = IOpenFileOrAbort(filename, "w", NULL) ;

  n = currentNet->numVariables;
  if (!saveFrozen)
    n -= currentNet->numFrozenVariables;

  if (binaryDouble || binarySingle)
    ascii = 0;
  status = fprintf(stream, "%s%s: %d # currentNet.currentEpoch: %d\n", 
		   named ? "named" : "unnamed",
		   binaryDouble ? ",binarydouble"
		   : binarySingle ? ",binary" : "", n,
		   currentNet->currentEpoch);
  
  for (i = 0 ; i<n && status>0; ++i) {
    if (named) {
      /*
       * find the link that corresponds to this variable, look for one
       * that has a scaleFactor==1.0, but remember others in k
       */
      k = -1;
      for (j=0; j<currentNet->numLinks; j++) {
	 if (link[j]->variableIdx==i)
	   if (link[j]->scaleFactor==1.0)
	     break;
	   else if (k==-1)
	     k = j;
      }
      if (j>=currentNet->numLinks) {
	if (k != -1)
	  j = k;
	else {
	  resetFile(filename, stream);
	  IErrorAbort("cannot find link for variable %d (needed for name)", i);
	}
      }
      if (link[j]->name==NULL || link[j]->name[0]=='\0') {
	resetFile(filename, stream);
	IErrorAbort("link for variable %d has no name", i);
      }
      if (index(link[j]->name, ':')!=NULL) {
	resetFile(filename, stream);
	IErrorAbort("cannot have colons in link names (%s)", link[j]->name);
      }
      status = fprintf(stream, "%s:", link[j]->name);
      if (status>0 && ascii)
	status = fprintf(stream, " ");
      saveValueF = saveValueD = currentNet->variables[i]*link[j]->scaleFactor;
    } else {
      /* save unnamed variable */
      saveValueF = saveValueD = currentNet->variables[i];
    }
    if (status>0)
      if (binaryDouble)
	status = fwrite(&saveValueD, sizeof(saveValueD), 1, stream);
      else if (binarySingle)
	status = fwrite(&saveValueF, sizeof(saveValueF), 1, stream);
      else
	status = fprintf(stream, format, saveValueD);
    if (status>0 && (named || ascii))
      status = fprintf(stream, "\n");
  }
  errnum = errno;
  if (stream!=dout)
    closeStatus = ICloseFile(stream, NULL);
  if (status<=0 || closeStatus==EOF) {
    resetFile(filename, stream);
    if (status>0) /* error happenned in fclose */
      errnum = errno;
    if (stream==dout || filename==NULL) {
      perror(tokv[0]) ;
      IErrorAbort("error writing stdout");
    } else {
      perror(tokv[0]) ;
      IErrorAbort("error writing \"%s\"", filename);
    }
  }

  return 1;
}
/**********************************************************************/
static void resetFile(name, stream)
  char *name;
  FILE *stream;
{
  if (name==NULL || strcmp(name, "-")==0)
    return;
  ICloseFile(stream, NULL);
  stream = fopen(name, "w");
  fclose(stream);
}
/**********************************************************************/


/***********************************************************************
 *	Name:		loadWeights
 *	Description:	restores the weights in the current network from file
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_loadWeights - 0 on failure, 1 on success
 ***********************************************************************/
int command_loadWeights (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  struct ARG_DESC *argd;
  FILE *stream = din;
  int named = 1;
  int frozen = 1;
  int loadedCount = 0, ignoredCount = 0;
  int binarySingle = 0;
  int binaryDouble = 0;
  int ascii = 1;
  char *name, *p;
  int c;
  Link *link;
  int weightNum, linkIdx, variableIdx, n;
  double valueD;
  float valueF;
  char buf[1024];
  int status, closeStatus=0, errnum;
  char *filename = NULL;
  argd = StartArgs(tokv[0]);
  Args(argd, "[<file>%s]", &filename);
  Args(argd, "[-Sfrozen%S%?]", &frozen,
       "load values into frozen weights (default on)");
  EndArgs(argd);
  if (GiveHelp(tokc)) {
    ISynopsis("load the weights");
    IHelp(tokc, tokv[0], NULL, synopsis,
	  "any errors result in the weights being left unchanged",
	  "AUTHOR",
	  "Tony Plate (tap@cs.toronto.edu)",
          NULL);
    return 1;
  }
  (void) ParseArgs(argd, tokc, tokv, 0);

  if (currentNet==NULL)
    IErrorAbort("current net is NULL");
  syncValues(currentNet, VectorFromWeights) ;

  link = currentNet->links;
  if (filename!=NULL && strcmp(filename, "-")==0)
    filename = NULL;
  if (filename!=NULL)
    stream = IOpenFileOrAbort(filename, "r", NULL) ;
  
  status = fscanf(stream, "%[^:\n]: %d", buf, &n);
  do { c = getc(stream); } while (c!=EOF && c!='\n');
  if (status!=2) {
    if (filename!=NULL) ICloseFile(stream, NULL);
    IErrorAbort("first line must be of form \"<keywords>: <n>\"");
  }
  p = strtok(buf, ",");
  while (p!=NULL) {
    if (!strcmp(p, "named"))
      named = 1;
    else if (!strcmp(p, "unnamed"))
      named = 0;
    else if (!strcmp(p, "binarydouble"))
      {binaryDouble = 1; ascii = 0;}
    else if (!strcmp(p, "binarysingle") || !strcmp(p, "binary"))
      {binarySingle = 1; ascii = 0;}
    else {
      if (filename!=NULL) ICloseFile(stream, NULL);
      IErrorAbort("invalid keyword for first line: \"%s\"", p);
    }
    p = strtok(NULL, "-");
  }

  if (!named && n!=currentNet->numVariables
      && n!=currentNet->numVariables-currentNet->numFrozenVariables) {
    if (filename!=NULL) ICloseFile(stream, NULL);
    IErrorAbort("supplied number of weights (%d) does not match net", n);
  }
  
  /* fprintf(dout, "Loading weights, named = %d, n = %d\n", named, n); */

  for (weightNum = 0 ; weightNum<n  && !feof(stream); ++weightNum) {
    if (named) {
      buf[0] = '\0';
      if (fscanf(stream, "%[^:\n]:", buf)!=1) {
	syncValues(currentNet, VectorFromWeights) ;
	if (filename!=NULL) ICloseFile(stream, NULL);
	IErrorAbort("cannot read name for weight %d", weightNum+1);
      }
      /* chop whitespace off the end */
      name = buf + strlen(buf) - 1;
      while (name>buf && isspace(*name)) {
	*name = '\0';
	name--;
      }
      /* and the beginning */
      name = buf;
      while (isspace(*name))
	name++;
      /* now find the name corresponding to name, first try weightNum */
      linkIdx = weightNum;
      if (link[linkIdx]->name==NULL || strcmp(link[linkIdx]->name,name)) {
	for (linkIdx=0; linkIdx<currentNet->numLinks; linkIdx++)
	  if (link[linkIdx]->name!=NULL
	      && strcmp(link[linkIdx]->name, name)==0)
	    break;
      }
      /* else fprintf(dout, "Guessed! "); */
	
      if (linkIdx>=currentNet->numLinks) {
	syncValues(currentNet, VectorFromWeights) ;
	if (filename!=NULL) ICloseFile(stream, NULL);
	IErrorAbort("cannot find weight with name %s", name);
      }
      variableIdx = link[linkIdx]->variableIdx;
    } else {
      sprintf(buf, "%d", weightNum);
      name = buf;
      variableIdx = weightNum;
      linkIdx = -1;
    }
    /* get the value */
    if (binaryDouble)
      status = fread(&valueD, sizeof(valueD), 1, stream);
    else if (binarySingle) {
      status = fread(&valueF, sizeof(valueF), 1, stream);
      valueD = valueF;
    } else
      status = fscanf(stream, " %lf", &valueD);
    if (status!=1) {
      syncValues(currentNet, VectorFromWeights) ;
      if (filename!=NULL) ICloseFile(stream, NULL);
      IErrorAbort("cannot read weight %s (status=%d)", name, status);
    }

    /*fprintf(dout, "Loading %s, idx=%d, with %g\n",name,variableIdx,valueD);*/

    /* assign it */
    if (frozen==0 &&
	variableIdx > currentNet->numVariables-currentNet->numFrozenVariables)
      ignoredCount++;
    else {
      if (named)
	currentNet->variables[variableIdx] = valueD/link[linkIdx]->scaleFactor;
      else
	currentNet->variables[variableIdx] = valueD;
      loadedCount++;
    }

    /* read to the end of the line */
    if (named || ascii)
      do { c = getc(stream); } while (c!=EOF && c!='\n');
  }
  if (weightNum!=n) {
    syncValues(currentNet, VectorFromWeights) ;
    if (filename!=NULL) ICloseFile(stream, NULL);
    IErrorAbort("could not read %d weights (only read %d)", n, weightNum);
  }
  if (filename!=NULL)
    status = ICloseFile(stream, NULL);
  if (status==EOF) {
    syncValues(currentNet, VectorFromWeights) ;
    if (filename==NULL) {
      perror(tokv[0]) ;
      IErrorAbort("error closing file stdout");
    } else {
      ICloseFile(stream, NULL);
      perror(tokv[0]) ;
      IErrorAbort("error closing file \"%s\"", filename);
    }
  }
  syncValues(currentNet, WeightsFromVector) ;
  if (ignoredCount>0)
    fprintf(dout, "Loaded %d values into %s, ignored %d values.\n",
	    loadedCount, currentNet->name, ignoredCount);
  else
    fprintf(dout, "Loaded %d values into %s.\n",
	    loadedCount, currentNet->name);

  markToRebuildDisplay(CONNECTION_DISPLAY) ;
  return 1;
}
/**********************************************************************/


/**********************************************************************
 *	Name:	     RandReal
 *	Description: generates a uniformly distributed random number 
 *			with a given spread and centre
 *	Parameters:  real centre   - the desired centre of the distribution
 *		     real field    - HALF the spread
 *	Return:	     real RandReal - the random number
 **********************************************************************/
static Real	randReal(centre, field)
  double	centre ;
  double	field ;
{
  return centre + (2.0*IRandReal() - 1.0)*field ;
}
/**********************************************************************/
