/* -*- Mode: C++ -*- */

/* These class represent a Markov Decision Process */

#ifndef _MDP_H_
#define _MDP_H_

#include <iostream>
#include <map>
#include <vector>
#include "data.h"
#include "BinaryFile.h"

class QTable; // QTable.h

/*****************************************************************************/

// a virtual base class to represent an action
class MDPAction
{
public:
  MDPAction() {}
  virtual ~MDPAction() {}

  virtual void print(std::ostream& os) const = 0;
  
  friend std::ostream& operator<<(std::ostream& os, const MDPAction& act)
  { act.print(os); return os; }

  virtual bool writeTo(BinaryFileWriter& writer) = 0;
};

/*****************************************************************************/

class MDP
{
public:
  class TranInfo
  {
  public:
    TranInfo() : nextstate(-1), prob(0) {}
    TranInfo(int nextstate, double prob)
      : nextstate(nextstate), prob(prob) {}
    ~TranInfo() {}

    int getNextState() const { return nextstate; }
    void setNextState(int s) { nextstate = s; }

    double getProb() const { return prob; }
    void setProb(double p) { prob = p; }

    friend std::ostream& operator<<(std::ostream& os, const TranInfo& t)
    { os << t.nextstate << ' ' << t.prob; return os; }
    friend std::istream& operator>>(std::istream& is, TranInfo& t)
    { is >> t.nextstate >> t.prob; return is; }
    
    bool writeTo(BinaryFileWriter& writer) const
    { return writer.writeInt(nextstate) && writer.writeFloat(prob); }
    bool readFrom(BinaryFileReader& reader)
    { return reader.readInt(&nextstate) && reader.readFloat(&prob); }
    
  private:
    int nextstate;
    double prob;
  };
  
  class ActionInfo
  {
  public:
    typedef std::vector<TranInfo> TranStorage;

  public:
    //NOte that this does NOT delete its MDPAction unless specifically told to!
    // This is so we can freely put this thing into an STL container and not
    // worry about lots of copying and bad memory stuff
    ActionInfo(MDPAction* pact = NULL);
    ~ActionInfo() {}

    MDPAction* getAction() { return pact; }
    const MDPAction* getAction() const { return pact; }
    void setAction(MDPAction* p) { deleteAction(); pact = p; }
    void deleteAction() { if (pact) delete pact; pact = NULL; }

    void addTransition( const TranInfo& t ) { trans.push_back(t); }
    
    void clear() { deleteAction(); clearTransitions(); }
    void clearTransitions() { trans.clear(); }

    const TranStorage& getTransitions() const { return trans; }
    TranStorage& getTransitions() { return trans; }
    
    double getTotalProb() const;
    
    void normalize();

    // return a number < 0 if no possible transition
    double probForTranTo(int sidx) const;

    int getTranForProb(double prob) const;

    std::vector<int> getMaxProbTran() const;
    
    // returns the index of a maximum probability transition
    int getMaxTranRandom() const;
    
    // if #tran >= min_tran adds #tran to sds
    void addTranCountData(SingleDataSummary& sds, int min_tran) const;

    friend std::ostream& operator<<(std::ostream& os, const ActionInfo& t);
    // This does NOT read an MDPAction. That should have been done already
    friend std::istream& operator>>(std::istream& is, ActionInfo& t);

    bool writeTo(BinaryFileWriter& writer) const;
    // This does NOT read an MDPAction. That should have been done already
    bool readFrom(BinaryFileReader& reader);

  private:

    MDPAction* pact;
    TranStorage trans;
  };

public:
  MDP(int num_states = 0);
  virtual ~MDP();

  int getNumStates() const { return states.size(); }
  int getNumActionsInState(int sidx) const;
  ActionInfo* getAction(int sidx, int aidx);
  const ActionInfo* getAction(int sidx, int aidx) const;
  double getRewardForState(int sidx) const;

  void setNumStates(int n) { clear(); states.resize(n); }
  void setReward(int sidx, double r) { rewards[sidx] = r; }
  void addAction(int sidx, const ActionInfo& act) { states[sidx].push_back(act); }

  double getTotalProbForAction(int sidx, int act);
  double getTotalProbForState(int sidx);
  
  void normalize();

  void clear();

  int solveByQTable(QTable& qt, int progress_interval, bool* p_early_term = NULL);
  
  virtual MDPAction* createAction(std::istream& is) = 0;
  virtual MDPAction* createAction(BinaryFileReader& reader) = 0;
  
  // Returns the indices ofa ll actions which could have led to this transition
  std::vector<int> getPossibleActsForTran(int sidx1, int sidx2) const;

  // returns the new state idx
  // randomly takes a next state
  int takeStep(int state_idx, int act_idx) const;

  // returns the new state idx
  // follows a max probability transition, breaking ties randomly
  int takeMaxStep(int state_idx, int act_idx) const;

  // for every state with >= min_act, adds the number of actions to sds
  void addActCountData(SingleDataSummary& sds, int min_act) const;
  // for every action (in every state) with >= min_tran adds the number of
  // transition to sds
  void addTranCountData(SingleDataSummary& sds, int min_tran) const;

  // returns a random state that has actions
  int getRandomValidState() const;
  
  friend std::ostream& operator<<(std::ostream& os, const MDP& m);
  friend std::istream& operator>>(std::istream& is, MDP& m);

  bool writeTo(BinaryFileWriter& writer) const;
  bool readFrom(BinaryFileReader& reader);

  bool readTextOrBinary(std::istream& is);
  
protected:
  virtual bool writeHeader(BinaryFileWriter& writer) const = 0;
  virtual bool readHeader(BinaryFileReader& reader) = 0;

private:
  typedef std::map<int, double> RewardStorage;
  typedef std::vector<ActionInfo> ActionStorage;
  typedef std::vector<ActionStorage> StateStorage;

  RewardStorage rewards;
  StateStorage states;
};


/*****************************************************************************/




#endif
