/* this class provides functionality for learning a formation from data */

#include <iomanip>
#include "FormationLearner.h"
#include "ServerParam.h"
#include "Logger.h"
#include "CoachParam.h"
#include "misc.h"

using namespace spades;

FormationLearner::FormationLearner()
  :  side_to_learn(TS_None),
     vSdsXPos(ServerParam::instance()->getSPTeamSize()),
     vSdsYPos(ServerParam::instance()->getSPTeamSize()),
     vvSdsPlayerDist(ServerParam::instance()->getSPTeamSize()),
     vvPos(ServerParam::instance()->getSPTeamSize())
{
  for (int i=0; i < ServerParam::instance()->getSPTeamSize(); i++)
    {
      vvSdsPlayerDist[i].insert(vvSdsPlayerDist[i].end(), i+1, SingleDataSummary());
    }
}

FormationLearner::~FormationLearner()
{
}


// reads data (created by recordCurrDataToFile) into summary structures
bool
FormationLearner::loadDataFrom(std::istream& is)
{
  while (skip_to_non_comment(is))
    {
      std::vector< std::pair<bool,VecPosition> > vPosData;
      int time;
      is >> time;
      for (int num=0; num<ServerParam::instance()->getSPTeamSize(); num++)
	{
	  bool valid;
	  VecPosition pos;
	  is >> valid >> pos;
	  if (is.fail())
	    {
	      errorlog << "loadDataFrom: format error" << ende;
	      return false;
	    }
	  vPosData.push_back( std::make_pair(valid, pos) );
	}
      addToSummaries(vPosData);
    }
  return true;
}

// records positions of a file to be read in by loadDataFrom
bool
FormationLearner::recordCurrDataToFile(std::ostream& o,const WorldState& ws)
{
  if (side_to_learn != TS_Left && side_to_learn != TS_Right)
    {
      errorlog << "recordCurrDataToFile: side_to_learn bad: " << side_to_learn << ende;
      return false;
    }
  
  int gnum = ws.getGoalieNum(side_to_learn);
  o << ws.getTime() << " ";
  for (int num=1; num<=ServerParam::instance()->getSPTeamSize(); num++)
    {
      int num_to_print = num;
      // swap 1 and gnum if gnum != 1
      if (num_to_print == 1)
	num_to_print = gnum;
      else if (num_to_print == gnum)
	num_to_print = 1;

      const PlayerInfo* pi = ws.getPlayer(side_to_learn, num);
      if (pi == NULL)
	{
	  o << 0 << ' ' << VecPosition(0,0) << ' ';
	}
      else
	{
	  VecPosition v = pi->getPos();
	  if (side_to_learn == TS_Right)
	    v = v.flipCoords();
	  o << 1 << ' ' << v << ' ';
	}
    }
  o << std::endl;
  return true;
}

bool
FormationLearner::writeRecordHeader(std::ostream& o)
{
  o << "# This file was generated by FormationLearner (OWL coach)" << std::endl
    << "# the method recordCurrDataToFile records this to be read by loadDataFrom later" << std::endl
    << "# Format: [<validbit> <pos of player i] (repeated team_size times)" << std::endl;
  return true;
}


// takes the current positions and loads them into the summary structures
bool
FormationLearner::addCurrentToSummary(const WorldState& ws)
{
  //Note that we do NOT do anything special for the goalie here!
  // we'll let this be taken care of elsewhere
  std::vector< std::pair<bool,VecPosition> > vPosData;
  for (int num=1; num<=ServerParam::instance()->getSPTeamSize(); num++)
    {
      const PlayerInfo* pi = ws.getPlayer(side_to_learn, num);
      if (pi == NULL)
	vPosData.push_back( std::make_pair(false, VecPosition(0,0)) );
      else
	vPosData.push_back( std::make_pair(true, pi->getPos()) );
    }
  addToSummaries(vPosData);
  return true;
}


// the caller takes over the memory
bool
FormationLearner::learnFormation(Formation& f)
{
  actionlog(20) << "Now learning a formation based on observations" << ende;

  actionlog(40) << "Formation learning: Target distances" << ende;
  for (int num1=1; num1<=ServerParam::instance()->getSPTeamSize(); num1++)
    {
      for (int num2=1; num2<=ServerParam::instance()->getSPTeamSize(); num2++)
	{
	  actionlog(40) << std::setw(7) << std::setprecision(4) << getSdsPlayerDist(num1, num2).getMean();
	}
      actionlog(40) << ende;
  }

  std::cout << "Learning formation (stage 1): " << std::flush;
  std::vector<Rectangle> vRects(ServerParam::instance()->getSPTeamSize());
  for (int num=1; num<=ServerParam::instance()->getSPTeamSize(); num++)
    {
      actionlog(30) << "Working on player " << num << ende;
      VecPosition center(vSdsXPos[num-1].getMean(), vSdsYPos[num-1].getMean());
      VecPosition size(vSdsXPos[num-1].getStdDev(), vSdsYPos[num-1].getStdDev());
      vRects[num-1] = Rectangle(center-size, center+size);
      vRects[num-1] = randomExplore(vRects[num-1], vvPos[num-1]);
      std::cout << "." << std::flush;
    }
  std::cout << std::endl;

  //#define TEST_LEARN
#ifdef TEST_LEARN
#error This code has not been converted
  //this is so we can compare formations before and after phase 2 learning
  for (int num=1; num<=ServerParam::instance()->getSPTeamSize(); num++)
    f.setHomeRegion(num, RegionPtr(new RegQuad(rects[num-1])));

  char newname[100];
  setFormationNameFromFn(&f, Mem->CP_formation_learn_output_fn);
  sprintf(newname, "%s.phase1", f.getName());
  f.setName(newname);
  lForm.push_back(f);

  char fn[MAX_FILE_LEN];
  sprintf(fn, "%s.phase1", Mem->CP_formation_learn_output_fn);
  ofstream outtmp(fn);
  if (!outtmp)
    my_error("Could not open phase 1 learned formation output file '%s'", fn);
  outtmp << f;
  outtmp.close();

#endif

  
  actionlog(40) << "Formation learning: Phase 1 distances" << ende;
  for (int num1=1; num1<=ServerParam::instance()->getSPTeamSize(); num1++)
    {
      for (int num2=1; num2<=ServerParam::instance()->getSPTeamSize(); num2++)
	{
	  actionlog(40) << std::setw(7) << std::setprecision(5)
			<< vRects[num1-1].getCenter().getDistanceTo(vRects[num2-1].getCenter());
	}
      actionlog(40) << ende;
    }
  
  std::cout << "Learning formation (stage 2): " << std::flush;
  climbOnPlayerDistances(vRects, true);
  std::cout << std::endl;
  
  actionlog(40) << "Formation learning: Phase 2 distances" << ende;
  for (int num1=1; num1<=ServerParam::instance()->getSPTeamSize(); num1++)
    {
      for (int num2=1; num2<=ServerParam::instance()->getSPTeamSize(); num2++)
	{
	  actionlog(40) << std::setw(7) << std::setprecision(5)
			<< vRects[num1-1].getCenter().getDistanceTo(vRects[num2-1].getCenter());
	}
      actionlog(40) << ende;
    }
  
  for (int num=1; num<=ServerParam::instance()->getSPTeamSize(); num++)
    f.setHomeRegion(num, new rcss::clang::RegRec(vRects[num-1]));

  return true;
}

void
FormationLearner::addToSummaries(const std::vector< std::pair<bool,VecPosition> >& v)
{
  /* Set average position and store current position */
  for (unsigned num=0; num < v.size(); num++) {
    if (!v[num].first) continue;
    vSdsXPos[num].addPoint(v[num].second.getX());
    vSdsYPos[num].addPoint(v[num].second.getY());
    vvPos[num].push_back(v[num].second);
  }

  /* handle the distance between players stuff */
  for (unsigned num1=0; num1 < v.size(); num1++) {
    if (!v[num1].first) continue;
    for (unsigned num2=0; num2 <= num1; num2++) {
      if (!v[num2].first) continue;
      float d = v[num1].second.getDistanceTo(v[num2].second);
      //LogAction5(10, "adding %d->%d distance %.3f", num1, num2, d);
      getSdsPlayerDist(num1+1, num2+1).addPoint(d);
    }
  }
}


SingleDataSummary&
FormationLearner::getSdsPlayerDist(int num1, int num2) 
{
  if (num2 <= num1) {
    return vvSdsPlayerDist[num1-1][num2-1];
  } else {
    return vvSdsPlayerDist[num2-1][num1-1];
  }
}

float
FormationLearner::simpleEvalFunction(const Rectangle& r, float point_fraction)
{
  float max_area = 900;
  return
    CoachParam::instance()->getFormationPointWeight() * pow(point_fraction, 1.0/3.0) +
    (1.0-CoachParam::instance()->getFormationPointWeight()) * (-1/max_area * r.getArea() + 1);
}

float
FormationLearner::simpleEvalFunction(const Rectangle& r, const std::vector<VecPosition>& v)
{
  int num_pts_int = 0;
  int num_pts_ext = 0;
  for (std::vector<VecPosition>::const_iterator iter = v.begin();
       iter != v.end();
       ++iter) {
    if (r.isInside(*iter))
      num_pts_int++;
    else
      num_pts_ext++;
  }
  return simpleEvalFunction(r, (float)num_pts_int / (float)(num_pts_int + num_pts_ext));
}

//this does a simple hillclimbing search with axis-aligned rectangles
Rectangle
FormationLearner::randomExplore(Rectangle start_rect, const std::vector<VecPosition>& vpos)
{
  Rectangle best_rect = start_rect;
  float best_val = simpleEvalFunction(start_rect, vpos);
  
  //each step we generate a random rectangle based on the start_rect and see how it is
  for (int step=0; step < CoachParam::instance()->getFormationREMaxSteps(); step++)
    {
      float
	left =   gaussian_sample(start_rect.getPosLeftTop().getX(),
				 CoachParam::instance()->getFormationREStDev()),
	right =  gaussian_sample(start_rect.getPosRightBottom().getX(),
				 CoachParam::instance()->getFormationREStDev()),
	top =    gaussian_sample(start_rect.getPosLeftTop().getY(),
				 CoachParam::instance()->getFormationREStDev()),
	bottom = gaussian_sample(start_rect.getPosRightBottom().getY(),
				 CoachParam::instance()->getFormationREStDev());
      if (left >= right-EPSILON || top >= bottom-EPSILON)
	continue;
      if (left < ServerParam::instance()->getSPPitchLength()/2.0 ||
	  right > ServerParam::instance()->getSPPitchLength()/2.0 ||
	  top < ServerParam::instance()->getSPPitchWidth()/2.0 ||
	  bottom > ServerParam::instance()->getSPPitchWidth()/2.0)
	continue;
      Rectangle test_rect(VecPosition(left, top), VecPosition(right, bottom));
      float value = simpleEvalFunction(test_rect, vpos);
      if (value > best_val)
	{
	  actionlog(210) << "I found a better rectangle on step " << step << ": "
			 << std::setprecision(4) << value << ende;
	  best_rect = test_rect;
	  best_val  = value;
	}
    }
  return best_rect;
}



void
FormationLearner::climbOnPlayerDistances(std::vector<Rectangle>& vRects, bool show_status)
{
  const float alpha_slope = -.01;
  const float alpha_y_int = 0.1;
  const float orig_pos_factor = .5;
  std::vector<VecPosition> orig_center(ServerParam::instance()->getSPTeamSize());
  std::vector<VecPosition> curr_center(ServerParam::instance()->getSPTeamSize());
  for (int num=0; num<ServerParam::instance()->getSPTeamSize(); num++)
    {
      curr_center[num] = orig_center[num] = vRects[num].getCenter();
    }

  int step;
  for (step = 0; step < CoachParam::instance()->getFormationHCMaxSteps(); step++)
    {
      int p = int_random(ServerParam::instance()->getSPTeamSize());

      /* Now we need to calculate the gradient for this players center of its
	 rectangle
	 The evaluation function is
	 OLD: alpha * (curr_dist / ideal_dist - 1)^2
	 alpha * curr_dist_off^2
	 alpha decreases with length of target_dist
      */
      //SMURF: do we want to decrease the contribution of the players orginal position?
      VecPosition gradient(0,0);
      for (int num=0; num<ServerParam::instance()->getSPTeamSize(); num++)
	{
	  VecPosition disp;
	  float extra_factor;
	  if (num == p)
	    {
	      disp = orig_center[num] - curr_center[p];
	      extra_factor = orig_pos_factor;
	    }
	  else
	    {
	      disp = curr_center[num] - curr_center[p];
	      extra_factor = 1.0;
	    }
      
	  /* THis is not really the right thing! we probabaly want to go in *any* direction
	     (if target dist is > 0) */
	  if (disp == VecPosition(0,0)) continue;
	  /* This was my old code; alpha was 0.5
	     if (getSdsPlayerDist(p+1, num+1).getMean() <= 1.0) continue;
	     gradient += disp.scaleTo(2 * alpha * (disp.mod() / getSdsPlayerDist(p+1, num+1).getMean() - 1.0 ));
	  */
	  float alpha =
	    alpha_y_int * exp (alpha_slope * getSdsPlayerDist(p+1, num+1).getMean());
	  //one can also consider the extra_factor to multiply alpha
	  gradient += VecPosition(extra_factor) * 
	    disp.scaleTo(2 * alpha * (disp.getMagnitude() - getSdsPlayerDist(p+1, num+1).getMean()));
	  actionlog(240) << "gradient piece "
			 << 2 * alpha * (disp.getMagnitude() - getSdsPlayerDist(p+1, num+1).getMean())
			 << ende;
	  if (gradient.isnan() || gradient.isinf())
	    {
	      errorlog << "gradient is invalid! "
		       << alpha << " " 
		       << disp.getMagnitude() << " " 
		       << getSdsPlayerDist(p+1, num+1).getMean() << " " 
		       << ende;
	      errorlog << "More info: "
		       << p << " "
		       << num << " "
		       << curr_center[p] << " "
		       << curr_center[num] << " "
		       << ende;
	      gradient = VecPosition(0,0);
	    }
	}
      actionlog (210) << "gradient "
		      << gradient.getMagnitude() << " "
		      << gradient << " " 
		      << ende;
      if (gradient.getMagnitude() > 1000)
	{
	  errorlog << "Gradient is big on step " << step << ": " << gradient.getMagnitude() << ende;
	}
      curr_center[p] += gradient * CoachParam::instance()->getFormationHCLRate();

      if (show_status && step % 100 == 0)
	std::cout << "." << std::flush;
    }

  for (int num=0; num<ServerParam::instance()->getSPTeamSize(); num++)
    {
      vRects[num] = vRects[num].shiftCenter(curr_center[num]);
    }
}
