/**********************************************************************
 * Written by Tony Plate in June 1991.
 **********************************************************************/


/**********************************************************************
 * $Id: checkgrad.c,v 1.4 93/02/09 16:45:12 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>
#include <errno.h>

#include <xerion/useful.h>

#include "minimize.h"
#include "linesearch.h"
#include "checkgrad.h"

#include <signal.h>

#ifndef MIN
#define MIN(x, y)	((x) < (y) ? (x) : (y))
#endif
#ifndef MAX
#define MAX(x, y)	((x) > (y) ? (x) : (y))
#endif

static char *   simpleName       ARGS((Minimize, int i));
static void     checkConsistency ARGS((Minimize mz, int n, 
				       intFunc getNVars, VecProc getValues, 
				       VecProc setValues, RealVecFunc fEval, 
				       VecProc gEval, Real2VecFunc fgEval, 
				       strFunc valueName, Real *save, 
				       Real *x1, Real *temp, Real *x2, 
				       Real *g1, Real *g2, Real *g3, 
				       char *me, int *warnCtr, int *serious, 
				       int verbosity));

static char *simpleName(mz, i)
  Minimize	mz;
  int i;
{
  static char name[10];
  sprintf(name, "p%d", i);
  return name;
}

int checkGradients(chkg, n, x, ef, mz, fEval, gEval, fgEval, valueName)
  struct CHECK_GRAD *chkg ;
  int	n ;			/* size of x & ef */
  Real	*x ;			/* values to do it at */
  Real	*ef;			/* error factors in gradient */
  Minimize	mz ;		/* the minimize record */
  RealVecFunc	fEval ;		/* evaluate function */
  VecProc	gEval ;		/* evaluate gradient */
  Real2VecFunc	fgEval ;	/* evaluation function and gradient */
  strFunc 	valueName ;	/* name of value */
{
  Real *g, *newg;
  Real Eep, xep, ratio, dE, dx, E, newE, newx, oldx, dEdx;
  Real gch;
  Real macheps = machineEpsilon();
  int i, this_fail, printed_header=0;
  int fail = 0;
  char *name;
  char flags[10];
  int zg_warning = 0;
  int xep_warning = 0;
  int Eep_warning = 0;
  int g_warning = 0;
  if (valueName==NULL)
    valueName = simpleName;
  if (chkg->epswarn<=0)
    chkg->epswarn = 2;
  if (chkg->gwarn<=0)
    chkg->gwarn = 0.5;
  if (chkg->gwarn>1)
    chkg->gwarn = 1/chkg->gwarn;
  if (chkg->criterion<=0)
    chkg->criterion = 0.9;
  if (chkg->criterion>1)
    chkg->criterion = 1/chkg->criterion;
  if (chkg->epsilon==0.0)
    chkg->epsilon = 0.001;

  /* Compute the gradients */

  g = ITempCalloc(Real, n);
  newg = ITempCalloc(Real, n);
  if (fgEval)
    E = fgEval(mz, n, x, g);
  else if (fEval && gEval) {
    E = fEval(mz, n, x) ;
    gEval(mz, n, g) ;
  } else {
    fprintf(dout, "Error: No way to evaluate the function\n") ;
    return ;
  }

  for (i=0; i<n; i++) {
    dEdx = g[i];
    name = valueName(mz, i);
    oldx = x[i];
    this_fail = 0;

    if (chkg->method==0)
      dx = chkg->epsilon*(dEdx>0 ? -1.0:1.0);
    else if (chkg->method==1)
      dx = -chkg->epsilon*dEdx;
    else if (chkg->method==2)
      dx = chkg->epsilon*oldx*(oldx*dEdx>0 ? -1.0:1.0);
    else if (chkg->method==3)
      dx = -chkg->epsilon/dEdx;
    else if (chkg->method==4) {
      Real dx1 = chkg->epsilon*oldx*(oldx*dEdx>0 ? -1.0:1.0);
      Real dx2 = -chkg->epsilon/dEdx;
      /* fprintf(dout, "dx1 = %12g dx2 = %12g\n", dx1, dx2); */
      if (dx1 < 0)
	dx = MIN(dx1, dx2);
      else
	dx = MAX(dx1, dx2);
    } else {
      IErrorAbort("checkGrad: bad method %d", chkg->method);
      return fail ;
    }

    newx = oldx + dx;

    x[i] = newx;
    if (fgEval)
      newE = fgEval(mz, n, x, newg);
    else if (fEval && gEval) {
      newE = fEval(mz, n, x) ;
      gEval(mz, n, newg) ;
    } else {
      fprintf(dout, "Error: No way to evaluate the function\n") ;
      return ;
    }
    x[i] = oldx;
    dE = newE - E;
    ratio = dE/(dEdx*dx);
    Eep = log10(fabs(dEdx*dx)/macheps);
    xep = log10(fabs(dx)/macheps);
    gch = newg[i]/g[i];
    if (ratio<chkg->criterion || ratio>1/chkg->criterion) {
      this_fail = 1;
      fail++;
    }
    if (chkg->verbosity && (this_fail||chkg->printall)) {
      if (xep<=chkg->epswarn)
	xep_warning++;
      if (dEdx*dx==0.0)
	zg_warning++;
      else if (Eep<=chkg->epswarn)
	Eep_warning++;
      if (gch<chkg->gwarn || gch>1/chkg->gwarn)
	g_warning++;

      if (!printed_header) {
	fprintf(dout, "Starting value of E= %g, method= %d eps= %g\n\n",
		E, chkg->method, chkg->epsilon);
	fprintf(dout, "%-26s %-12s %4s", "      - Name -", "obs/exp", "Note");
	if (chkg->verbosity>1)
	  fprintf(dout, " %-12s %-12s %-12s %-12s", "E'", "dE", "dx", "dEdx");
	if (chkg->verbosity>2)
	  fprintf(dout, " %-4s %-4s %-4s", "Eep", "xep", "g'/g");
	fprintf(dout, "\n");
	printed_header = 1;
      }

      flags[0] = '\0';
      if (this_fail) strcat(flags, "*");
      if (dEdx*dx==0.0) strcat(flags, "1");
      else if (Eep<=chkg->epswarn) strcat(flags, "2");
      if (xep<=chkg->epswarn) strcat(flags, "3");
      if (gch<chkg->gwarn || gch>1/chkg->gwarn) strcat(flags, "4");
      if (strlen(flags)==0) strcat(flags, "-");
      fprintf(dout, "%-26s %-12g %-4s", name, ratio, flags);
      if (chkg->verbosity>1)
	fprintf(dout, " %-12g %-12g %-12g %-12g",
		newE, dE, dx, dEdx);
      if (chkg->verbosity>2 && dEdx*dx!=0.0)
	fprintf(dout, " %-4.2g %-4.2g %-4.2g", Eep, xep, gch);
      fprintf(dout, "\n");
      fflush(dout);
    }
  }
  if (chkg->verbosity && (fail||chkg->printall)) {
    if (xep_warning || Eep_warning || g_warning || fail)
      fprintf(dout, "\n");
    if (fail)
      fprintf(dout,
	      "Note *: Failed: The observed change was more than %g%% different to\n\
         the expected change.\n", 100*(1-chkg->criterion));
    if (zg_warning)
      fprintf(dout,
	      "Note 1: Warning: gradient disappeared: expected change in E was zero - obs/exp not computable\n");
    if (Eep_warning)
      fprintf(dout,
	      "Note 2: Warning: dE was less than 10^%g times the machine precision\n\
         for E - poor accuracy in obs/exp is expected.\n", chkg->epswarn);
    if (xep_warning)
      fprintf(dout,
	      "Note 3: Warning: dx was less than 10^%g times the machine precision\n\
         for x - poor accuracy in obs/exp is expected.\n", chkg->epswarn);
    if (g_warning)
      fprintf(dout,
	      "Note 4: Warning: old gradient for x more than %g%% different to\n\
         new gradient : - poor accuracy in obs/exp is expected.\n",
	      100*(1-chkg->gwarn));
    if (xep_warning || Eep_warning || g_warning)
      fprintf(dout, "\n");
  }
  return fail;
}

static void checkConsistency (mz, n, getNVars, getValues, setValues, 
		       fEval, gEval, fgEval, valueName, save, x1, 
		       temp, x2, g1, g2, g3, me, warnCtr, serious, verbosity)
  Minimize		mz ;
  int			n ;
  intFunc		getNVars ;
  VecProc		getValues ;
  VecProc		setValues ;
  RealVecFunc		fEval ;
  VecProc		gEval ;
  Real2VecFunc		fgEval ;
  strFunc		valueName ;
  Real			*save ;
  Real			*x1 ;
  Real			*temp ;
  Real			*x2 ;
  Real			*g1 ;
  Real			*g2 ;
  Real			*g3 ;
  char			*me ;
  int			*warnCtr ;
  int			*serious ;
  int			verbosity ;
{
  Real f1, f2, f3;
  char seq[100];
  if ((!fEval && !fgEval) || (gEval && !fEval))
    IErrorAbort("checkConsistency: bad calling combination");

  copyVector(n, save, x1);
  setValues(mz, n, x1);

  seq[0] = '\0';

  /* Make sure there is a way to evaluate the function */
  if (fEval == NULL && fgEval == NULL) {
    if (verbosity)
      fprintf(dout, "Error: No way to evaluate the function\n") ;
    (*serious)++;
    return ;
  }

  /* If we have a choice, use fEval;gEval first time */
  if (fEval) {
    f1 = fEval(mz, n, x1);
    sprintf(seq, "%s", "f1 = fEval(mz, n, x1);");
    if (gEval) {
      gEval(mz, n, g1);
      sprintf(seq+strlen(seq), "%s", " gEval(mz, n, g1);");
    }
  } else {
    f1 = fgEval(mz, n, x1, g1);
    sprintf(seq, "%s", "f1 = fgEval(mz, n, x1, g1);");
  }
  
  if (!sameVector(n, save, x1)) {
    if (verbosity)
      fprintf(dout, 
	      "Warning: The calling sequence\n  {%s}\nchanges x values.\n",
	      seq);
    (*warnCtr)++;
    copyVector(n, save, x1);
  }
  
  if (getNVars(mz)!=n)
    IErrorAbort("%s: The calling sequence\n  {%s}\nchanges number of values", 
		me, seq);
  
  if (x2==NULL) {
    getValues(mz, n, temp);
    if (!sameVector(n, temp, x1)) {
      if (verbosity)
	fprintf(dout, 
		"Warning: The calling sequence\n  {%s getValues(n, x2);}\ngives x1 != x2.\n", 
		seq);
      (*warnCtr)++;
    }
  }

  /***
   * a distractor computation
   */
  if (x2!=NULL) {
    if (fEval) {
      f2 = fEval(mz, n, x2);
      sprintf(seq+strlen(seq), "%s", " f2 = fEval(mz, n, x2);");
      if (gEval) {
	gEval(mz, n, g2);
	sprintf(seq+strlen(seq), "%s", " gEval(mz, n, g2);");
      }
    } else {
      f2 = fgEval(mz, n, x2, g2);
      sprintf(seq+strlen(seq), "%s", " f2 = fgEval(mz, n, x2, g2);");
    }

    if (f1 == f2 && !sameVector(n, x1, x2)) {
      if (verbosity)
	fprintf(dout, "Warning: The calling sequence\n  {%s}\ngives f1 == f2 for x1 != x2.\n", seq);
      (*warnCtr)++;
    }
  }

  /* If we have a choice, use fgEval second time */

  if (fgEval) {
    f3 = fgEval(mz, n, x1, g3);
    sprintf(seq+strlen(seq), "%s", " f3 = fgEval(mz, n, x1, g3);");
  } else {
    f3 = fEval(mz, n, x1);
    sprintf(seq+strlen(seq), "%s", " f1 = fEval(mz, n, x1);");
    if (gEval) {
      gEval(mz, n, g3);
      sprintf(seq+strlen(seq), "%s", " gEval(mz, n, g3);");
    }
  }
  
  if (f1 != f3) {
    if (verbosity)
      fprintf(dout, "Error: The calling sequence\n  {%s}\ngives f1 != f3.\n",
	      seq);
    (*serious)++;
  }
  
  if ((gEval || fgEval) && !sameVector(n, g1, g3)) {
    if (verbosity)
      fprintf(dout, "Error: The calling sequence\n  {%s}\ngives g1 != g3.\n",
	      seq);
    (*serious)++;
  }

  if (verbosity>1) {
    fprintf(dout, "Results for sequence\n  {%s}\nare:\n", seq);
    fprintf(dout, "  |x1|= %-12g f1= %-12g", vectorLength(n, x1), f1);
    if (gEval || (fgEval && !fEval))
      fprintf(dout, " |g1|= %-12g", vectorLength(n, g1));
    fprintf(dout, "\n");
    if (x2!=NULL) {
      fprintf(dout, "  |x2|= %-12g f2= %-12g", vectorLength(n, x2),f2);
      if (gEval || (fgEval && !fEval))
	fprintf(dout, " |g2|= %-12g", vectorLength(n, g2));
      fprintf(dout, "\n");
    }
    fprintf(dout, "  |x1|= %-12g f3= %-12g", vectorLength(n, x1), f3);
    if (gEval || (fgEval && !fEval))
      fprintf(dout, " |g3|= %-12g", vectorLength(n, g3));
    fprintf(dout, "\n\n");
  }

  setValues(mz, n, save);
}


static jmp_buf jmp_buf_env;

/***
 * Handle an interrupt while in checkOutFunctions by doing a longjmp
 */
int interruptCheckOutFunctions(sig)
  int	sig ;
{
  longjmp(jmp_buf_env, 1);
  return 1 ;
}

int checkOutFunctions (mz, getNVars, getValues, setValues, fEval, 
		       gEval, fgEval, valueName, me, verbosity)
  Minimize		mz ;
  intFunc		getNVars ;
  VecProc		getValues ;
  VecProc		setValues ;
  RealVecFunc		fEval ;
  VecProc		gEval ;
  Real2VecFunc		fgEval ;
  strFunc		valueName ;
  char			*me;
  int			verbosity ;
{
  Real *g1, *g2, *g3, *save, *x1, *temp, *x2;
  int i, n;
  int warnCtr = 0;
  int serious = 0;
  SignalHandler old_signal_handler ;
  
  if (fEval==NULL && gEval==NULL && fgEval==NULL)
    IErrorAbort("%sno evaluation functions supplied", me);
  if (fEval==NULL && gEval!=NULL)
    IErrorAbort("%smakes no sense to supply gEval without fEval", me);
  if (gEval==NULL && fgEval==NULL)
    IErrorAbort("%sno gradient evaluation functions supplied", me);
  if (fEval!=NULL && fEval==(RealVecFunc)gEval)
    IErrorAbort("%sfEval==gEval - cannot be!", me);
  if (fEval!=NULL && (Real2VecFunc)fEval==fgEval)
    IErrorAbort("%sfEval==fgEval - give NULL for fEval", me);
  if (gEval!=NULL && (Real2VecFunc)gEval==fgEval)
    IErrorAbort("%sgEval==fgEval - give NULL for gEval", me);

  n = getNVars(mz);
  save = ITempCalloc(Real, n);
  x1 = ITempCalloc(Real, n);
  temp = ITempCalloc(Real, n);
  x2 = ITempCalloc(Real, n);
  g1 = ITempCalloc(Real, n);
  g2 = ITempCalloc(Real, n);
  g3 = ITempCalloc(Real, n);

  getValues(mz, n, save);

  fprintf(dout, "Checking consistency of functions... (CTRL-C to abort)\n");

  old_signal_handler = (SignalHandler)signal(SIGINT, interruptCheckOutFunctions);
  if (setjmp(jmp_buf_env)) {
    setValues(mz, n, save);
    signal(SIGINT, old_signal_handler);
    fprintf(dout, "\nConsistency checking interrupted and stopped.  Values restored.\n");
    return 0;
  }
  
  if (fEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     fEval, NULL, NULL, valueName,
		     save, x1, temp, NULL, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);
  
  if (fEval && gEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     fEval, gEval, NULL, valueName,
		     save, x1, temp, NULL, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);
  
  if (fgEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     NULL, NULL, fgEval, valueName,
		     save, x1, temp, NULL, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);

  if (fEval && fgEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     fEval, gEval, fgEval, valueName,
		     save, x1, temp, NULL, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);

  for (i=0; i<n; i++)
    x2[i] = x1[i] * (0.8 + 0.4*IRandReal());

  if (fEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     fEval, NULL, NULL, valueName,
		     save, x1, temp, x2, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);

  if (fEval && gEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     fEval, gEval, NULL, valueName,
		     save, x1, temp, x2, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);

  if (fgEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     NULL, NULL, fgEval, valueName,
		     save, x1, temp, x2, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);

  if (fEval && fgEval)
    checkConsistency(mz, n, 
		     getNVars, getValues, setValues,
		     fEval, gEval, fgEval, valueName,
		     save, x1, temp, NULL, g1, g2, g3,
		     me, &warnCtr, &serious, verbosity);

  if (warnCtr || serious) {
    if (verbosity) {
      fprintf(dout, "\n");
      if (warnCtr)
	fprintf(dout, "The behaviour described in the warnings should not affect the\nminimize procedure.\n");
      if (serious)
	fprintf(dout, "The behaviour described as errors could cause serious problems\nin the minimize procedure.\n");
      fprintf(dout, "\n");
    } else {
      fprintf(dout, "CheckOutFunctions found");
      if (warnCtr)
	fprintf(dout, " %d cosmetic problem%s", warnCtr, warnCtr>1?"s":"");
      if (warnCtr && serious)
	fprintf(dout, " and");
      if (serious)
	fprintf(dout, " %d serious problem%s", serious, serious>1?"s":"");
      fprintf(dout, " with the functions.\n");
      fprintf(dout, "To print details of the problems, run the command\n  \"checkGrad -check 1 -nogradients\"\n");
    }
  } else {
    if (verbosity)
      fprintf(dout, "All functions checked out OK\n");
  }

  signal(SIGINT, old_signal_handler);

  return serious;
}
