/* This set of classes captures differences between abstract state spaces */

#include "AbstractStateDifference.h"
#include "AbstractStateElements.h"
#include "MarkovChain.h"
#include "utility.h"
#include "ServerParam.h"
#include "Logger.h"

using namespace spades;

/***************************************************************************************/
/***************************************************************************************/
/***************************************************************************************/

AbstractStateCompare::AbstractStateCompare(AbstractState* first_state,
					   AbstractState* second_state)
  : first_state(first_state), second_state(second_state)
{
  checkStates(true);
  setupUsedIdxStorage();
}

AbstractStateCompare::~AbstractStateCompare()
{
}

void
AbstractStateCompare::setFirstState(AbstractState* p)
{
  first_state = p;
  if (!checkStates(true))
    errorlog << "AbstractStateCompare::setFirstState: checkStates failed" << ende;
  setupUsedIdxStorage();
}

void
AbstractStateCompare::setSecondState(AbstractState* p)
{
  second_state = p;
  if (!checkStates(true))
    errorlog << "AbstractStateCompare::setSecondState: checkStates failed" << ende;
  setupUsedIdxStorage();
}

// returns whether the state was already marked as used
bool
AbstractStateCompare::markUsed(int fac_idx, bool used)
{
  if (fac_idx < 0 || fac_idx >= (signed)used_idx_storage.size())
    {
      errorlog << "markUsed: idx out of range " << fac_idx
	       << ", max=" << used_idx_storage.size()
	       << ende;
      return false;
    }

  bool oldval = used_idx_storage[fac_idx];
  used_idx_storage[fac_idx] = used;
  return oldval;
}

// clear all marks
void
AbstractStateCompare::clearUsed()
{
  std::fill(used_idx_storage.begin(), used_idx_storage.end(), false);
}
  
bool
AbstractStateCompare::checkStates(bool allow_null)
{
  if (first_state == NULL && !allow_null)
    {
      errorlog << "AbstractStateCompare::checkState: first_state is NULL" << ende;
      return false;
    }
  if (second_state == NULL && !allow_null)
    {
      errorlog << "AbstractStateCompare::checkState: second_state is NULL" << ende;
      return false;
    }
  if (first_state == NULL || second_state == NULL)
    return true;
  if (first_state->getStateDescription() != second_state->getStateDescription())
    {
      errorlog << "AbstractStateCompare::checkState: descriptions do not match" << ende;
      return false;
    }
  return true;	
}

void
AbstractStateCompare::setupUsedIdxStorage()
{
  if (first_state == NULL || second_state == NULL)
    return;
  // the two states should have the same state description
  used_idx_storage.resize(first_state->getStateDescription()->getFactor()->getLeafCount());
  clearUsed();
}

// returns whether all elements which are not marked as used as identical
bool
AbstractStateCompare::checkUnused(CheckUnusedPolicy policy)
{
  if (policy == CP_AllowAll)
    return true;
  
  for (int fac_idx = 0;
       fac_idx < (signed)used_idx_storage.size();
       fac_idx++)
    {
      if (used_idx_storage[fac_idx])
	continue;
      if (first_state->getFactorIdx(fac_idx) == second_state->getFactorIdx(fac_idx))
	continue;
      if (policy == CP_Strict)
	{
	  actionlog(120) << "checkUnusedSame: policy=" << policy
			 << ", factor " << fac_idx << " differs: "
			 << first_state->getFactorIdx(fac_idx) << " "
			 << second_state->getFactorIdx(fac_idx) << ende;
	  return false;
	}
      if (first_state->getFactorIdx(fac_idx) == -1 &&
	  (policy == CP_AllowFromInvalid || policy == CP_AllowAllInvalid))
	continue;
      if (second_state->getFactorIdx(fac_idx) == -1 &&
	  (policy == CP_AllowToInvalid || policy == CP_AllowAllInvalid))
	continue;
      actionlog(120) << "checkUnusedSame: policy=" << policy
		     << ", factor " << fac_idx << " differs: "
		     << first_state->getFactorIdx(fac_idx) << " "
		     << second_state->getFactorIdx(fac_idx) << ende;
      return false;
    }
  return true;
}

static const char* CHECK_UNUSED_POLICY_STRINGS[] = {
  "Strict", "AllowToInvalid", "AllowFromInvalid", "AllowAllInvalid", "AllowAll" };

std::ostream&
operator<<(std::ostream& os, AbstractStateCompare::CheckUnusedPolicy p)
{
  if (p < 0 || p >= AbstractStateCompare::CP_MAX)
    os << "InvalidCheckUnusedPolicy(" << (int)p << ")";
  else
    os << CHECK_UNUSED_POLICY_STRINGS[(int)p];
  return os;
}


/***************************************************************************************/
/***************************************************************************************/
/***************************************************************************************/

ASDiffPatternElement::ASDiffPatternElement(ASFactorType factor_type)
  : factor_type(factor_type), p_match_factor(NULL), match_factor_idx(-1)
{
}

// also fills in the pFactor and factor_idx
bool
ASDiffPatternElement::matchTo(AbstractStateDescription* pdesc)
{
  if (pdesc == NULL)
    {
      errorlog << "ASDiffPatternElement::matchTo: null description" << ende;
      return false;
    }
  if (pdesc->getFactor() == NULL)
    {
      errorlog << "ASDiffPatternElement::matchTo: null factor in description" << ende;
      return false; 
    }
  return matchTo(pdesc->getFactor());
}

bool
ASDiffPatternElement::matchTo(AbstractStateFactor* pfac, int first_idx)
{
  clearMatch();
  
  if (pfac == NULL)
    {
      errorlog << "ASDiffPatternElement::matchTo: null factor" << ende;
      return false;
    }
  if (pfac->getType() != factor_type)
    {
      actionlog(210) << "matchTo: wrong type, fac="
		     << pfac->getType() << ", mine=" << factor_type << ende;
      return false;
    }
  if (!tryMatch(pfac, first_idx))
    {
      actionlog(210) << "matchTo: subclass rejected" << ende;
      return false;
    }
  actionlog(210) << "matchTo: got a match of type " << factor_type
		 << " idx=" << first_idx << ende;

  p_match_factor = pfac;

  return true;
}

void
ASDiffPatternElement::markUsed(AbstractStateCompare* pcomp)
{
  if (match_factor_idx != -1)
    pcomp->markUsed(match_factor_idx);
  protMarkUsed(pcomp);
}

std::ostream&
operator<<(std::ostream& os, const ASDiffPatternElement& e)
{
  os << "ASDiffPatternElement(" << e.factor_type
     << ", " << e.p_match_factor
     << ", " << e.match_factor_idx << ")";
  e.printExtra(os);
  return os;
}

/***************************************************************************************/
// These are conveinance functions
// allocates new memory. The types should be AND or OR
// The leaf can be anything (including a element with children)
ASDiffPatternChildren*
ASDiffPatternElement::makePattern(ASFactorType type1,
				  ASDiffPatternElement* leaf)
{
  ASDiffPatternChildren* p1 = new ASDiffPatternChildren(type1);
  p1->addChild(leaf);
  return p1;
}

ASDiffPatternChildren*
ASDiffPatternElement::makePattern(ASFactorType type1, ASFactorType type2,
				  ASDiffPatternElement* leaf)
{
  ASDiffPatternChildren* p1 = new ASDiffPatternChildren(type1);
  ASDiffPatternChildren* p2 = new ASDiffPatternChildren(type2);
  p1->addChild(p2);
  p2->addChild(leaf);
  return p1;
}

/***************************************************************************************/

ASDiffPatternChildren::ASDiffPatternChildren(ASFactorType factor_type)
  : ASDiffPatternElement(factor_type)
{
  if (factor_type != ASF_And && factor_type != ASF_Or)
    errorlog << "ASDiffPatternChildren: I don't think this is a children factor type: "
	     << factor_type << ende;
}

ASDiffPatternChildren::~ASDiffPatternChildren()
{
  std::for_each(children.begin(), children.end(), deleteptr<ASDiffPatternElement>());
}

bool
ASDiffPatternChildren::tryMatch(AbstractStateFactor* pfac, int first_idx)
{
  // This is a safe downcast because we know this is an and or an or
  AbstractStateFactorWChildren* pfacchildren = (AbstractStateFactorWChildren*)pfac;
  std::vector<bool> vused(pfacchildren->getNumChildren(), false);
  bool overall_match = true;
  
  for (ChildrenStorage::iterator iter = children.begin();
       iter != children.end();
       iter++)
    {
      int curr_idx = first_idx;
      bool child_matched = false;
      for (int fac_idx = 0;
	   fac_idx < pfacchildren->getNumChildren();
	   fac_idx++)
	{
	  AbstractStateFactor* p_fac_child = pfacchildren->getChild(fac_idx);
	  if (!vused[fac_idx] && (*iter)->matchTo(p_fac_child, curr_idx))
	    {
	      // Even though this factor does not absorb any elements, we set the
	      // factor idx to the first idx, which is where you would have to start
	      // for using this factor in an AbstractState
	      setMatchingFactorIdx(first_idx);
	      child_matched = true;
	      vused[fac_idx] = true;
	      break;
	    }
	  curr_idx += p_fac_child->getLeafCount();
	}

      if (!child_matched)
	{
	  overall_match = false;
	  break;
	}
    }

  if (!overall_match)
    clearMatch();

  return overall_match;
}

void
ASDiffPatternChildren::protClearMatch()
{
  std::for_each(children.begin(), children.end(), std::mem_fun(&ASDiffPatternElement::clearMatch));
}

void
ASDiffPatternChildren::protMarkUsed(AbstractStateCompare* pcomp)
{
  for (ChildrenStorage::iterator iter = children.begin();
       iter != children.end();
       iter++)
    {
      (*iter)->markUsed(pcomp);
    }
}

void
ASDiffPatternChildren::printExtra(std::ostream& os) const
{
  os << "[";
  std::for_each(children.begin(), children.end(), derefprinter<ASDiffPatternElement>(os));
  os << "]";
}


/***************************************************************************************/

// When this is called, it is already verified that the types are correct, so a
// downcast is safe. If any extra checking needs to be done (like checking children)
// it should be done here
// If this returns true, match_factor_idx must be set!
bool
ASDiffPatternSimple::tryMatch(AbstractStateFactor* pfac, int first_idx)
{
  setMatchingFactorIdx(first_idx);
  return true;
}


/***************************************************************************************/
/***************************************************************************************/
/***************************************************************************************/
//static
ASDiffFeature*
ASDiffFeature::createFromStream(std::istream& is)
{
  std::string name;
  is >> name;
  if (is.fail())
    return NULL;
  ASDiffFeature* p = resolveName(name);

  //SMURF: extract the arg part and let the sublass get it

  return p;
}

/***************************************************************************************/
static const char* STATE_SELECTION_STRINGS[] = {"First", "Second", "Both", "Neither"};

std::ostream&
operator<<(std::ostream& os, const ASDiffFeature::StateSelection& s)
{
  if (s < 0 || s>= ASDiffFeature::SS_MAX)
    os << "InvalidStateSelection(" << (int)s << ")";
  else
    os << STATE_SELECTION_STRINGS[(int)s];
  return os;
}


/***************************************************************************************/
ASDiffFeature::ASDiffFeature(ASDiffPatternElement* ppat,
			     AbstractStateCompare::CheckUnusedPolicy policy)
  : ppat(ppat), check_unused_policy(policy), need_pattern_match(true)
{
}

ASDiffFeature::ASDiffFeature(AbstractStateCompare::CheckUnusedPolicy policy)
  : ppat(NULL), check_unused_policy(policy), need_pattern_match(true)
{
}

ASDiffFeature::~ASDiffFeature()
{
  if (ppat)
    delete ppat;
}

bool
ASDiffFeature::match(AbstractStateCompare* pcomp)
{
  if (ppat == NULL || pcomp == NULL)
    {
      errorlog << "ASDiffFeature::match: invalid vals " << ppat << " " << pcomp << ende;
      return false;
    }
  // Here are the steps:
  // * check that our pattern matches
  // * call the subclass for it to check its specific stuff
  // * mark indices in the pattern as used
  // * check that all unused elements in the state are the same
  //   (according to our match policy)

  pcomp->clearUsed();

  if (need_pattern_match)
    {
      if (!ppat->matchTo(pcomp->getFirstState()->getStateDescription()))
	{
	  // We expect that the pattern will always match 
	  warninglog(10) << "ASDiffFeature::match: pattern did not match: " << *ppat << ende;
	  return false;
	}
      need_pattern_match = false;
    }
  
  // Note that protMatch could set our CheckUnusedPolicy
  if (!protMatch(pcomp))
    {
      actionlog(80) << "ASDiffFeature::match: " << *this << ": protMatch rejected" << ende;
      return false;
    }
 
  ppat->markUsed(pcomp);

  if (!pcomp->checkUnused(check_unused_policy))
    {
      actionlog(80) << "ASDiffFeature::match: " << *this
		    << ": checkUnused failed" << ende;
      return false;
    }
  
  actionlog(70) << "ASDiffFeature: got a full match: " << *this << ende;
  
  return true;
}

std::ostream&
operator<<(std::ostream& os, const ASDiffFeature& f)
{
  os << "ASDiffFeature(ppat=";
  if (f.ppat == NULL)
    os << "NULL";
  else
    os << *f.ppat;
  os << ", policy=" << f.check_unused_policy << ")";
  f.printExtra(os);
       
  return os;
}

/***************************************************************************************/

bool
ASDiffFeatureSingleState::protMatch(AbstractStateCompare* pcomp)
{
  if (state == SS_First || state == SS_Both)
    {
      
      if (!singleStateMatch(pcomp->getFirstState()))
	{
	  actionlog(150) << "SingleState: first returned false" << ende;
	  return false;
	}
      else
	{
	  actionlog(150) << "SingleState: first returned true" << ende;
	}
    }
  
  
  if (state == SS_Second || state == SS_Both)
    {
      if (!singleStateMatch(pcomp->getSecondState()))
	{
	  actionlog(150) << "SingleState: second returned false" << ende;
	  return false;
	}
      else
	{
	  actionlog(150) << "SingleState: second returned true" << ende;
	}
    }
  
  actionlog(150) << "SingleState: whole thing is true" << ende;
  return true;
}

/***************************************************************************************/

ASDiffFeatureGoal::ASDiffFeatureGoal(StateSelection state, bool my_side)
  : ASDiffFeatureSingleState(state),
    pGoal(new ASDiffPatternSimple(ASF_Goal)),
    my_side(my_side)
{
  setPattern(ASDiffPatternElement::makePattern(ASF_Or, pGoal));
}

bool
ASDiffFeatureGoal::singleStateMatch(AbstractState* pstate)
{
  int val = pstate->getFactorIdx(pGoal->getMatchingFactorIdx());
  actionlog(210) << "Goal: index is " << pGoal->getMatchingFactorIdx() << ende;
  if (val != -1)
    {
      GoalFactor::State state = (GoalFactor::State)val;
      if ((my_side && state == GoalFactor::ST_Mine) ||
	  (!my_side && state == GoalFactor::ST_Theirs))
	{
	  actionlog(150) << "Goal: right goal "
			<< my_side << " " << state << " " << pstate->getStateIdx() << ende;
	  
	  return true;
	}
      else
	{
	  actionlog(150) << "Goal: wrong goal "
			<< my_side << " " << state << " " << pstate->getStateIdx() << ende;
	}
    }
  else
    {
      actionlog(150) << "Goal: Not a goal " << val << ende;
    }
  
  return false;
}

/***************************************************************************************/

ASDiffFeatureBallOwner::ASDiffFeatureBallOwner(StateSelection state, RelativeTeamSide side)
  : ASDiffFeatureSingleState(state),
    pBallOwner(new ASDiffPatternSimple(ASF_BallOwner)),
    side(side)
{
  setPattern(ASDiffPatternElement::makePattern(ASF_Or, ASF_And, pBallOwner));
}

bool
ASDiffFeatureBallOwner::singleStateMatch(AbstractState* pstate)
{
  BallOwnerFactor::State owner = 
    (BallOwnerFactor::State)pstate->getFactorIdx(pBallOwner->getMatchingFactorIdx());
  switch (side)
    {
    case RTS_None:   return owner == BallOwnerFactor::ST_Other;
    case RTS_Mine:   return owner == BallOwnerFactor::ST_Mine;
    case RTS_Theirs: return owner == BallOwnerFactor::ST_Theirs;
    default:
      errorlog << "ASDiffFeatureBallOwner: What is side? " << side << ende;
    }
  return false;
}

/***************************************************************************************/

ASDiffFeatureFreeKickTran::ASDiffFeatureFreeKickTran(bool my_side)
  : ASDiffFeature(),
    pDeadBall(new ASDiffPatternSimple(ASF_DeadBall)),
    my_side(my_side)
{
  setPattern(ASDiffPatternElement::makePattern(ASF_Or, ASF_And, pDeadBall));
}

bool
ASDiffFeatureFreeKickTran::protMatch(AbstractStateCompare* pcomp)
{
  int first_val = pcomp->getFirstState()->getFactorIdx(pDeadBall->getMatchingFactorIdx());
  int second_val = pcomp->getSecondState()->getFactorIdx(pDeadBall->getMatchingFactorIdx());
  actionlog(210) << "FreeKickTran: index is " << pDeadBall->getMatchingFactorIdx() << ende;
  if (first_val == -1 && second_val != -1)
    {
      DeadBallFactor::State state = (DeadBallFactor::State)second_val;
      if ((my_side && state == DeadBallFactor::ST_Mine) ||
	  (!my_side && state == DeadBallFactor::ST_Theirs))
	{
	  return true;
	}
      else
	{
	  actionlog(210) << "FreeKickTran: wrong transition "
			 << my_side << " " << state << ende;
	}
    }
  else
    {
      actionlog(210) << "FreeKickTran: Not free kick transition "
		     << first_val << " " << second_val << ende;
    }
  
  return false;
}

/***************************************************************************************/
ASDiffFeatureSelfTran::ASDiffFeatureSelfTran()
  : ASDiffFeature()
{
  // we don't really need a pattern at all, but the base class wants one
  setPattern(new ASDiffPatternChildren(ASF_Or));
}

bool
ASDiffFeatureSelfTran::protMatch(AbstractStateCompare* pcomp)
{
  return (pcomp->getFirstState()->getStateIdx() == pcomp->getSecondState()->getStateIdx());
}


/***************************************************************************************/
ASDiffFeatureFreeKick::ASDiffFeatureFreeKick(StateSelection state, RelativeTeamSide side)
  : ASDiffFeatureSingleState(state),
    pDeadBall(new ASDiffPatternSimple(ASF_DeadBall)),
    side(side)
{
  setPattern(ASDiffPatternElement::makePattern(ASF_Or, ASF_And, pDeadBall));
}

bool
ASDiffFeatureFreeKick::singleStateMatch(AbstractState* pstate)
{
  DeadBallFactor::State state =
    (DeadBallFactor::State)pstate->getFactorIdx(pDeadBall->getMatchingFactorIdx());
  switch (side)
    {
    case RTS_Mine:   return state == DeadBallFactor::ST_Mine;
    case RTS_Theirs: return state == DeadBallFactor::ST_Theirs;
    case RTS_None:   return state == DeadBallFactor::ST_Invalid;
    default:
      errorlog << "ASDiffFeatureFreeKick: What is side? " << side << ende;
    }
	return false;
}

/***************************************************************************************/
BallPosDecoder::BallPosDecoder()
  : pBallGrid1(new ASDiffPatternSimple(ASF_BallGrid)),
    pBallGrid2(new ASDiffPatternSimple(ASF_BallGrid))
{
}


//allocates memory for this
ASDiffPatternChildren*
BallPosDecoder::constructPattern()
{
  ASDiffPatternChildren *pOr = ASDiffPatternElement::makePattern(ASF_Or, ASF_And, pBallGrid1);
  pOr->addChild( ASDiffPatternElement::makePattern(ASF_And, pBallGrid2) );
  return pOr;
}
  
//ppfac and pgrid_idx are return params
// returns whether either factor is valid
bool
BallPosDecoder::getValidFactor(AbstractState* pstate,
			       BallGridFactor** ppfac,
			       int *pgrid_idx)
{
  if (isValid(pBallGrid1, pstate, ppfac, pgrid_idx))
    return true;
  if (isValid(pBallGrid2, pstate, ppfac, pgrid_idx))
    return true;
  return false;
}

bool
BallPosDecoder::isValid(ASDiffPatternSimple* ppat, AbstractState* pstate,
			BallGridFactor** ppfac, int *pgrid_idx)
{
  int val = pstate->getFactorIdx(ppat->getMatchingFactorIdx());
  if (val == -1)
    return false;
  // this is a safe downcast because of our pattern match
  if (ppfac)
    *ppfac = (BallGridFactor*)ppat->getMatchingFactor();
  if (pgrid_idx)
    *pgrid_idx = val;
  return true;
}

bool
BallPosDecoder::getGridCoords(AbstractState* pstate, int* px, int* py)
{
  BallGridFactor* pfac;
  int grid_idx;
  if (!getValidFactor(pstate, &pfac, &grid_idx))
    return false;
  if (!pfac->decodeGridPos(grid_idx, px, py))
    errorlog << "BallPosDecoder::getGridCoords: I got a valid factor, but failed in decodeGridPos?" << ende;
  return true;
}



/***************************************************************************************/

ASDiffFeatureBallInRect::ASDiffFeatureBallInRect(StateSelection state, const Rectangle& rect)
  : ASDiffFeatureSingleState(state),
    decoder(),
    rect(rect)
{
  setPattern(decoder.constructPattern());
}

bool
ASDiffFeatureBallInRect::singleStateMatch(AbstractState* pstate)
{
  BallGridFactor* pfac;
  int grid_idx;
  if (!decoder.getValidFactor(pstate, &pfac, &grid_idx))
    return false;
  Rectangle ball_rect(pfac->getGridRectangle(grid_idx));
  return ball_rect.doesIntersect(rect);
}

/***************************************************************************************/

ASDiffFeatureBallAtPenCorner::ASDiffFeatureBallAtPenCorner(StateSelection state, RelativeTeamSide side)
  : ASDiffFeatureSingleState(state),
    decoder(),
    side(side)
{
  setPattern(decoder.constructPattern());
  if (side != RTS_Mine && side != RTS_Theirs)
    errorlog << "ASDiffFeatureBallAtPenCorner: what is side? " << side << ende;
}

bool
ASDiffFeatureBallAtPenCorner::singleStateMatch(AbstractState* pstate)
{
  BallGridFactor* pfac;
  int grid_idx;
  if (!decoder.getValidFactor(pstate, &pfac, &grid_idx))
    return false;
  Rectangle ball_rect(pfac->getGridRectangle(grid_idx));
  return ball_rect.isInside( (side == RTS_Mine)
			     ? ServerParam::instance()->getSPPenaltyAreaRectangle(true).getPosRightTop()
			     : ServerParam::instance()->getSPPenaltyAreaRectangle(false).getPosLeftTop()
			     ) ||
    ball_rect.isInside( (side == RTS_Mine)
			? ServerParam::instance()->getSPPenaltyAreaRectangle(true).getPosRightBottom()
			: ServerParam::instance()->getSPPenaltyAreaRectangle(false).getPosLeftBottom()
			) ;
}

/***************************************************************************************/
ASDiffFeatureBallMoveRange::ASDiffFeatureBallMoveRange(int max_x,
						       int max_y,
						       int min_manhattan,
						       int max_manhattan,
						       int min_on_max_dim)
  : ASDiffFeature(),
    decoder(),
    max_x(max_x),
    max_y(max_y),
    min_manhattan(min_manhattan),
    max_manhattan(max_manhattan),
    min_on_max_dim(min_on_max_dim)
{
  setPattern(decoder.constructPattern());
}

bool
ASDiffFeatureBallMoveRange::protMatch(AbstractStateCompare* pcomp)
{
  int x1, y1, x2, y2;
  if (!decoder.getGridCoords(pcomp->getFirstState(), &x1, &y1))
    return false;
  if (!decoder.getGridCoords(pcomp->getSecondState(), &x2, &y2))
    return false;
  int xdiff = abs(x1-x2);
  int ydiff = abs(y1-y2);
  if (max_x >= 0 && xdiff > max_x)
    return false;
  if (max_y >= 0 && ydiff > max_y)
    return false;
  int manhattan_dist = abs(x1-x2) + abs(y1-y2);
  if ( (min_manhattan > 0 && manhattan_dist < min_manhattan) ||
       (max_manhattan > 0 && manhattan_dist > max_manhattan) )
    return false;
  if (min_on_max_dim >= 0 && min_on_max_dim > std::max(xdiff, ydiff))
    return false;
  return true;
}


// extra information to be printed, after the standard factor type and such
void
ASDiffFeatureBallMoveRange::printExtra(std::ostream& os) const
{
  os << "[BallMoveRange("
     << "max_x=" << max_x << ", "
     << "max_y=" << max_y << ", "
     << "min_manhattan=" << min_manhattan << ", "
     << "max_manhattan=" << max_manhattan << ")]";
}

/***************************************************************************************/
ASDiffFeatureBallDir::ASDiffFeatureBallDir(RelativeTeamSide side)
  : ASDiffFeature(),
    decoder(),
    side(side)
{
  setPattern(decoder.constructPattern());
}

bool
ASDiffFeatureBallDir::protMatch(AbstractStateCompare* pcomp)
{
  int x1, y1, x2, y2;
  if (!decoder.getGridCoords(pcomp->getFirstState(), &x1, &y1))
    return false;
  if (!decoder.getGridCoords(pcomp->getSecondState(), &x2, &y2))
    return false;
  switch (side)
    {
    case RTS_Mine:    return (x1 <  x2);
    case RTS_Theirs:  return (x1 >  x2);
    case RTS_None:    return (x1 == x2);
    default:
      errorlog << "ASDiffFeatureBallDir::protMatch: what is side?" << side << ende;
    }
  errorlog << "ASDiffFeatureBallDir::protMatch: How did I get here?" << ende;
  return false;
}

/***************************************************************************************/
ASDiffFeaturePlayerRegFilledSet::ASDiffFeaturePlayerRegFilledSet(StateSelection state,
							   unsigned bitmask_filled,
							   unsigned bitmask_empty)
  : ASDiffFeatureSingleState(state),
    pPlayerOcc(new ASDiffPatternSimple(ASF_POSet)),
    bitmask_filled(bitmask_filled),
    bitmask_empty(bitmask_empty)
{
  setPattern(ASDiffPatternElement::makePattern(ASF_Or, ASF_And, pPlayerOcc));
}

bool
ASDiffFeaturePlayerRegFilledSet::singleStateMatch(AbstractState* pstate)
{
  unsigned val = pstate->getFactorIdx(pPlayerOcc->getMatchingFactorIdx());
  // check that all the regions specified by the bitmask are filled
  if ( (val & bitmask_filled) != bitmask_filled)
    return false;
  if ( (val & bitmask_empty) != 0)
    return false;
  return true;
}

void
ASDiffFeaturePlayerRegFilledSet::printExtra(std::ostream& os) const
{
  os << "[PlayerRegFilledSet("
     << "state=" << getStateSel() << ", "
     << "bitmask_filled=" << std::hex << bitmask_filled << ", "
     << "bitmask_empty=" << std::hex << bitmask_empty << ")]";
}

/***************************************************************************************/
ASDiffFeaturePlayerRegFilledElem::ASDiffFeaturePlayerRegFilledElem(StateSelection state,
								   int num_elements,
								   unsigned bitmask_filled,
								   unsigned bitmask_empty)
  : ASDiffFeatureSingleState(state),
    patterns(),
    bitmask_filled(bitmask_filled),
    bitmask_empty(bitmask_empty)
{
  ASDiffPatternChildren* pAnd = new ASDiffPatternChildren(ASF_And);
  for (int i = 0; i < num_elements; i++)
    {
      ASDiffPatternSimple* p = new ASDiffPatternSimple(ASF_POElement);
      pAnd->addChild(p);
      patterns.push_back(p);
    }
  
  setPattern(ASDiffPatternElement::makePattern(ASF_Or, ASF_And, pAnd));
}

bool
ASDiffFeaturePlayerRegFilledElem::singleStateMatch(AbstractState* pstate)
{
  //we'll construct a value that contains factors
  unsigned val = 0;
  for (PatternStorage::iterator iter = patterns.begin();
       iter != patterns.end();
       iter++)
    {
      val |= (pstate->getFactorIdx((*iter)->getMatchingFactorIdx()) << (iter-patterns.begin()));
    }
  
  // check that all the regions specified by the bitmask are filled
  if ( (val & bitmask_filled) != bitmask_filled)
    return false;
  if ( (val & bitmask_empty) != 0)
    return false;
  return true;
}

void
ASDiffFeaturePlayerRegFilledElem::printExtra(std::ostream& os) const
{
  os << "[PlayerRegFilledElem("
     << "state=" << getStateSel() << ", "
     << "bitmask_filled=" << std::hex << bitmask_filled << ", "
     << "bitmask_empty=" << std::hex << bitmask_empty << ")]";
}

/***********************************************************************/
/***********************************************************************/
/***********************************************************************/

ASDiffClassifier::ASDiffClassifier(const char* name)
  : name(name == NULL ? "" : name), terms()
{
}

ASDiffClassifier::~ASDiffClassifier()
{
  for (TermStorage::iterator iter = terms.begin();
       iter != terms.end();
       iter++)
    std::for_each(iter->begin(), iter->end(), std::mem_fun_ref(&FeatureAtom::deleteFeature));
}

bool
ASDiffClassifier::match(AbstractStateCompare* pcomp)
{
  for (TermStorage::iterator term_iter = terms.begin();
       term_iter != terms.end();
       term_iter++)
    {
      bool term_match = true;
      for (FeatureStorage::iterator feature_iter = term_iter->begin();
	   term_match && feature_iter != term_iter->end();
	   feature_iter++)
	{
	  if (!(*feature_iter).match(pcomp))
	    term_match = false;
	}
      if (term_match)
	return true;
    }
  return false;
}

  
std::ostream&
operator<<(std::ostream& os, const ASDiffClassifier& c)
{
  os << c.name << " ";
  os << "{";
  ASDiffClassifier::TermStorage::const_iterator next_term_iter;
  for (ASDiffClassifier::TermStorage::const_iterator term_iter = c.terms.begin();
       term_iter != c.terms.end();
       term_iter = next_term_iter)
    {
      next_term_iter = term_iter;
      next_term_iter++;
      std::copy(term_iter->begin(), term_iter->end(),
		std::ostream_iterator<ASDiffClassifier::FeatureAtom>(os, " "));
      if (next_term_iter == c.terms.end())
	os << " || ";
    }
  os << "}";
  return os;
}

std::istream&
operator>>(std::istream& is, ASDiffClassifier& c)
{
  std::string name;
  c.clear();
  
  is >> name;
  if (is.fail())
    return is;
  c.setName(name.c_str());
  
  if (!skip_white_space(is) || is.peek() != '{')
    {
      is.setstate(std::ios::failbit);
      return is;
    }
  is.get(); // remove the '{'

  c.terms.push_back(ASDiffClassifier::FeatureStorage());
  ASDiffClassifier::TermStorage::iterator curr_term = c.terms.end() - 1;
  
  for (;;)
    {
      if (!skip_white_space(is))
	{
	  is.setstate(std::ios::failbit);
	  return is;
	}
      if (is.peek() == '}')
	break;

      if (is.peek() == '|')
	{
	  is.get(); //remove the first |
	  if (is.get() != '|')
	    {
	      is.setstate(std::ios::failbit);
	      return is;
	    }
	  c.terms.push_back(ASDiffClassifier::FeatureStorage());
	  curr_term = c.terms.end() - 1;
	  if (!skip_white_space(is))
	    {
	      is.setstate(std::ios::failbit);
	      return is;
	    }
	}

      bool not_flag = false;
      if (is.peek() == '!')
	{
	  not_flag = true;
	  is.get();
	}
      
      ASDiffFeature* p = ASDiffFeature::createFromStream(is);
      if (p == NULL)
	{
	  is.setstate(std::ios::failbit);
	  return is;
	}
      curr_term->push_back(ASDiffClassifier::FeatureAtom(not_flag, p));
    }
  is.get(); //remove the '}'

  return is;
}


/***********************************************************************/
ASDiffClassifierSet::ASDiffClassifierSet()
  : classifiers()
{
}

ASDiffClassifierSet::~ASDiffClassifierSet()
{
  clear();
}

ASDiffClassifier*
ASDiffClassifierSet::getClassifier(int idx)
{
  if (idx < 0 || idx >= (signed)classifiers.size())
    return NULL;
  return classifiers[idx];
}

const ASDiffClassifier*
ASDiffClassifierSet::getClassifier(int idx) const
{
  if (idx < 0 || idx >= (signed)classifiers.size())
    return NULL;
  return classifiers[idx];
}

int
ASDiffClassifierSet::lookupClassifier(const char* name) const
{
  for (ClassifierStorage::const_iterator iter = classifiers.begin();
       iter != classifiers.end();
       iter++)
    {
      if (strcasecmp(name, (*iter)->getName().c_str()) == 0)
	return (iter - classifiers.begin());
    }
  return -1;
}


void
ASDiffClassifierSet::clear()
{
  std::for_each(classifiers.begin(), classifiers.end(),
		deleteptr<ASDiffClassifier>());
  classifiers.clear();
}

// returns -1 if nothing matches
// returns the idx (which can be passed to getClassifier)
int
ASDiffClassifierSet::classify(AbstractStateCompare* pcomp)
{
  for (ClassifierStorage::iterator iter = classifiers.begin();
       iter != classifiers.end();
       iter++)
    {
      if ((*iter)->match(pcomp))
	return iter - classifiers.begin();
    }
  return -1;
}

std::ostream&
operator<<(std::ostream& os, const ASDiffClassifierSet& s)
{
  os << "# This represents a set of ASDiffClassifiers" << std::endl;
  os << "# Each line represents on classifier as: <name> { <list of features> }" << std::endl;
  os << "# Each feature is an ASDiffFeature, and they are space separated" << std::endl;
  os << "# IMPORTANT: This file can not be read back in as a classifier set!" << std::endl;
  std::for_each(s.classifiers.begin(), s.classifiers.end(),
		derefprinter<ASDiffClassifier>(os, "\n"));
  return os;
}

  
/***********************************************************************/
bool
ASDCSetFileReader::processLine(std::istrstream& line,
			       const char* fileid,
			       const char* path,
			       float version)
{
  ASDiffClassifier* p = new ASDiffClassifier;
  line >> *p;
  if (line.fail())
    {
      delete p;
      return false;
    }
  
  pset->addClassifier(p);
  return true;
}


/***********************************************************************/
/***********************************************************************/
/***********************************************************************/

ASDiffClassificationStore::ASDiffClassificationStore( ASDiffClassifierSet* p)
  : pset(p), stored_class()
{
  resize();
}

ASDiffClassificationStore::~ASDiffClassificationStore()
{
}


  
// Stores this classification
void
ASDiffClassificationStore::add(AbstractStateCompare* pcomp, int classification)
{
  if (pcomp == NULL)
    {
      errorlog << "ASDiffClassificationStore::add: NULL comp" << ende;
      return;
    }
  if (pcomp->getFirstState() == NULL || pcomp->getSecondState() == NULL)
    {
      errorlog << "ASDiffClassificationStore::add: NULL state: "
	       << pcomp->getFirstState() << " " << pcomp->getSecondState()
	       << ende;
      return;
    }
  add(pcomp->getFirstState()->getStateIdx(), pcomp->getSecondState()->getStateIdx(), classification);
}

void
ASDiffClassificationStore::add(int first_state, int second_state, int classification)
{
  // The invalid classification is stored as one past the regular ones
  if (classification == -1)
    classification = stored_class.size() - 1;
  if (classification < 0 || classification >= (signed)stored_class.size())
    {
      errorlog << "ASDiffClassificationStore::add: classification out of range "
	       << classification << ", max=" << stored_class.size() << ende;
      return;
    }
  stored_class[classification].push_back(std::make_pair(first_state, second_state));
}


void
ASDiffClassificationStore::clear()
{
  std::fill(stored_class.begin(), stored_class.end(), StateStateList());
}

void
ASDiffClassificationStore::resize()
{
  if (pset == NULL)
    {
      stored_class.clear();
    }
  else
    {
      // The +1 is for the unmatched class
      stored_class.resize(pset->getNumClassifiers() + 1);
    }
}

void
ASDiffClassificationStore::classifyMarkovChainTransitions(const MarkovChain& mc,
							  AbstractStateDescription* pdesc,
							  int progress_interval)
{
  // First, let's do some sanity checks
  if (pdesc == NULL)
    {
      errorlog << "classifyMarkovChainTransitions: NULL description" << ende;
      return;
    }
  if (mc.getNumStates() != pdesc->getNumStates())
    {
      errorlog << "classifyMarkovChainTransitions: state count mismatch "
	       << mc.getNumStates() << " " << pdesc->getNumStates()
	       << ende;
      return;
    }
  if (pset == NULL)
    {
      errorlog << "classifyMarkovChainTransitions: null classifier set" << ende;
      return;
    }
  
  // put this out here so we don't have to keep realloc
  AbstractState first_state(pdesc);
  AbstractState second_state(pdesc);
  
  for (int state_idx = mc.getNumStates() - 1;
       state_idx >= 0;
       state_idx--)
    {
      if (progress_interval > 0 && state_idx % progress_interval == 0)
	std::cout << '.' << std::flush;
      first_state.setStateIdx(state_idx);
      for (int tran = mc.getNumTransitionsForState(state_idx) - 1;
	   tran >= 0;
	   tran--)
	{
	  second_state.setStateIdx(mc.getTranNextState(state_idx, tran));
	  AbstractStateCompare comp(&first_state, &second_state);
	  int c = pset->classify(&comp);
	  add(&comp, c);
	}
    }
}


void
ASDiffClassificationStore::writeSummary(std::ostream& os) const
{
  os << "This file is a summary of the classification of abstract state differences (usually transitions)"
     << std::endl;
  os << "# as represented by ASDiffClassificationStore" << std::endl;
  os << "Format: <idx> <name> <num of matching>" << std::endl;
  for (Classifications::const_iterator iter = stored_class.begin();
       iter != stored_class.end();
       iter++)
    {
      os << (iter-stored_class.begin()) << "\t"
	 << ( (iter == stored_class.end() - 1)
	      ? "UNMATCHED"
	      : pset->getClassifier(iter-stored_class.begin())->getName()) << "\t"
	 << iter->size() << std::endl;
    }
}


void
ASDiffClassificationStore::writeFull(std::ostream& os) const
{
  os << "# This file is a the classification of abstract state differences (usually transitions)" << std::endl;
  os << "# as represented by ASDiffClassificationStore" << std::endl;
  os << "# Format:" << std::endl;
  os << "# First line: <number of classes>" << std::endl;
  os << "# Section head: <idx> <name> <number of pairs>" << std::endl;
  os << "# Each line in section:       <state idx 1> <state idx 2>" << std::endl;

  os << stored_class.size() << std::endl;
  
  for (ASDiffClassificationStore::Classifications::const_iterator class_iter = stored_class.begin();
       class_iter != stored_class.end();
       class_iter++)
    {
      os << (class_iter-stored_class.begin()) << "\t"
	 << ( (class_iter == stored_class.end() - 1)
	      ? "UNMATCHED"
	      : pset->getClassifier(class_iter-stored_class.begin())->getName()) << "\t"
	 << class_iter->size() << std::endl;
      for (ASDiffClassificationStore::StateStateList::const_iterator pair_iter = class_iter->begin();
	   pair_iter != class_iter->end();
	   pair_iter++)
	{
	  os << "\t" << pair_iter->first << " " << pair_iter->second << std::endl;
	}
    }
}

