/* -*- Mode: C++ -*- */

/* These class handle conversion from a MarkovChain to a Markov Decision Process
   This is failrly soccer specific */

#ifndef _MDP_CONVERSION_H_
#define _MDP_CONVERSION_H_

#include <vector>
#include <list>
#include "SoccerMDP.h"
#include "MarkovChain.h"
#include "AbstractStateDifference.h"

/*****************************************************************************/

class TransitionSorter;
class MDPConvTranInfo;
class MDPConvActionInfo;
class MDPConverter;

typedef std::list<MDPConvActionInfo*> ActionList;

/*****************************************************************************/

//Used to sort transitions to eventually become actions
class TransitionSorter
{
public:
  class Element
  {
  public:
    class ClassElement
    {
    public:
      ClassElement(const std::string& name = "", int idx = -1)
	: name(name), idx(idx) {}
      ~ClassElement() {}

      const std::string& getName() const { return name; }
      int getIdx() const { return idx; }

      bool lookupIndex(const ASDiffClassifierSet& cset) const;

      bool operator<(const ClassElement& e) const { return name < e.name; }
	
      friend std::ostream& operator<<(std::ostream& os, const ClassElement& e)
      { os << e.name; return os; }
      friend std::istream& operator>>(std::istream& is, ClassElement& e)
      { is >> e.name; return is; }
	
    private:
      std::string name;
      // mutable here means that this can change in const function
      // We do this so that we can put ClassElement into a set (which
      // only has const iterators); Changing idx will NOT change the order
      // in the set since that is based only on name (see op< above)
      mutable int idx;
    };

    typedef std::set< ClassElement > ClassStorage;

  public:
    Element();
    ~Element();

    void clear();
      
    void lookupIndices(const ASDiffClassifierSet& cset);

    bool isPrimary(int class_idx, const ASDiffClassifierSet& cset)
    { return isInClasses(true, class_idx, cset); }
    bool isSecondary(int class_idx, const ASDiffClassifierSet& cset)
    { return isInClasses(false, class_idx, cset); }
    bool isInClasses(bool primary, int class_idx, const ASDiffClassifierSet& cset);
    
    void createActions(AbstractStateDescription* pdesc,
		       const ASDiffClassifierSet& cset,
		       int current_state_idx,
		       std::vector<MDPConvTranInfo*>& vtran,
		       ActionList& vact);
      
    friend std::ostream& operator<<(std::ostream& os, const Element& e);
    friend std::istream& operator>>(std::istream& is, Element& e);
      
  private:
    void lookupIndices(const ASDiffClassifierSet& cset, ClassStorage* pclass);
      
    friend std::istream& operator>>(std::istream& is, ClassStorage& s);
      
    std::string act_name; // the name of the SoccerMDPAction to which this corresponds
    ClassStorage primary_classes;
    ClassStorage secondary_classes;
  };

public:
  TransitionSorter(ASDiffClassifierSet* pcset = NULL,
		   AbstractStateDescription* pdesc = NULL);
  ~TransitionSorter();

  void clear();
  void clearElements() { elements.clear(); allowed_to_null_classes.clear(); }
  
  // does NOT take over the memory
  void setClassifierSet(ASDiffClassifierSet* p) { pcset = p; }
  // does NOT take over this memory
  void setStateDescription(AbstractStateDescription* p) { pdesc = p; }

  void performClassLookups();
  
  void createActions(int current_state_idx,
		     std::vector<MDPConvTranInfo*>& vtran,
		     ActionList& vact);
      
  friend std::ostream& operator<<(std::ostream& os, const TransitionSorter& e);
  friend std::istream& operator>>(std::istream& is, TransitionSorter& e);
      
private:
  typedef std::vector<Element> ElementStorage;
  
  ElementStorage elements;
  ASDiffClassifierSet* pcset;  // does NOT take over the memory
  AbstractStateDescription* pdesc; // does NOT take over this memory
  // This is a list of the classes which we attach to the Null action
  // if they don't transition to something else
  Element::ClassStorage allowed_to_null_classes; 
};


/*****************************************************************************/

// This is a helper class decribing a transition
class MDPConvTranInfo
{
public:
  MDPConvTranInfo();
  MDPConvTranInfo(int nextstate, double prob);
  ~MDPConvTranInfo();

  int getNextState() const       { return nextstate; }
  double getOriginalProb() const { return prob; }
  int getNumReplications() const { return num_replications; }
  int getClassIdx() const        { return class_idx; }
  
  void incrReplicationCount() { num_replications++; }
  void decrReplicationCount() { num_replications--; }
  
  bool convert(MDP::TranInfo& tran);

  int classify(AbstractStateDescription* pdesc, ASDiffClassifierSet* pcset,
	       int original_state);
  
  friend std::ostream& operator<<(std::ostream& os, const MDPConvTranInfo& t);

private:
  int nextstate;
  double prob;
  int num_replications;
  int class_idx; // the index for ASDiffClassifierSet
};



/*****************************************************************************/

// This class represent an intermediary before converting to an MDP::ActionInfo
class MDPConvActionInfo
{
public:
  MDPConvActionInfo();
  ~MDPConvActionInfo();

  SoccerMDPAction* getAction() { return pact; }
  void setAction(SoccerMDPAction* p) { pact = p; }
  void forgetAction() { pact = NULL; }

  //does NOT take over the memory
  void addTran(MDPConvTranInfo* t) { trans.push_back(t); }

  bool tryUnify(MDPConvActionInfo* actinfo);
  
  // we don't copy the MDPAction! You may want to forget the pointer!
  bool convert(MDP::ActionInfo& actinfo);

  friend std::ostream& operator<<(std::ostream& os, const MDPConvActionInfo& a);

public:
  // return number removed
  static int unifyList(ActionList& lact);
  
private:
  void unifyTransitions(MDPConvActionInfo* pactinfo);
  
  // we have to use pointer because we always want to refer to
  // the same MDPConvTranInfo object
  // we are NOT in charge of the memory
  typedef std::vector<MDPConvTranInfo*> TranStorage;

  SoccerMDPAction* pact;
  TranStorage trans;
};

/*****************************************************************************/

class MDPConverter
{
public:
  MDPConverter(TransitionSorter* psorter) : psorter(psorter), pmdp(NULL) {}
  ~MDPConverter();

  TransitionSorter* getSorter() { return psorter; }
  void setSorter(TransitionSorter* p) { if (psorter) delete psorter; psorter = p; }

  SoccerMDP* getMDP() { return pmdp; }
  // the caller takes over the memory! We forget abot the mdp now!
  SoccerMDP* releaseMDP() { SoccerMDP* p = pmdp; pmdp = NULL; return p; }

  bool convertMC(MarkovChain* pmc, int progress_interval = -1);
  
private:
  TransitionSorter* psorter;
  SoccerMDP* pmdp;

};

#endif
