/*
   File:        stage.c
   Author:      Justin A. Boyan
   Created:     Fri. Mar 29, 1996 09:24
   Description: multi-stage predictive hillclimbing

   Copyright (C) 1996, Justin A. Boyan
*/

#ifdef ENABLE_STAGE

#include "ambs.h"
#include "jbbs.h"
#include "fitter.h"

#include "opera.h"

#define MAX_STAGES 10



/* optimize by hillclimbing on fits[0], then fits[1], then ... fits[nfits-1] *
 * st is an array of stat data structures to update with result@end-of-stage */
void run_stage_hc(Opgov *gov, Opworld *ow, fitter **fits, int nfits, stat *st,
		  Opresults *res)
{
  Opconfig *start_oc;
  int i, nevals=0;
  double fx;

  for (i=0, start_oc=mk_random_opconfig(ow);
       i<nfits;
       i++, start_oc=mk_copy_opconfig(ow,res->best)) {

    IFVERB(15) printf("\nStage %d:\n", i);
    IFVERB(20) if (gov->draw_freq) draw_opconfig(ow, start_oc);
    IFVERB(15) print_opconfig("|| ", ow, start_oc);
    
    reinit_opresults(gov, ow, res);
    fitted_hc_start(gov, ow, fits[i], start_oc, res);
    nevals += res->tot_evaluations;
    fx = objective_function(ow, res->best);

    stat_update(&st[i], fx);
    IFVERB(10) {
      printf("{STAGE[%d]: this=%g; n=%d, min=%g, mean=%g, max=%g, stdev=%g}\n",
	     i, fx, stat_n(&st[i]), stat_min(&st[i]), stat_mean(&st[i]),
	     stat_max(&st[i]), stat_sdev(&st[i]));
    }
  }
  IFVERB(15) {
    printf("\nStage %d:\n", i);
    if (gov->draw_freq) draw_opconfig(ow, start_oc);
    print_opconfig("!! ", ow, start_oc);
  }
  
  free_opconfig(ow, start_oc);
    
  res->tot_evaluations = nevals;
}



void do_stage(Opgov *gov, Opworld *ow, Opresults *res)
{
  int i;
  int nevals=0;
  fitter *fits[MAX_STAGES];
  fitter *fit = gov->fit;
  fit_hints *fh = gov->fh;
  Opresults res_traj[1];
  stat objval_st[2];

  init_opresults(gov, ow, res_traj);
  init_fitter(fit, fh);
  init_stat(objval_st);
  
  /*
   * first pass implementation (2-stage only):
   * learn function q2(x) = map x -> hillclimb(obj(x))
   *
   * new optimizer = hillclimb on q2(x), then on obj(x)!
   *
   */

  if (gov->draw_freq) draw_opconfig(ow, NULL);

  for (i=0; i<gov->ntrain_traj; i++) {
    Opconfig *oc;
    int j;
    double endval;

    reinit_opresults(gov, ow, res_traj);
    fitted_hc_traj(gov, ow, NULL, res_traj);
    endval = res_traj->best->store;
    nevals += res_traj->tot_evaluations;
    
    for (j=0, oc = res_traj->path;
	 oc->hook;
	 j--, oc=oc->hook) {
      
      if (gov->fit_intermediates) {  /* fit intermediate states, too */
	fitter_observe1(fit, oc->feats,
			gov->just_smooth_objval ? oc->store : endval);
	IFVERB(15) {
	  char str[8];
	  sprintf(str, "%d> ", j);
	  print_opconfig(str, ow, oc);
	}
      }
      
    }
    if (gov->draw_freq) draw_opconfig(ow, oc);
    IFVERB(15) {
      print_opconfig(".. ", ow, oc);
      printf("objective: %g ---> %g\n", oc->store, endval);
    }
    fitter_observe1(fit, oc->feats,
		    gov->just_smooth_objval ? oc->store : endval);

    stat_update(objval_st, endval);
    IFVERB(10) {
      printf("{TRAIN: this=%g; n=%d, min=%g, mean=%g, max=%g, stdev=%g}\n",
	     endval, stat_n(objval_st), stat_min(objval_st),
	     stat_mean(objval_st), stat_max(objval_st), stat_sdev(objval_st));
    }
  }

  print_fitter(fit);

  fits[0] = fit;   init_stat(&objval_st[0]);
  fits[1] = NULL;  init_stat(&objval_st[1]);
  
  for (i=0; i<gov->neval_traj; i++) {
    reinit_opresults(gov, ow, res_traj);
    run_stage_hc(gov, ow, fits, 2, objval_st, res_traj);
    
    if (i==0 || is_better_opconfig(ow, res_traj->best, res->best)) {
      free_opconfig(ow, res->best);
      res->best = mk_copy_opconfig(ow, res_traj->best);
    }
    nevals += res_traj->tot_evaluations;
  }
  
  fitter_free_data(fit, fh);
  free_opresults_internals(gov, ow, res_traj);

  res->tot_evaluations = nevals;
}



#endif /* ENABLE_STAGE */
