/* -*- Mode: C++ -*- */
/* These classes represent a Markov Decision Process */

#include <numeric>
#include "MDP.h"
#include "QTable.h"
#include "utility.h"
#include "Logger.h"
#include "misc.h"

using namespace spades;

/*****************************************************************************/
MDP::ActionInfo::ActionInfo(MDPAction* pact)
  : pact(pact), trans()
{
}

struct TranProbAdder {
  double operator()(double p, const MDP::TranInfo& tran) 
  { return p + tran.getProb(); }
};
double
MDP::ActionInfo::getTotalProb() const
{
  return std::accumulate(trans.begin(), trans.end(), 0.0, TranProbAdder());
}

void
MDP::ActionInfo::normalize()
{
  double total = 0.0;
  for (TranStorage::iterator iter = trans.begin();
       iter != trans.end();
       iter++)
    total += iter->getProb();
  for (TranStorage::iterator iter = trans.begin();
       iter != trans.end();
       iter++)
    iter->setProb( iter->getProb() / total);
}

// return a number < 0 if no possible transition
double
MDP::ActionInfo::probForTranTo(int sidx) const
{
  for (TranStorage::const_iterator iter = trans.begin();
       iter != trans.end();
       iter++)
    if (iter->getNextState() == sidx)
      return iter->getProb();
  return -1.0;
}

int
MDP::ActionInfo::getTranForProb(double prob) const
{
  double curr_prob = 0.0;
  for (TranStorage::const_iterator iter = trans.begin();
       iter != trans.end();
       ++iter)
    {
      curr_prob += iter->getProb();
      if (prob < curr_prob)
	return iter - trans.begin();
    }
  errorlog << "MDP::ActionInfo::getTranForProb: Bad prob value? " << prob << ende;
  return -1;
}

std::vector<int>
MDP::ActionInfo::getMaxProbTran() const
{
  const double PROB_EPS = .00001;
  double max_prob = -1.0;
  std::vector<int> res;
  for (TranStorage::const_iterator iter = trans.begin();
       iter != trans.end();
       ++iter)
    {
      if (iter->getProb() > max_prob + PROB_EPS)
	{
	  res.clear();
	  max_prob = iter->getProb();
	}
      if (iter->getProb() > max_prob - PROB_EPS)
	{
	  res.push_back(iter - trans.begin());
	}
    }
  return res;
}

// returns the index of a maximum probability transition
int
MDP::ActionInfo::getMaxTranRandom() const
{
  std::vector<int> vmaxact = getMaxProbTran();
  return vmaxact.empty() ? -1 : vmaxact[int_random(vmaxact.size())];
}

// if #tran >= min_tran adds #tran to sds
void
MDP::ActionInfo::addTranCountData(SingleDataSummary& sds, int min_tran) const
{
  int cnt = trans.size();
  if (cnt >= min_tran)
    sds.addPoint((double)cnt);
}


std::ostream&
operator<<(std::ostream& os, const MDP::ActionInfo& t)
{
  os << *t.pact << ' ';
  os << t.trans.size() << ' ';
  std::copy(t.trans.begin(), t.trans.end(), std::ostream_iterator<MDP::TranInfo>(os, " "));
  return os;
}

// This does NOT read an MDPAction. That should have been done already
std::istream&
operator>>(std::istream& is, MDP::ActionInfo& act)
{
  act.clearTransitions();
  int size;
  is >> size;
  if (is.fail())
    return is;
  for (int i = 0; i < size; i++)
    {
      MDP::TranInfo t;
      is >> t;
      if (is.fail())
	return is;
      act.trans.push_back(t);
    }
  
  return is;
}

bool
MDP::ActionInfo::writeTo(BinaryFileWriter& writer) const
{
  if (!pact->writeTo(writer)) return false;
  if (!writer.writeIntAsShort(trans.size())) return false;
  for (TranStorage::const_iterator iter = trans.begin();
       iter != trans.end();
       ++iter)
    {
      if (!iter->writeTo(writer))
	return false;
    }
  return true;
}

// This does NOT read an MDPAction. That should have been done already
bool
MDP::ActionInfo::readFrom(BinaryFileReader& reader)
{
  short size;
  if (!reader.readShort(&size)) return false;
  for (int i=0; i < size; i++)
    {
      MDP::TranInfo t;
      if (!t.readFrom(reader)) return false;
      trans.push_back(t);
    }
  return true;
}


/*****************************************************************************/

MDP::MDP(int num_states)
  : rewards(), states(num_states)
{
}

MDP::~MDP()
{
  clear();  
}

int
MDP::getNumActionsInState(int sidx) const
{
  if (sidx < 0 || sidx >= getNumStates())
    {
      errorlog << "MDP::getNumActionsInState: sidx out of range "
	       << sidx << ", max=" << getNumStates() << ende;
      return -1;
    }

  return states[sidx].size();
}


MDP::ActionInfo*
MDP::getAction(int sidx, int aidx)
{
  if (sidx < 0 || sidx >= getNumStates())
    {
      errorlog << "MDP::getAction: sidx out of range "
	       << sidx << ", max=" << getNumStates() << ende;
      return NULL;
    }
  ActionStorage& actions = states[sidx];
  if (aidx < 0 || aidx >= (signed)actions.size())
    {
      errorlog << "MDP::getAction: aidx out of range "
	       << aidx << ", max=" << actions.size()
	       << ende;
      return NULL;
    }
  return &(actions[aidx]);
}

const MDP::ActionInfo*
MDP::getAction(int sidx, int aidx) const
{
  if (sidx < 0 || sidx >= getNumStates())
    {
      errorlog << "MDP::getAction: sidx out of range "
	       << sidx << ", max=" << getNumStates() << ende;
      return NULL;
    }
  const ActionStorage& actions = states[sidx];
  if (aidx < 0 || aidx >= (signed)actions.size())
    {
      errorlog << "MDP::getAction: aidx out of range "
	       << aidx << ", max=" << actions.size()
	       << ende;
      return NULL;
    }
  return &(actions[aidx]);
}


void
MDP::normalize()
{
  for (StateStorage::iterator state_iter = states.begin();
       state_iter != states.end();
       state_iter++)
    std::for_each(state_iter->begin(), state_iter->end(),
		  std::mem_fun_ref(&ActionInfo::normalize));
}

void
MDP::clear()
{
  rewards.clear();
  for (StateStorage::iterator state_iter = states.begin();
       state_iter != states.end();
       state_iter++)
    std::for_each(state_iter->begin(), state_iter->end(),
		  std::mem_fun_ref(&ActionInfo::deleteAction));
  states.clear();
}

double
MDP::getRewardForState(int sidx) const
{
  if (sidx < 0 || sidx >= getNumStates())
    {
      errorlog << "MDP::getRewardForState: bad idx " << sidx << ", max=" << getNumStates() << ende;
      return 0;
    }
  RewardStorage::const_iterator iter = rewards.find(sidx);
  return (iter == rewards.end()) ? 0.0 : iter->second;
}


double
MDP::getTotalProbForAction(int sidx, int act)
{
  if (sidx < 0 || sidx >= getNumStates())
    {
      errorlog << "MDP::getTotalProbForAction: bad state idx " << sidx
	       << ", max=" << getNumStates() << ende;
      return 0;
    }
  if (act < 0 || act >= (signed)states[sidx].size())
    {
      errorlog << "MDP::getTotalProbForAction: bad act idx " << act
	       << ", max=" << states[sidx].size() << ende;
      return 0;
    }
  return states[sidx][act].getTotalProb();
}



struct ActProbAdder {
  double operator()(double p, const MDP::ActionInfo& act) 
  { return p + act.getTotalProb(); }
};
double
MDP::getTotalProbForState(int sidx)
{
  if (sidx < 0 || sidx >= getNumStates())
    {
      errorlog << "MDP::getTotalProbForState: bad idx " << sidx
	       << ", max=" << getNumStates() << ende;
      return 0;
    }
  return std::accumulate(states[sidx].begin(), states[sidx].end(),
			 0.0, ActProbAdder());
}


//we pass in the qtable to allow stuff like disabled actions
//returns the number of DP iterations needed
//We use a simple DP method where we update all states in order and in place
int
MDP::solveByQTable(QTable& qt, int progress_interval, bool* p_early_term)
{
  if (qt.getNumStates() != getNumStates())
    {
      errorlog << "Can't solve by QTable without size match! "
	       << qt.getNumStates() << " == " << getNumStates() << ende;
      return 0;
    }
  
  int num_updates;
  int num_iterations = 0;

  if (progress_interval >= 0)
    std::cout << "MDP solving by QTable: " << std::endl;
  
  do
    {
      if (progress_interval >= 0)
	std::cout << "\tIteration " << num_iterations << ": " << std::flush;

      double change;
      num_updates = qt.mdpDPUpdate(*this, progress_interval, &change);

      actionlog(100) << "MDP::solveByQTable: iteration " << num_iterations << " has "
		     << num_updates << " updates" << ende;

      actionlog(240) << "MDPsolve: iteration " << num_iterations
	//<< '\n' << qt
		     << ende;
      
      num_iterations++;
      if (progress_interval >= 0)
	std::cout << ": update = " << num_updates << ", change = " << change << std::endl;

      if (p_early_term && *p_early_term)
	std::cout << "Got a termination request, exiting" << std::endl;
    }
  while (num_updates > 0);

  return num_iterations;
}

// Returns the indices of all actions which could have led to this transition
std::vector<int>
MDP::getPossibleActsForTran(int sidx1, int sidx2) const
{
  std::vector<int> vposs;

  if (sidx1 < 0 || sidx1 >= getNumStates())
    {
      errorlog << "MDP::getPossibleActsForTran: sidx1 out of range "
	       << sidx1 << ", max=" << getNumStates() << ende;
      return vposs;
    }
  if (sidx2 < 0 || sidx2 >= getNumStates())
    {
      errorlog << "MDP::getPossibleActsForTran: sidx2 out of range "
	       << sidx2 << ", max=" << getNumStates() << ende;
      return vposs;
    }

  for (ActionStorage::const_iterator iter = states[sidx1].begin();
       iter != states[sidx1].end();
       iter++)
    {
      if (iter->probForTranTo(sidx2) >= 0)
	vposs.push_back(iter - states[sidx1].begin());
    }

  return vposs;
}

// returns the new state idx
// randomly takes a next state
int
MDP::takeStep(int state_idx, int act_idx) const
{
  const ActionInfo* pact = getAction(state_idx, act_idx);
  if (pact == NULL)
    {
      errorlog << "MDP::takeStep: bad state or act idx " << state_idx << " " << act_idx << ende;
      return -1;
    }
  double prob = prob_random();
  int tran = pact->getTranForProb(prob);
  if (tran < 0)
    {
      errorlog << "MDP::takeStep: failed to tran (" << tran << ") for prob " << prob << ende;
      return -1;
    }
  return pact->getTransitions()[tran].getNextState();
}

// returns the new state idx
// follows a max probability transition, breaking ties randomly
int
MDP::takeMaxStep(int state_idx, int act_idx) const
{
  const ActionInfo* pact = getAction(state_idx, act_idx);
  if (pact == NULL)
    {
      errorlog << "MDP::takeMaxStep: bad state or act idx " << state_idx << " " << act_idx << ende;
      return -1;
    }
  int tran = pact->getMaxTranRandom();
  if (tran < 0)
    {
      errorlog << "MDP::takeMaxStep: failed to tran (" << tran << ")" << ende;
      return -1;
    }
  return pact->getTransitions()[tran].getNextState();
}

// for every state with >= min_act, adds the number of actions to sds
void
MDP::addActCountData(SingleDataSummary& sds, int min_act) const
{
  for (int sidx = getNumStates() - 1;
       sidx >= 0;
       sidx--)
    {
      int cnt = getNumActionsInState(sidx);
      if (cnt < min_act)
	continue;
      sds.addPoint((double)cnt);
    }
}

// for every action (in every state) with >= min_tran adds the number of
// transition to sds
void
MDP::addTranCountData(SingleDataSummary& sds, int min_tran) const
{
  for (int sidx = getNumStates() - 1;
       sidx >= 0;
       sidx--)
    {
      for (ActionStorage::const_iterator iter = states[sidx].begin();
	   iter != states[sidx].end();
	   iter++)
	{
	  iter->addTranCountData(sds, min_tran);
	}
    }
}

  
// returns a random state that has actions
int
MDP::getRandomValidState() const
{
  // This is kind of a dumb way to do this, but I want something that terminates
  // even if very few states are valid
  int start_point = int_random(states.size());
  int amt_forward = int_random(states.size());
  int state = start_point;
  for (int cnt = 0; cnt < amt_forward; cnt++)
    {
      do
	{
	  state = (state + 1) % states.size();
	}
      while (states[state].empty());
    }

  return state;
}


std::ostream&
operator<<(std::ostream& os, const MDP& m)
{
  os << "# This file represents an MDP (aka Markov Decision Process)" << std::endl;
  os << "# Format: First a rewards line, then one section per state" << std::endl;
  os << "# Rewards line: <num rewards> <list rewards>" << std::endl;
  os << "# where <list rewards> is a list of pairs <state> reward" << std::endl;
  os << "# A section consists of <number of actions> on one line, followed by one line per action" << std::endl;
  os << "# Each action line is <action desc> <num tran> <list of tran>" << std::endl;
  os << "# where <action desc> format is determined by the subclass of MDP action" << std::endl;
  os << "# and <list of tran> is a list of pairs of <nextstate> <probability>" << std::endl;

  os << m.rewards.size();
  for (MDP::RewardStorage::const_iterator rew_iter = m.rewards.begin();
       rew_iter != m.rewards.end();
       rew_iter++)
    os << ' ' << rew_iter->first << ' ' << rew_iter->second;
  os << std::endl;
      
  for (MDP::StateStorage::const_iterator state_iter = m.states.begin();
       state_iter != m.states.end();
       state_iter++)
    {
      //os << "# " << state_iter - m.states.begin() << std::endl;
      os << state_iter->size() << std::endl;
      std::copy(state_iter->begin(), state_iter->end(), std::ostream_iterator<MDP::ActionInfo>(os, "\n"));
    }

  return os;
}

std::istream&
operator>>(std::istream& is, MDP& m)
{
  m.clear();

  if (!spades::skip_to_non_comment(is))
    {
      if (m.getNumStates() == 0)
	is.setstate(std::ios::failbit);
      return is;
    }

  int size;
  is >> size;
  if (is.fail())
    return is;
  for (int i=0; i<size; i++)
    {
      int sidx;
      double rew;
      is >> sidx >> rew;
      if (is.fail())
	return is;
      m.rewards[sidx] = rew;
    }
  
  for (;;)
    {
      if (!spades::skip_to_non_comment(is))
	{
	  if (m.getNumStates() == 0)
	    is.setstate(std::ios::failbit);
	  else
	    is.clear(is.rdstate() & ~std::ios::failbit);
	  return is;
	}
      
      int size;
      is >> size;
      if (is.fail())
	{
	  if (m.getNumStates() > 0)
	    {
	      // this removes the fail bit and bad bitfrom the stream;
	      is.clear(is.rdstate() & ~std::ios::failbit & ~std::ios::badbit);
	    }
	  
	  return is;
	}

      m.states.push_back(MDP::ActionStorage());
      MDP::ActionStorage& these_actions = *(m.states.end() - 1);
      for (int i = 0; i < size; i++)
	{
	  MDP::ActionInfo act;
	  act.setAction(m.createAction(is));
	  if (act.getAction() == NULL)
	    {
	      errorlog << "MDP reading: Could not read action" << ende;
	      is.setstate(std::ios::failbit);
	      return is;
	    }
	  is >> act;
	  if (is.fail())
	    {
	      errorlog << "MDP reading: failed reading action info after type" << ende;
	      return is;
	    }
	  these_actions.push_back(act);
	}
    }
  
  return is;
}

bool
MDP::writeTo(BinaryFileWriter& writer) const
{
  if (!writeHeader(writer)) return false;
  
  if (!writer.writeInt(rewards.size())) return false;
  for (RewardStorage::const_iterator rew_iter = rewards.begin();
       rew_iter != rewards.end();
       ++rew_iter)
    {
      if (!writer.writeInt(rew_iter->first) || !writer.writeFloat(rew_iter->second))
	return false;
    }

  for (StateStorage::const_iterator state_iter = states.begin();
       state_iter != states.end();
       ++state_iter)
    {
      if (!writer.writeIntAsShort(state_iter->size())) return false;
      for (ActionStorage::const_iterator act_iter = state_iter->begin();
	   act_iter != state_iter->end();
	   ++act_iter)
	{
	  if (!act_iter->writeTo(writer))
	    return false;
	}
    }

  return true;
}

bool
MDP::readFrom(BinaryFileReader& reader)
{
  clear();
  
  actionlog(230) << "MDP::readFrom; start" << ende;

  if (!readHeader(reader)) return false;

  int size;
  if (!reader.readInt(&size)) return false;
  actionlog(230) << "MDP::readFrom; rewards size is " << size << ende;
  for (int i=0; i<size; i++)
    {
      int sidx;
      double rew;
      if (!reader.readInt(&sidx) || !reader.readFloat(&rew)) return false;
      rewards[sidx] = rew;
    }

  actionlog(230) << "MDP::readFrom; read rewards " << size << ende;
  
  for (;;)
    {
      int size;
      if (!reader.readShort(&size))
	{
	  if (getNumStates() > 0)
	    return true;
	  return false;
	}

      states.push_back(MDP::ActionStorage());
      MDP::ActionStorage& these_actions = *(states.end() - 1);
      for (int i = 0; i < size; i++)
	{
	  MDP::ActionInfo act;
	  act.setAction(createAction(reader));
	  if (act.getAction() == NULL)
	    {
	      errorlog << "MDP reading: Could not read action" << ende;
	      return false;
	    }
	  if (!act.readFrom(reader))
	    {
	      errorlog << "MDP reading: failed reading action info after type" << ende;
	      return false;
	    }
	  these_actions.push_back(act);
	}

      actionlog(230) << "MDP::readFrom; read state " << states.size() - 1 << ende;
    }
  
  return true;
}

bool
MDP::readTextOrBinary(std::istream& is)
{
  // we'll try binary first, and if that fails, we'll reset and try ascii
  std::streampos init_pos = is.tellg();
  
  {
    BinaryFileReader reader(is);

    if (readFrom(reader))
      return true;
  }

  is.seekg(init_pos);

  is >> *this;

  return !is.fail();
}

