/* -*- Mode: C++ -*- */

/* This set of classes captures differences between abstract state spaces */

#ifndef _ABSTRACT_STATE_DIFFERENCE_H_
#define _ABSTRACT_STATE_DIFFERENCE_H_

#include <iostream>
#include <fstream>
#include <vector>
#include "AbstractState.h"
#include "FileReader.h"

class MarkovChain;

/***********************************************************************/
/***********************************************************************/
/***********************************************************************/

// Will NOT take over the memory
/* This class will be used by the others to compare the states
   Notably, it will record which elements have been used */
class AbstractStateCompare
{
public:
  // this tells the checkUnusedSame method how stict to be
  enum CheckUnusedPolicy {
    CP_Strict, // all changes are bad
    CP_AllowToInvalid, // allow valid to change to invalid
    CP_AllowFromInvalid, // allow invalid states to turn into anything
    CP_AllowAllInvalid, // allow all transitions to/from invalid
    CP_AllowAll, // everything is ok
    CP_MAX // never actually use this
  };

  friend std::ostream& operator<<(std::ostream& os, CheckUnusedPolicy p);
  
public:
  AbstractStateCompare(AbstractState* first_state, AbstractState* second_state);
  ~AbstractStateCompare();

  AbstractState* getFirstState() { return first_state; }
  const AbstractState* getFirstState() const { return first_state; }
  void setFirstState(AbstractState* p);

  AbstractState* getSecondState() { return second_state; }
  const AbstractState* getSecondState() const { return second_state; }
  void setSecondState(AbstractState* p);

  // returns whether the state was already marked as used
  bool markUsed(int fac_idx, bool used = true);

  // clear all marks
  void clearUsed();
  
  // returns whether all elements which are not marked as used as identical
  bool checkUnused(CheckUnusedPolicy policy);
  
private:
  typedef std::vector<bool> UsedIdxStorage;
  
  // checks that the states are something that we can use
  bool checkStates(bool allow_null);

  void setupUsedIdxStorage();
  
  AbstractState* first_state;
  AbstractState* second_state;

  UsedIdxStorage used_idx_storage;
};


/***********************************************************************/
/***********************************************************************/
/***********************************************************************/
class ASDiffPatternElement;
class ASDiffPatternChildren;
class ASDiffPatternSimple;

// These class represent a pattern for what the state description should
// look like. Subclass will enable a tree to be formed. The leaves of the tree will
// store their associated index into the AbstractState class
class ASDiffPatternElement
{
public:
  ASDiffPatternElement(ASFactorType factor_type);
  virtual ~ASDiffPatternElement() {}

  // also fills in the pFactor and factor_idx
  bool matchTo(AbstractStateDescription* pdesc);
  // the first_idx is the first index into the AbstractState to use for this match
  bool matchTo(AbstractStateFactor* pfac, int first_idx = 0);

  bool isMatched() const { return p_match_factor != NULL; }
  
  void clearMatch() { p_match_factor = NULL; match_factor_idx = -1; }
  
  AbstractStateFactor* getMatchingFactor() { return p_match_factor; }
  const AbstractStateFactor* getMatchingFactor() const { return p_match_factor; }
  int getMatchingFactorIdx() const { return match_factor_idx; }

  void markUsed(AbstractStateCompare* pcomp);
    
  friend std::ostream& operator<<(std::ostream& os, const ASDiffPatternElement& e);

public:
  // These are conveinance functions
  // allocates new memory. The types should be AND or OR
  // The leaf can be anything (including a element with children)
  static ASDiffPatternChildren* makePattern(ASFactorType type1,
					    ASDiffPatternElement* leaf);
  static ASDiffPatternChildren* makePattern(ASFactorType type1, ASFactorType type2,
					    ASDiffPatternElement* leaf);
  
protected:
  void setMatchingFactorIdx(int idx) { match_factor_idx = idx; }
  
  // When this is called, it is already verified that the types are correct, so a
  // downcast is safe. If any extra checking needs to be done (like checking children)
  // it should be done here
  // If this returns true, match_factor_idx must be set!
  virtual bool tryMatch(AbstractStateFactor* pfac, int first_idx) = 0;

  virtual void protClearMatch() = 0;

  virtual void protMarkUsed(AbstractStateCompare* pcomp) = 0;

  // extra information to be printed, after the standard factor type and such
  virtual void printExtra(std::ostream& os) const = 0;
  
private:
  ASFactorType factor_type;
  // whenever these next two elements are set, the type guarunteed to be factor_type,
  // so a downcast is safe
  AbstractStateFactor* p_match_factor;
  int match_factor_idx;
  
};

/***********************************************************************/

/* Note that this element will only allow each factor element to match once */
class ASDiffPatternChildren
  : public ASDiffPatternElement
{
public:
  typedef std::vector<ASDiffPatternElement*> ChildrenStorage;
  
public:
  ASDiffPatternChildren(ASFactorType factor_type);
  ~ASDiffPatternChildren();

  void addChild(ASDiffPatternElement* p) { children.push_back(p); clearMatch(); }
  
protected:

  // When this is called, it is already verified that the types are correct, so a
  // downcast is safe. If any extra checking needs to be done (like checking children)
  // it should be done here
  // If this returns true, match_factor_idx must be set!
  bool tryMatch(AbstractStateFactor* pfac, int first_idx);

  void protClearMatch();

  void protMarkUsed(AbstractStateCompare* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const;

private:
  ChildrenStorage children;
};

/***********************************************************************/

class ASDiffPatternSimple
  : public ASDiffPatternElement
{
public:
  ASDiffPatternSimple(ASFactorType factor_type)
    : ASDiffPatternElement(factor_type) {}
  ~ASDiffPatternSimple() {}

protected:

  // When this is called, it is already verified that the types are correct, so a
  // downcast is safe. If any extra checking needs to be done (like checking children)
  // it should be done here
  // If this returns true, match_factor_idx must be set!
  bool tryMatch(AbstractStateFactor* pfac, int first_idx);

  void protClearMatch() {}

  void protMarkUsed(AbstractStateCompare* pcomp) {}
  
  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const {}

};

/***********************************************************************/
/***********************************************************************/
/***********************************************************************/

// This class represents one particular feature of an AbstractStateCompare
// (a pair of AbstractState)
class ASDiffFeature
{
public:
  static ASDiffFeature* createFromStream(std::istream& is);
  
public:

  enum StateSelection {
    SS_First,
    SS_Second,
    SS_Both,
    SS_Neither,
    SS_MAX
  };
  friend std::ostream& operator<<(std::ostream& os, const StateSelection& s);
  
public:
  // This class will take over the pattern memory
  ASDiffFeature(ASDiffPatternElement* ppat,
		AbstractStateCompare::CheckUnusedPolicy policy = AbstractStateCompare::CP_AllowAll);
  ASDiffFeature(AbstractStateCompare::CheckUnusedPolicy policy = AbstractStateCompare::CP_AllowAll);
  virtual ~ASDiffFeature();

  const ASDiffPatternElement* getPattern() const { return ppat; }
  ASDiffPatternElement* getPattern() { return ppat; }
  void setPattern(ASDiffPatternElement* p)
  { if (ppat) delete ppat; ppat = p; need_pattern_match = true; }

  bool match(AbstractStateCompare* pcomp);
  
  friend std::ostream& operator<<(std::ostream& os, const ASDiffFeature& e);

protected:
  // This is called after the pattern match is confirmed to match
  // (and notably, when the pattern has the associated factors
  //  and indeices stored)
  // This function does NOT need to mark the indices used in
  // AbstractStateCompare. The base match function will do that
  // if this function returns true
  virtual bool protMatch(AbstractStateCompare* pcomp) = 0;

  // extra information to be printed, after the standard factor type and such
  virtual void printExtra(std::ostream& os) const = 0;

  AbstractStateCompare::CheckUnusedPolicy getCheckUnusedPolicy() const
  { return check_unused_policy; }
  void setCheckUnusedPolicy(AbstractStateCompare::CheckUnusedPolicy p)
  { check_unused_policy = p; }
  
private:
  // This function is generated from asdiff_features.names
  static ASDiffFeature* resolveName(std::string name);
  
  ASDiffPatternElement* ppat;
  AbstractStateCompare::CheckUnusedPolicy check_unused_policy;

  bool need_pattern_match;
};


/***********************************************************************/

/* This is an abstract class for features that depend on only a single state
   It includes a selector for which state to look at */
class ASDiffFeatureSingleState
  : public ASDiffFeature
{
public:
  ASDiffFeatureSingleState(AbstractStateCompare::CheckUnusedPolicy policy, StateSelection state)
    : ASDiffFeature(policy), state(state) {}
  ASDiffFeatureSingleState(StateSelection state)
    : ASDiffFeature(), state(state) {}
  ~ASDiffFeatureSingleState() {}

  StateSelection getStateSel() const { return state; }
  
protected:
  bool protMatch(AbstractStateCompare* pcomp);

  virtual bool singleStateMatch(AbstractState* pstate) = 0;
  
private:
  StateSelection state; // 0 or 1 for first or second state
};

/***********************************************************************/
/* returns true if this is a transition to a free kick for (my_side ? us : them)
   from a play_on mode */
class ASDiffFeatureConstant
  : public ASDiffFeature
{
public:
  // my_side is true if we want this to find where we have the free kick
  ASDiffFeatureConstant(bool val)
    : ASDiffFeature(new ASDiffPatternSimple(ASF_Or)), val(val) {}
  ~ASDiffFeatureConstant() {}

protected:
  bool protMatch(AbstractStateCompare* pcomp) { return val; }

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[Constant(val=" << val << ")]"; }

private:
  bool val;
};

/***********************************************************************/

class ASDiffFeatureGoal
  : public ASDiffFeatureSingleState
{
public:
  // my_side is true if we want this to find when we score
  ASDiffFeatureGoal(StateSelection state, bool my_side);
  ~ASDiffFeatureGoal() {}

protected:
  bool singleStateMatch(AbstractState* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[GoalScore(state=" << getStateSel() << ", ours=" << my_side << ")]"; }

private:
  ASDiffPatternSimple* pGoal;
  bool my_side;
};

/***********************************************************************/

class ASDiffFeatureBallOwner
  : public ASDiffFeatureSingleState
{
public:
  ASDiffFeatureBallOwner(StateSelection state, RelativeTeamSide side);
  ~ASDiffFeatureBallOwner() {}

protected:
  bool singleStateMatch(AbstractState* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[BallOwner(state=" << getStateSel() << ", idx=" << side << ")]"; }

private:
  ASDiffPatternSimple* pBallOwner;
  RelativeTeamSide side;
};

/***********************************************************************/

/* returns true if this is a transition to a free kick for (my_side ? us : them)
   from a play_on mode */
class ASDiffFeatureFreeKickTran
  : public ASDiffFeature
{
public:
  // my_side is true if we want this to find where we have the free kick
  ASDiffFeatureFreeKickTran(bool my_side);
  ~ASDiffFeatureFreeKickTran() {}

protected:
  bool protMatch(AbstractStateCompare* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[FreeKickTran(ours=" << my_side << ")]"; }

private:
  ASDiffPatternSimple* pDeadBall;
  bool my_side;
};

/***********************************************************************/

class ASDiffFeatureFreeKick
  : public ASDiffFeatureSingleState
{
public:
  // The state referes to whether to check the first or the second state
  ASDiffFeatureFreeKick(StateSelection state, RelativeTeamSide side);
  ~ASDiffFeatureFreeKick() {}

protected:
  bool singleStateMatch(AbstractState* pstate);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[FreeKick(side=" << side << ", state=" << getStateSel() << ")]"; }

private:
  bool doesMatch(ASDiffPatternSimple* ppat, AbstractState* pstate);
  
  ASDiffPatternSimple* pDeadBall;
  RelativeTeamSide side;
};

/***********************************************************************/

class ASDiffFeatureSelfTran
  : public ASDiffFeature
{
public:
  // my_side is true if we want this to find where we have the free kick
  ASDiffFeatureSelfTran();
  ~ASDiffFeatureSelfTran() {}

protected:
  bool protMatch(AbstractStateCompare* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[SelfTran]"; }

private:
};

/***********************************************************************/
class BallGridFactor; //defined in AbstractStateElements

// This is not an ASDiffFeature, but a helper class for anything that wants to
// look at the ball
// Note that this does NOT dealloc it's own memory, letting whoever use constructPattern
// to do it
class BallPosDecoder
{
public:
  BallPosDecoder();
  ~BallPosDecoder() {}

  //allocates memory for this
  // You then have to call a pattern match yourself
  // This class stores info from this pattern, so once you do the match
  // it has all the info it needs
  ASDiffPatternChildren* constructPattern();
  
  //ppfac and pgrid_idx are return params
  // returns whether either factor is valid
  bool getValidFactor(AbstractState* pstate, BallGridFactor** ppfac, int *pgrid_idx);
  bool getGridCoords(AbstractState* pstate, int* px, int* py);

  // This is somewhat of a hack. The users of this class really shouldn't have to know
  // that there are two possible ball grid factors. IF YOU CHANGE THIS, make sure you look
  // at SoccerMDP.C: SoccerMDPActionBallMovement
  ASDiffPatternSimple* getFirstPattern() { return pBallGrid1; }
  ASDiffPatternSimple* getSecondPattern() { return pBallGrid2; }
  
private:
  bool isValid(ASDiffPatternSimple*, AbstractState* pstate,
	       BallGridFactor** ppfac, int *pgrid_idx);

  // We have two because the the ball grid appears with the dead ball factor and
  // the main play on factor
  ASDiffPatternSimple* pBallGrid1;
  ASDiffPatternSimple* pBallGrid2;
};

/***********************************************************************/

class ASDiffFeatureBallInRect
  : public ASDiffFeatureSingleState
{
public:
  // If the ball could be anywhere in that rectangle, this matches
  // THe state referes to whether to check the first or the second state
  ASDiffFeatureBallInRect(StateSelection state, const Rectangle& rect);
  ~ASDiffFeatureBallInRect() {}

protected:
  bool singleStateMatch(AbstractState* pstate);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[BallInRect(rect=" << rect << ", state=" << getStateSel() << ")]"; }

private:
  BallPosDecoder decoder;
  Rectangle rect;
};

/***********************************************************************/

class ASDiffFeatureBallAtPenCorner
  : public ASDiffFeatureSingleState
{
public:
  // If the ball could be anywhere in that rectangle, this matches
  // THe state referes to whether to check the first or the second state
  ASDiffFeatureBallAtPenCorner(StateSelection state, RelativeTeamSide side);
  ~ASDiffFeatureBallAtPenCorner() {}

protected:
  bool singleStateMatch(AbstractState* pstate);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[BallAtPenCorner(side=" << side << ", state=" << getStateSel() << ")]"; }

private:
  BallPosDecoder decoder;
  RelativeTeamSide side;
};


/***********************************************************************/

// Feature about how far the ball moved
class ASDiffFeatureBallMoveRange
  : public ASDiffFeature
{
public:
  // all min/max are inclusive
  ASDiffFeatureBallMoveRange(int max_x, int max_y,
			     int min_manhattan, int max_manhattan,
			     int min_on_max_dim);
  ~ASDiffFeatureBallMoveRange() {}

protected:
  bool protMatch(AbstractStateCompare* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const;

private:
  BallPosDecoder decoder;
  int max_x;
  int max_y;
  int min_manhattan;
  int max_manhattan;
  int min_on_max_dim;
};

/***********************************************************************/

// Feature about whethe the ball moved in the forward or backward direction
class ASDiffFeatureBallDir
  : public ASDiffFeature
{
public:
  // all min/max are inclusive
  ASDiffFeatureBallDir(RelativeTeamSide side);
  ~ASDiffFeatureBallDir() {}

protected:
  bool protMatch(AbstractStateCompare* pcomp);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const
  { os << "[BallDir(side=" << side << ")]"; }

private:
  BallPosDecoder decoder;
  RelativeTeamSide side;
};

/***********************************************************************/
/* This is actually pretty ugly; I use a bitmask on the player occupancy
   set. This works only because the set I currently use is binary
   The bitmask could still make sense as "a region has at least one player in it"
   but that would require a bit more hacking */
class ASDiffFeaturePlayerRegFilledSet
  : public ASDiffFeatureSingleState
{
public:
  // The state referes to whether to check the first or the second state
  ASDiffFeaturePlayerRegFilledSet(StateSelection state,
			       unsigned bitmask_filled, unsigned bitmask_empty);
  ~ASDiffFeaturePlayerRegFilledSet() {}

protected:
  bool singleStateMatch(AbstractState* pstate);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const;

private:
  ASDiffPatternSimple* pPlayerOcc;
  unsigned bitmask_filled;
  unsigned bitmask_empty;
};

/***********************************************************************/
/* This is actually pretty ugly; I use a bitmask on the player occupancy
   set. This works only because the set I currently use is binary
   The bitmask could still make sense as "a region has at least one player in it"
   but that would require a bit more hacking */
/* This differs from the Set Variant because it expects an AND factor with a bunch
   of POElements, where the Set expects one POSetFactor */
class ASDiffFeaturePlayerRegFilledElem
  : public ASDiffFeatureSingleState
{
public:
  // The state referes to whether to check the first or the second state
  ASDiffFeaturePlayerRegFilledElem(StateSelection state,
				   int num_elements,
				   unsigned bitmask_filled, unsigned bitmask_empty);
  ~ASDiffFeaturePlayerRegFilledElem() {}

protected:
  bool singleStateMatch(AbstractState* pstate);

  // extra information to be printed, after the standard factor type and such
  void printExtra(std::ostream& os) const;

private:
  typedef std::vector<ASDiffPatternSimple*> PatternStorage;
  
  PatternStorage patterns;
  unsigned bitmask_filled;
  unsigned bitmask_empty;
};


/***********************************************************************/
/***********************************************************************/
/***********************************************************************/

// This class represents a classification of the difference from one state
// to another. I don't think it should be subclassed
// It represents a logical expression in DNF format
class ASDiffClassifier
{
public:
  ASDiffClassifier(const char* name = NULL);
  ~ASDiffClassifier();

  void setName(const char* n) { name = n; }
  const std::string& getName() const { return name; }

  void clear() { name.clear(); terms.clear(); }
  
  bool match(AbstractStateCompare* pcomp);
  
  friend std::ostream& operator<<(std::ostream& os, const ASDiffClassifier& c);
  friend std::istream& operator>>(std::istream& is, ASDiffClassifier& c);

private:
  // You have to be VERY careful with this; It does not free it's own memory
  class FeatureAtom
  {
  public:
    FeatureAtom(bool not_flag, ASDiffFeature* pfeature)
      : not_flag(not_flag), pfeature(pfeature) {}
    ~FeatureAtom() { /* do NOT delete */ }

    void deleteFeature() { if (pfeature) delete pfeature; pfeature = NULL; }

    bool match(AbstractStateCompare* pcomp)
    { return not_flag ^ pfeature->match(pcomp); }

    friend std::ostream& operator<<(std::ostream& os, const FeatureAtom& a)
    { os << (a.not_flag ? "!" : "") << *a.pfeature; return os; }
    
  private:
    bool not_flag;
    ASDiffFeature* pfeature;
  };
  // the AND of these features (logical atoms)
  typedef std::vector<FeatureAtom> FeatureStorage;
  // the OR of these terms
  typedef std::vector<FeatureStorage> TermStorage;
  
  std::string name;
  TermStorage terms;
};

/***********************************************************************/

// represents an ordered set of classifications
class ASDiffClassifierSet
{
public:
  ASDiffClassifierSet();
  ~ASDiffClassifierSet();

  void addClassifier(ASDiffClassifier* p) { classifiers.push_back(p); }
  ASDiffClassifier* getClassifier(int idx);
  const ASDiffClassifier* getClassifier(int idx) const;
  int getNumClassifiers() const { return classifiers.size(); }
  int lookupClassifier(const char* name) const;
  
  void clear();

  // returns -1 if nothing matches
  // returns the idx (which can be passed to getClassifier)
  int classify(AbstractStateCompare* pcomp);

  friend std::ostream& operator<<(std::ostream& os, const ASDiffClassifierSet& s);
  // see ASDCFileReader for the input operation
  
private:
  typedef std::vector<ASDiffClassifier*> ClassifierStorage;
  
  ClassifierStorage classifiers;
};

/***********************************************************************/
class ASDCSetFileReader
  : public spades::FileReader
{
public:
  ASDCSetFileReader(ASDiffClassifierSet* pset) : FileReader(), pset(pset) {}
  ~ASDCSetFileReader() {}

  // default ourselve to version 1.0
  void readFile (const char* path) { FileReader::readFile(path, 1.0); }

protected:
  bool processLine(std::istrstream& line,
		   const char* fileid,
		   const char* path,
		   float version);
  
private:
  ASDiffClassifierSet* pset;
};

/***********************************************************************/
// This class stores a classification of state -> state
class ASDiffClassificationStore
{
public:
  // does NOT take over the memory
  ASDiffClassificationStore( ASDiffClassifierSet* p);
  ~ASDiffClassificationStore();

  ASDiffClassifierSet* getSet() { return pset; }
  const ASDiffClassifierSet* getSet() const { return pset; }
  void setSet(ASDiffClassifierSet* p) { pset = p; resize(); }
  
  // Stores this classification
  void add(AbstractStateCompare* pcomp, int classification);
  void add(int first_state, int second_state, int classification);

  void clear();

  void classifyMarkovChainTransitions(const MarkovChain& mc,
				      AbstractStateDescription* pdesc,
				      int progress_interval);
  
  void writeSummary(std::ostream& os) const;
  void writeFull(std::ostream& os) const;
  
private:
  void resize();
  
  typedef std::list< std::pair<int, int> > StateStateList;
  typedef std::vector< StateStateList > Classifications;

  ASDiffClassifierSet* pset;
  // the index one past the last valid from pset will be for unmatched ones
  Classifications stored_class;
};

#endif
