/* -*- Mode: C++ -*- */

/* These classes provide an interface to a MDP model with a value,
   and a factored description of the state space that collected it */

#include <iostream>
#include <algorithm>
#include <iomanip>
#include "SoccerModelWValue.h"
#include "AdviceTree.h"
#include "misc.h"
#include "Policy.h"
#include "CoachParam.h"
#include "Logger.h"

using namespace spades;

SoccerModelWValue::SoccerModelWValue(float gamma)
  : pdesc(NULL), mdp(), qtable(mdp, gamma)
{
}


SoccerModelWValue::~SoccerModelWValue()
{
  if (pdesc)
    delete pdesc;
}


void
SoccerModelWValue::setStateDescription(AbstractStateDescription* p)
{
  if (pdesc)
    delete pdesc;
  pdesc = p;

  mdp.processDescription(pdesc);
}

bool
SoccerModelWValue::readMDPFrom(const char* fn)
{
  std::ifstream is(fn);
  if (!is)
    {
      errorlog << "SoccerModelWValue::readMDPFrom: could not open file '"
	       << fn << "'" << ende;
      return false;
    }

  if (!mdp.readTextOrBinary(is))
    {
      errorlog << "SoccerModelWValue::readMDPFrom: failed reading from '"
	       << fn << "'" << ende;
      return false;
    }

  if (pdesc)
    mdp.processDescription(pdesc);
  
  return true;
}

bool
SoccerModelWValue::readQTableFrom(const char* fn)
{
  std::ifstream is(fn);
  if (!is)
    {
      errorlog << "SoccerModelWValue::readQTableFrom: could not open file '"
	       << fn << "'" << ende;
      return false;
    }

  if (!qtable.readTextOrBinary(is))
    {
      errorlog << "SoccerModelWValue::readQTableFrom: failed reading from '"
	       << fn << "'" << ende;
      return false;
    }

  return true;
}

bool
SoccerModelWValue::check(bool desc_valid, bool mdp_valid, bool qtable_valid)
{
  if (desc_valid)
    {
      if (pdesc == NULL)
	{
	  errorlog << "SoccerModelWValue::check: pdesc is NULL" << ende;
	  return false;
	}
      
      if (mdp_valid)
	{
	  if (pdesc->getNumStates() != mdp.getNumStates())
	    {
	      errorlog << "SoccerModelWValue::check: pdesc and mdp size mismatch: "
		       << pdesc->getNumStates() << " "
		       << mdp.getNumStates() << ende;
	      return false;
	    }
	}

      if (qtable_valid)
	{
	  if (pdesc->getNumStates() != qtable.getNumStates())
	    {
	      errorlog << "SoccerModelWValue::check: pdesc and qtable size mismatch: "
		       << pdesc->getNumStates() << " "
		       << qtable.getNumStates() << ende;
	      return false;
	    }
	}
    }

  if (mdp_valid && qtable_valid)
    {
      if (mdp.getNumStates() != qtable.getNumStates())
	{
	  errorlog << "SoccerModelWValue::check: mdp and qtable size mismatch: "
		   << mdp.getNumStates() << " "
		   << qtable.getNumStates() << ende;
	  return false;
	}
    }

  return true;
}

  
// returns number of actions in advice
int
SoccerModelWValue::adviseForState(int sidx,
				  std::list<SoccerActionFilter*>& lfilters,
				  SoccerModelAdviser* padviser)
{
  if (!check(true, true, true))
    {
      errorlog << "SoccerModelWValue::adviseForState: failed check" << ende;
      return 0;
    }

  AbstractState state(pdesc);
  state.setStateIdx(sidx);

  return adviseForState(&state, lfilters, padviser);
}

int
SoccerModelWValue::adviseForState(AbstractState* pstate,
				  std::list<SoccerActionFilter*>& lfilters,
				  SoccerModelAdviser* padviser)
{
  int act_count = 0;
  int sidx = pstate->getStateIdx();
  
  padviser->beginStateAdvice(pstate);

  if (sidx != -1)
    {
      
      // Now let's do the actions
      for (int aidx=mdp.getNumActionsInState(sidx) - 1; aidx >= 0; aidx--)
	{
	  if (!doesPassActionFilters(sidx, aidx, lfilters))
	    continue;

	  SoccerMDPAction* pmdpact = (SoccerMDPAction*)mdp.getAction(sidx, aidx)->getAction();

	  if (!padviser->addStateAdvice(pstate, pmdpact))
	    continue;

	  act_count++;
	}
    }
  
  padviser->endStateAdvice(pstate, act_count);
  
  return act_count;
}


int
SoccerModelWValue::adviseFor(std::list<SoccerStateFilter*>& lfilters_state,
			     std::list<SoccerActionFilter*>& lfilters_act,
			     SoccerModelAdviser* padviser,
                             int* pcount_no_act_states)
{
  if (!check(true, true, true))
    {
      errorlog << "SoccerModelWValue::adviseFor: failed check" << ende;
      return false;
    }

  int state_count = 0;
  int count_no_act_states = 0;
  for (int sidx = 0; sidx < mdp.getNumStates(); sidx++)
    {
      if (!doesPassStateFilters(sidx, lfilters_state))
	continue;

      if (adviseForState(sidx, lfilters_act, padviser) == 0)
        ++count_no_act_states;
      
      state_count++;
    }

  if (pcount_no_act_states)
    *pcount_no_act_states = count_no_act_states;
  
  return state_count;
}

bool
SoccerModelWValue::doesPassActionFilters(int sidx, int aidx,
					 std::list<SoccerActionFilter*>& lfilters)
{
  for (std::list<SoccerActionFilter*>::iterator iter = lfilters.begin();
       iter != lfilters.end();
       iter++)
    {
      if (!(*iter)->acceptAction(this, sidx, aidx))
	{
	  return false;
	}
    }
  return true;
}

bool
SoccerModelWValue::doesPassStateFilters(int sidx,
					std::list<SoccerStateFilter*>& lfilters)
{
  for (std::list<SoccerStateFilter*>::iterator iter = lfilters.begin();
       iter != lfilters.end();
       iter++)
    {
      if (!(*iter)->acceptState(this, sidx))
	{
	  return false;
	}
    }
  return true;
}

// estimates the average reward per step by taking num_steps from start state
// repeats the process repititions times
SingleDataSummary
SoccerModelWValue::estimateOptAvgStepReward(int start_idx, 
					    int num_steps, int repititions,
					    std::list<SoccerActionFilter*>& lfilters_act,
					    bool show_status)
{
  if (show_status)
    std::cout << "Estimating reward: " << std::flush;
  SingleDataSummary sds;
  for (int rep = 0; rep < repititions; ++rep)
    {
      int state = start_idx;
      for (int step = 0; step < num_steps; ++step)
	{
	  std::vector<int> vact;
	  for (int aidx=mdp.getNumActionsInState(state) - 1; aidx >= 0; aidx--)
	    {
	      if (doesPassActionFilters(state, aidx, lfilters_act))
		vact.push_back(aidx);
	    }
	  state = mdp.takeStep(state, vact[int_random(vact.size())]);
	  double rew = mdp.getRewardForState(state);
	  sds.addPoint(rew);
	}
      if (show_status)
	std::cout << '.' << std::flush;
    }
  if (show_status)
    std::cout << std::endl;
  return sds;
}



/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

bool
SSFilterNumActions::acceptState(SoccerModelWValue* pmodel, int state_idx)
{
  int numact = pmodel->getMDP().getNumActionsInState(state_idx);
  return (min <= numact) && (numact < max);
}

/*****************************************************************************/

bool
SSFilterRand::acceptState(SoccerModelWValue* pmodel, int state_idx)
{
  return prob_random() < accept_prob;
}


/*****************************************************************************/

bool
SSFilterHasCLangAct::acceptState(SoccerModelWValue* pmodel, int state_idx)
{
  for (int aidx=pmodel->getMDP().getNumActionsInState(state_idx) - 1; aidx >= 0; aidx--)
    {
      if (((SoccerMDPAction*)(pmodel->getMDP().getAction(state_idx, aidx)->getAction()))->doesCreateCLangAction())
	return true;
    }
  return false;
}

/*****************************************************************************/

bool
SSFilterFactorHasValue::acceptState(SoccerModelWValue* pmodel, int state_idx)
{
  AbstractState state(pmodel->getStateDescription());
  state.setStateIdx(state_idx);
  return state.getFactorIdx(fac_idx) == val;
}


/*****************************************************************************/

bool
SAFilterNearOpt::acceptAction(SoccerModelWValue* pmodel, int state_idx, int act_idx)
{
  if (choose_worst)
    {
      double worst = pmodel->getQTable().getWorstQ(state_idx);
      return pmodel->getQTable().getQ(state_idx, act_idx) <= worst + fabs(worst * (1.0 - perc));
    }
  else
    {
      double opt = pmodel->getQTable().getV(state_idx);
      /* I know this looks like a funny expression. Why can't you just say getQ(..) >= opt*perc
	 The answer is: Q and V can be negative, in which case, no actions will match those
	 criteria. That's bad! */
      return pmodel->getQTable().getQ(state_idx, act_idx) >= opt - fabs(opt * (1.0 - perc));
      // this is the old bad way
      //return pmodel->getQTable().getQ(state_idx, act_idx) >= opt * perc;
    }
}

/*****************************************************************************/
/*****************************************************************************/
/*****************************************************************************/

void
SoccerModelAdviserFlat::beginStateAdvice(AbstractState* pstate)
{
  rcss::clang::UNumSet uset;
  uset.add(rcss::clang::UNum::uAll);
  p_curr_dir = new rcss::clang::DirComm(true, true, uset);
}

// you have to take over the memory!
void
SoccerModelAdviserFlat::addStateAdvice(AbstractState* pstate, rcss::clang::Action* pact)
{
  p_curr_dir->add(std::auto_ptr<rcss::clang::Action>(pact));
}

void
SoccerModelAdviserFlat::endStateAdvice(AbstractState* pstate, int act_count)
{
  if (act_count == 0)
    {
      // we didn't have anything to say about this state!
      delete p_curr_dir;
      return;
    }

  rcss::clang::Cond* pCond = pstate->createCondition();
  if (!pCond)
    {
      errorlog << "SoccerModelAdviserFlat::endStateAdvice: did not get a condition!" << ende;
      return;
    }
  
  std::string rule_name =
    rule_prefix + toString(pstate->getStateIdx());

  rcss::clang::SimpleRule* pRule =
    new rcss::clang::SimpleRule(std::auto_ptr<rcss::clang::Cond>(pCond));

  //pRule->getDirs().push_back(std::auto_ptr<rcss::clang::Dir>(pDir));
  pRule->getDirs().push_back(p_curr_dir);

  rcss::clang::DefRule* pDef =
    new rcss::clang::DefRule( rule_name,
			      std::auto_ptr<rcss::clang::Rule>(pRule),
			      false );
  pmqueue->getDefineContainer().push(pDef);

  rcss::clang::RuleIDList l;
  l.push_back(rule_name);
  pmqueue->getRuleContainer().push( new rcss::clang::ActivateRules(true, l) );
}

/*****************************************************************************/

void
SoccerModelAdviserTree::addStateAdvice(AbstractState* pstate,
				       rcss::clang::Action* pact)
{
  errorlog << "SoccerModelAdviserTree::addStateAdvice: How did I get called" << ende;
  delete pact;
}

// returns whether a real action got output
bool
SoccerModelAdviserTree::addStateAdvice(AbstractState* pstate, SoccerMDPAction* pmdpact)
{
  if (ptree == NULL)
    {
      errorlog << "SoccerModelAdviserTree::addStateAdvice: null tree" << ende;
      return false;
    }

  actionlog(220) << "SoccerModelAdviserTree::addStateAdvice: advice for "
		 << *pstate << " is " << *pmdpact << ende;

  AdviceTreeAction* pact =
    pmdpact->createAdviceTreeAction(pstate->getStateDescription(),
				    pstate->getStateIdx());
  if (pact == NULL)
    return false;
  
  ptree->addAction(pstate, pact);

  return true;
}


/*****************************************************************************/

void
SoccerModelAdviserOutput::beginStateAdvice(AbstractState* pstate)
{
  int idx = pstate->getStateIdx();
  if (idx == -1)
    os << idx << ' ' << 0;
  else
    os << idx << ' ' << soccer_model.getMDP().getRewardForState(idx);
}

// you have to take over the memory!
void
SoccerModelAdviserOutput::addStateAdvice(AbstractState* pstate,
					 rcss::clang::Action* pact)
{
  os << ' ' << *pact ;
  delete pact;
}

void
SoccerModelAdviserOutput::endStateAdvice(AbstractState* pstate, int act_count)
{
  os << std::endl;
}

/*****************************************************************************/

void
SoccerModelAdviserCountAct::endStateAdvice(AbstractState* pstate, int act_count)
{
  bucket.addPoint(act_count);
}

//friend
std::ostream&
operator<<(std::ostream& os, const SoccerModelAdviserCountAct& sma)
{
  os << "# This is a SoccerModelAdviserCountAct" << std::endl
     << "# It represents the number of states which have the given number of advised acts" << std::endl
     << "# Format: <num acts> <num states which advise that many acts" << std::endl
     << sma.bucket;

  return os;
}


/*****************************************************************************/
SoccerModelTraceComparator::SoccerModelTraceComparator(SoccerModelWValue* pmodel,
						       std::list<SoccerStateFilter*>& lfilters_state,
						       std::list<SoccerActionFilter*>& lfilters_act)
  : pmodel(pmodel),
    lfilters_state(lfilters_state),
    lfilters_act(lfilters_act),
    v_tran_class_counts(NUM_TRANSITION_CLASSES, 0)
{
}

SoccerModelTraceComparator::~SoccerModelTraceComparator()
{
  //nothing
}


void
SoccerModelTraceComparator::resetCounts()
{
  std::fill(v_tran_class_counts.begin(), v_tran_class_counts.end(), 0);
}


void
SoccerModelTraceComparator::printCounts(std::ostream& o, double avg_factor)
{
  for (VTranClassCounts::iterator iter = v_tran_class_counts.begin();
       iter != v_tran_class_counts.end();
       iter++)
    {
      o << std::setw(25) << ((TransitionClass)(iter-v_tran_class_counts.begin()))
	<< ": "
	<< std::setw(10) << *iter;
      if (avg_factor > 0)
	o << "\t" << (((double)(*iter)) / avg_factor);
      o << std::endl;
    }
}


// THe main callback: inside of each file, called for each state trace element
// Also remembers previous times and states, just for conveinance
// -1 means invalid for last_state, last_tiem
void
SoccerModelTraceComparator::handleNextState(int lasttime, int laststate,
					    int time, int state)
{
  if (laststate == -1)
    return;
  v_tran_class_counts[classifyTransition(laststate, state)]++;
}

SoccerModelTraceComparator::TransitionClass
SoccerModelTraceComparator::classifyTransition(int laststate, int state)
{
  //no actions means a state we haven't seen
  if (pmodel->getMDP().getNumActionsInState(laststate) <= 0)
    return TC_ImpossibleState;

  bool state_filtered = pmodel->doesPassStateFilters(laststate, lfilters_state);
  std::vector<int> vpossact = pmodel->getMDP().getPossibleActsForTran(laststate, state);
  
  if (vpossact.empty())
    {
      if (laststate == state &&
	  CoachParam::instance()->getAbstractCompareSt2mIgnoreSelfTranMin() <= laststate &&
	  laststate < CoachParam::instance()->getAbstractCompareSt2mIgnoreSelfTranMax())
	return TC_IgnoredSelfTran;

      actionlog(200) << "Impossible action (" << state_filtered << ")"
		     << ": " << laststate << " -> " << state
		     << ende;
      //We've never observed this particular transition before!
      return state_filtered ? TC_FilteredStateImpossibleAct : TC_ImpossibleAct;
    }

  if (state_filtered)
    // The state was filitered, and there is some action which could have done this
    return TC_FilteredStatePossibleAct;

  //Now go through and see if any non filtered action produced this
  // Otherwise, it's a possible but not correct action
  for (std::vector<int>::iterator iter = vpossact.begin();
       iter != vpossact.end();
       iter++)
    {
      if (pmodel->doesPassActionFilters(laststate, *iter, lfilters_act))
	return TC_CorrectAct;
    }

  return TC_PossibleAct;
}


std::ostream&
operator<<(std::ostream& os,
	   SoccerModelTraceComparator::TransitionClass tc)
{
  static const char* TRANSITION_CLASS_STRINGS[] = 
    { "Error", "CorrectAct", "PossibleAct", "IgnoredSelfTran", "ImpossibleAct", 
      "FilteredStatePossibleAct", "FilteredStateImpossibleAct", "ImpossibleState" };

  if (tc < 0 || tc >= SoccerModelTraceComparator::NUM_TRANSITION_CLASSES)
    os << "InvalidTransitioClass(" << (int)tc << ")";
  else
    os << TRANSITION_CLASS_STRINGS[(int)tc];
  return os;
}


