/* -*- Mode: C++ -*- */

/* These classes provide an interface to a MDP model with a value,
   and a factored description of the state space that collected it */

#ifndef _SOCCER_MODEL_W_VALUE_H_
#define _SOCCER_MODEL_W_VALUE_H_

#include "SoccerMDP.h"
#include "AbstractState.h"
#include "QTableFlex.h"
#include "StateValue.h"
#include "StateTraceAnalyzer.h"
#include "data.h"

class AdviceTree;

/*****************************************************************************/

class SoccerStateFilter;
class SoccerActionFilter;
class SoccerModelAdviser;

/*****************************************************************************/

class SoccerModelWValue
{
public:
  SoccerModelWValue(float gamma = 0.9);
  ~SoccerModelWValue();

  void setStateDescription(AbstractStateDescription* p);
  AbstractStateDescription* getStateDescription() { return pdesc; }
  const AbstractStateDescription* getStateDescription() const { return pdesc; }
  void forgetStateDescription() { pdesc = NULL; }
  
  bool readMDPFrom(const char* fn);
  SoccerMDP& getMDP() { return mdp; }
  const SoccerMDP& getMDP() const { return mdp; }

  bool readQTableFrom(const char* fn);
  QTable& getQTable() { return qtable; }
  const QTable& getQTable() const { return qtable; }

  bool check(bool desc_valid, bool mdp_valid, bool qtable_valid);
  
  // returns number of actions in advice
  int adviseForState(int sidx,
		     std::list<SoccerActionFilter*>& lfilters,
		     SoccerModelAdviser* padviser);

  // Does not take memory
  int adviseForState(AbstractState* pstate,
		     std::list<SoccerActionFilter*>& lfilters,
		     SoccerModelAdviser* padviser);

  // returns the number of states for which advice was given
  int adviseFor(std::list<SoccerStateFilter*>& lfilters_state,
		std::list<SoccerActionFilter*>& lfilters_act,
		SoccerModelAdviser* padviser,
                int* pcount_no_act_states = NULL);

  bool doesPassActionFilters(int sidx, int aidx,
			     std::list<SoccerActionFilter*>& lfilters);
  bool doesPassStateFilters(int sidx,
			    std::list<SoccerStateFilter*>& lfilters);

  // estimates the average reward per step by taking num_steps from start state
  // repeats the process repititions times
  SingleDataSummary estimateOptAvgStepReward(int start_idx, 
					     int num_steps, int repititions,
					     std::list<SoccerActionFilter*>& lfilters_act,
					     bool show_status);
  
private:

  AbstractStateDescription* pdesc;
  SoccerMDP mdp;
  QTableFlex qtable;
  
};

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

class SoccerStateFilter
{
public:
  SoccerStateFilter() {}
  virtual ~SoccerStateFilter() {}

  virtual bool acceptState(SoccerModelWValue* pmodel, int state_idx) = 0;
};

/*****************************************************************************/

class SSFilterNumActions
  : public SoccerStateFilter
{
public:
  //min <= #actions < max
  SSFilterNumActions(int min, int max)
    : SoccerStateFilter(), min(min), max(max) {}
  ~SSFilterNumActions() {}

  bool acceptState(SoccerModelWValue* pmodel, int state_idx);
  
private:
  int min;
  int max;
};

/*****************************************************************************/

class SSFilterRand
  : public SoccerStateFilter
{
public:
  SSFilterRand(double accept_prob)
    : SoccerStateFilter(), accept_prob(accept_prob) {}
  ~SSFilterRand() {}

  bool acceptState(SoccerModelWValue* pmodel, int state_idx);
  
private:
  double accept_prob;
};

/*****************************************************************************/

class SSFilterHasCLangAct
  : public SoccerStateFilter
{
public:
  SSFilterHasCLangAct() : SoccerStateFilter() {}
  ~SSFilterHasCLangAct() {}

  bool acceptState(SoccerModelWValue* pmodel, int state_idx);
  
};

/*****************************************************************************/

class SSFilterFactorHasValue
  : public SoccerStateFilter
{
public:
  SSFilterFactorHasValue(int fac_idx, int val)
    : SoccerStateFilter(), fac_idx(fac_idx), val(val) {}
  ~SSFilterFactorHasValue() {}

  bool acceptState(SoccerModelWValue* pmodel, int state_idx);

private:
  int fac_idx;
  int val;
};

/*****************************************************************************/

template <class Value>
class SSFilterByRank
  : public SoccerStateFilter
{
public:
  //Does NOT take over the memory
  // count can be negative to accept all
  SSFilterByRank(StateValue<Value>* psv, int count, bool from_top)
    : SoccerStateFilter(), psv(psv), count(count), from_top(from_top) {}
  ~SSFilterByRank() {}

  // Have to include imple for templated thingys
  bool acceptState(SoccerModelWValue* pmodel, int state_idx)
  {
    if (count < 0)
      return true;
    if (from_top)
      return psv->getRank(state_idx) < count;
    else
      return psv->getRank(state_idx) >= psv->getNumStates() - count;
  }
  

private:
  StateValue<Value>* psv;
  int count;
  bool from_top; // whether to take the first or last ranked
};

/*****************************************************************************/

template <class Value>
class SSFilterByValue
  : public SoccerStateFilter
{
public:
  //Does NOT take over the memory
  // count can be negative to accept all
  SSFilterByValue(StateValue<Value>* psv, const Value& min, const Value& max)
    : SoccerStateFilter(), psv(psv), min(min), max(max) {}
  ~SSFilterByValue() {}

  // Have to include imple for templated thingys
  bool acceptState(SoccerModelWValue* pmodel, int state_idx)
  {
    Value v = psv->getValue(state_idx);
    return (min <= v) && (v < max);
  }

private:
  StateValue<Value>* psv;
  Value min;
  Value max;
};

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

class SoccerActionFilter
{
public:
  SoccerActionFilter() {}
  virtual ~SoccerActionFilter() {}

  virtual bool acceptAction(SoccerModelWValue* pmodel, int state_idx, int act_idx) = 0;
};

/*****************************************************************************/

class SAFilterNearOpt
  : public SoccerActionFilter
{
public:
  // if choose worst is true, we take the worst actions
  SAFilterNearOpt(double perc = 1.0, bool choose_worst = false)
    : SoccerActionFilter(), perc(perc), choose_worst(choose_worst) {}
  ~SAFilterNearOpt() {}

  bool acceptAction(SoccerModelWValue* pmodel, int state_idx, int act_idx);

private:
  double perc;
  bool choose_worst;
};

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

class SoccerModelAdviser
{
public:
  SoccerModelAdviser() {}
  virtual ~SoccerModelAdviser() {}

  virtual void beginStateAdvice(AbstractState* pstate) = 0;

  // If you want to redefine the second, then you don't need to redefine the first
  // you have to take over the memory!
  virtual void addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact) = 0;
  // returns whether a real action got output
  virtual bool addStateAdvice(AbstractState* pstate, SoccerMDPAction* pmdpact)
  {
    rcss::clang::Action* pclangact =
      pmdpact->createAction(pstate->getStateDescription(), pstate->getStateIdx());
    if (pclangact == NULL)
      {
	// No real action got produced here
	return false;
      }
    addStateAdvice(pstate, pclangact);
    return true;
  }

  virtual void endStateAdvice(AbstractState* pstate, int act_count) = 0;
};

/*****************************************************************************/

class SoccerModelAdviserNull
  : public SoccerModelAdviser
{
public:
  SoccerModelAdviserNull() : SoccerModelAdviser() {}
  ~SoccerModelAdviserNull() {}

  void beginStateAdvice(AbstractState* pstate) {}
  // you have to take over the memory!
  void addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact) { delete pact; }
  void endStateAdvice(AbstractState* pstate, int act_count) {}
};

/*****************************************************************************/

class SoccerModelAdviserFlat
  : public SoccerModelAdviser
{
public:
  SoccerModelAdviserFlat(const std::string& rule_prefix, CoachMessageQueue* pmqueue)
    : SoccerModelAdviser(),
      rule_prefix(rule_prefix),
      pmqueue(pmqueue),
      p_curr_dir(NULL)
  {}
  ~SoccerModelAdviserFlat() { if (p_curr_dir) delete p_curr_dir; }

  void beginStateAdvice(AbstractState* pstate);
  // you have to take over the memory!
  void addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact);
  void endStateAdvice(AbstractState* pstate, int act_count);

private:
  std::string rule_prefix;
  CoachMessageQueue* pmqueue;

  rcss::clang::DirComm* p_curr_dir;
};

/*****************************************************************************/

class SoccerModelAdviserTree
  : public SoccerModelAdviser
{
public:
  SoccerModelAdviserTree(AdviceTree* ptree)
    : SoccerModelAdviser(), ptree(ptree) {}
  ~SoccerModelAdviserTree() {}

  void beginStateAdvice(AbstractState* pstate) {}

  // you have to take over the memory!
  void addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact);
  // returns whether a real action got output
  bool addStateAdvice(AbstractState* pstate, SoccerMDPAction* pmdpact);

  void endStateAdvice(AbstractState* pstate, int act_count) {}

private:
  AdviceTree* ptree;
};

/*****************************************************************************/

class SoccerModelAdviserOutput
  : public SoccerModelAdviser
{
public:
  SoccerModelAdviserOutput(const SoccerModelWValue& soccer_model, std::ostream& os)
    : SoccerModelAdviser(), soccer_model(soccer_model), os(os) {}
  ~SoccerModelAdviserOutput() {}

  void beginStateAdvice(AbstractState* pstate);
  // you have to take over the memory!
  void addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact);
  void endStateAdvice(AbstractState* pstate, int act_count);
  
private:
  const SoccerModelWValue& soccer_model;
  std::ostream& os;
};

/*****************************************************************************/

class SoccerModelAdviserCountAct
  : public SoccerModelAdviser
{
public:
  typedef std::vector<int> CountStorage;

public:
  SoccerModelAdviserCountAct()
    : SoccerModelAdviser(), bucket(0, 1) {}
  ~SoccerModelAdviserCountAct() {}

  void beginStateAdvice(AbstractState* pstate) {}
  // you have to take over the memory!
  void addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact)
  { delete pact; }
  void endStateAdvice(AbstractState* pstate, int act_count);

  const IntBucket& getBucket() const { return bucket; }
  
  friend std::ostream& operator<<(std::ostream& os, const SoccerModelAdviserCountAct& sma);
  
private:
  IntBucket bucket;
};

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

//Compares what's in a SoccerModel with an observed trace
// You pass this to a StateTraceAnalyzer to make it do its thing
class SoccerModelTraceComparator
  : public StateTraceAnalyzerStrategy
{
public:
  enum TransitionClass
    {
      TC_Error, //some error occurred in classification 
      TC_CorrectAct, //did an action as allowed by the state/action filters
      TC_PossibleAct, //in an allowed state, did a disallowed action
      TC_IgnoredSelfTran, //explictly told to ignore these self tran
      TC_ImpossibleAct, //in an allowed state, did an action never seen before
      TC_FilteredStatePossibleAct, // in a filtered date, with possible action
      TC_FilteredStateImpossibleAct, // in a filtered date, with possible action
      TC_ImpossibleState // in a state never seen before
    };
  static const int NUM_TRANSITION_CLASSES = 8;
  
public:
  //CAreful here: The list is copied but the contents are not cloned 
  SoccerModelTraceComparator(SoccerModelWValue* pmodel,
			     std::list<SoccerStateFilter*>& lfilters_state,
			     std::list<SoccerActionFilter*>& lfilters_act);
  ~SoccerModelTraceComparator();

  void resetCounts();

  void printCounts(std::ostream& o, double avg_factor = 0.0);
  
public:
  //The StateTraceAnalyzerStrategy
  void startFileList(const char* fn) {}
  
  //called at the beginning of every file of state transitions
  void startFile(const char* fn) {}

  // THe main callback: inside of each file, called for each state trace element
  // Also remembers previous times and states, just for conveinance
  // -1 means invalid for last_state, last_tiem
  void handleNextState(int last_time, int last_state, int time, int state);

  //called at the end of every file analysis
  void endFile(const char* fn) {}

  //called at end of processing a list of files
  void endFileList(const char* fn) {}

private:
  TransitionClass classifyTransition(int laststate, int state);
  
  typedef std::vector<int> VTranClassCounts;

  SoccerModelWValue* pmodel;
  std::list<SoccerStateFilter*> lfilters_state;
  std::list<SoccerActionFilter*> lfilters_act;
  VTranClassCounts v_tran_class_counts;
};

std::ostream& operator<<(std::ostream& os,
			 SoccerModelTraceComparator::TransitionClass tc);

#endif
