
/**********************************************************************
 * $Id: netCom.c,v 1.4 92/11/30 11:27:42 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 **********************************************************************/

#include <xerion/commands.h>
#include "simulatorHelp.h"

static Net 	*netArray ;
static int	numNets ;

Net	currentNet ;

/***********************************************************************
 *	Name:		command_addNet
 *	Description:	creates a network and adds it to the list
 *			of existing networks, it creates a bias
 *			unit automatically in the net
 *	Parameters:	
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_addNet - 0 on failure, 1 on success
 ***********************************************************************/
int	command_addNet (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  char	*name ;
  int	idx, timeSlices, newMask, netMask ;
  Net	net ;

  IUsage("[ -type TYPE [ -type TYPE ...] ] [ -time <n> ] <name>") ;
  if (GiveHelp(tokc)) {
    ISynopsis("add a network to the list of existing ones");
    IHelp
      (IHelpArgs,
       "\"addNet\" creates  a network  object with the  given name and adds it",
       "into the  list  of   existing networks.   It  then  sets the current",
       "network to  this net.   The network  is  created  with the specified",
       "(user defined) type mask.   Multiple  types may be  specified,  each",
       "preceded with the keyword \"-type\".",
       "",
       "If  another network with  the same name already   exists the command",
       "returns an error message. The net is created with a  bias layer with",
       "a single unit in it (named \"Bias\").",
       "",
       "The keyword  \"-time\" specifies  that the   net is to be  a recurrent",
       "network with <n> time slices.  This feature is only used  in the rbp",
       "module.",
       "EXAMPLE",
       "To create a new network with the name \"Backprop Net\",",
       "",
       "\txerion-> addNet \"Backprop Net\"",
       "SEE ALSO",
       "useNet, deleteNets, addGroup, addUnit, addExamples",
       NULL);
    return 0;
  }

  if (tokc < 2 || tokc % 2 != 0) {
    IErrorAbort(IPrintUsage(tokv[0], usage)) ;
    return 0 ;
  }

  netMask = 0 ;
  for (idx = 1 ; 
       idx < tokc && 
       strncmp(tokv[idx], "-type", strlen(tokv[idx])) == 0 ; idx += 2) {
    newMask = findClassMask(NET_CLASS, tokv[idx+1]) ;
    if (newMask == 0) {
      IErrorAbort("Unknown net type: \"%s\".", tokv[idx+1]) ;
      return 0 ;
    }    
    netMask |= newMask ;
  }

  netMask    = 0 ;
  timeSlices = 1 ;
  for (++tokv, --tokc ; tokc > 1 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-type", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      newMask = findClassMask(NET_CLASS, *tokv) ;
      if (newMask == 0) {
	IErrorAbort("Unknown net type: \"%s\".", *tokv) ;
	return 0 ;
      }    
      netMask |= newMask ;
    } else if (strncmp(*tokv, "-time", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      timeSlices = atoi(*tokv) ;
      if (timeSlices < 2)
	IErrorAbort("Invalid number of time slices: %s", *tokv) ;
      netMask |= RECURRENT ;
    } else {
      IErrorAbort("Invalid option: %s", *tokv) ;
    }
  }

  name = *tokv ;

  if ((net = addNet(name, netMask, timeSlices)) == NULL) {
    char	buffer[BUFSIZ] ;
    fprintf(stdout,
	    "Network \"%s\" already exists, replace it? (y/n)[n] ", name) ;
    fgets(buffer, BUFSIZ, stdin) ;
    if (buffer[0] == 'y') {
      delNet(name) ;
      net = addNet(name, netMask, timeSlices) ;
    } else {
      IErrorAbort("Network \"%s\" unchanged\n", name) ;
    }
  }

  if (createUnit("Bias", createGroup("Bias", BIAS, net)) == NULL)
    IErrorAbort("Cannot create bias unit") ;

  useNet(net->name) ;

  return 1 ;
}
/**********************************************************************/
Net	addNet(name, mask, timeSlices)
  char	*name ;
  int	mask ;
  int	timeSlices ;
{
  int	idx ;

  for (idx = 0 ; idx < numNets ; ++idx) {
    if (strcmp(name, netArray[idx]->name) == 0)
      return NULL ;
  }

  if (netArray == NULL)
    netArray = (Net *)callocOrAbort(1, sizeof(Net)) ;
  else
    netArray = (Net *)reallocOrAbort((void *)netArray, 
				     (numNets + 1)*sizeof(Net)) ;
  
  if ((netArray[numNets] = createNet(name, mask, timeSlices)) == NULL)
    IErrorAbort("Cannot create net \"%s\"", name) ;

  return netArray[numNets++] ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_deleteNets
 *	Description:	deletes a net from the array netArray
 *	Parameters:	
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_deleteNets - 0 on failure, 1 on success
 ***********************************************************************/
int	command_deleteNet (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  return command_deleteNets (tokc, tokv) ;
}
/**********************************************************************/
int	command_deleteNets (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  IUsage("<name> [ <name2> ... ]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("delete networks from the simulator");
    IHelp
      (IHelpArgs,
       "\"deleteNets\" destroys the networks with the given  names and removes",
       "them from the list of existing networks.  Names  may contain pattern",
       "matching expressions of the form used by grep.",
       "",
       "EXAMPLE",
       "To delete the network with name \"Backprop Net\",",
       "",
       "\txerion-> deleteNets \"Backprop Net\"",
       "",
       "To delete all networks,",
       "",
       "\txerion-> deleteNets .*",
       "NOTES",
       "\"deleteNet\" and \"deleteNets\" are the same command.",
       "SEE ALSO",
       "useNet, addNet, addGroup, addUnit, addExamples",
       NULL);
    return 0;
  }

  if (tokc == 1)
    IErrorAbort(IPrintUsage(tokv[0], usage)) ;

  for (++tokv, --tokc ; tokc != 0 ; ++tokv, --tokc) {
    if (delNet(*tokv) == 0) {
      IErrorAbort("Net not found: \"%s\"", *tokv) ;
      return 0 ;
    }
  }
    
  return 1 ;
}
/**********************************************************************/
int	delNet(name)
  char	*name ;
{
  int		next, last ;
  char		buffer[BUFSIZ], *compiled ;
  Boolean	match = FALSE ;

  if (strpbrk(name, "$^") != NULL)
    return 0 ;

  sprintf(buffer, "^%s$", name) ;
  compiled = regcmp(buffer, NULL) ;
  if (compiled == NULL)
    return 0 ;

  for (next = 0 ; next < numNets ; ++next) {
    if (regex(compiled, netArray[next]->name) != NULL) {
      match = TRUE ;
      if (currentNet == netArray[next]) {
	currentNet = NULL ;
	markToRebuildDisplay(ALL_DISPLAYS) ;
      }
      destroyNet(netArray[next]) ;
      netArray[next] = NULL ;
    }
  }
  free(compiled) ;
  
  if (match == FALSE)
    return 0 ;

  for (last = next = 0 ; next < numNets ; ++next) {
    if (netArray[next] != NULL)
      netArray[last++] = netArray[next] ;
  }
  numNets = last ;

  if (numNets == 0) {
    free(netArray) ;
    netArray = NULL ;
  } else {
    netArray = (Net *)reallocOrAbort((char *)netArray, 
				     numNets*sizeof(Net)) ;
  }  

  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_useNet
 *	Description:	sets the currentNet variable from the netArray
 *			list.
 *	Parameters:	
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_useNet - 0 on failure, 1 on success
 ***********************************************************************/
int	command_useNet (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  char	*name ;
  int	idx ;

  IUsage("[ <name> ]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("set the current network");
    IHelp
      (IHelpArgs,
       "\"useNet\" sets the current network to the one  with the given <name>.",
       "If no name is supplied, it lists all of the existing networks.",
       "EXAMPLES",
       "To set the current network to \"Backprop Net\",",
       "",
       "\txerion-> useNet \"Backprop Net\"",
       "",
       "To list the existing networks,",
       "",
       "\txerion-> useNet",
       "\tCurrent net:",
       "\t\tFamily Tree",
       "\tNetworks:",
       "\t\tFamily Tree",
       "\t\tBackprop Net",
       "SEE ALSO",
       "addNet, deleteNets",
       NULL);
    return 0;
  }

  if (tokc == 1) {
    fprintf(dout, "Current net:\n\t%s\n", 
	    currentNet == NULL ? "NULL" : currentNet->name) ;
    if (numNets == 0)
      fprintf(dout, "No networks.\n") ;
    else
      fprintf(dout, "Networks:\n") ;

    for (idx = 0 ; idx < numNets ; ++idx)
      fprintf(dout, "\t%s\n", netArray[idx]->name) ;

  } else if (tokc == 2) {
    name = tokv[1] ;
    if (useNet(name) == NULL)
      IErrorAbort("Net not found.") ;
  } else {
    IErrorAbort(IPrintUsage(tokv[0], usage)) ;
  }

  return 1 ;
}
/**********************************************************************/
Net	useNet(name)
  char	*name ;
{
  Net	net ;

  net = getNet(name) ;
  if (net != NULL) {
    currentNet = net ;
    markToRebuildDisplay(ALL_DISPLAYS) ;
  }
  return net ;
}
/**********************************************************************/
Net	getNet(name)
  char	*name ;
{
  int	idx ;

  for (idx = 0 ; idx < numNets ; ++idx) {
    if (strcmp(name, netArray[idx]->name) == 0)
      break ;
  }
  if (idx < numNets)
    return netArray[idx] ;
  else
    return NULL ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		listNets
 *	Description:	returns a list of names of all the nets.
 *	Parameters:	NONE
 *	Return Value:	
 *		char	**listNets - a NULL terminated array of pointers 
 *			to the name field in the nets. The names are the
 *			REAL name in the net, DO NOT FREE THEM.
 *			The array is malloced. IT SHOULD BE FREED.
 ***********************************************************************/
char	**listNets()
{
  int	idx ;
  char	**nameList ;
  
  nameList = (char **)callocOrAbort(numNets + 1, sizeof(char *)) ;

  for (idx = 0 ; idx < numNets ; ++idx)
    nameList[idx] = netArray[idx]->name ;

  return nameList ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_train
 *	Description:	trains the current network on its training example
 *			set
 *	Parameters:	
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_train - 0 on failure, 1 on success
 ***********************************************************************/
int command_train (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  ExampleSet	exampleSet ;
  int		repSize = 0 ;
  int		numEpochs = -1 ;
  char		command[64] ;
  struct ARG_DESC *argd;

  argd = StartArgs(tokv[0]);
  Args(argd, "[-Prepetitions <m batches>%i]", &repSize);
  Args(argd, "[<n batches>%i]", &numEpochs);
  EndArgs(argd);
  if (GiveHelp(tokc)) {
    ISynopsis("train a network on the examples");
    IHelp
      (tokc, tokv[0], NULL, synopsis,
       "Train the network for n batches (iterations) on the training example",
       "set.   Use repetitions   of <m>  batches  and at  the end   of  each",
       "repetition print network statistics.",
       "",
       "The size of each batch is determined by the \"batchSize\" field in the",
       "currentNet.",
       "",
       "This command is equivalent to:",
       "  minimize -iterations n -rep m -momentum",
       "SEE ALSO",
       "minimize, test, validate",
       NULL);
    return 0;
  }
  (void) ParseArgs(argd, tokc, tokv, 0);

  if (currentNet == NULL)
    IErrorAbort("There is no current net.") ;

  if (currentNet->calculateErrorDerivProc == NULL)
    IErrorAbort("There is no train procedure for net \"%s\".", 
		currentNet->name) ;

  exampleSet = netGetExampleSet(currentNet, TRAINING) ;
  if (exampleSet == NULL)
    IErrorAbort("There is no training example set for net \"%s\".", 
		currentNet->name) ;

  sprintf(command, "minimize -iterations %d -repetitions %d -momentum", 
	  numEpochs, repSize) ;

  if (IDoCommandLine(command) == 0)
    return 0 ;

  fprintf(dout,"epoch = %d,\terror = %-8g",
	  currentNet->currentEpoch, currentNet->error);
  if (currentNet->cost != 0.0)
    fprintf(dout,"\textra cost = %-8g", currentNet->cost);
  putc('\n', dout) ;

  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_test
 *	Description:	runs one epoch of the test examples on the current
 *			net
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_test - 0 on failure, 1 on success
 ***********************************************************************/
int command_test (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  ExampleSet	exampleSet ;

  IUsage(" ");
  if (GiveHelp(tokc)) {
    ISynopsis("test a network on the testing example set");
    IHelp
      (IHelpArgs,
       "For one  batch of   the testing example  set,  activate   the current",
       "network and print network statistics.",
       "SEE ALSO",
       "minimize, train, validate",
       NULL);
    return 0;
  }

  if (tokc != 1)
    IErrorAbort(IPrintUsage(tokv[0],usage));

  if (currentNet == NULL)
    IErrorAbort("There is no current net.") ;

  if (currentNet->calculateErrorProc == NULL)
    IErrorAbort("There is no test procedure for net \"%s\".", 
		currentNet->name) ;

  exampleSet = netGetExampleSet(currentNet, TESTING) ;
  if (exampleSet == NULL) {
    IErrorAbort("There is no testing example set for net \"%s\".", 
		currentNet->name) ;
  } else {
    int	oldBatchSize = currentNet->batchSize ;
    if (oldBatchSize <= 0)
      currentNet->batchSize = exampleSet->numExamples ;
    
    MupdateNetError(currentNet, exampleSet) ;

    currentNet->batchSize = oldBatchSize ;
  }

  fprintf(dout,"epoch = %d,\terror = %-8g",
	  currentNet->currentEpoch, currentNet->error);
  if (currentNet->cost != 0.0)
    fprintf(dout,"\textra cost = %-8g", currentNet->cost);
  putc('\n', dout) ;

  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_validate
 *	Description:	runs one epoch of the validation examples on 
 *			the current net
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_validate - 0 on failure, 1 on success
 ***********************************************************************/
int command_validate (tokc, tokv)
  int	tokc ;
  char	*tokv[] ;
{
  ExampleSet	exampleSet ;

  IUsage(" ");
  if (GiveHelp(tokc)) {
    ISynopsis("validate a network using the validation example set");
    IHelp
      (IHelpArgs,
       "For  one batch of  the validation  example set, activate the current",
       "network and print network statistics.",
       "SEE ALSO",
       "train, test",
       NULL);
    return 0;
  }

  if (tokc != 1)
    IErrorAbort(IPrintUsage(tokv[0],usage));

  if (currentNet == NULL)
    IErrorAbort("There is no current net.") ;

  if (currentNet->calculateErrorProc == NULL)
    IErrorAbort("There is no validation procedure for net \"%s\".", 
		currentNet->name) ;

  exampleSet = netGetExampleSet(currentNet, VALIDATION) ;
  if (exampleSet == NULL) {
    IErrorAbort("There is no validation example set for net \"%s\".", 
		currentNet->name) ;
  } else {
    int	oldBatchSize = currentNet->batchSize ;
    if (oldBatchSize <= 0)
      currentNet->batchSize = exampleSet->numExamples ;
    
    MupdateNetError(currentNet, exampleSet) ;

    currentNet->batchSize = oldBatchSize ;
  }

  fprintf(dout,"epoch = %d,\terror = %-8g",
	  currentNet->currentEpoch, currentNet->error);
  if (currentNet->cost != 0.0)
    fprintf(dout,"\textra cost = %-8g", currentNet->cost);
  putc('\n', dout) ;
    
  return 1 ;
}
/**********************************************************************/
