/* a simple class to do table based Q-learning.
   Assumes unsigned ints for state and action space */

#include <cmath>
#include <iomanip>
#include "QTable.h"
#include "misc.h"
#include "MDP.h"
#include "Logger.h"

using namespace spades;
using namespace std;

/****************************************************************************************/
static const char* EXPLORE_STRINGS[] = {"none", "max", "random", "alpha", "boltzmann"};

explore_t
parseExplore(const char** ps)
{
  for (int i=0; i<NUM_EXPLORE_T; i++)
    {
      if (strncmp(*ps, EXPLORE_STRINGS[i], strlen(EXPLORE_STRINGS[i])) == 0)
	{
	  *ps += strlen(EXPLORE_STRINGS[i]);
	  return (explore_t)i;
	}
    }
  return EXP_None;
}

const char*
getExploreString(const explore_t& exp)
{
  if (exp < 0 || exp >= NUM_EXPLORE_T)
    errorlog << "Trying to get explore move out of range! " << (int)exp << ende;
  return EXPLORE_STRINGS[(int)exp];
}


/****************************************************************************************/

const int QTable::StateActionEntry::DISABLE_VALUE = -1;

/****************************************************************************************/

QTable::QTable(int num_states, float gamma)
  : num_states(num_states), gamma(gamma)
{
}


QTable::~QTable()
{
}

/****************************************************************************************/

bool
QTable::observedUpdate(int last_state, int action, int next_state, int reward, bool det)
{
  if (!checkValid(last_state, action))
    {
      errorlog << "QTable: last_state/action out of range: " << last_state << ' ' << action << ende;
      return false;
    }

  if (!checkValid(next_state, -1))
    {
      errorlog << "QTable: next_state out of range: " << next_state << ende;
      return false;
    }
  
  StateActionEntry* psa = getSA(last_state, action);
  float alpha = det ? 1.0 : psa->calcUpdateAlpha();
  int old_best_action, new_best_action;
  float thisQVal;
    
  if (!psa->isEnabled())
    {
      errorlog << "QTable::observedUpdate requesting a disabled action: "
	       << last_state << ' ' << action << ende;
      return false;
    }

  getV(last_state, &old_best_action);
  thisQVal = reward + gamma * getV(next_state);
  /* we only increment visits if the value we are putting in is signifigantly different from
     the value already there. his prevents 'wasting' good visits because values haven't been backed
     up far enough */
  if (fabs(psa->Q - thisQVal) > .000001)
    psa->visits++;
  psa->Q = (1.0 - alpha) * psa->Q + alpha * (thisQVal);

  getV(last_state, &new_best_action);

  actionlog(210) << "observedUpdate: policy change "
		 << ((old_best_action != new_best_action) ? "true" : "false") << ' '
		 << last_state << ' '
		 << action << ' '
		 << psa->Q << ' '
		 << old_best_action << ' '
		 << new_best_action << ' '
		 << "; " << psa->visits
		 << ende;

  return old_best_action != new_best_action;
}

/****************************************************************************************/

//returns the number of Q values updated
//does an in place in order update
int
QTable::mdpDPUpdate(const MDP& mdp, int progress_interval, double* ptotal_change)
{
  int num_updates = 0;
  double total_change = 0;
  
  for (int sidx = 0; sidx < num_states; sidx++)
    {
      if (progress_interval >= 0 && sidx % progress_interval == 0)
	std::cout << '.' << std::flush;
      
      int num_actions = getNumActions(sidx);
      for (int aidx = 0; aidx < num_actions; aidx++)
	{
	  if (!isEnabled(sidx, aidx))
	    continue;

	  const MDP::ActionInfo* ainfo = mdp.getAction(sidx,aidx);
	  StateActionEntry* pentry = getSA(sidx, aidx);
	  if (pentry == NULL)
	    {
	      errorlog << "mdpDPUpdate: pentry is NULL for " << sidx << ' ' << aidx << ende;
	      continue;
	    }
	  
	  if (ainfo->getTransitions().empty())
	    continue;
      
	  //calc new Q
	  float newQ = 0;
	  for (MDP::ActionInfo::TranStorage::const_iterator iter = ainfo->getTransitions().begin();
	       iter != ainfo->getTransitions().end();
	       iter++)
	    {
	      newQ += iter->getProb() * (mdp.getRewardForState(iter->getNextState()) + gamma * getV(iter->getNextState()));
	    }

	  //we consider it an update if the value changes by more than 0.0001%
	  double change = fabs(pentry->Q - newQ);
	  if (change > (fabs(newQ) * .000001))
	    num_updates++;
	  total_change += change;
	  
	  pentry->Q = newQ;
	}
    }

  if (ptotal_change)
    *ptotal_change = total_change;
  
  return num_updates;
}

/****************************************************************************************/

float
QTable::getV(int state, int* paction) const
{
  int max_act = -1;
  if (paction)
    *paction = -1;
  //SMURF: I would like to use -infinity here, but I can't find a float constant
  // defined on several architectures.
  float max_val = -1e9;
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getV of an invalid state " << state << ende;
      return max_val;
    }
  else if (num_actions == 0)
    {
      actionlog(40) << "QTable: tried to getV of a state with no actions " << state << ende;
      return 0;
    }
  
  for (int act=0; act < num_actions; act++)
    {
      const StateActionEntry* psa = getSA(state, act);
      if (!psa->isEnabled())
	continue;
      float val = psa->Q;
      if (val > max_val)
	{
	  max_act = act;
	  max_val = val;
	}
    }

  if (paction)
    *paction = max_act;
  return max_val;
}

/****************************************************************************************/

float
QTable::getQ(int state, int action) const
{
  const StateActionEntry* psa = getSA(state, action);
  if (psa == NULL)
    {
      errorlog << "QTable::getQ: invalid state/action " << state << " " << action << ende;
      return 0.0;
    }
  
  return psa->Q;
}

/****************************************************************************************/
float
QTable::getQRange(int state) const
{
  //SMURF: I would like to use infinity here, but I can't find a float constant
  // defined on several architectures.
  float min = 1e9;
  float max = -1e9;
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getQRange of an invalid state " << state << ende;
      return -1.0;
    }
  else if (num_actions == 0)
    {
      return -1.0;
    }
  
  for (int act=0; act < num_actions; act++)
    {
      const StateActionEntry* psa = getSA(state, act);
      if (!psa->isEnabled())
	continue;
      if (psa->Q < min)
	min = psa->Q;
      if (psa->Q > max)
	max = psa->Q;
    }

  return max - min;
}

/****************************************************************************************/
float
QTable::getWorstQ(int state, int* paction) const
{
  int min_act = -1;
  if (paction)
    *paction = -1;
  //SMURF: I would like to use -infinity here, but I can't find a float constant
  // defined on several architectures.
  float min_val = 1e9;
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getWorstQ of an invalid state " << state << ende;
      return min_val;
    }
  else if (num_actions == 0)
    {
      actionlog(40) << "QTable: tried to getWorstQ of a state with no actions " << state << ende;
      return 0;
    }
  
  for (int act=0; act < num_actions; act++)
    {
      const StateActionEntry* psa = getSA(state, act);
      if (!psa->isEnabled())
	continue;
      float val = psa->Q;
      if (val < min_val)
	{
	  min_act = act;
	  min_val = val;
	}
    }

  if (paction)
    *paction = min_act;
  return min_val;
}

/****************************************************************************************/


int
QTable::getActMax(int state) const
{
  int move;
  getV(state, &move);
  return move;
}

// if several max, choose randomly between them
int
QTable::getActMaxRandom(int state) const
{
  float value;
  vector<int> vAct;
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getActMaxRandom of an invalid state " << state << ende;
      return -1;
    }

  value = getV(state);
  for (int a=0; a < num_actions; a++)
    {
      const StateActionEntry* psa = getSA(state, a);
      if (!psa->isEnabled())
	continue;
      if (fabs(value - psa->Q) < .0001)
	vAct.push_back(a);
    }

  if (vAct.empty())
    {
      actionlog(210) << "QTable::getActMaxRandom: vAct is empty " << state << ende;
      return -1;
    }

  actionlog(220) << "getActMaxRandom: state=" << state << " has " << vAct.size()
		 << " opt actions of value " << value << ": ";
#ifndef NO_ACTION_LOG
    for (vector<int>::iterator iter = vAct.begin();
	 iter != vAct.end();
	 iter++)
      actionlog(220) << *iter << ' ';
#endif    
  actionlog(220) << ende;
  
  return (vAct[int_random(vAct.size())]);
}



int
QTable::getActRandom(int state) const
{
  int num_valid_acts = 0;
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getActMaxRandom of an invalid state " << state << ende;
      return -1;
    }

  for (int a=0; a < num_actions; a++)
    {
      if (getSA(state, a)->isEnabled())
	num_valid_acts++;
    }

  if (num_valid_acts == 0)
    {
      errorlog << "QTable::getActRandom: No valid actions for " << state << ende;
      return -1;
    }

  int act = int_random(num_valid_acts);
  for (int a=0; a < num_actions; a++)
    {
      if (!getSA(state, a)->isEnabled())
	continue;
      if (act-- == 0)
	return a;
    }
  errorlog << "I should never get here" << ende;
  return -1;
}

  
// with prob explore_prob, chooses a random action, and with prob (1-explore_prob, chooses max)
int
QTable::getActAlphaExplore(int state, float explore_prob) const
{
  if (range_random(0, 1.0) < explore_prob)
    return getActRandom(state);
  else
    //return getActMax(state);
    //pfr: I used to use getActMax, but that does pretty ridiculous things at the beginning,
    // like always sending X
    return getActMaxRandom(state);
}

//higher k values are more exploitive
int
QTable::getActBoltzmann(int state, float k) const
{
  float total_weight = 0.0;
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getActBoltzMann of an invalid state " << state << ende;
      return -1;
    }
  float act_weight[num_actions];

  //first calculate the weights
  actionlog(210) << "getActBoltzmann weights for " << state << ": ";
  for (int i=0; i<num_actions; i++)
    {
      const StateActionEntry* psa = getSA(state, i);
      if (psa->isEnabled())
	{
	  act_weight[i] = pow(k, psa->Q);
	  actionlog(210) << i << ':' << act_weight[i] << ' ';
	  total_weight += act_weight[i];
	}
      else
	{
	  act_weight[i] = 0.0;
	}
    }

  float r = range_random(0.0, total_weight);
  float weight = 0.0;
  for (int i=0; i<num_actions; i++)
    {
      weight += act_weight[i];
      if (r < weight)
	{
	  actionlog(210) << "; chose " << i << ende;
	  return i;
	}
    }

  errorlog << "I should never get here" << ende;
  return 0;
}

//gets the first enabled
int
QTable::getActFirst(int state) const
{
  int num_actions = getNumActions(state);
  if (num_actions == -1)
    {
      errorlog << "QTable: tried to getActFirst of an invalid state " << state << ende;
      return -1;
    }

  for (int a=0; a < num_actions; a++)
    {
      if (getSA(state, a)->isEnabled())
	return a;
    }

  errorlog << "QTable: getActFirst with no enabled actions" << ende;
  return -1;

}


int
QTable::getAct(explore_t exp, int state, float explore_prob, float k)
{
  switch (exp)
    {
    case EXP_None:
      errorlog << "getAct: I don't handle EXP_None" << ende;
      break;
    case EXP_Max:
      return getActMax(state);
    case EXP_Random:
      return getActRandom(state);
    case EXP_AlphaExplore:
      return getActAlphaExplore(state, explore_prob);
    case EXP_Boltzmann:
      return getActBoltzmann(state, k);
    default:
      errorlog << "I don't understand exp: " << exp << ende;
      break;
    }
  return 0;
}

/****************************************************************************************/

bool
QTable::changeEnableStatus(int state, int action, bool enable)
{
  StateActionEntry* psa = getSA(state, action);
  if (psa == NULL)
    {
      errorlog << "changeEnableStatus: invalid state/action " << state << " " << action << ende;
      return false;
    }
  
  return psa->changeEnableStatus(enable);
}


int
QTable::disableNonApproxMax(float perc_max)
{
  int num_disabled = 0;
  
  for (int s=0; s < num_states; s++)
    {
      float max_val = getV(s);
      float critical_value = max_val * (max_val > 0 ? perc_max : (1.0+perc_max));
      int num_actions = getNumActions(s);
      if (num_actions == -1)
	{
	  errorlog << "QTable: s is invalid? " << s << ende;
	  continue;
	}

      actionlog(100) << "disableNonApproxMax: "
		     << "state=" << s << "  "
		     << "max_val=" << max_val << "  "
		     << "critical_value=" << critical_value << "  "
		     << ende;
      for (int a=0; a < num_actions; a++)
	{
	  StateActionEntry* psa = getSA(s, a);
	  if (!psa->isEnabled())
	    continue;
	  if (psa->Q < critical_value)
	    {
	      psa->changeEnableStatus(false);
	      num_disabled++;
	    }
	}
    }
  
  return num_disabled;
}

// I didn't move ActionSubsetMap over
#ifdef OLD_CODE
int
QTable::changeEnableStatusSet(const ActionSubsetMap& m, bool enable)
{
  if (m.getNumStates() == -1)
    {
      actionlog(150) << "QTable::changeEnableStatusSet: given empty set" << ende;
      return 0;
    }
  
  if (m.getNumStates() != num_states)
    {
      errorlog << "Can't change enable status without size match! "
	       << m.getNumStates() << " == " << num_states << ende;
      return 0;
    }

  int count = 0;

  for (int s=0; s<num_states; s++)
    {
      for (ActionSubsetMap::ActionSet::const_iterator iter = m.getVActions()[s].begin();
	   iter != m.getVActions()[s].end();
	   iter++)
	{
	  changeEnableStatus(s, *iter, enable);
	}
    }
  
  return count;
}
#endif

int
QTable::changeEnableStatusAll(bool enable)
{
  int count = 0;
  for (int s= num_states-1; s>=0; s--)
    {
      int num_actions = getNumActions(s);
      if (num_actions == -1)
	{
	  errorlog << "QTable: s is invalid? " << s << ende;
	  continue;
	}
      for (int a=0; a < num_actions; a++)
	{
	  StateActionEntry* psa = getSA(s, a);
	  if (psa->changeEnableStatus(enable))
	    count++;
	}
    }
  return count;
}

bool
QTable::isEnabled(int state, int action) const
{
  const StateActionEntry* psa = getSA(state, action);
  if (psa == NULL)
    {
      errorlog << "isEnabled: invalid state/action " << state << " " << action << ende;
      return false;
    }
  
  return psa->isEnabled();
}

/****************************************************************************************/

int
QTable::computeStateDiff(const QTable* pQT)
{
  int diff_count = 0;
  for (int s=0; s<num_states; s++)
    {
      int a;
      getV(s, &a);
      if (!pQT->isEnabled(s,a))
	diff_count++;
    }

  return diff_count;
}


/****************************************************************************************/
bool
QTable::isEqualTo(QTable& qt, std::ostream* perrout)
{
  if (getNumStates() != qt.getNumStates())
    {
      if (perrout) *perrout << "Num states differ: "
			    << getNumStates() << "\t" << qt.getNumStates()
			    << std::endl;
      return false;
    }
  
  if (gamma != qt.gamma)
    {
      if (perrout) *perrout << "Gamma differs: "
			    << gamma << "\t" << qt.gamma
			    << std::endl;
      return false;
    }
  
  for (int sidx = 0; sidx < num_states; sidx++)
    {
      int num_actions = getNumActions(sidx);
      if (num_actions != qt.getNumActions(sidx))
	{
	  if (perrout) *perrout << "Num actions differs as " << sidx << ": "
				<< num_actions << "\t" << qt.getNumActions(sidx)
				<< std::endl;
	  return false;
	}
      
      for (int aidx = 0; aidx < num_actions; aidx++)
	{
	  if (!isEnabled(sidx, aidx))
	    {
	      if (!qt.isEnabled(sidx, aidx))
		continue;
	      if (perrout) *perrout << "Enabled differs at (" << sidx << ", " << aidx << "): "
				    << false << "\t" << true
				    << std::endl;
	      return false;
	    }
	  if (!qt.isEnabled(sidx, aidx))
	    {
	      if (perrout) *perrout << "Enabled differs at (" << sidx << ", " << aidx << "): "
				    << true << "\t" << false
				    << std::endl;
	      return false;
	    }
	  
	  StateActionEntry* pthisentry = getSA(sidx, aidx);
	  StateActionEntry* pqtentry = qt.getSA(sidx, aidx);
	  if (!(*pthisentry == *pqtentry))
	    {
	      if (perrout) *perrout << "StateActionEntry differ at (" << sidx << ", " << aidx << "): "
				    << *pthisentry << "\t" << *pqtentry
				    << std::endl;
	      return false;
	    }
	}
    }

  return true;
}


/****************************************************************************************/
/****************************************************************************************/

float
QTable::StateActionEntry::calcUpdateAlpha() const
{
  return 1.0 / (1.0 + (float)visits);
}

bool
QTable::StateActionEntry::changeEnableStatus(bool enable)
{
  if (enable && !isEnabled())
    {
      visits = 0;
      return true;
    }
  else if (!enable && isEnabled())
    {
      visits = StateActionEntry::DISABLE_VALUE;
      return true;
    }
  return false;
}

