/* -*- Mode: C++ -*- */

#include <algorithm>
#include <iterator>
#include <functional>
#include "AdviceTree.h"
#include "AbstractState.h"
#include "CoachMessageQueue.h"
#include "misc.h"
#include "Logger.h"

using namespace spades;

/***************************************************************************************/

AdviceTree::AdviceTree()
  : root(NULL), levels()
{
}

AdviceTree::~AdviceTree()
{
  clear();
}

void
AdviceTree::clear()
{
  levels.clear();
  if (root)
    delete root;
}

void
AdviceTree::createTreeFromLevels()
{
  if (root)
    delete root;
  root = createTree(levels.begin());
  if (root == NULL)
    errorlog << "Creating tree from levels failed" << ende;
}

void
AdviceTree::addAction(AbstractState* pstate, AdviceTreeAction* pact)
{
  root->addAction(pstate, pact, levels, 0);
}

std::string
AdviceTree::createAdvice(CoachMessageQueue* pqueue, const std::string& prefix,
			 std::ostream* pos_assoc)
{
  if (root == NULL)
    {
      errorlog << "AdviceTree::createAdvice: NULL root" << ende;
      return "";
    }
  if (pos_assoc)
    {
      (*pos_assoc) << "#This file was created by AdviceTree to establish the association between" << std::endl
		   << "# rule names and abstract states that go into them " << std::endl
		   << "# Format: <rulename> [<list of state indices>]" << std::endl;
    }
  return root->createAdvice(pqueue, prefix, levels, 0, pos_assoc, &assoc_map);
}

const std::string&
AdviceTree::getRuleForState(int state_idx) const
{
  static std::string empty_ret("");
  AssocMap::const_iterator iter = assoc_map.find(state_idx);
  if (iter == assoc_map.end())
    return empty_ret;
  return iter->second;
}

AdviceTreeNode*
AdviceTree::createTree(LevelStorage::iterator ilevel)
{
  if (ilevel == levels.end())
    return new AdviceTreeNodeLeaf();

  if (ilevel->pfac == NULL)
    {
      errorlog << "AdviceTree::createTree: level " << (ilevel - levels.begin())
	       << " has a null pfac" << ende;
      return NULL;
    }

  AdviceTreeNodeInternal* pnode = new AdviceTreeNodeInternal;
  
  int num_children = ilevel->pfac->getNumStates();
  for (int child_idx = 0; child_idx < num_children; child_idx++)
    {
      pnode->addChild(createTree(ilevel + 1));
    }

  return pnode;
}

//randomly shuffles all actions in the tree
void
AdviceTree::shuffleActions()
{
  std::vector<AdviceTreeNodeLeaf*> vleaves;

  root->findLeaves(vleaves, true);

  actionlog(100) << "AdviceTree: Shuffling " << vleaves.size() << " leaf actions" << ende;

  for (int idx1 = vleaves.size() - 1; idx1 >= 0; idx1--)
    {
      // the +1 is to allow idx2 == idx1;
      int idx2 = int_random(idx1 + 1);
      vleaves[idx1]->swapActions(*vleaves[idx2]);
    }
}


//friend
std::ostream&
operator<<(std::ostream& os, const AdviceTree& tree)
{
  os << "AdviceTree: " << std::endl;
  if (tree.root == NULL)
    os << "-NULL root" << std::endl;
  else
    tree.root->print(os, "-", tree.levels, 0);
  return os;
}


/***************************************************************************************/

void
AdviceTreeAction::clear()
{
  std::for_each(conds.begin(), conds.end(), deleteptr<rcss::clang::Cond>());
  std::for_each(dirs.begin(), dirs.end(), deleteptr<rcss::clang::Dir>());
  conds.clear();
  dirs.clear();
}

void
AdviceTreeAction::addCondition(rcss::clang::Cond* p)
{
  if (p == NULL)
    errorlog << "AdviceTreeAction::addCondition: adding NULL!" << ende;
  conds.push_back(p);
}

void
AdviceTreeAction::addDirective(rcss::clang::Dir* p)
{
  if (p == NULL)
    errorlog << "AdviceTreeAction::addDirective: adding NULL!" << ende;
  dirs.push_back(p);
}


void
AdviceTreeAction::addCondsTo(rcss::clang::CondAnd* pAnd)
{
  for (CondStorage::iterator iter = conds.begin();
       iter != conds.end();
       iter++)
    {
      pAnd->push_back( (*iter)->deepCopy() );
    }
}

void
AdviceTreeAction::addDirsTo(rcss::clang::SimpleRule::Storage& store)
{
  for (DirStorage::iterator iter = dirs.begin();
       iter != dirs.end();
       iter++)
    {
      store.push_back( (*iter)->deepCopy().release() );
    }
}

//friend
std::ostream&
operator<<(std::ostream& os, const AdviceTreeAction& a)
{
  os << "AdviceTreeAction[conds=";
  std::for_each(a.conds.begin(), a.conds.end(), derefprinter<rcss::clang::Cond>(os, " "));
  os << "; dirs=";
  std::for_each(a.dirs.begin(), a.dirs.end(), derefprinter<rcss::clang::Dir>(os, " "));
  os << "]";
  return os;
}


/***************************************************************************************/

void
AdviceTreeNode::addAction(AbstractState* pstate, AdviceTreeAction* pact,
			  AdviceTree::LevelStorage& levels, int depth)
{
  protAddAction(pstate, pact, levels, depth);
  setMark();
}

/***************************************************************************************/

AdviceTreeNodeInternal::~AdviceTreeNodeInternal()
{
  std::for_each(children.begin(), children.end(), deleteptr<AdviceTreeNode>());
}

void
AdviceTreeNodeInternal::protAddAction(AbstractState* pstate, AdviceTreeAction* pact,
				      AdviceTree::LevelStorage& levels, int depth)
{
  // pick the right child and recurse
  int child_idx = pstate->getFactorIdx(levels[depth].factor_idx);
  if (child_idx < 0 || child_idx >= (signed)children.size())
    {
      errorlog << "AdviceTreeNodeInternal::protAddAction: child_idx out of range: "
	       << child_idx
	       << ", max=" << children.size() << ende;
      return;
    }
  
  children[child_idx]->addAction(pstate, pact, levels, depth+1);
}

std::string
AdviceTreeNodeInternal::createAdvice(CoachMessageQueue* pqueue, const std::string& prefix,
				     AdviceTree::LevelStorage& levels, int depth,
				     std::ostream* pos_assoc,
				     AdviceTree::AssocMap* p_assoc_map)
{
  // Here's the strategy:
  // for each child:
  //   Tell children to do it's rules with a  "<prefix><idx>" prefix
  //   Create a rule (which we will store for now) which tests the right thing and calls "<prefix><idx>"
  // Then define a rule named "<prefix> which just points to all the test rules

  rcss::clang::NestedRule::Storage subrules;
  for (ChildrenStorage::iterator iter = children.begin();
       iter != children.end();
       iter++)
    {
      if (!(*iter)->isMarked())
	continue;
      
      int child_idx = iter - children.begin();
      //std::string child_name = prefix + "_" + toString(child_idx);
      std::string child_name = prefix + toStringGivenMax(child_idx, children.size());
      (*iter)->createAdvice(pqueue, child_name, levels, depth + 1, pos_assoc, p_assoc_map);

      rcss::clang::NestedRule* prule =
	new rcss::clang::NestedRule(std::auto_ptr<rcss::clang::Cond>(levels[depth].pfac->createCondition(child_idx)));

      rcss::clang::RuleIDList idlist;
      idlist.push_back(child_name);
      rcss::clang::IDListRule* psubrule =
	new rcss::clang::IDListRule(idlist);

      prule->getRules().push_back(psubrule);
      
      subrules.push_back(prule);
    }

  if (subrules.empty())
    {
      actionlog(50) << "AdviceTreeNodeInternal::createAdvice: my children didn't make any rules!"
		    << "(prefix=" << prefix << ")" << ende;
      return std::string("");
    }
  
  rcss::clang::NestedRule* poverallrule =
    new rcss::clang::NestedRule(std::auto_ptr<rcss::clang::Cond>(new rcss::clang::CondBool(true)),
				subrules);

  rcss::clang::DefRule *pdef =
    new rcss::clang::DefRule( prefix,
			      std::auto_ptr<rcss::clang::Rule>(poverallrule),
			      false );
  pqueue->getDefineContainer().push(pdef);
  return prefix;
}

void
AdviceTreeNodeInternal::findLeaves(std::vector<AdviceTreeNodeLeaf*>& vleaves,
				   bool require_marked)
{
  /* don't really know why this didn't work (didn't compile 
  std::for_each(children.begin(), children.end(),
		std::bind2nd(std::mem_fun(&AdviceTreeNode::findLeaves), vleaves));
  */
  for (ChildrenStorage::const_iterator iter = children.begin();
       iter != children.end();
       iter++)
    (*iter)->findLeaves(vleaves, require_marked);
}

void
AdviceTreeNodeInternal::print(std::ostream& os, const std::string& prefix,
			      const AdviceTree::LevelStorage& levels, int depth) const
{
  os << prefix << *levels[depth].pfac << std::endl;
  for (ChildrenStorage::const_iterator iter = children.begin();
       iter != children.end();
       iter++)
    (*iter)->print(os, prefix + "-", levels, depth + 1);
}

/***************************************************************************************/

AdviceTreeNodeLeaf::~AdviceTreeNodeLeaf()
{
  std::for_each(actions.begin(), actions.end(), deleteptr<AdviceTreeAction>());
}

void
AdviceTreeNodeLeaf::protAddAction(AbstractState* pstate, AdviceTreeAction* pact,
				  AdviceTree::LevelStorage& levels, int depth)
{
  actionlog(220) << "AdviceTreeNodeLeaf: for state " << *pstate
		 << ", adding action " << *pact << ende;
  addAdvice(pact);
  relevant_states.insert(pstate->getStateIdx());
}

std::string
AdviceTreeNodeLeaf::createAdvice(CoachMessageQueue* pqueue, const std::string& prefix,
				 AdviceTree::LevelStorage& levels, int depth,
				 std::ostream* pos_assoc,
				 AdviceTree::AssocMap* p_assoc_map)
{
  if (actions.empty())
    errorlog << "AdviceTreeNodeLeaf::createAdvice: I have no actions (prefix="
	     << prefix << ")" << ende;

  rcss::clang::SimpleRule::Storage dirs;
  rcss::clang::CondAnd* pAnd =
    new rcss::clang::CondAnd();
  
  for (ActionStorage::iterator iter = actions.begin();
       iter != actions.end();
       iter++)
    {
      (*iter)->addDirsTo(dirs);
      (*iter)->addCondsTo(pAnd);
    }

  rcss::clang::Cond* pCond;
  if (pAnd->getConds().empty())
    {
      delete pAnd;
      pCond = new rcss::clang::CondBool(true);
    }
  else
    {
      pCond = pAnd;
    }

  rcss::clang::SimpleRule* prule =
    new rcss::clang::SimpleRule( std::auto_ptr<rcss::clang::Cond>(pCond),
				 dirs );

  rcss::clang::DefRule *pdef =
    new rcss::clang::DefRule( prefix,
			      std::auto_ptr<rcss::clang::Rule>(prule),
			      false );
  pqueue->getDefineContainer().push(pdef);

  if (pos_assoc)
    {
      (*pos_assoc) << prefix << " ";
      std::copy(relevant_states.begin(), relevant_states.end(),
		std::ostream_iterator<int>(*pos_assoc, " "));
      (*pos_assoc) << "\n";
    }
  if (p_assoc_map)
    {
      for (StateIdxStorage::const_iterator iter = relevant_states.begin();
	   iter != relevant_states.end();
	   ++iter)
	{
	  (*p_assoc_map)[*iter] = prefix;
	}
    }
  
  return prefix;
}

void
AdviceTreeNodeLeaf::findLeaves(std::vector<AdviceTreeNodeLeaf*>& vleaves,
			       bool require_marked)
{
  if (!require_marked || (require_marked && isMarked()))
    vleaves.push_back(this);
}

void
AdviceTreeNodeLeaf::print(std::ostream& os, const std::string& prefix,
			  const AdviceTree::LevelStorage& levels, int depth) const
{
  os << prefix << ": Leaf: " << actions.size() << " actions" << std::endl;
}

void
AdviceTreeNodeLeaf::swapActions(AdviceTreeNodeLeaf& other_leaf)
{
  actions.swap(other_leaf.actions);
}


/***************************************************************************************/
