/* -*- Mode: C++ -*- */

#include "MDPConversion.h"
#include "utility.h"
#include "Logger.h"

using namespace spades;


bool
TransitionSorter::Element::ClassElement::lookupIndex(const ASDiffClassifierSet& cset) const
{
  idx = cset.lookupClassifier(name.c_str());
  return (idx >= 0);
}

/*****************************************************************************/
TransitionSorter::Element::Element()
  : act_name(),
    primary_classes(),
    secondary_classes()
{
}
      
TransitionSorter::Element::~Element()
{
  // don't need to do anything
}
      
void
TransitionSorter::Element::clear()
{
  act_name.clear();
  primary_classes.clear();
  secondary_classes.clear();
}


void
TransitionSorter::Element::lookupIndices(const ASDiffClassifierSet& cset)
{
  lookupIndices(cset, &primary_classes);
  lookupIndices(cset, &secondary_classes);
}

void
TransitionSorter::Element::lookupIndices(const ASDiffClassifierSet& cset,
					 ClassStorage* pclass)
{
  for (ClassStorage::iterator iter = pclass->begin();
       iter != pclass->end();
       iter++)
    {
      if (!iter->lookupIndex(cset))
	errorlog << "TransitionSorter::Element: Failed lookup for " << *iter << ende;
    }
}

      

bool
TransitionSorter::Element::isInClasses(bool primary,
				       int class_idx,
				       const ASDiffClassifierSet& cset)
{
  if (class_idx < 0)
    {
      errorlog << "isInClasses: bad class_idx " << class_idx << ende;
      return false;
    }
  const ASDiffClassifier* p = cset.getClassifier(class_idx);
  if (p == NULL)
    {
      errorlog << "isInClasses: lookup failed for " << class_idx << ende;
      return false;
    }
  return (primary ? primary_classes : secondary_classes).count(ClassElement(p->getName()));
}

void
TransitionSorter::Element::createActions(AbstractStateDescription* pdesc,
					 const ASDiffClassifierSet& cset,
					 int current_state_idx,
					 std::vector<MDPConvTranInfo*>& vtran,
					 ActionList& lact)
					 
{
  /* This is the big complicated function here. Here are the steps
   *   * For every transition that could be a primary transition
   *   *   * create an MDPConvActionInfo* to
   *   *   * create an SoccerMDPAction based on this
   *   *   * Find all secondary transitions
   */

  for (std::vector<MDPConvTranInfo*>::iterator prim_tran_iter = vtran.begin();
       prim_tran_iter != vtran.end();
       prim_tran_iter++)
    {
      if ((*prim_tran_iter)->getClassIdx() == -1)
	errorlog << "createActions: vtran has element with undefined class: "
		 << **prim_tran_iter << ende;
      if (isPrimary((*prim_tran_iter)->getClassIdx(), cset))
	{
	  MDPConvActionInfo* pactinfo = new MDPConvActionInfo;
	  pactinfo->setAction(SoccerMDPAction::createByName(act_name.c_str()));
	  if (!pactinfo->getAction()->setParamsFromTransition(pdesc,
							      current_state_idx,
							      (*prim_tran_iter)->getNextState()))
	    {
	      //rejected transition
	      delete pactinfo;
	      continue;
	    }
	  
	  pactinfo->addTran(*prim_tran_iter);
	  (*prim_tran_iter)->incrReplicationCount();
	  for (std::vector<MDPConvTranInfo*>::iterator sec_tran_iter = vtran.begin();
	       sec_tran_iter != vtran.end();
	       sec_tran_iter++)
	    {
	      if (isSecondary((*sec_tran_iter)->getClassIdx(), cset) &&
		  pactinfo->getAction()->acceptTran(pdesc,
						    current_state_idx,
						    (*sec_tran_iter)->getNextState()))
		{
		  pactinfo->addTran(*sec_tran_iter);
		  (*sec_tran_iter)->incrReplicationCount();
		}
	    }
	  lact.push_back(pactinfo);
	}
    }

  return;
}
      
      
std::ostream&
operator<<(std::ostream& os, const TransitionSorter::Element& e)
{
  os << e.act_name << " : ";
  os << "{ ";
  std::copy(e.primary_classes.begin(), e.primary_classes.end(),
	    std::ostream_iterator<TransitionSorter::Element::ClassElement>(os, " "));
  os << "} ";

  os << "{ ";
  std::copy(e.secondary_classes.begin(), e.secondary_classes.end(),
	    std::ostream_iterator<TransitionSorter::Element::ClassElement>(os, " "));
  os << "} ";

  return os;
}
      
std::istream&
operator>>(std::istream& is, TransitionSorter::Element& e)
{
  e.clear();
  is >> e.act_name;
  if (is.fail())
    return is;
  if (!skip_white_space(is))
    return is;
  if (is.get() != ':')
    {
      is.setstate(std::ios::failbit);
      return is;
    }
  is >> e.primary_classes;
  if (is.fail())
    return is;
  is >> e.secondary_classes;
  if (is.fail())
    return is;
  
  return is;
}

std::istream&
operator>>(std::istream& is, TransitionSorter::Element::ClassStorage& s)
{
  s.clear();
  
  if (!skip_white_space(is))
    return is;
  if (is.get() != '{')
    {
      is.setstate(std::ios::failbit);
      return is;
    }

  for (;;)
    {
      if (!skip_white_space(is))
	return is;
      if (is.peek() == '}')
	{
	  is.get(); //absorb the '}'
	  // we're done!
	  return is;
	}
      std::string name;
      is >> name;
      if (is.fail())
	return is;
      s.insert(name);
    }
  
  return is;
}


/*****************************************************************************/

TransitionSorter::TransitionSorter(ASDiffClassifierSet* pcset,
				   AbstractStateDescription* pdesc)
  : elements(), pcset(pcset), pdesc(pdesc), allowed_to_null_classes()
{
}
  
TransitionSorter::~TransitionSorter()
{
  // NOTE: we do NOT free this memory!
}
  
void
TransitionSorter::clear()
{
  clearElements();
  //NOTE we do NOT free this memory
  pcset = NULL;
  pdesc = NULL;
}

void
TransitionSorter::performClassLookups()
{
  for (ElementStorage::iterator iter = elements.begin();
       iter != elements.end();
       iter++)
    {
      iter->lookupIndices(*pcset);
    }
  for (Element::ClassStorage::iterator iter = allowed_to_null_classes.begin();
       iter != allowed_to_null_classes.end();
       iter++)
    {
      if (!iter->lookupIndex(*pcset))
	errorlog << "TransitionSorter::performClassLookups: Failed lookup for " << *iter << ende;
    }
}

void
TransitionSorter::createActions(int current_state_idx,
				std::vector<MDPConvTranInfo*>& vtran,
				ActionList& lact)
{
  // Classify all the transitions
  for (std::vector<MDPConvTranInfo*>::iterator tran_iter = vtran.begin();
       tran_iter != vtran.end();
       tran_iter++)
    {
      (*tran_iter)->classify(pdesc, pcset, current_state_idx);
    }

  // Let each type of transition have a crack at this
  for (ElementStorage::iterator elem_iter = elements.begin();
       elem_iter != elements.end();
       elem_iter++)
    {
      elem_iter->createActions(pdesc, *pcset, current_state_idx, vtran, lact);
    }

  // Make sure all the transitions got put at least one place
  for (std::vector<MDPConvTranInfo*>::iterator tran_iter = vtran.begin();
       tran_iter != vtran.end();
       tran_iter++)
    {
      if ((*tran_iter)->getNumReplications() == 0)
	{
	  // This is okay only if this is in our list of allowed classes
	  const ASDiffClassifier* p = pcset->getClassifier((*tran_iter)->getClassIdx());
	  if (p == NULL)
	    {
	      errorlog << "Check for allowed null: Lookup failed for " << **tran_iter << ende;
	    }
	  if (allowed_to_null_classes.count(Element::ClassElement(p->getName())))
	    {
	      // This is in our set of allowed Null guys!
	      MDPConvActionInfo* pactinfo = new MDPConvActionInfo;
	      pactinfo->setAction(new SoccerMDPActionNull);
	      pactinfo->addTran(*tran_iter);
	      (*tran_iter)->incrReplicationCount();
	      lact.push_back(pactinfo);
	    }
	  else
	    {
	      errorlog << "TransitionSorter::createActions: in state " << current_state_idx
		       << ": transition went no where: " << **tran_iter << ende;
	    }
	}
    }

  // Now, remove duplications from the list
  //int cnt =
    MDPConvActionInfo::unifyList(lact);
  //std::cout << "SMURF: actions unified: " << cnt << std::endl;
}

      

std::ostream&
operator<<(std::ostream& os, const TransitionSorter& s)
{
  os << "# This file represents a TransitionSorter" << std::endl;
  os << "# Format: each line is an TransitionSorter::Element, describing a possible action" << std::endl;
  os << "# line: format: <act name> : { <list of primary classes> } { <list of sec. classes> } " << std::endl;
  os << "{ ";
  std::copy(s.allowed_to_null_classes.begin(), s.allowed_to_null_classes.end(),
	    std::ostream_iterator<TransitionSorter::Element::ClassElement>(os, " "));
  os << "} " << std::endl;

  std::copy(s.elements.begin(), s.elements.end(),
	    std::ostream_iterator<TransitionSorter::Element>(os, "\n"));
  return os;
}
  
std::istream&
operator>>(std::istream& is, TransitionSorter& s)
{
  // We only clear the elements, don't forget our AbstractStateDescription
  // or classification set
  s.clearElements();
  
  if (!skip_to_non_comment(is))
    return is;

  is >> s.allowed_to_null_classes;
  if (is.fail())
    return is;
  
  while (!is.eof())
    {
      if (!skip_to_non_comment(is))
	{
	  is.clear(std::ios::eofbit);
	  return is;
	}
      TransitionSorter::Element e;
      is >> e;
      if (is.fail())
	return is;
      s.elements.push_back(e);
    }
  return is;
}

/*****************************************************************************/
MDPConvTranInfo::MDPConvTranInfo()
  : nextstate(-1), prob(0.0), num_replications(0), class_idx(-1)
{
}

MDPConvTranInfo::MDPConvTranInfo(int nextstate, double prob)
  : nextstate(nextstate), prob(prob), num_replications(0), class_idx(-1)
{
}


MDPConvTranInfo::~MDPConvTranInfo()
{
}
  
bool
MDPConvTranInfo::convert(MDP::TranInfo& tran)
{
  tran.setNextState(nextstate);
  if (num_replications == 0)
    errorlog << "MDPConvTranInfo::convert: converting with no replications: "
	     << *this << ende;
  tran.setProb(prob / num_replications);
  
  return true;
}
  
int
MDPConvTranInfo::classify(AbstractStateDescription* pdesc, ASDiffClassifierSet* pcset,
			  int original_state)
{
  if (nextstate < 0)
    errorlog << "MDPConvTranInfo::classify: invalid nextstate: " << *this << ende;
  AbstractState first_state(pdesc);
  first_state.setStateIdx(original_state);
  AbstractState second_state(pdesc);
  second_state.setStateIdx(nextstate);

  AbstractStateCompare comp(&first_state, &second_state);
  class_idx = pcset->classify(&comp);
  if (class_idx < 0)
    errorlog << "MDPConvTranInfo::classify: could not classify "
	     << class_idx << ": " << *this << ende;
  
  return -1;
}

std::ostream&
operator<<(std::ostream& os, const MDPConvTranInfo& t)
{
  os << "MDPConvTranInfo("
     << "nextstate=" << t.nextstate << ", "
     << "prob=" << t.prob << ", "
     << "repl=" << t.num_replications << ", "
     << "class=" << t.class_idx << ")";
  return os;
}



/*****************************************************************************/

MDPConvActionInfo::MDPConvActionInfo()
  : pact(NULL), trans()
{
}
  
MDPConvActionInfo::~MDPConvActionInfo()
{
  if (pact)
    delete pact;
}

bool
MDPConvActionInfo::tryUnify(MDPConvActionInfo* pactinfo)
{
  if (!pact->tryUnify(pactinfo->pact))
    return false;
  unifyTransitions(pactinfo);
  return true;
}

void
MDPConvActionInfo::unifyTransitions(MDPConvActionInfo* pactinfo)
{
  //we'll do this the simple but inefficient way
  // We can compare pointer to the MDPConvTranInfo
  for (TranStorage::iterator iter = pactinfo->trans.begin();
       iter != pactinfo->trans.end();
       iter++)
    {
      // The transition should appear at most once
      TranStorage::iterator find_iter =
	std::find(trans.begin(), trans.end(), *iter);
      if (find_iter == trans.end())
	{
	  // *this doesn't have this transition
	  trans.push_back(*iter);
	}
      else
	{
	  // *this does have this transition, just decrement the replication
	  (*find_iter)->decrReplicationCount();
	}
    }
}



// we don't copy the MDPAction! You may want to forget the pointer!
bool
MDPConvActionInfo::convert(MDP::ActionInfo& actinfo)
{
  actinfo.clear();
  actinfo.setAction(pact);

  for (TranStorage::iterator iter = trans.begin();
       iter != trans.end();
       iter++)
    {
      MDP::TranInfo t;
      if (!(*iter)->convert(t))
	{
	  errorlog << "MDPConvActionInfo::convert: Error converting transition: "
		   << **iter << ende;
	  return false;
	}
      actinfo.addTransition(t);
    }
  return true;
}

std::ostream&
operator<<(std::ostream& os, const MDPConvActionInfo& a)
{
  os << "MDPConvActionInfo(" << *a.pact << ", [";
  std::for_each (a.trans.begin(), a.trans.end(),
		 derefprinter<MDPConvTranInfo>(os, " "));
  return os;
}

//static
int
MDPConvActionInfo::unifyList(ActionList& lact)
{
  int count = 0;
  
  for (ActionList::iterator iter1 = lact.begin();
       iter1 != lact.end();
       iter1++)
    {
      ActionList::iterator nextiter;
      for (ActionList::iterator iter2 = lact.begin();
	   iter2 != lact.end();
	   iter2 = nextiter)
	{
	  nextiter = iter2;
	  nextiter++;

	  if (iter1 == iter2)
	    continue;
	  
	  if ((*iter1)->tryUnify(*iter2))
	    {
	      //remove all iter2 stuff
	      delete *iter2;
	      lact.erase(iter2);
	      count++;
	    }
	}
    }
  return count;
}


/*****************************************************************************/

MDPConverter::~MDPConverter()
{
  if (psorter) delete psorter;
  if (pmdp) delete pmdp;
}

bool
MDPConverter::convertMC(MarkovChain* pmc, int progress_interval)
{
  bool retval = true;

  if (pmdp)
    delete pmdp;

  pmdp = new SoccerMDP(pmc->getNumStates());

  for (int sidx = pmc->getNumStates() - 1; sidx >= 0; sidx--)
    {
      if (progress_interval > 0 && sidx % progress_interval == 0)
	std::cout << '.' << std::flush;
      
      // Set up our transitions
      std::vector<MDPConvTranInfo*> vtran;
      for (int tidx = pmc->getNumTransitionsForState(sidx) - 1; tidx >= 0; tidx--)
	{
	  vtran.push_back(new MDPConvTranInfo(pmc->getTranNextState(sidx, tidx),
					      pmc->getTranProb(sidx, tidx)));
	}
      // Create our actions
      ActionList lact;
      psorter->createActions(sidx, vtran, lact);

      // Now make these MDP actions
      for (ActionList::iterator iter = lact.begin();
	   iter != lact.end();
	   iter++)
	{
	  MDP::ActionInfo actinfo;
	  if (!(*iter)->convert(actinfo))
	    errorlog << "Failed converting action: " << **iter << ende;
	  pmdp->addAction(sidx, actinfo);
	}

      // Clean up
      // we have to forget the actions because the MDP actions took them
      std::for_each(lact.begin(), lact.end(), std::mem_fun(&MDPConvActionInfo::forgetAction));
      std::for_each(vtran.begin(), vtran.end(), deleteptr<MDPConvTranInfo>());

      // Sanity checking
      // We'll look at the total prob in this state, across all action
      // Since we haven't normalized, this should be equal to the total probability
      // that existed in the MarkovChain transitions (1.0)
      double total_prob = pmdp->getTotalProbForState(sidx);
      if (total_prob != 0.0 && fabs(total_prob - 1.0) > EPSILON)
	{
	  errorlog << "MDPConverter: total prob for state " << sidx
		   << " not equal to 1.0: " << total_prob << ende;
	  retval = false;
	}
    }

  pmdp->normalize();
  
  return retval;
}

