
/**********************************************************************
 * $Id: oldLoad.c,v 1.3 93/01/22 15:12:50 drew Exp $
 **********************************************************************/

/**********************************************************************
 *	Copyright 1990,1991,1992 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 **********************************************************************/

#include <xerion/commands.h>

#ifndef MAXINT
#define MAXINT	(~(int)(1 << sizeof(int)*8 - 1))
#endif

#define LOAD	(int)(1<<0)
#define SAVE	(int)(1<<1)
#define TEXT	(int)(1<<3)

static void	resetLinks	ARGS((Net)) ;
static void	loadSaveWeights	ARGS((Net, FILE *, int)) ;
static void	loadSaveVector	ARGS((Net, FILE *stream, int mode)) ;
static void	getIncoming	ARGS((Unit unit, void *data)) ;
static void	resetArrays	ARGS((Net  net)) ;



/***********************************************************************
 *	Name:		loadWeights
 *	Description:	restores the weights in the current network from file
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_loadWeights - 0 on failure, 1 on success
 ***********************************************************************/
int command_oldLoadWeights (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  FILE	*inStream ;
  String	name ;
  int		mode ;

  IUsage("[-text] [ <file> ]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("Load a weight file into the current network") ;
    IHelp
      (IHelpArgs,
       "\"oldLoadWeights\"  reads  the  weights in  a  file  into  the current",
       "network.  The weights  should have  been  saved  using \"saveWeights\"",
       "from a  version of  xerion  prior to  3.1.  It  is only supplied for",
       "backwards compatability.",
       "",
       "If no file is given, the weights are read from standard input.",
       "",
       "The \"-text\" option means that the input file is in text format.  The",
       "default file format is binary.",
       "EXAMPLE",
       "\txerion-> loadWeights foo.weights",
       "SEE ALSO",
       "loadWeights, saveWeights",
       NULL) ;
    return 1 ;
  }
  
  name = *tokv ;
  mode = LOAD ;
  for (++tokv, --tokc ; tokc && *tokv ; ++tokv, --tokc) {
    if (strncmp("-text", *tokv, strlen(*tokv)) == 0)
      mode |= TEXT ;
    else
      break ;
  }
  
  if (tokc == 0) {
    inStream = din ;
  } else if (tokc == 1) {
    inStream = IOpenFileOrAbort(*tokv, "r", NULL) ;
  } else {
    IErrorAbort(IPrintUsage(name, usage)) ;
    return 0 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No current network") ;

  loadSaveWeights(currentNet, inStream, mode) ;

  if (inStream != din)
    ICloseFile(inStream, NULL) ;
    
  markToRebuildDisplay(CONNECTION_DISPLAY) ;
  return 1 ;
}
/**********************************************************************/


/**********************************************************************/
#define INT_STRING	"%d\n"
#ifdef DOUBLE
#define REAL_STRING	"%lg\n"
#else
#define REAL_STRING	"%g\n"
#endif
/***********************************************************************
 *	Name:		loadSaveWeights
 *	Description:	loads or saves all the weights in a network
 *	Parameters:	
 *		Net	net	- the net to use
 *		FILE	*stream	- the stream to read/write from/to
 *		int	mode	- either LOAD or SAVE
 *	Return Value:	NONE
 ***********************************************************************/
static int	lsMode ;
static int	recordNum ;
/**********************************************************************/
static void	resetLSerror(mode) 
  int		mode ;
{
  recordNum = 0 ;
  lsMode    = mode ;
}
/**********************************************************************/
static void	checkLSerror(status) 
  int		status ;
{
  if (status == EOF || status == 0)
    IErrorAbort("Error %s record %d", 
		lsMode & LOAD ? "reading" : "writing", recordNum) ;
  ++recordNum ;
}
/**********************************************************************/
static void	loadSaveWeights(net, stream, mode)
  Net		net ;
  FILE		*stream ;
  int		mode ;
{
  int		numVars ;
  int		status ;

  resetLSerror(mode) ;

  /* if normal mode try binary. if load fails, try text */
  if (!(mode & TEXT)) {
    if (mode & LOAD) {
      status = read (fileno(stream), &numVars, sizeof(numVars)) ;
      if (numVars != net->numVariables && stream != din) {
	rewind(stream) ;
	mode |= TEXT ;
      }
    } else {
      status = write (fileno(stream), &net->numVariables, 
		      sizeof(net->numVariables)) ;
    }
    checkLSerror(status) ;
  }

  if (mode & TEXT) {
    if (mode & LOAD) {
      status = fscanf (stream, INT_STRING, &numVars) ;
      if (numVars != net->numVariables)
	IErrorAbort("Inconsistent number of weights (%d, %d)",
		    numVars, net->numVariables) ;
    } else {
      status = fprintf(stream, INT_STRING, net->numVariables) ;
    }
    checkLSerror(status) ;
  }

  loadSaveVector(net, stream, mode) ;
  if (mode & LOAD) {
    resetLinks(net) ;
    currentNet->currentEpoch = 0 ;
  }
}
/**********************************************************************/
static void	loadSaveVector(net, stream, mode)
  Net		net ;
  FILE		*stream ;
  int		mode ;
{
  int		idx ;
  int		status ;

  if (mode & TEXT) {
    for (idx = 0 ; idx < net->numVariables ; ++idx) {
      if (mode & LOAD)
	status = fscanf (stream, REAL_STRING, &net->variables[idx]) ;
      else
	status = fprintf(stream, REAL_STRING, net->variables[idx]) ;

      checkLSerror(status) ;
    }
  } else {
    if (mode & LOAD)
      status = read (fileno(stream), net->variables, 
		     net->numVariables*sizeof(*net->variables)) ;
    else
      status = write(fileno(stream), net->variables, 
		     net->numVariables*sizeof(*net->variables)) ;

    checkLSerror(status) ;
  }
}
/**********************************************************************/


/***********************************************************************
 *	Name:		resetLinks
 *	Description:	resets the link weights to their proper values.
 *	Parameters:	
 *	  Net	net - the net with all the links.
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	resetLinks(net) 
  Net		net ;
{
  unsigned int	numLinks, idx ;

  numLinks = net->numLinks ;
  for (idx = 0 ; idx < numLinks ; ++idx) {
    Link	link = net->links[idx] ;	
    link->weight = net->variables[link->variableIdx]*link->scaleFactor ;
    link->deltaWeight = 0.0 ;
  }
}
/**********************************************************************/


/*********************************************************************
 *	Name:		resetArrays
 *	Description:	resets the arrays in the net from the link values
 *	Parameters:
 *	  Net		net - the net to act on
 *	Return Value:
 *	  static void	resetArrays - NONE
 *********************************************************************/
static void	resetArrays(net)
  Net		net ;
{
  memset(net->gradient, (int)0, currentNet->numVariables*sizeof(Real)) ;
  netForAllUnits(currentNet, ALL, getIncoming, NULL) ;
}
/********************************************************************/
static void	getIncoming(unit, data)
  Unit		unit ;
  void		*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	*variable   = unit->net->variables ;
  Real	*gradient   = unit->net->gradient ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link  = incoming[idx] ;
    variable[link->variableIdx]  = link->weight/link->scaleFactor ;
    gradient[link->variableIdx] += link->deriv/link->scaleFactor ;
  }
}
/********************************************************************/
