
/**********************************************************************
 * $Id: costCom.c,v 1.8 92/11/30 13:16:50 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <xerion/useful.h>
#include <xerion/commands.h>
#include <xerion/simulator.h>

#include "costCom.h"

/***********************************************************************
 *	Name:		command_addMixtureCost
 *	Description:	creates a mixture of gaussians cost model.
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_addMixtureCost - 0 on failure, 1 on success
 ***********************************************************************/
static void	removeFromCostModel(unit, data)
  Unit		unit ;
  void		*data ;
{
  CostModel	costModel = (CostModel)data ;
  int		numIncoming = unit->numIncoming ;
  Link		*incoming   = unit->incomingLink ;
  int		idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    MCMremoveLink(costModel, incoming[idx]) ;
}
/**********************************************************************/
static void	addToGroupCostModel(unit, data)
  Unit		unit ;
  void		*data ;
{
  CostModel	costModel = (CostModel)data ;
  int		numIncoming = unit->numIncoming ;
  Link		*incoming   = unit->incomingLink ;
  int		idx ;

  if (unit->group && McostModel(unit->group))
    removeFromCostModel(unit, McostModel(unit->group)) ;

  else if (unit->net && McostModel(unit->net))
    removeFromCostModel(unit, McostModel(unit->net)) ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    MCMaddLink(costModel, incoming[idx]) ;
}
/**********************************************************************/
static void	addToNetCostModel(unit, data)
  Unit		unit ;
  void		*data ;
{
  CostModel	costModel = (CostModel)data ;
  int		numIncoming = unit->numIncoming ;
  Link		*incoming   = unit->incomingLink ;
  int		idx ;

  if (unit->group && McostModel(unit->group))
    return ;

  else if (unit->net && McostModel(unit->net))
    removeFromCostModel(unit, McostModel(unit->net)) ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    MCMaddLink(costModel, incoming[idx]) ;
}
/**********************************************************************/
int	command_addMixtureCost(tokc, tokv)
  int	tokc ;
 char	**tokv ;
{
  Real		min, max ;
  int		numGaussians ;
  String	name ;
  Group		group ;
  CostModel	costModel ;
  
  IUsage("[-min <x>] [-max <y>] [-group <group>] <n>");
  if (GiveHelp(tokc)) {
    ISynopsis("add a mixture of gaussians cost model to the net");
    IHelp
      (IHelpArgs,
       "This command is used to create a Mixture  of Gaussians cost model of",
       "the form used by Nowlan and Hinton. When used,  it implements a form",
       "of \"soft weight-sharing\", in  which groups of  weights are clustered",
       "together.",
       "",
       "When creating  the  mixture,  you must  specify  <n>, the  number of",
       "gaussians,  in the  model.  You may  also specify  the  minimum  and",
       "maximum values  to  spread the means of  the  gausians  over. If  no",
       "minimum and maximum are specified the range [-1, 1] is assumed.",
       "",
       "The \"-group\" option allows you to create a cost model that acts only",
       "on the incoming links of a  single group. If a  group does not  have",
       "it's own cost model, it uses the network cost model.",
       "",
       "The means, standard deviations, and proportions of the gaussians can",
       "be frozen and constrained just like the links.   They  are given the",
       "names  \"?.Mean.?\", \"?.StdDeviation.?\",  and  \"?.Proportion.?\", where",
       "the first \"?\"  is  the name of the network (or group) the cost model",
       "is connected to, and the second  is the index of the gaussian within",
       "the model, starting at 0.",
       "",
       "SEE ALSO",
       "addSumSquareCost, resetMixtureCost, showCost, showMixture,",
       "updateMixtureDisplay",
       "",
       "CRG-TR-91-4, \"Simplifying Neural Networks by Soft Weight-Sharing\"",
       "Steven J. Nowlan, Geoffrey E. Hinton, October 1991.",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No currentNet") ;

  name = *tokv ;
  min = -1.0 ;
  max =  1.0 ;
  group = NULL ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-min", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      if (!IIsNumber(*tokv))
	IErrorAbort(IPrintUsage(name, usage)) ;
      min = atof(*tokv) ;
    } else if (strncmp(*tokv, "-max", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      if (!IIsNumber(*tokv))
	IErrorAbort(IPrintUsage(name, usage)) ;
      max = atof(*tokv) ;
    } else if (strncmp(*tokv, "-group", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      group = groupFromName(currentNet, *tokv) ;
      if (group == NULL)
	IErrorAbort("Unknown group: \"%s\"", *tokv) ;
    } else if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }

  if (tokc != 1 || !IIsInteger(tokv[0])) 
    IErrorAbort(IPrintUsage(name, usage)) ;

  numGaussians = atol(tokv[0]) ;

  return addMixtureCost(currentNet, group, numGaussians, min, max) ;
}
/**********************************************************************/
int		addMixtureCost(net, group, num, min, max)
  Net		net ;
  Group		group ;
  int		num ;
  double	min ;
  double	max ;
{
  CostModel	costModel ;

  if (net == NULL || num <= 0 || max < min)
    return 0 ;

  if (group) {
    costModel = createMixtureCostModel(group->name, net, num, min, max) ;

    groupForAllUnits(group, addToGroupCostModel, costModel) ;

    if (McostModel(group))
      MCMdestroy(McostModel(group)) ;
    McostModel(group) = costModel ;

  } else {
    costModel = createMixtureCostModel(net->name, net, num, min, max) ;

    netForAllUnits(net, ALL, addToNetCostModel, costModel) ;

    if (McostModel(net))
      MCMdestroy(McostModel(net)) ;
    McostModel(net) = costModel ;
  }

  markToRebuildDisplay(ALL_DISPLAYS) ;

  return 1 ;
}
/**********************************************************************/
int	command_addSumSquareCost(tokc, tokv)
  int	tokc ;
  char	**tokv ;
{
  CostModel	costModel ;
  Group		group ;
  String	name ;

  IUsage("[-group <group>]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("set the network cost model to use sum of squares") ;
    IHelp
      (IHelpArgs,
       "This command sets the cost model used by the  current network to use",
       "sum of squares. That is, the network cost will be the network  field",
       "\"weightCost\" times the  sum  of the  squares of  the  weights in the",
       "network.",
       "",
       "The \"-group\" option allows you to create a cost model that acts only",
       "on the incoming links of a  single group. If a  group does not  have",
       "it's own cost model, it uses the network cost model.",
       "",
       "SEE ALSO",
       "addMixtureCost, showCost",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No currentNet") ;

  name = *tokv ;
  group = NULL ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-group", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      group = groupFromName(currentNet, *tokv) ;
      if (group == NULL)
	IErrorAbort("Unknown group: \"%s\"", *tokv) ;
    } else if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }

  if (tokc != 0)
    IErrorAbort(IPrintUsage(name, usage)) ;

  return addSumSquareCost(currentNet, group) ;
}
/**********************************************************************/
int	addSumSquareCost(net, group)
  Net	net ;
  Group	group ;
{
  CostModel	costModel ;

  if (group) {
    costModel = createSumSquareCostModel(group->name, net) ;

    groupForAllUnits(group, addToGroupCostModel, costModel) ;

    if (McostModel(group))
      MCMdestroy(McostModel(group)) ;
    McostModel(group) = costModel ;

  } else {
    costModel = createSumSquareCostModel(net->name, net) ;

    netForAllUnits(net, ALL, addToNetCostModel, costModel) ;

    if (McostModel(net))
      MCMdestroy(McostModel(net)) ;
    McostModel(net) = costModel ;
  }

  markToRebuildDisplay(ALL_DISPLAYS) ;

  return 1 ;
}
/**********************************************************************/

int	command_resetMixtureCost(tokc, tokv)
  int	tokc ;
  char	**tokv ;
{
  CostModel	costModel ;
  Group		group ;
  String	name ;
  Real		min, max ;

  IUsage("[-min <x>] [-max <y>] [-group <group>]");
  if (GiveHelp(tokc)) {
    ISynopsis("reset the mixture of gaussians cost model in the net");
    IHelp
      (IHelpArgs,
       "This command is used to reset the Mixture of Gaussians cost model in",
       "the  current  net  (see addMixtureCost  for  a  description  of this",
       "model).",
       "",
       "When reseting the mixture,  you may  specify the minimum and maximum",
       "values to spread the means of  the gausians over. If no minimum  and",
       "maximum are specified the range [-1, 1] is assumed.",
       "",
       "The \"-group\" option  allows you to reset the cost model of  a single",
       "group.",
       "",
       "The means, variances, and proportions of the gaussians can be frozen",
       "and constrained  just  like the links.  They  are  given  the  names",
       "\"?.Mean.?\",  \"?.Variance.?\",  and \"?.Proportion.?\",  where the first",
       "\"?\"  is  the name of  the  network  (or group)  the  cost  model  is",
       "connected to, and the second is the index of the gaussian within the",
       "model, starting at 0.",
       "",
       "SEE ALSO",
       "addSumSquareCost, showCost, showMixture, updateMixtureDisplay",
       "",
       "CRG-TR-91-4, \"Simplifying Neural Networks by Soft Weight-Sharing\"",
       "Steven J. Nowlan, Geoffrey E. Hinton, October 1991.",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No currentNet") ;

  name = *tokv ;
  min = -1.0 ;
  max =  1.0 ;
  group = NULL ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-min", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      if (!IIsNumber(*tokv))
	IErrorAbort(IPrintUsage(name, usage)) ;
      min = atof(*tokv) ;
    } else if (strncmp(*tokv, "-max", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      if (!IIsNumber(*tokv))
	IErrorAbort(IPrintUsage(name, usage)) ;
      max = atof(*tokv) ;
    } else if (strncmp(*tokv, "-group", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      group = groupFromName(currentNet, *tokv) ;
      if (group == NULL)
	IErrorAbort("Unknown group: \"%s\"", *tokv) ;
    } else if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }

  if (tokc != 0)
    IErrorAbort(IPrintUsage(name, usage)) ;

  if (resetMixtureCost(currentNet, group, min, max) == 0)
    IErrorAbort("Not using mixture of gaussians cost model") ;
  else
    return 1 ;
}
/**********************************************************************/
int	resetMixtureCost(net, group, min, max)
  Net		net ;
  Group		group ;
  double	min ;
  double	max ;
{
  CostModel	costModel ;

  if (group)
    costModel = McostModel(group) ;
  else
    costModel = McostModel(currentNet) ;

  if (costModel == NULL || MCMtype(costModel) != CM_MIXTURE)
    return 0 ;

  resetMixtureCostModel(costModel, min, max) ;

  markToRebuildDisplay(ALL_DISPLAYS) ;

  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_deleteCost
 *	Description:	deletes the cost model from the current network
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_deleteCost - 0 on failure, 1 on success
 ***********************************************************************/
int	command_deleteCost(tokc, tokv)
  int	tokc ;
  char	**tokv ;
{
  Group		group ;
  String	name ;

  IUsage("[-group <group>]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("delete the network cost model") ;
    IHelp
      (IHelpArgs,
       "This command removes the network cost model.",
       "",
       "The \"-group\" option allows you to delete a cost model from a single",
       "group.  Once a group does not have it's own cost model, it uses the",
       "network cost model.",
       "SEE ALSO",
 "addMixtureCost, addSumSquareCost",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No currentNet") ;

  name = *tokv ;
  group = NULL ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-group", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      group = groupFromName(currentNet, *tokv) ;
      if (group == NULL)
	IErrorAbort("Unknown group: \"%s\"", *tokv) ;
    } else if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }

  if (tokc != 0)
    IErrorAbort(IPrintUsage(name, usage)) ;

  return deleteCost(currentNet, group) ;
}
/**********************************************************************/
int	deleteCost(net, group)
  Net	net ;
  Group	group ;
{
  if (group && McostModel(group)) {
    MCMdestroy(McostModel(group)) ;
    McostModel(group) = NULL ;
    if (McostModel(net))
      groupForAllUnits(group, addToNetCostModel, McostModel(net)) ;
  } else if (McostModel(net)) {
    MCMdestroy(McostModel(net)) ;
    McostModel(net) = NULL ;
  }

  markToRebuildDisplay(ALL_DISPLAYS) ;

  return 1 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		command_showCost
 *	Description:	shows the cost associated with a mixture of 
 *			gaussians
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_cascade - 0 on failure, 1 on success
 ***********************************************************************/
int	command_showCost(tokc, tokv)
  int	tokc ;
  char	**tokv ;
{
  CostModel	costModel ;
  Group		group ;
  String	name ;
  int	idx ;

  IUsage("[-group <group>]");
  if (GiveHelp(tokc)) {
    ISynopsis("print the cost of a network") ;
    IHelp
      (IHelpArgs,
       "The \"-group\"  option  allows you  to show the  cost  of  a  specific",
       "group's cost model.",
       "",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No currentNet") ;

  name = *tokv ;
  group = NULL ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-group", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      group = groupFromName(currentNet, *tokv) ;
      if (group == NULL)
	IErrorAbort("Unknown group: \"%s\"", *tokv) ;
    } else if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }

  if (tokc != 0)
    IErrorAbort(IPrintUsage(name, usage)) ;

  if (group)
    costModel = McostModel(group) ;
  else
    costModel = McostModel(currentNet) ;

  if (costModel == NULL)
    IErrorAbort("No cost model") ;
  
  for (idx = 0 ; idx < currentNet->numLinks ; ++idx)
    currentNet->links[idx]->deriv = 0.0 ;

  fprintf(dout, "Cost = %g\n", 
	  MCMevaluateCostAndDerivs(costModel, MweightCost(currentNet))) ;
  fprintf(dout, "Look at all the links to see the derivatives\n") ;

  return 1 ;
}
/**********************************************************************/



/***********************************************************************
 *	Name:		command_showMixture
 *	Description:	shows the cost associated with a mixture of 
 *			gaussians
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_showMixture - 0 on failure, 1 on success
 ***********************************************************************/
int	command_showMixture(tokc, tokv)
  int	tokc ;
  char	**tokv ;
{
  CostModel	costModel ;
  Mixture	mixture ;
  String	name ;
  Group		group ;
  int		idx ;
  
  IUsage("[-group <group>]");
  if (GiveHelp(tokc)) {
    ISynopsis("show the mixture of gaussians cost model") ;
    IHelp
      (IHelpArgs,
       "The  \"-group\"  option allows you to  show  the mixture  model  of  a",
       "specific group.",
       "",
       "",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("No currentNet") ;

  name = *tokv ;
  group = NULL ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (strncmp(*tokv, "-group", strlen(*tokv)) == 0) {
      ++tokv, --tokc ;
      group = groupFromName(currentNet, *tokv) ;
      if (group == NULL)
	IErrorAbort("Unknown group: \"%s\"", *tokv) ;
    } else if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }

  if (tokc != 0)
    IErrorAbort(IPrintUsage(name, usage)) ;

  if (group)
    costModel = McostModel(group) ;
  else
    costModel = McostModel(currentNet) ;

  if (costModel == NULL || MCMtype(costModel) != CM_MIXTURE)
    IErrorAbort("Not using a mixture of gaussians cost model.") ;

  mixture = costModel->costModelData.mixtureData->mixture ;

  fprintf(dout, "%10s\t%10s\t%10s\n", "Proportion", "Mean", "Variance") ;
  for (idx = 0 ; idx < MMnumGaussians(mixture) ; ++idx) {
    Gaussian	gaussian = MMgaussian(mixture, idx) ;
    fprintf(dout, "%10.6g\t%10.6g\t%10.6g\n",
	    MGproportion(gaussian), MGmean(gaussian), MGvariance(gaussian)) ;
  }
  return 1 ;
}
/**********************************************************************/

