
/**********************************************************************
 * $Id: command.c,v 1.4 93/01/15 13:04:19 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <xerion/useful.h>
#include <xerion/simulator.h>
#include <xerion/commands.h>

#include "unit.h"
#include "bp.h"

typedef struct _GatePair {
  Group	expert ;
  Unit	unit ;
} GatePairRec, *GatePair ;

#define expertIsGated(e)	(Mgate(e))

static int	resetGate	ARGS((Group gate, Group output)) ;
static int	confirmGate	ARGS((Group gate)) ;
static int	ungateExpert	ARGS((Group expert)) ;
static int	gateExpert	ARGS((Group gate, Group expert)) ;

static void	ungateExpertFA	ARGS((Unit, void *)) ;
static void	confirmGateFA	ARGS((Unit, void *)) ;
static void	findExpertFA	ARGS((Unit, void *)) ;
static void	findGateFA	ARGS((Unit, void *)) ;

/***********************************************************************
 *	Name:		command_gate
 *	Description:	gate a set of expert groups
 *		int	tokc    - the number of command line tokens
 *		char	*tokv[] - the vector of tokens
 *	Return Value:	
 *		int	command_gate - 0 on failure, 1 on success
 ***********************************************************************/
int	command_gate(tokc, tokv)
  int	tokc ;
  char	**tokv ;
{
  Group	gate, output, expert ;
  char	*name ;

  IUsage("<gate> <ouput> <expert1> [<expert2> ...]") ;
  if (GiveHelp(tokc)) {
    ISynopsis("gate a set of experts");
    IHelp
      (IHelpArgs,
       "This  command is used to build a  mixture of experts network. To use",
       "it, you must have already created the following groups:",
       "",
       "n expert groups with m units:",
       "	bp-> addGroup \"Expert 1\" <m>",
       "	...",
       "	bp-> addGroup \"Expert <n>\" <m>",
       "",
       "a dummy output group with m units, no input connections and type",
       "GATED:",
       "	bp-> addGroup -type OUTPUT -type GATED \"Target\" <m>",
       "	bp-> disconnectGroups '.*' 'Target'",
       "",
       "a GATE unit with n units (the same number of units as experts)",
       "	bp-> addGroup -type GATE \"Gate\" <n>",
       "",
       "With these groups you can then use the command:",
       "	bp-> gate \"Gate\" \"Target\" \"Expert 1\" ... \"Expert <n>\"",
       "",
       "The command  will  associate each unit  in the gate with one of  the",
       "expert groups (in the order specified on the command line), and  all",
       "of the  experts with the output  group. When an example is presented",
       "to  the network, the targets are stored  in the output  group.   The",
       "error measure is then calculated using the gate and experts, not the",
       "output group. ",
       "",
       "The major  assumption is that the errors  of each expert come from a",
       "gaussian distribution with  0 mean, and fixed variance.  The gate is",
       "used to calculate mixing  proportions for each of these gaussians to",
       "produce  a mixture  of gaussians error function  for the entire net.",
       "The error  being  optimized is the -log  of the probability that the",
       "observed  errors  came  from  the  underlying mixture  of  gaussians",
       "distribution.",
       "",
       NULL) ;
    return 1 ;
  }

  if (currentNet == NULL)
    IErrorAbort("There is no current net.") ;

  /* Parse any command line options */
  name = *tokv ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    if (*tokv[0] == '-') {
      IErrorAbort(IPrintUsage(name, usage)) ;
    } else {
      break ;
    }
  }
  if (tokc < 3)
    IErrorAbort(IPrintUsage(name, usage)) ;

  gate = groupFromName(currentNet, *tokv) ;
  if (gate == NULL)
    IErrorAbort("Unknown group: \"%s\".", *tokv) ;
  if (!(gate->type & GATE))
    IErrorAbort("\"%s\" is not a gate group.", *tokv) ;

  ++tokv, --tokc ;
  output = groupFromName(currentNet, *tokv) ;
  if (output == NULL)
    IErrorAbort("Unknown group: \"%s\".", *tokv) ;
  if (!(output->type & GATED))
    IErrorAbort("\"%s\" is not a gated group.", *tokv) ;

  resetGate(gate, output) ;
  for (++tokv, --tokc ; tokc > 0 ; ++tokv, --tokc) {
    expert = groupFromName(currentNet, *tokv) ;
    if (expert == NULL)
      IErrorAbort("Unknown group: \"%s\".", *tokv) ;
    if (expertIsGated(expert))
      ungateExpert(expert) ;
    gateExpert(gate, expert) ;
  }

  if (!confirmGate(gate))
    warn("gate", "Gate has empty slots.") ;

  return 1 ;
}  
/********************************************************************/
    

/*********************************************************************
 *	Name:		resetGate
 *	Description:	
 *	Parameters:
 *	  Group	gate - 
 *	Return Value:
 *	  int	resetGate - 
 *********************************************************************/
static int	resetGate(gate, output)
  Group		gate ;
  Group		output ;
{
  int	idx ;

  if (!(gate->type & GATE)
      || !(output->type & OUTPUT) || !(output->type & GATED))
    return 0 ;

  Mgated(gate)  = output ;
  Mgate(output) = gate ;

  if (Mmixture(gate))
    MMdestroy(Mmixture(gate)) ;

  Mmixture(gate) = createMixture(gate->numUnits, 0.0, 0.0) ;
  for (idx = 0 ; idx < gate->numUnits ; ++idx) {
    Gaussian	gaussian = MMgaussian(Mmixture(gate), idx) ;
    MGsetVariance(gaussian, 1.0) ;
    Mgaussian(gate->unit[idx]) = gaussian ;
  }

  groupForAllUnits(gate, ungateExpertFA, NULL) ;

  return 1 ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		gateExpert
 *	Description:	
 *	Parameters:
 *	  Group	gate - 
 *	  Group	expert - 
 *	Return Value:
 *	  int	gateExpert - 
 *********************************************************************/
static int	gateExpert(gate, expert)
  Group	gate ;
  Group	expert ;
{
  GatePairRec	gatePair ;

  if (!(gate->type & GATE))
    return 0 ;

  if (expert->numUnits != Mgated(gate)->numUnits)
    return 0 ;
  
  gatePair.unit   = NULL ;
  gatePair.expert = NULL ;
  groupForAllUnitsBack(gate, findGateFA, &gatePair) ;

  if (gatePair.unit == NULL)
    return 0 ;

  Mexpert(gatePair.unit)   = expert ;

  return 1 ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		ungateExpert
 *	Description:	
 *	Parameters:
 *	  Group	expert - 
 *	Return Value:
 *	  int	ungateExpert - 
 *********************************************************************/
static int	ungateExpert(expert)
  Group	expert ;
{
  GatePairRec	gatePair ;
  Group	gate = Mgate(expert) ;

  if (!(gate->type & GATE))
    return 0 ;

  gatePair.unit   = NULL ;
  gatePair.expert = expert ;
  groupForAllUnitsBack(gate, findGateFA, &gatePair) ;

  if (gatePair.unit == NULL)
    return 0 ;

  Mexpert(gatePair.unit) = NULL ;
  Mgate(gatePair.expert) = NULL ;

  return 1 ;
}
/********************************************************************/

/*********************************************************************
 *	Name:		confirmGate
 *	Description:	
 *	Parameters:
 *	  Group	gate - 
 *	Return Value:
 *	  int	confirmGate - 
 *********************************************************************/
static int	confirmGate(gate)
  Group	gate ;
{
  Boolean	ok = TRUE ;

  if (!(gate->type & GATE))
    ok = FALSE ;

  groupForAllUnits(gate, confirmGateFA, &ok) ;
  
  return ok ;
}
/********************************************************************/


/********************************************************************/
static void	ungateExpertFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  if (Mexpert(unit)) {
    ungateExpert(Mexpert(unit)) ;
    Mexpert(unit) = NULL ;
  }
}
/********************************************************************/
static void	confirmGateFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  Group	expert = Mexpert(unit) ;
  Group	output = Mgated(unit->group) ;

  if (expert  == NULL)
    *(Boolean *)data = FALSE ;
  else if (expert->numUnits != output->numUnits)
    *(Boolean *)data = FALSE ;
}
/********************************************************************/
static void	findGateFA(unit, data)
  Unit		unit ;
  void		*data ;
{
  GatePair	gatePair = (GatePair)data ;
  if (Mexpert(unit) == gatePair->expert)
    gatePair->unit = unit ;
}
/********************************************************************/
