/*
   File:        wrld.c
   Author:      Andrew W. Moore
   Created:     Sat Sep 19 11:29:33 EDT 1992
   Description: Spec. of the World Control problem

   Copyright (C) 1992, Andrew W. Moore
*/

#include <stdio.h>
#include <math.h>
#include "ambs.h"      /* Very basic operations */
#include "amma.h"      /* Fast, non-fragmenting, memory management */
#include "amar.h"      /* Obvious operations on 1-d arrays */
#include "amgr.h"      /* Basic (0,512)x(0,512) Graphics window */
#include "maxdim.h"    /* The MAX_DIM declaration */
#include "gpro.h"      /* Graphics projections kd->2d space */
#include "hype.h"      /* Hyper-rectangles from ../kdtr */
#include "region.h"    /* K-d region Data Structure */
#include "wrld.h"      /* Spec. of the World Control problem */
#include "whis.h"      /* History of worsts */

void init_worst(wst,mode)
worst *wst;
int mode;
{
  wst -> mode = mode;
}

worst *create_worst(mode,size)
int mode;
int size;
{
  worst *new = AM_MALLOC(worst);
  init_worst(new,mode);
  return(new);
}

void free_worst(wst)
worst *wst;
{
  am_free((char *) wst,sizeof(worst));
}

void copy_worst(wst1,wst2,size)
worst *wst1,*wst2;
int size;
{
  wst2 -> mode = wst1 -> mode;
  copy_floats(wst1->nstate,wst2->nstate,size);
}

mtrans *add_to_mtrans(reg,next_worst,mts)
region *reg;
worst *next_worst;
mtrans *mts;
{
  mtrans *new = AM_MALLOC(mtrans);
  new -> region = reg;
  new -> next_worst = next_worst;
  new -> next = mts;
  return(new);
}

void init_space_spec(spsp,dim)
space_spec *spsp;
int dim;
{
  spsp -> dim = dim;
  spsp -> bound = create_hype(dim);
}

mode_spec *create_empty_mode_spec(sdim,adim)
int sdim,adim;
/*
   Allocates all necessary memory, apart from name. Sets bounding
   hyper-rectangles of state and action space as infinite.
*/
{
  mode_spec *new = AM_MALLOC(mode_spec);
  new -> name = NULL;
  init_space_spec(&new->state,sdim);
  init_space_spec(&new->action,adim);
  new -> middle = AM_MALLOC_ARRAY(float,sdim);
  new -> scales = AM_MALLOC_ARRAY(float,sdim);
  new -> splittable_attributes = AM_MALLOC_ARRAY(bool,sdim);
  new -> mode_transitions = NULL;
  
  new -> next_worst = NULL;
  new -> local_control = NULL;
  new -> draw_worst = NULL;
  new -> is_stuck = NULL;
  new -> should_remove_trans = NULL;

  return(new);
}

mode_spec *mode_ref(wld,mode)
world *wld;
int mode;
{
  if ( wld == NULL || mode < 0 || mode >= wld->number_modes )
  {
    if ( wld == NULL ) printf("wld == NULL\n");
    printf("mode = %d\n",mode);
    my_error("mode_ref()");
  }

  if ( wld->modes[mode] == NULL )
    my_error("mode_ref() NULL mode");
  
  return(wld->modes[mode]);
}

int state_dim(wld,mode)
world *wld;
int mode;
{
  if ( wld == NULL || mode < 0 || mode >= wld->number_modes )
  {
    if ( wld == NULL ) printf("wld == NULL\n");
    printf("mode = %d\n",mode);
    my_error("state_dim()");
  }

  return(mode_ref(wld,mode)->state.dim);
}  

void fprintf_worst(s,m1,wst,wld,m2)
FILE *s;
char *m1;
worst *wst;
world *wld;
char *m2;
{
  if ( wst == NULL )
    fprintf(s,"%s (worst *)NULL %s",m1,m2);
  else if ( wst->mode >= wld->number_modes )
    my_error("oufbv");
  else
  {
    char *name = mode_ref(wld,wst->mode)->name;
    if ( name == NULL ) name = "nameless";
    fprintf(s,"%s%s",m1,name);
    fprintf_floats(s,"",wst->nstate,state_dim(wld,wst->mode),"");
    fprintf(s,"%s",m2);
  }
}

void fprint_mode(s,m,md,wld)
FILE *s;
char *m;
mode_spec *md;
world *wld;
{
  fprintf(s,"%s -> name = %s\n",m,(md->name==NULL)?"nameless":md->name);
  fprintf(s,"%s -> state.dim = %d\n",m,md -> state.dim);
  fprintf(s,"%s",m);
  fprintf_hype(s," -> state.bound = ",md -> state.bound,"\n");
  fprintf(s,"%s -> action.dim = %d\n",m,md -> action.dim);
  fprintf(s,"%s",m);
  fprintf_hype(s," -> action.bound = ",md -> action.bound,"\n");
  fprintf(s,"%s",m);
  fprintf_floats(s," -> middle = ",md -> middle,md->state.dim,"\n");
  fprintf(s,"%s",m);
  fprintf_floats(s," -> scales = ",md -> scales,md->state.dim,"\n");
  fprintf(s,"%s",m);
  fprintf_bools(s," -> splittable_attributes = ",
                md -> splittable_attributes,
                md->state.dim,
                "\n"
               );
  if ( md->mode_transitions == NULL )
    fprintf(s,"%s -> mode_transitions = NULL\n",m);
  else
  {
    mtrans *mts;
    int mnum = 1;
    for ( mts = md->mode_transitions ; mts != NULL ; mts = mts->next )
    {
      char buff[100];
      sprintf(buff,"%s -> mtrans[%d] -> region = ",m,mnum);
      fprintf_region(s,buff,mts->region,"\n");
      sprintf(buff,"%s -> mtrans[%d] -> next_worst = ",m,mnum);
      fprintf_worst(s,buff,mts->next_worst,wld,"\n");
      mnum += 1;
    }
  }
}

void fprint_world(s,m,wld)
FILE *s;
char *m;
world *wld;
{
  int i;
  for ( i = 0 ; i < wld->number_modes ; i++ )
  {
    char buff[100];
    sprintf(buff,"%s->mode[%d]",m,i);
    fprint_mode(s,buff,mode_ref(wld,i),wld);
    wait_for_key();
  }
  fprintf(s,"%s -> number_modes = %d\n",m,wld->number_modes);
  fprintf(s,"%s",m);
  fprintf_worst(s," -> goal = ",wld->goal,wld,"\n");
}

void perform_mtranses(wld,wst,next_wst)
world *wld;
worst *wst,*next_wst;
/*
   See the discussion of mode transitions in wrld.h.
   This function starts with wst and continues to apply mode transitions
   to it until there are none to apply.
*/
{
  bool stop = FALSE;
  worst *this_worst = wst;
  mode_spec *this_msp = mode_ref(wld,wst->mode);
  int count = 0;

  while ( !stop )
  {
    mtrans *mts = this_msp -> mode_transitions;
    mtrans *apply_me = NULL;
    for ( ; apply_me == NULL && mts != NULL ; mts = mts -> next )
      if ( is_inside_region(mts->region,this_worst->nstate) )
        apply_me = mts;

    if ( apply_me == NULL )
    {
      stop = TRUE;
      copy_worst(this_worst,next_wst,this_msp->state.dim);
    }
    else
    {
      this_worst = apply_me -> next_worst;
      this_msp = mode_ref(wld,this_worst->mode);
      count++;
      if ( count > 50 )
        fprintf_worst(stderr,"looper",this_worst,wld,"\n");
    }
  }
}

void perform_world_action(wld,wst,wac,next_wst)
world *wld;
worst *wst;
float *wac;
worst *next_wst;
/* Note, it *IS* okay for wst and next_wst to have the same memory, though
   the individual next-routines do not assume that.
*/
{
  worst int_wst;
  mode_spec *msp;
  msp = mode_ref(wld,wst->mode);
  if ( msp->next_worst != NULL )
    msp->next_worst(wld,wst,wac,&int_wst);
  else
    copy_worst(wst,&int_wst,state_dim(wld,wst->mode));

  if ( Verbosity > 40.0 )
    printf("int_wst.mode = %d\n",int_wst.mode);

  perform_mtranses(wld,&int_wst,next_wst);
}
  
void step_world(wld,wst,goal_wst)
world *wld;
worst *wst;
worst *goal_wst;
/*
   One control cycle: the action is chosen according to the world's
   local model.
   The action is applied and the new state is visited.
    
   POST: wst contains the new state.
*/
{
  float wac[MAX_DIM];
  mode_spec *msp;
  msp = mode_ref(wld,wst->mode);
  if ( msp->local_control != NULL )
    msp->local_control(wld,wst,goal_wst,wac);
  else if ( msp -> action.dim != 0 )
  {
    fprint_mode(stderr,"msp",msp,wld);
    printf("The above mode has no local control yet action.dim = %d > 0\n",
           msp->action.dim
          );
    my_error("step_world()");
  }
  if ( Verbosity > 30.0 )
    fprintf_floats(stdout,"wac-after-lcon = ",wac,msp->action.dim,"\n");

  perform_world_action(wld,wst,wac,wst);
}

void run_world(wld,wst,goal_wst,steps)
world *wld;
worst *wst;
worst *goal_wst;
int steps;
/*
   Performs "steps" number of control cycle.
   On each cycle the action is chosen according to the world's
   local model.
   The action is applied and the new state is visited.
   
   POST: wst contains the final state visited.
*/
{
  int i;
  for ( i = 0 ; i < steps ; i++ )
    step_world(wld,wst,goal_wst);
}

int mode_number_called(wld,name)
world *wld;
char *name;
{
  int i;
  int result = -1;
  for ( i = 0 ; result < 0 && i < wld -> number_modes ; i++ )
    if ( eq_string(name,mode_ref(wld,i)->name) )
      result = i;

  return(result);
}

void init_empty_mode_spec_array(wld,size)
world *wld;
int size;
{
  int i;
  wld -> number_modes = size;
  wld -> modes = AM_MALLOC_ARRAY(mode_spec_ptr,size);
  for ( i = 0 ; i < size ; i++ )
    wld->modes[i] = NULL;
}

void world_draw_worst(gp,wld,wst)
gproj *gp;
world *wld;
worst *wst;
{
  if ( wst == NULL )
    my_error("world_darw_worst NULL worst");
  else
  {
    mode_spec *msp = mode_ref(wld,wst->mode);
    if ( msp -> draw_worst != NULL )
      msp->draw_worst(gp,wld,wst);
  }
}

void start_worst_from_args(wld,start_worst,argc,argv)
world *wld;
worst *start_worst;
int argc;
char *argv[];
{
  int i;
  int running_mode = mode_number_called(wld,"running");
  start_worst -> mode = running_mode;
  copy_floats(mode_ref(wld,running_mode)->middle,
              start_worst->nstate,
              state_dim(wld,running_mode)
             );

  i = index_of_arg("-start",argc,argv);
  if ( i >= 0 && i + state_dim(wld,running_mode) < argc )
  {
    int j;
    for ( j = 0 ; j < state_dim(wld,running_mode) ; j++ )
      start_worst->nstate[j] = atof(argv[i+j+1]);
  }
}

void test_world(argc,argv)
int argc;
char *argv[];
{
  if ( argc < 2 )
    printf("Usage testwrld <worldtype> <worlddetails>\n");
  else
  {
    world wld;
    int run_mode;
    worst wst;
    mode_spec *mds;
    gproj gp;
    worst_hist wh;
    worst start_worst;

    int i = index_of_arg("-verbose",argc,argv);

    if ( i >= 0 && i < argc-1 )
      Verbosity = atof(argv[i+1]);
    else
      Verbosity = 3.0;

    init_worst_hist(&wh);

    load_world(&wld,argv[1],argc,argv);
    printf("Will fprint_world\n");
    fprint_world(stdout,"wld",&wld);
    printf("printed\n");

    run_mode = mode_number_called(&wld,"running");
    printf("run_mode = %d\n",run_mode);

    start_worst_from_args(&wld,&start_worst,argc,argv);
    copy_worst(&start_worst,&wst,state_dim(&wld,run_mode));

    mds = mode_ref(&wld,run_mode);
    printf("get groj..\n");
    gproj_from_m_and_s(&gp,mds->middle,mds->scales,mds->state.dim,0,1);

    gp_clear(&gp);
    if ( wld.draw_structure != NULL )
      wld.draw_structure(&gp,&wld,&wst);

    printf("Click on Start WorSt\n");
    (void) gp_worst_from_mouse(&gp,&wld,&start_worst,&wst);
                            
    world_draw_worst(&gp,&wld,&wst);
    fprintf_worst(stdout,"wst",&wst,&wld,"\n");

    while ( TRUE )
    {
      int button;
      int steps;
      worst goal;

      empty_worst_hist(&wh);

      printf("Click LEFT to run one step to same local goal (click point)\n");
      printf("Click MIDDLE to run 10 steps to same local goal(click point)\n");
      printf("Click RIGHT to run 40 steps to new local goal (click point)\n");

      init_worst(&goal,run_mode);
      button = gp_worst_from_mouse(&gp,&wld,&start_worst,&goal);
      fprintf_worst(stdout,"local_goal = ",&goal,&wld,"\n");
      steps = (button==1) ? 1 : (button==2) ? 10 : 40;

      add_to_worst_hist(&wh,&wst,state_dim(&wld,wst.mode));
      printf("Added to worst hist\n");

      for ( ; steps > 0 ; steps-- )
      {
        printf("wst.mode = %d\n",wst.mode);
        printf("goal.mode = %d\n",goal.mode);
        step_world(&wld,&wst,&goal);
        printf("%d step%s left\n",steps,(steps==1)?"":"s");
        add_to_worst_hist(&wh,&wst,state_dim(&wld,wst.mode));
        fprintf_worst(stdout,"wst",&wst,&wld,"\n");
        world_draw_worst(&gp,&wld,&wst);
        if ( Verbosity > 30.0 ) wait_for_key();
      }

      fprint_worst_hist(stdout,"whist",&wh,&wld);
      if ( wh.modes_all_same )
        printf("whist %s stuck.\n",(whist_is_stuck(&wld,&wh))?"is":"isn't");
      else
        printf("whist not stuck -- has differing modes\n");
    }
  }
}

bool is_scaled_worst_dist_below(wld,w1,w2,epsilon)
world *wld;
worst *w1,*w2;
float epsilon;
{
  if ( w1 == NULL || w2 == NULL || w1->mode != w2 -> mode )
    my_error("whist.c, is_scaleisdub");
  else
  {
    mode_spec *msp = mode_ref(wld,w1->mode);
    float dsqd = floats_antiscaled_dsqd(w1->nstate,
                                          w2->nstate,
                                          msp->scales,
                                          msp->state.dim
                                         );
    return(dsqd < epsilon * epsilon);
  }
}

   
int gp_worst_from_mouse(gp,wld,start_wst,wst)
gproj *gp;
world *wld;
worst *start_wst,*wst;
/*
   LEFT OR MID CLICK = use start
   RIGHT = update start and use it
*/
{
  int run_mode = mode_number_called(wld,"running");
  int sdim = state_dim(wld,run_mode);
  int button;

  wst -> mode = run_mode;

  button = gp_use_mouse(gp,wst->nstate);

  if ( button == 3 )
  {
    bool enter_all = eq_string(wld->name,"arm");
    if ( enter_all || sdim > 2 )
    {
      int i;
      for ( i = 0 ; i < sdim ; i++ )
        if ( enter_all || (i != gp -> x.comp && i != gp -> y.comp) )
        {
          char buff[200];
          sprintf(buff,"Buff initial value of state variable %d> ",i);
          wst->nstate[i] = input_float(buff);
        }
    }
    copy_worst(wst,start_wst,sdim);
  }
  else
    copy_worst(start_wst,wst,sdim);

  return(button);
}

void normalize_to_legality(wld,mode,wac)
world *wld;
int mode;
float *wac;
/*
   Finds the longest vector which points in the same direction as
   wac, which is no longer than wac, and which does not exceed the
   action bounds for this mode. This is all relative to the action origin,
   which is assumed at (0,0, .. 0)

   This is a utility used optionally in the local controller of some of the
   worlds.
*/
{
  mode_spec *msp = mode_ref(wld,mode);
  int adim = msp -> action.dim;
  float multiple = 1.0;
  int i;
  hype *hy = msp->action.bound;

  if ( Verbosity > 19.0 )
    fprintf_floats(stdout,"\n------\nwac before: ",wac,adim,"\n");
  if ( Verbosity > 19.0 )
    fprintf_hype(stdout,"bound hype: ",hy,"\n");
  
  for ( i = 0 ; i < adim ; i++ )
  {
    if ( fabs(wac[i]) >= 1e-6 )
    {
      float magi;
      if ( wac[i] < 0.0 )
        magi = hype_ref(hy,i,LO) / wac[i];
      else
        magi = hype_ref(hy,i,HI) / wac[i];
      if ( Verbosity > 19.0 )
        printf("mag(%d) = %g\n",i,magi);
      multiple = real_min(multiple,magi);
    }
  }

  floats_scalar_multiply(wac,adim,multiple,wac);
  if ( Verbosity > 19.0 )
    printf("multiple = %g\n",multiple);
  if ( Verbosity > 19.0 )
    fprintf_floats(stdout,"wac after : ",wac,adim,"\n------\n\n");
}


  
