
/**********************************************************************
 * $Id: tapsearch.c,v 1.2 92/11/30 11:31:01 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>
#include <errno.h>

#include <xerion/useful.h>
#include "minimize.h"
#include "linesearch.h"
#include "wobbly.h"

#include "tapsearch.h"

/***
 * TODO:
 *	quadratic interp on d when we have three points
 *	cubic interp on f when we have two trustworthy f values
 *	check for NaN in f and d
 */

static int	linearFit(a1, d1, a2, d2, low, high, x, expected_df, vb)
  double	a1, d1 ;
  double	a2, d2 ;
  double	low, high ;
  double	*x ;
  double	*expected_df ;
  int		vb ;
{
  /* linear interpolation formula c=a1-d1*(a1-a2)/(d1-d2) */
  VB(3, vb, "Line fitting on (%g,%g) (%g,%g)\n", a1, d1, a2, d2);
  if (d1==d2) {
    VB(3, vb, "line fit failed - points equal\n");
    return 0;
  }
  *x = a1-d1*(a1-a2)/(d1-d2);
  /* expected df is what we expect for f2-f1 */
  if (expected_df!=NULL)
    *expected_df = 0.5*(a2-a1)*(d2-d1) + d1*(a2-a1);
  if (*x==a1 || *x==a2) {
    VB(3, vb, "line fit failed - result same as end point\n");
    return 0;
  }
  if (*x<low) {
    VB(3, vb, "line fit failed - less than lower bound\n");
    return 0;
  }
  if (high>0.0 && *x>high) {
    VB(3, vb, "line fit failed - greater than upper bound\n");
    return 0;
  }
  VB(3, vb, "line fit gives %g\n", *x);
  return 1;
}

static int	quadraticFit(a1, d1, a2, d2, c, fc, 
			     low, high, x, expected_df, vb)
  double	a1, d1 ;
  double	a2, d2 ;
  double	c, fc ;
  double	low, high ;
  double	 *x ;
  double	*expected_df ;
  int		vb ;
{
  /* linear interpolation formula c=a1-d1*(a1-a2)/(d1-d2) */
  VB(3, vb, "Line fitting on (%g,%g) (%g,%g)\n", a1, d1, a2, d2);
  if (d1==d2) {
    VB(3, vb, "line fit failed - points equal\n");
    return 0;
  }
  *x = a1-d1*(a1-a2)/(d1-d2);
  if (*x==a1 || *x==a2) {
    VB(3, vb, "line fit failed - result same as end point\n");
    return 0;
  }
  if (*x<low) {
    VB(3, vb, "line fit failed - less than lower bound\n");
    return 0;
  }
  if (high>0.0 && *x>high) {
    VB(3, vb, "line fit failed - greater than upper bound\n");
    return 0;
  }
  VB(3, vb, "line fit gives %g\n", *x);
  return 1;
}

int	tapsLineSearch (mz, n, start, search, x, grad, 
			fEval, gEval, fgEval, ap, fp, dp)
  Minimize		mz ;
  int			n ;
  Real			*start ;
  Real			*search ;
  Real			*x ;
  Real			*grad ;
  RealVecFunc		fEval ;
  VecProc		gEval ;
  Real2VecFunc		fgEval ;
  double		*ap ;
  double		*fp ;
  double		*dp ;
{
  double a = *ap;  /* The initial step size */
  double f = *fp; /* f(start) */
  double d = *dp; /* d(start) */
  double f0 = *fp; /* The value of f at start (used to remember start value)*/
  double d0 = *dp; /* The value of f at start (used to remember start value)*/
  double a1=0.0, f1=0.0, d1=0.0;	/* point < a2 */
  double a2=0.0, f2 = *fp, d2 = *dp;	/* the lower bracket (d2<0) */
  double a3=0.0, f3, d3;		/* the upper bracket (d3>0) */
  double a4=0.0, f4, d4;		/* point > a3 */
  double new_a, r, expected_df;
  enum {STARTED,INTERPOLATED,BISECTED,EXTRAPOLATED,JUMPED,MOVEDLEFT} last_action;
  int ok, low_delta_f, reduction_ok, asked_step;
  int vb = mz->lsVerbosity;
  double delta_f, old_d, absmax=0.0;
  last_action = STARTED;

  mz->evalReason = "Starting line search";
  mz->lsnFuncEvals = 0;

  initBest(mz, grad);

  if (a<=0.0)
    IErrorAbort("tapsLineSearch: no initial stepsize supplied");

  VB(1, vb, "a= %-12g f= %-12g f-f0= %-12g d= %-12g\n", 0.0, f, f-f0, d);

  for (;;) {
    asked_step = 0;
    if (mz->askStep>0 && mz->nFuncEvals>mz->askStep) {
      char buf[100];
      buf[0] = '\0';
      fprintf(dout, "Enter the step size [%g]: ", a);
      fgets(buf, 100, din);
      if (atof(buf)>0.0) {
	a = atof(buf);
	fprintf(dout, "Taking step %g\n", a);
	asked_step = 1;
      }
    }
    /* check whether to stop now */
    if (!asked_step && (a<0.0 || a<a2 || a1>a2
			|| (a3>0.0 && (a>a3 || a3<a2))
			|| (a4>0.0 && a4<a3))) {
      VB(2, vb, "Stopping line search: bug in books (%g %g) %g (%g %g)\n",
	 a1, a2, a, a3, a4);
      mz->lsResultCode = MZFAIL;
      goto getout;
    }
    if (absmax>0 && a>absmax) {
      VB(2, vb, "Stopping line search: step greater than absolute maximum\n");
      mz->lsResultCode = MZFAIL;
      goto getout;
    }
    if (mz->stopFlag>1) {
      VB(2, vb, "Stopping line search: stop flag set\n");
      mz->lsResultCode = MZSTOPPED;
      goto getout;
    }
    if (mz->lsnFuncEvals >= mz->lsMaxFuncEvals) {
      VB(2, vb, "Stopping line search: too many f evals in line search\n");
      mz->lsResultCode = (mz->lsBestByD>0 ? MZMAXFLINE : MZMAXFLINEFAIL);
      goto getout;
    }
    if (mz->maxFuncEvals>0
	&& (  mz->lsnFuncEvals
	    > (mz->maxFuncEvals + mz->lsFlexFuncEvals))) {
      VB(2, vb, "Stopping line search: too many f evals in total\n");
      mz->lsResultCode = MZMAXF;
      goto getout;
    }
    if (a3>0 && a2==a3) {
      VB(2, vb, "Stopping line search: interval dissappeared\n");
      mz->lsResultCode = MZFAIL;
      goto getout;
    }

    low_delta_f = 0;
    if (-a*d0 < fPrecision(mz, f0)) {
      VB(2, vb, "Expecting low change in func value\n");
      low_delta_f = 1;
    } else if (mz->lsnFuncEvals>0 && fabs(delta_f) < fPrecision(mz, f0)) {
      VB(2, vb, "Observed low change in func value\n");
      low_delta_f = 1;
    }

    if (mz->wobbleWatch) {
      int good, total;
      gradientsConsistent(mz, &good, &total);
      if (low_delta_f || good < total)
	fprintf(dout,
		"Wobbliness diagnostics: deltas: %s, %d slopes ok out of %d\n",
		low_delta_f ? "low" : "ok", good, total);
    }

    /* make a step of distance a along search */
    moveInDirection(n, start, search, a, x);
    
    if (fgEval)
      f = fgEval(mz, n, x, grad);
    else {
      f = fEval(mz, n, x);
      gEval(mz, n, grad);
    }
    old_d = d;
    d = dotProduct(n, grad, search);
    reduction_ok = (fabs(d*1.5) < fabs(old_d));

    VB(1, vb, "a= %-12g f= %-12g f-f0= %-12g d= %-12g\n", a, f, f-f0, d);
    mz->nFuncEvals++;
    mz->lsnFuncEvals++;

    insertLSData(mz, mz->lsnFuncEvals, a, f, d, 1, 0);
    updateBest(mz, mz->lsnFuncEvals, grad);

    delta_f = f - f0;

    r = fabs(d/d0);
    if (r >= mz->maxSlopeRatio) {
      VB(2, vb, "Continuing search: slope ratio %g too large\n", r);
      VB(3, vb, "Slope at start = %g, at current point = %g\n", d0, d);
    } else if (!funcValueOK(mz, f, r)) {
      VB(2, vb, "Continuing search: function value %g too high\n", f);
    } else {
      VB(2, vb, "Stopping line search: delta_f=%g\n", f-f0);
      VB(1, vb, "Line search %d steps: slope ratio= %g\n", mz->lsnFuncEvals,r);
      *fp = f; *ap = a; *dp = d;
      return mz->lsResultCode = MZSUCCEED;
    }

    /* at this point a2 < a < a3 (except for a3=0.0) */

    if (d<0.0 && f-f0 > fPrecision(mz, f0)) {
      VB(2, vb, "Over in second minima - going back\n");
      /* have to fix all the book keeping now */
      new_a = a2 + (a-a2)/mz->maxExtrapol;
      absmax = a;
      a3 = 0.0;
      a = new_a;
    } else if (d<0.0) {
      if (a3>0.0) {
	/* have an upper bound */
	VB(2, vb, "New slope < 0: working in <%g,%g,%g>\n", a2, a, a3);
	VB(2, vb, "TODO: quadratic interpolation on <%g,%g,%g>\n", a2, a, a3);
	if (a1>0.0)
	  VB(2, vb, "TODO: quadratic interpolation on <%g,%g,%g>\n", a1, a2,a);
	ok = 0;
	if (d>d2 && last_action!=EXTRAPOLATED) {
	  /* first try extrapolating from a2, a */
	  VB(2, vb, "New slope > d2: extrapolating from <%g,%g>\n", a2, a);
	  last_action = EXTRAPOLATED;
	  ok = linearFit(a2, d2, a, d, a, 0.0, &new_a, NULL, vb);
	}
	if (!ok || new_a>=a3) {
	  VB(2, vb, "Interpolating in <%g,%g>\n", a, a3);
	  ok = linearFit(a, d, a3, d3, a, a3, &new_a, NULL, vb);
	  if (!ok || (!reduction_ok && last_action==INTERPOLATED)) {
	    VB(2, vb, "Interpolation failed or reduction not ok - bisecting\n");
	    new_a = (a+a3)/2.0;
	    last_action = BISECTED;
	  } else
	    last_action = INTERPOLATED;
	}
      } else {
	double amax = a*mz->maxExtrapol;
	VB(2, vb, "New slope < 0: extrapolating in <%g,%g>\n", a, amax);
	/* extrapolate to beyond a */
	/* want to watch out for zero d's here - could go too far */
	/* TODO: try quadratic interpolation on 3 points a1, a2, a */
	if (a1>0.0)
	  VB(2, vb, "TODO: quadratic interpolation on <%g,%g,%g>\n", a1, a2,a);
	if (d>d2) {
	  VB(2, vb, "New slope > old slope: extrapolating from <%g,%g>\n", a2, a);
	  /* line is pointing at zero - aim for zero */
	  last_action = EXTRAPOLATED;
	  ok = linearFit(a2, d2, a, d, a, 0.0, &new_a, NULL, vb);
	  if (!ok || new_a>amax) {
	    VB(2, vb, "Extrapolation failed new a too high - limiting\n");
	    new_a = amax;
	  }
	} else {
	  VB(2, vb, "New slope <= old slope: jumping from <%g,%g>\n", a2, a);
	  /* line is pointing down - jump to limit */
	  new_a = amax;
	  last_action = JUMPED;
	}
      }
      a1 = a2; f1 = f2; d1 = d2;
      a2 = a; f2 = f; d2 = d;
      a = new_a;
    } else {
      VB(2, vb, "New slope >= 0: interpolating in <%g,%g>\n", a2, a);
      /* TODO: try quadratic interpolation on 3 points a2, a, a3 */
      /* TODO: try quadratic interpolation on 3 points a1, a2, a */
      ok = linearFit(a2, d2, a, d, a2, a, &new_a, &expected_df, vb);
      VB(2, vb, "Expected df=%g, observed=%g\n", expected_df, f-f2);
      if (!ok || (!reduction_ok && last_action==INTERPOLATED)) {
	VB(2, vb, "Interplation failed or reduction not ok - bisecting\n");
	last_action = BISECTED;
	new_a = (a2+a)/2.0;
      } else
	last_action = INTERPOLATED;
      if (!funcValueOK(mz, f, r) && new_a>a2+(a-a2)/mz->maxExtrapol
	  && ((f-f2)<expected_df/2.0 || last_action==BISECTED)) {
	new_a = a2+(a-a2)/mz->maxExtrapol;
	VB(2, vb, "Function value too high, moving left to %g\n", new_a);
	last_action = MOVEDLEFT;
      }
      a4 = a3; f4 = f3; d4 = d3;
      a3 = a; f3 = f; d3 = d;
      if (a3<absmax || absmax==0.0)
	absmax = a3;
      a = new_a;
    }
    mz->evalReason = "Retrying line search";
  }

  IErrorAbort("tapsLineSearch: should not reach here.");
 getout:
  ok = getBestByD(mz, n, start, search, x, grad, ap, fp, dp);
  if (!ok) {
    VB(1, vb, "Failed to restore best point of line search\n");
    return mz->lsResultCode = MZFAIL;
  } else
    return mz->lsResultCode;
}

