
/**********************************************************************
 * $Id: costModel.c,v 1.7 92/11/30 13:13:01 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>
#include <xerion/useful.h>
#include <xerion/simulator.h>

#include "gaussian.h"
#include "costModel.h"

#ifndef M_PI
#define M_PI		3.14159265358979323846
#endif
#ifndef M_SQR
#define M_SQR(x)	((x)*(x))
#endif

/********************************************************************/
static CostModel	createCostModel ARGS((String name,
					      int type, Net net, 
					      int, double, double)) ;
static SumSquareData	createSumSquareData ARGS((CostModel)) ;
static MixtureData	createMixtureData ARGS((CostModel, 
						int, double, double)) ;
/********************************************************************/
static void	destroy ARGS((CostModel)) ;
/********************************************************************/
static void	addLink    ARGS((CostModel, Link)) ;
static void	removeLink ARGS((CostModel, Link)) ;
/********************************************************************/
static void	mixtureReset	     ARGS((CostModel, double, double)) ;
static void	mixtureSync	     ARGS((CostModel)) ;
static Real	mixtureCost	     ARGS((CostModel, double)) ;
static Real	mixtureCostAndDerivs ARGS((CostModel, double)) ;
/********************************************************************/
static void	sumSquareSync	       ARGS((CostModel)) ;
static Real	sumSquareCost	       ARGS((CostModel, double)) ;
static Real	sumSquareCostAndDerivs ARGS((CostModel, double)) ;
/********************************************************************/


/*********************************************************************
 *	Name:		createCostModel/destroyCostModel
 *	Description:	creates (destroys) a costModel distribution object
 *	Parameters:
 *	  int		numGaussians - the number of gaussians in the
 *			  costModel
 *	  double	min - the minimum value for the range the 
 *			  gaussians will be spread out on.
 *	  double	max - the maximum value.
 *	Return Value:
 *	  CostModel	createCostModel  - the new costModel
 *	  void		destroyCostModel - NONE
 *********************************************************************/
CostModel	createMixtureCostModel(name, net, numGaussians, min, max)
  String	name ;
  Net		net ;
  int		numGaussians ;
  double	min ;
  double	max ;
{
  return createCostModel(name, CM_MIXTURE, net, numGaussians, min, max) ;
}
/********************************************************************/
void		resetMixtureCostModel(this, min, max)
  CostModel	this ;
  double	min ;
  double	max ;
{
  this->reset(this, min, max) ;
}
/********************************************************************/
CostModel	createSumSquareCostModel(name, net)
  String	name ;
  Net		net ;
{
  return createCostModel(name, CM_SUMSQUARE, net, 0.0, 0.0, 0.0) ;
}
/********************************************************************/
static CostModel createCostModel(name, type, net, numGaussians, min, max)
  String	name ;
  int		type ;
  Net		net ;
  int		numGaussians ;
  double	min ;
  double	max ;
{
  CostModel	this ;
  int		idx ;
  
  this = (CostModel)malloc(sizeof(CostModelRec)) ;
  if (this == NULL)
    return NULL ;
  this->name	 = strdup(name) ;
  this->type	 = type ;
  this->net	 = net ;
  this->numLinks = 0 ;
  this->maxLinks = 0 ;
  this->link     = NULL ;

  this->destroy = destroy ;
  this->addLink = addLink ;
  this->removeLink = removeLink ;

  if (type == CM_MIXTURE)
    this->costModelData.mixtureData
      = createMixtureData(this, numGaussians, min, max) ;
  else if (type == CM_SUMSQUARE)
    this->costModelData.sumSquareData = createSumSquareData(this) ;
  else
    return NULL ;

  return this ;
}
/********************************************************************/
static SumSquareData	createSumSquareData(parent)
  CostModel	parent ;
{
  SumSquareData	this ;
  int		idx ;
  Net		net = parent->net ;

  this = (SumSquareData)malloc(sizeof(SumSquareDataRec)) ;
  if (this == NULL)
    return NULL ;

  parent->reset			= NULL ;
  parent->sync			= sumSquareSync ;
  parent->evaluateCost		= sumSquareCost ;
  parent->evaluateCostAndDerivs	= sumSquareCostAndDerivs ;
  
  return this ;
}
/********************************************************************/
static MixtureData	createMixtureData(parent, numGaussians, min, max)
  CostModel	parent ;
  int		numGaussians ;
  double	min ;
  double	max ;
{
  MixtureData	this ;
  int		idx ;
  Net		net = parent->net ;

  this = (MixtureData)malloc(sizeof(MixtureDataRec)) ;
  if (this == NULL)
    return NULL ;

  parent->reset			= mixtureReset ;
  parent->sync			= mixtureSync ;
  parent->evaluateCost		= mixtureCost ;
  parent->evaluateCostAndDerivs	= mixtureCostAndDerivs ;

  this->mixture	     = createMixture(numGaussians, min, max) ;
  this->numGaussians = MMnumGaussians(this->mixture) ;
  numGaussians       = this->numGaussians ;

  if (numGaussians == 0) {
    this->mean		= NULL ;
    this->stdDeviation	= NULL ;
    this->proportion	= NULL ;
  } else {
    this->mean		= (Variable *)calloc(numGaussians, sizeof(Variable));
    this->stdDeviation	= (Variable *)calloc(numGaussians, sizeof(Variable));
    this->proportion	= (Variable *)calloc(numGaussians, sizeof(Variable));
  }
    
  /* create the variables to hold the variables */
  for (idx = 0 ; idx < numGaussians ; ++idx) {
    Gaussian	gaussian = MMgaussian(this->mixture, idx) ;
    char	name[32] ;

    sprintf(name, "%s.Mean.%d", parent->name, idx) ;
    this->mean[idx] = createVariable(name, net, UNKNOWN) ;
    variableSetValue(this->mean[idx], MGmean(gaussian)) ;

    sprintf(name, "%s.StdDeviation.%d", parent->name, idx) ;
    this->stdDeviation[idx] = createVariable(name, net, LOG_TRANSFORM) ;
    variableSetValue(this->stdDeviation[idx], MGstdDeviation(gaussian)) ;

    sprintf(name, "%s.Proportion.%d", parent->name, idx) ;
    this->proportion[idx] = createVariable(name, net, LOG_TRANSFORM) ;
    variableSetValue(this->proportion[idx], MGproportion(gaussian)) ;
  }
  return this ;
}
/********************************************************************/
static void	mixtureReset(this, min, max)
  CostModel	this ;
  double	min ;
  double	max ;
{
  MixtureData	data = this->costModelData.mixtureData ;
  Real		*variables = this->net->variables ;
  int		idx ;

  if (data->mixture)
    MMdestroy(data->mixture) ;

  data->mixture = createMixture(data->numGaussians, min, max) ;

  /* reset the variables */
  for (idx = 0 ; idx < data->numGaussians ; ++idx) {
    Gaussian	gaussian = MMgaussian(data->mixture, idx) ;
    Variable	variable ;

    variable = data->mean[idx] ;
    if (!(variableGetType(data->mean[idx]) & FROZEN))
      variableSetValue(variable, MGmean(gaussian)) ;

    variable = data->stdDeviation[idx] ;
    if (!(variableGetType(variable) & FROZEN))
      variableSetValue(variable, MGstdDeviation(gaussian)) ;

    variable = data->proportion[idx] ;
    if (!(variableGetType(variable) & FROZEN))
      variableSetValue(variable, MGproportion(gaussian)) ;
  }
}
/********************************************************************/
static void	destroy(this)
  CostModel	this ;
{
  int		idx ;

  for (idx = this->numLinks - 1 ; idx >= 0 ; --idx)
    MCMremoveLink(this, this->link[idx]) ;

  if (this->link)
    free((void *)this->link) ;

  if (this->type == CM_MIXTURE) {
    MixtureData	mixtureData = this->costModelData.mixtureData ;

    MMdestroy(mixtureData->mixture) ;

    for (idx = 0 ; idx < mixtureData->numGaussians ; ++idx) {
      destroyVariable(mixtureData->mean[idx],		this->net) ;
      destroyVariable(mixtureData->stdDeviation[idx],	this->net) ;
      destroyVariable(mixtureData->proportion[idx],	this->net) ;
    }

    if (mixtureData->mean)
      free((void *)mixtureData->mean) ;
    if (mixtureData->stdDeviation)
      free((void *)mixtureData->stdDeviation) ;
    if (mixtureData->proportion)
      free((void *)mixtureData->proportion) ;
    
    free((void *)mixtureData) ;
  } if (this->type == CM_SUMSQUARE) {
    free((void *)this->costModelData.sumSquareData) ;
  }

  free((void *)this) ;
}
/********************************************************************/

/********************************************************************/
#ifndef GRANULARITY
#define GRANULARITY 64
#endif
/********************************************************************/
static void	addLink(this, link)
  CostModel	this ;
  Link		link ;
{
  if (this->numLinks >= this->maxLinks) {
    if (this->maxLinks == 0) {
      /* set to sum default initial size */
      this->maxLinks = this->numLinks + GRANULARITY ;
      this->link     = (Link *)malloc(this->maxLinks*sizeof(Link)) ;
    } else {
      /* double the size of the array */
      this->maxLinks = 2 * this->numLinks ;
      this->link     = (Link *)realloc(this->link, 
				       this->maxLinks*sizeof(Link)) ;
    }
  }

  this->link[this->numLinks] = link ;
  ++(this->numLinks) ;
}
/**********************************************************************/  
static void	removeLink(this, link)
  CostModel	this ;
  Link		link ;
{
  int	idx ;

  for (idx = 0 ; idx < this->numLinks && this->link[idx] != link ; ++idx)
    ;

  if (this->link && this->link[idx] == link) {
    for ( ; idx < this->numLinks ; ++idx)
      this->link[idx] = this->link[idx + 1] ;
    --(this->numLinks) ;
  }
}
/**********************************************************************/  


/**********************************************************************/  
static void     mixtureSync(this)
  CostModel     this ;
{
  MixtureData	data    = this->costModelData.mixtureData ;
  Mixture	mixture = data->mixture ;
  Real		sum ;
  int		idx ;
  
  /* synchronize all the variables */
  for (idx = 0 ; idx < data->numGaussians ; ++idx) {
    Gaussian	gaussian = MMgaussian(mixture, idx) ;

    MGsetMean       (gaussian, variableGetValue(data->mean[idx])) ;
    MGsetVariance   (gaussian,
		     M_SQR(variableGetValue(data->stdDeviation[idx]))) ;
    MGsetProportion (gaussian, variableGetValue(data->proportion[idx])) ;
  }

  /* normalize the proportions */
  sum = 0.0 ;
  for (idx = 0 ; idx < data->numGaussians ; ++idx)
    sum += MGproportion(MMgaussian(mixture, idx)) ;

  if (sum != 1.0) {
    for (idx = 0 ; idx < data->numGaussians ; ++idx) {
      Gaussian	gaussian = MMgaussian(mixture, idx) ;
      MGsetProportion(gaussian, MGproportion(gaussian)/sum) ;
    }
  }
}
/**********************************************************************/  
static Real     mixtureCost(this, scale)
  CostModel     this ;
  double	scale ;
{
  MixtureData	data    = this->costModelData.mixtureData ;
  Mixture	mixture = data->mixture ;
  Real		sum ;
  int		idx ;

  MCMsync(this) ;

  /* now evaluate the complexity */
  sum = 0.0 ;
  for (idx = 0 ; idx < this->numLinks ; ++idx)
    sum += MMcomplexity(mixture, this->link[idx]->weight) ;

  return scale*sum ;
}
/**********************************************************************/
static Real	mixtureCostAndDerivs(this, scale)
  CostModel	this ;
  double	scale ;
{
  static Real	*dCdMean, *dCdProportion, *dCdStdDeviation ;
  static int	numGaussians ;

  MixtureData	data       = this->costModelData.mixtureData ;
  Mixture	mixture	   = data->mixture ;
  int		numLinks   = this->numLinks ;
  Real		sum, cost ;
  int		idx ;

  /* set up the arrays */
  if (numGaussians != MMnumGaussians(mixture)) {
    numGaussians = MMnumGaussians(mixture) ;
    if (dCdMean)
      free((void *)dCdMean) ;
    dCdMean = (Real *)calloc(numGaussians, sizeof(Real)) ;
    if (dCdStdDeviation)
      free((void *)dCdStdDeviation) ;
    dCdStdDeviation = (Real *)calloc(numGaussians, sizeof(Real)) ;
    if (dCdProportion)
      free((void *)dCdProportion) ;
    dCdProportion = (Real *)calloc(numGaussians, sizeof(Real)) ;
  }
  for (idx = 0 ; idx < numGaussians ; ++idx) {
    variableSetDeriv(data->mean[idx],		0.0) ;
    variableSetDeriv(data->stdDeviation[idx],	0.0) ;
    variableSetDeriv(data->proportion[idx],	0.0) ;
  }

  /* evaluate cost first to be sure all variables are synced */
  cost = MCMevaluateCost(this, scale) ;

  for (idx = 0 ; idx < numLinks ; ++idx) {
    int	subIdx ;
    this->link[idx]->deriv
      += scale*MMderivatives(mixture, this->link[idx]->weight, 
			     dCdMean, dCdStdDeviation, dCdProportion) ;

    for (subIdx = 0 ; subIdx < numGaussians ; ++subIdx) {
      Gaussian	gaussian = MMgaussian(mixture, subIdx) ;

      variableAddToDeriv(data->mean[subIdx], 
			 scale * dCdMean[subIdx]) ;
      variableAddToDeriv(data->stdDeviation[subIdx],
			 scale * dCdStdDeviation[subIdx]) ;
      variableAddToDeriv(data->proportion[subIdx],
			 scale * dCdProportion[subIdx]) ;
    }
  }

  /* take care of normalized proportions */
  sum = 0.0 ;
  for (idx = 0 ; idx < numGaussians ; ++idx)
    sum += variableGetDeriv(data->proportion[idx]) ;

  for (idx = 0 ; idx < numGaussians ; ++idx)
    variableAddToDeriv(data->proportion[idx],
		       -sum*MGproportion(MMgaussian(mixture, idx))) ;

  return cost ;
}
/**********************************************************************/  


/**********************************************************************/  
static void     sumSquareSync(this)
  CostModel     this ;
{
  return ;
}
/**********************************************************************/  
static Real     sumSquareCost(this, scale)
  CostModel     this ;
  double	scale ;
{
  Real	sum ;
  int	idx ;

  MCMsync(this) ;

  for (sum = 0.0, idx = 0 ; idx < this->numLinks ; ++idx)
    sum += M_SQR(this->link[idx]->weight) ;

  return scale*sum ;
}
/**********************************************************************/
static Real     sumSquareCostAndDerivs(this, scale)
  CostModel     this ;
  double	scale ;
{
  Real	sum ;
  int	idx ;

  MCMsync(this) ;

  for (sum = 0.0, idx = 0 ; idx < this->numLinks ; ++idx) {
    Link	link = this->link[idx] ;
    sum		+= M_SQR(link->weight) ;
    link->deriv	+= 2.0*scale*link->weight ;
  }
  return scale*sum ;
}
/**********************************************************************/
