/* funs - callbacks for Bayes routines in XLISP-STAT and S             */
/* XLISP-STAT 2.1 Copyright (c) 1990, by Luke Tierney                  */
/* Additions to Xlisp 2.1, Copyright (c) 1989 by David Michael Betz    */
/* You may give out copies of this software; for conditions see the    */
/* file COPYING included with this distribution.                       */

#include "xlisp.h"
#include "xlstat.h"
#define PRINTSTR(s) stdputstr(s)

extern char *S_alloc();
char *minresultstring();

/************************************************************************/
/**                                                                    **/
/**                      Definitions and Globals                       **/
/**                                                                    **/
/************************************************************************/

#define ROOT2PI 2.5066282746310005024157652848110452530070
#define PI_INV  0.3183098861837906715377675267450287240689

#define GRADTOL_POWER 1.0 / 3.0
#define H_POWER 1.0 / 6.0

typedef struct{
  char *f, **sf, **g;
  int n, k;
  int change_sign, fderivs;
  int *gderivs;
  double typf, h, dflt;
  RVector typx, fsum, cvals, ctarget;
  RMatrix gfsum;
} Fundata;

static Fundata func, gfuncs, cfuncs;

/* forward declarations */
LOCAL VOID add_tilt _((RVector x, double *pval, RVector grad, RMatrix hess,
		       double tilt, int exptilt));


/************************************************************************/
/**                                                                    **/
/**                         Memory Utilities                           **/
/**                                                                    **/
/************************************************************************/

/* this function is used to maintain a statically allocated piece of    */
/* memory of a specified size. If a larger piece is needed the pointer  */
/* is realloced. This allows functions using memory allocation to be    */
/* called repeatedly (but not recursively) from within the same call    */
/* from S. It attempts to avoid the danger of dangling callocs.         */

LOCAL VOID makespace(pptr, size)
     char **pptr;
     int size;
{
  if (size <= 0) return;
  if (*pptr == NULL) *pptr = calloc(size, 1);
  else *pptr = realloc(*pptr, size);
  if (size > 0 && *pptr == NULL) Recover("memory allocation failed", NULL);
}

/************************************************************************/
/**                                                                    **/
/**                    Functions Evaluation Routines                   **/
/**                                                                    **/
/************************************************************************/

/*
 * All Hessianevaluations by numerical derivatives assume the gradient is
 * evaluated first at the same location. The results are cached away.
 */

/* install log posterior function */
LOCAL VOID install_func(f, sf, n, change_sign, typf, h, typx, dflt)
     char *f, **sf;
     int n, change_sign;
     double typf, h, dflt;
     RVector typx;
{
  int i;
  static int inited = FALSE;

  if (! inited) {
    func.typx = nil;
    func.fsum = nil;
    inited = TRUE;
  }
  makespace((char **) &func.typx, n * sizeof(double));
  makespace((char **) &func.fsum, n * sizeof(double));

  func.f = f;
  func.sf = sf;
  func.n = n;
  func.change_sign = change_sign;
  func.typf = (typf > 0.0) ? typf : 1.0;
  func.h = (h > 0.0) ? h : pow(macheps(), H_POWER);
  for (i = 0; i < n; i++) 
    func.typx[i] = (typx != nil && typx[i] > 0.0) ? typx[i] : 1.0;
  func.dflt = dflt;
  func.fderivs = 0;
}

/* install tilt functions */
LOCAL VOID install_gfuncs(g, n, k, change_sign, h, typx)
     char **g;
     int n, k, change_sign;
     double h;
     RVector typx;
{
  int i;
  static int inited = FALSE;
  static double *gfsumdata = nil;

  if (! inited) {
    gfuncs.typx = nil;
    gfuncs.gfsum = nil;
    gfuncs.gderivs = nil;
    inited = TRUE;
  }
  makespace((char **) &gfuncs.typx, n * sizeof(double));
  makespace((char **) &gfuncs.gfsum, k * sizeof(double *));
  makespace((char **) &gfsumdata, k * n * sizeof(double));
  makespace((char **) &gfuncs.gderivs, k *sizeof(int));

  gfuncs.g = g;
  gfuncs.n = n;
  gfuncs.k = k;
  gfuncs.change_sign = change_sign;
  gfuncs.h = (h > 0.0) ? h : pow(macheps(), H_POWER);
  for (i = 0; i < n; i++)
    gfuncs.typx[i] = (typx != nil && typx[i] > 0.0) ? typx[i] : 1.0;
  for (i = 0; i < k; i++) gfuncs.gfsum[i] = gfsumdata + i * n;
}

/* install constraint functions */
LOCAL VOID install_cfuncs(g, n, k, ctarget, h, typx)
     char **g;
     int n, k;
     double h;
     RVector typx, ctarget;
{
  int i;
  static int inited = FALSE;

  if (! inited) {
    cfuncs.typx = nil;
    cfuncs.fsum = nil;
    cfuncs.gderivs = nil;
    inited = TRUE;
  }
  makespace((char **) &cfuncs.typx, n * sizeof(double));
  makespace((char **) &cfuncs.fsum, n * sizeof(double));
  makespace((char **) &cfuncs.gderivs, k * sizeof(int));

  cfuncs.g = g;
  cfuncs.n = n;
  cfuncs.k = k;
  cfuncs.h = (h > 0.0) ? h : pow(macheps(), H_POWER);
  for (i = 0; i < n; i++)
    cfuncs.typx[i] = (typx != nil && typx[i] > 0.0) ? typx[i] : 1.0;
  cfuncs.ctarget = ctarget;
}

/* callback to test if x is in the support of the posterior */
LOCAL int in_support(ff, n, x)
     char **ff;
     int n;
     double *x;
{
  char *args[1], *values[1];
  int *result;
  char *mode[1];
  long length[1];
  
  if (ff == nil || ff[0] == nil) return(TRUE);
  else {
    mode[0] = "double";
    length[0] =n;
    args[0] = (char *) x;
    call_S(ff[0], 1L, (ALLOCTYPE **) args, (ALLOCTYPE **) mode, length, 0L, 1L,
           (ALLOCTYPE **) values);
    result = (int *) values[0];
    return(result[0]);
  }
}

/* callback for logposterior evaluation */
LOCAL int evalfunc(x, pval, grad, hess)
     RVector x, grad;
     double *pval;
     RMatrix hess;
{
  char *args[1], *values[3];
  double *result, val;
  char *mode[1];
  long length[1];
  int i, j;

  for (i = 0; i < 3; i++) values[i] = nil;

  if (in_support(func.sf, func.n, x)) {
    if (pval != nil || func.fderivs > 0 || hess != nil) {
      mode[0] = "double";
      length[0] = func.n;
      args[0] = (char *) x;
      call_S(func.f, 1L, (ALLOCTYPE **) args, (ALLOCTYPE **) mode, length, 0L, 3L,
             (ALLOCTYPE **) values);
      result = (double *) values[0];
	  val = (! func.change_sign) ? result[0] : -result[0];
      if (pval != nil) *pval = val;
      if (values[2] != nil) func.fderivs = 2;
      else if (values[1] != nil) func.fderivs = 1;
      else func.fderivs = 0;
    }
    if (grad != nil) {
      if (func.fderivs > 0) {
	result = (double *) values[1];
	for (i = 0; i < func.n; i++)
	  grad[i] = (! func.change_sign) ? result[i] : -result[i];
      }
      else {
	numergrad(func.n, x, grad, func.fsum, evalfunc, func.h, func.typx);
      }
    }
    if (hess != nil) {
      if (func.fderivs == 2) {
	result = (double *) values[2];
	for (i = 0; i < func.n; i++) 
	  for (j = 0; j < func.n; j++)
	    hess[i][j] = (! func.change_sign) ? result[i + j * func.n]
	                                      : -result[i + j * func.n];
      }
      else {
	if (func.fderivs == 1) /* kludge to get fsum for analytic gradients */
	  numergrad(func.n, x, func.fsum, func.fsum,
		    evalfunc, func.h, func.typx);
	numerhess(func.n, x, hess, val, func.fsum, evalfunc, func.h, func.typx);
      }
    }
    return(TRUE);
  }
  else {
    if (pval != nil) *pval = func.dflt;
    return(FALSE);
  }
}


/* callback for tilt function evaluation */
static int which_gfunc;

LOCAL int evalgfunc(x, pval, grad, hess)
     RVector x, grad;
     double *pval;
     RMatrix hess;
{
  char *args[1], *values[3];
  double *result, val;
  char *mode[1];
  long length[1];
  int i, j;

  for (i = 0; i < 3; i++) values[i] = nil;

  if (pval != nil || gfuncs.gderivs[which_gfunc] > 0 || hess != nil) {
    mode[0] = "double";
    length[0] = gfuncs.n;
    args[0] = (char *) x;
    call_S(gfuncs.g[which_gfunc], 1L, (ALLOCTYPE **) args, (ALLOCTYPE **) mode,
           length, 0L, 3L, (ALLOCTYPE **) values);
    result = (double *) values[0];
	val = result[0];
    if (pval != nil) *pval = result[0];
    if (values[2] != nil) gfuncs.gderivs[which_gfunc] = 2;
    else if (values[1] != nil) gfuncs.gderivs[which_gfunc] = 1;
    else gfuncs.gderivs[which_gfunc] = 0;
  }
  if (grad != nil) {
    if (gfuncs.gderivs[which_gfunc] > 0) {
      result = (double *) values[1];
      for (i = 0; i < gfuncs.n; i++) grad[i] = result[i];
    }
    else {
      numergrad(gfuncs.n, x, grad, gfuncs.gfsum[which_gfunc], evalgfunc, 
		gfuncs.h, gfuncs.typx);
    }
  }
  if (hess != nil) {
    if (gfuncs.gderivs[which_gfunc] == 2) {
      result = (double *) values[2];
      for (i = 0; i < gfuncs.n; i++) 
	for (j = 0; j < gfuncs.n; j++)
	  hess[i][j] = result[i + j * gfuncs.n];
    }
    else {
      /* kludge to get fsum if analytic gradient used */
      if (gfuncs.gderivs[which_gfunc] == 1)
	numergrad(gfuncs.n, x, gfuncs.gfsum[which_gfunc],
		  gfuncs.gfsum[which_gfunc], evalgfunc, gfuncs.h, gfuncs.typx);
      numerhess(gfuncs.n, x, hess, val, gfuncs.gfsum[which_gfunc], evalgfunc,
		gfuncs.h, gfuncs.typx);
    }
  }
  return(FALSE);
}

/* callback for constraint function evaluation */
static int which_cfunc;

LOCAL int evalcfunc(x, pval, grad, hess)
     RVector x, grad;
     double *pval;
     RMatrix hess;
{
  char *args[1], *values[3];
  double *result, val;
  char *mode[1];
  long length[1];
  int i, j;

  if (pval != nil || cfuncs.gderivs[which_cfunc] > 0 || hess != nil) {
    mode[0] = "double";
    length[0] = cfuncs.n;
    args[0] = (char *) x;
    call_S(cfuncs.g[which_cfunc], 1L, (ALLOCTYPE **) args, (ALLOCTYPE **) mode,
           length, 0L, 3L, (ALLOCTYPE **) values);
    result = (double *) values[0];
	val = result[0];
    if (pval != nil) {
      *pval = result[0];
      if (cfuncs.ctarget != nil) *pval -= cfuncs.ctarget[which_cfunc];
    }
    if (values[2] != nil) cfuncs.gderivs[which_cfunc] = 2;
    else if (values[1] != nil) cfuncs.gderivs[which_cfunc] = 1;
    else cfuncs.gderivs[which_cfunc] = 0;
  }
  if (grad != nil) {
    if (cfuncs.gderivs[which_cfunc] > 0) {
      result = (double *) values[1];
      for (i = 0; i <cfuncs.n; i++) grad[i] = result[i];
    }
    else {
      numergrad(cfuncs.n, x, grad, cfuncs.fsum, evalcfunc, 
		cfuncs.h, cfuncs.typx);
    }
  }
  if (hess != nil) {
    if (cfuncs.gderivs[which_cfunc] == 2) {
      result = (double *) values[2];
      for (i = 0; i <cfuncs.n; i++)
	for (j = 0; j <cfuncs.n; j++)
	  hess[i][j] = result[i + j * cfuncs.n];
    }
    else {
      /* kludge to get fsum if analytic gradient used */
      if (cfuncs.gderivs[which_cfunc] == 1)
	numergrad(cfuncs.n, x, cfuncs.fsum, cfuncs.fsum, evalcfunc, 
		  cfuncs.h, cfuncs.typx);
      numerhess(cfuncs.n, x, hess, val, cfuncs.fsum, evalcfunc,
		cfuncs.h, cfuncs.typx);
    }
  }
  return(FALSE);
}

/* S front end for logposterior evaluation */
VOID evalfront(ff, n, x, val, grad, phess, h, typx)
     char **ff;
     int *n;
     double *x, *val, *grad, *phess, *typx, *h;
{
  int i;
  static RMatrix hess = nil;

  install_func(ff[0], nil, *n, FALSE, 1.0, *h, typx, 0.0);
  if (phess == nil) hess = nil;
  else {
    makespace((char **) &hess, *n * sizeof(double *));
    for (i = 0; i < *n; i++, phess += *n) hess[i] = phess;
  }
  evalfunc(x, val, grad, hess);
}

/************************************************************************/
/**                                                                    **/
/**                       Maximization Routines                        **/
/**                                                                    **/
/************************************************************************/

struct {
  double tilt;
  RVector gval;
  RMatrix  ggrad, ghess;
  int exptilt;
  RVector tscale;
} tiltinfo;

LOCAL VOID set_tilt_info(n, m, tilt, exptilt, tscale)
     int n, m;
     double tilt, *tscale;
     int exptilt;
{
  static double *hessdata = nil, *graddata = nil;
  int i;
  static int inited = FALSE;

  if (! inited) {
    tiltinfo.gval = nil;
    tiltinfo.ggrad = nil;
    tiltinfo.ghess = nil;
    inited = TRUE;
  }
  makespace((char **) &tiltinfo.gval, n * sizeof(double));
  makespace((char **) &tiltinfo.ggrad, m * sizeof(double *));
  makespace((char **) &tiltinfo.ghess, n * sizeof(double *));
  makespace((char **) &graddata, n * m * sizeof(double));
  makespace((char **) &hessdata, n * n * sizeof(double));

  tiltinfo.tilt = tilt;
  tiltinfo.exptilt = exptilt;
  for (i = 0; i < m; i++) tiltinfo.ggrad[i] = graddata + i * n;
  for (i = 0; i < n; i++) tiltinfo.ghess[i] = hessdata + i * n;
  tiltinfo.tscale = tscale;
}

LOCAL VOID minfunc(x, pval, grad, hess)
     RVector x, grad;
     double *pval;
     RMatrix hess;
{
  int k = gfuncs.k;

  if (evalfunc(x, pval, grad, hess) && (k > 0))
    add_tilt(x, pval, grad, hess, tiltinfo.tilt, tiltinfo.exptilt);
}

LOCAL int constfunc(x, vals, jac, hess)
     RVector x, vals;
     RMatrix jac, hess;
{
  int i, k = cfuncs.k;
  double *pvali, *jaci;

  for (i = 0; i < k; i++) {
    pvali = (vals != nil) ? vals + i : nil;
    jaci = (jac != nil) ? jac[i] : nil;
    which_cfunc = i;
    evalcfunc(x, pvali, jaci, nil);
  }
  return(FALSE);
}

LOCAL VOID add_tilt(x, pval, grad, hess, tilt, exptilt)
     RVector x, grad;
     double *pval, tilt;
     RMatrix hess;
     int exptilt;
{
  int i, j, k, n = func.n, m = gfuncs.k;
  double *gval, *ggrad, **ghess, etilt;

  if (m == 0) return;

  if (gfuncs.change_sign) tilt = -tilt;

  for (k = 0; k < m; k++) {
    gval = (pval != nil) ? tiltinfo.gval + k : nil;
    ggrad = (grad != nil) ? tiltinfo.ggrad[k] : nil;
    ghess = (hess != nil) ? tiltinfo.ghess : nil;

    which_gfunc = k;
    evalgfunc(x, gval, ggrad, ghess);
    
    if (exptilt) {
      etilt = (tiltinfo.tscale != nil) ? tilt / tiltinfo.tscale[k] : tilt;
      if (pval != nil) *pval += etilt * *gval;
      if (grad != nil) 
	for (i = 0; i < n; i++) grad[i] += etilt * ggrad[i];
      if (hess != nil)
	for (i = 0; i < n; i++) 
	  for (j = 0; j < n; j++) hess[i][j] += etilt * ghess[i][j];
    }
    else {
      gval = tiltinfo.gval;
      ggrad = tiltinfo.ggrad[k];
      ghess = tiltinfo.ghess;
      if (gval[k] <= 0.0) Recover("nonpositive function value", NULL);
      if (pval != nil) *pval += tilt * log(gval[k]);
      if (grad != nil) 
	for (i = 0; i < n; i++) grad[i] += tilt * ggrad[i] / gval[k];
      if (hess != nil)
        for (i = 0; i < n; i++)
          for (j = 0; j < n; j++)
	    hess[i][j] +=
	      tilt * (ghess[i][j] / gval[k] 
		      - (ggrad[i] / gval[k]) * (ggrad[j] / gval[k]));
    }
  }
}

VOID maxfront(ff, gf, cf, x, typx, fvals, gvals, cvals, ctarget, ipars, dpars, 
	 tscale, msg)
     char **ff, **gf, **cf;
     double *x, *typx, *fvals, *gvals, *cvals, *ctarget, *tscale;
     MaxIPars *ipars;
     MaxDPars *dpars;
     char **msg;
{
  static char *work = nil;
  static RMatrix H = nil, cJ = nil;
  double *pf, *grad, *c;
  int i, n, m, k;
  int (*cfun)();

  if (ipars->verbose > 0) PRINTSTR("maximizing...\n");

  n = ipars->n;
  m = ipars->m;
  k = ipars->k;
  if (k >= n) Recover("too many constraints", NULL);

  makespace((char **) &H, n * sizeof(double *));
  makespace((char **) &work, minworkspacesize(n, k));

  pf = fvals; fvals++;
  grad = fvals; fvals += n;
  for (i = 0; i < n; i++, fvals += n) H[i] = fvals;
  set_tilt_info(n, m, dpars->newtilt, ipars->exptilt, tscale);

  if (k == 0) {
    c = nil;
    cJ = nil;
    cfun = nil;
  }
  else {
    c = cvals;
    cvals += k;
    makespace((char **) &cJ, k * sizeof(double *));
    for (i = 0; i < k; i++) cJ[i] = cvals + i * n;
    cfun = constfunc;
  }

  install_func(ff[0], nil, n, TRUE, dpars->typf, dpars->h, typx, dpars->dflt);
  install_gfuncs(gf, n, m, TRUE, dpars->h, typx);
  install_cfuncs(cf, n, k, ctarget, dpars->h, typx);

  minsetup(n, k, minfunc, cfun, x, dpars->typf, typx, work);
  minsetoptions(dpars->gradtol, dpars->steptol, dpars->maxstep,
		ipars->itnlimit, ipars->verbose, ipars->backtrack, TRUE);

  if (ipars->vals_suppl) {
    for (i = 0; i < k; i++) c[i] -= ctarget[i];
    if (dpars->newtilt != dpars->tilt) {
      add_tilt(x, pf, grad, H, dpars->newtilt - dpars->tilt, ipars->exptilt);
      dpars->tilt = dpars->newtilt;
    }
    minsupplyvalues(*pf, grad, H, c, cJ);
  }

  minimize();
  minresults(x, pf, nil, grad, H, c, cJ, &ipars->count, &ipars->termcode,
	     &dpars->hessadd);
  msg[0] = minresultstring(ipars->termcode);

  for (i = 0; i < k; i++) c[i] += ctarget[i];
  ipars->vals_suppl = TRUE;
}

/************************************************************************/
/**                                                                    **/
/**                     Log Laplace Approximation                      **/
/**                                                                    **/
/************************************************************************/

VOID loglapdet(fvals, cvals, ipars, dpars, val, detonly)
     double *fvals, *cvals;
     MaxIPars *ipars;
     MaxDPars *dpars;
     double *val;
     int *detonly;
{
  int i, j, l, n = ipars->n, k = ipars->k;
  double f = -fvals[0], *hessdata = fvals + n + 1, *cgraddata = cvals + k;
  double ldL, ldcv, maxadd;
  static RMatrix hess = nil, cgrad = nil;

  if (k >= n) Recover("too many constraints", NULL);

  makespace((char **) &hess, n * sizeof(double *));
  makespace((char **) &cgrad, k * sizeof(double *));

  for (i = 0; i < n; i++) hess[i] = hessdata + i * n;
  for (i = 0; i < k; i++) cgrad[i] = cgraddata + i * n;

  choldecomp(hess, n, 0.0, &maxadd);
  /**** do something if not pos. definite ****/
  
  for (i = 0, ldL = 0.0; i < n; i++) ldL += log(hess[i][i]);

  if (k > 0) {
    /* forward solve for (L^-1) cgrad^T */
    for (l = 0; l < k; l++) {
      for (i = 0; i < n; i++) {
	if (hess[i][i] != 0.0) cgrad[l][i] /= hess[i][i];
	for (j = i + 1; j < n; j++) cgrad[l][j] -= hess[j][i] * cgrad[l][i];
      }
    }

    /* compute sigma and stdev */
    for (i = 0; i < k; i++) {
      for (j = i; j < k; j++) {
	for (l = 0, hess[i][j] = 0.0; l < n; l++)
	  hess[i][j] += cgrad[i][l] * cgrad[j][l];
	hess[j][i] = hess[i][j];
      }
    }

    choldecomp(hess, k, 0.0, &maxadd);
    /**** do something if not pos. definite ****/
    for (i = 0, ldcv = 0.0; i < k; i++) ldcv += log(hess[i][i]);
  }
  else ldcv = 0.0;

  *val = (n - k) * log(ROOT2PI) - ldL - ldcv;
  if (! *detonly) *val += f;
}

#ifdef TODO
get hessian from gradiant for analytical gradiants
avoid repeated derivative calls in mimimize.
2d margins
use pos. definiteness info in margins
#endif /* TODO */
