/* -*- Mode: C++ -*- */

/* a simple class to do table based Q-learning.
   Assumes unsigned ints for state and action space */

#ifndef _Q_TABLE_H_
#define _Q_TABLE_H_

#include <iostream>
#include <vector>
#include <StateValue.h>
#include "BinaryFile.h"

class MDP; //defined in MDP.hpp
// I didn't move this over
//class ActionSubsetMap; //defined in ActionSubsetMap.hpp

enum explore_t {
  EXP_None,
  EXP_Max,
  EXP_Random,
  EXP_AlphaExplore,
  EXP_Boltzmann
};
const int NUM_EXPLORE_T = 5;
explore_t parseExplore(const char** ps);
const char* getExploreString(const explore_t& exp);

/* You can mark state/actions as disabled.
   Internally, this is done by setting the visit number to -1
   all the functions will pretend that action does not exist,
   and in appropriate cases, errors will be printed if the state/action is disabled */
class QTable
{
public:
  QTable(int num_states, float gamma);
  virtual ~QTable();

  //zeros all ENABLED actions
  virtual void zero() = 0;

  int getNumStates() const { return num_states; }
  //returns -1 if state is invalid
  virtual int getNumActions(int state) const = 0;
  
  //returns true if this update causes a change in the policy for this state
  //det = true if you want a deterministic update
  bool observedUpdate(int last_state, int action, int next_state, int reward, bool det = false);

  //returns the number of Q values updated
  //does an in place in order update
  int mdpDPUpdate(const MDP& mdp, int progress_interval, double* ptotal_change);
  
  //if paction is not null, the max actions is filled in
  float getV(int state, int* paction = NULL) const;

  float getQ(int state, int action) const;

  float getQRange(int state) const;

  //if paction is not null, the action is filled in
  float getWorstQ(int state, int* paction = NULL) const;
  
  //we will have several methods of selecting actions; 
  int getActMax(int state) const;
  int getActMaxRandom(int state) const; // if several max, choose randomly between them
  int getActRandom(int state) const;
  // with prob explore_prob, chooses a random action, and with prob (1-explore_prob, chooses max)
  int getActAlphaExplore(int state, float explore_prob) const;
  //higher k values are more exploitive
  int getActBoltzmann(int state, float k) const;
  //gets the first enabled
  int getActFirst(int state) const;
  
  int getAct(explore_t exp, int state, float explore_prob, float k);
  
  float getGamma() const { return gamma; }
  void setGamma(float g) { gamma = g; }

  //Now the enable/disable functions
  //returns true if the action was previously disabled
  bool enable(int state, int action) { return changeEnableStatus(state, action, true); }
  //returns true if the action was previously enabled
  bool disable(int state, int action) { return changeEnableStatus(state, action, false); }
  // returns true if the status changes
  bool changeEnableStatus(int state, int action, bool enable);
  //This method disables all actions whose Q values are not withint perc_max of the V value
  //returns number of actions disabled
  int disableNonApproxMax(float perc_max);

  // I didn't move ActionSubsetMap over
#ifdef OLD_CODE	
  int enableSet(const ActionSubsetMap& m) { return changeEnableStatusSet(m, true); }
  int disableSet(const ActionSubsetMap& m)  { return changeEnableStatusSet(m, false); }
  //returns number of entries changed
  int changeEnableStatusSet(const ActionSubsetMap& m, bool enable);
#endif
  
  int enableAll()  { return changeEnableStatusAll(true); }
  int disableAll() { return changeEnableStatusAll(false); }
  int changeEnableStatusAll(bool enable);
  
  bool isEnabled(int state, int action) const;

  //computes the number of state for which the max action is not
  // an enabled action of pQT
  int computeStateDiff(const QTable* pQT);

  bool isEqualTo(QTable& qt, std::ostream* perrout = NULL);
  
  class MaxQValuator : public StateValue<float>::StateValuator
  {
  public:
    MaxQValuator(QTable* p, float invalid_val)
      : p(p), invalid_val(invalid_val) {}

    float getStateValue(int sidx)
    {
      int act;
      float val = p->getV(sidx, &act);
      if (act == -1)
	return invalid_val;
      return val;
    }
  private:
    QTable* p;
    float invalid_val;
  };
  
  class RangeQValuator : public StateValue<float>::StateValuator
  {
  public:
    RangeQValuator(QTable* p, float invalid_val) : p(p), invalid_val(invalid_val) {}

    float getStateValue(int sidx)
    {
      float val = p->getQRange(sidx);
      if (val < 0.0)
	return invalid_val;
      return val;
    }
  private:
    QTable* p;
    float invalid_val;
  };
  
protected:
  class StateActionEntry
  {
  public:
    StateActionEntry() : Q(0), visits(0) {}
    StateActionEntry(float Q, int visits) : Q(Q), visits(visits) {}

    static const int DISABLE_VALUE;

    bool isEnabled() const { return visits != DISABLE_VALUE; }
    bool changeEnableStatus(bool enable);
    float calcUpdateAlpha() const;

    void zero() { Q = 0.0; visits = 0; }
    
    friend std::ostream& operator<<(std::ostream& o, const StateActionEntry& e)
    { o << e.Q << ' ' << e.visits; return o; }
    friend std::istream& operator>>(std::istream& i, StateActionEntry& e)
    { i >> e.Q >> e.visits; return i; }

    bool writeTo(BinaryFileWriter& writer) const
    { return writer.writeFloat(Q) && writer.writeIntAsShort(visits); }
    bool readFrom(BinaryFileReader& reader)
    { return reader.readFloat(&Q) && reader.readShort(&visits); }

    bool operator==(const StateActionEntry& sa)
    { return Q == sa.Q && visits == sa.visits; }
      
    float Q;
    int visits;
    
  };
  
  virtual StateActionEntry* getSA(int state, int action) = 0;
  virtual const StateActionEntry* getSA(int state, int action) const = 0;
  
  //a -1 for action just verifies the state
  virtual bool checkValid(int state, int action) const = 0;

  int num_states;
  float gamma;


};



#endif
