/* -*- Mode: C++ -*- */

#include <functional>
#include <numeric>
#include <iterator>
#include "MarkovChain.h"
#include "utility.h"
#include "misc.h"
#include "StateTraceAnalyzer.h"
#include "Logger.h"

using namespace spades;

#define USE_SANITY_CHECKS

/**************************************************************************************/

MarkovChain::MarkovChain(int num_states)
  : states(num_states),
    reidx(),
    reidx_num_states(-1)
{
}

MarkovChain::~MarkovChain()
{
}

void
MarkovChain::setNumStates(int new_size)
{
  states.resize(new_size);
  reidx.clear();
}

template <class T>
struct size_adder : public std::binary_function<int, T, int> {
  int operator() (int Arg1, T Arg2) { return Arg1 + Arg2.size(); }
};

int
MarkovChain::getNumTransitions() const
{
  return std::accumulate(states.begin(), states.end(), 0, size_adder<TranStorage>());
}

int
MarkovChain::getNumTransitionsForState(int state) const
{
#ifdef USE_SANITY_CHECKS
  if (state < 0 || state >= (signed)states.size())
    {
      errorlog << "MarkovChain::addTransition: state out of range "
	       << state << ", max=" << states.size() << ende;
      return false;
    }
#endif
  return states[state].size();
}

MarkovChain::TranStorage::const_iterator
MarkovChain::getTran(int state, int tran) const
{
#ifdef USE_SANITY_CHECKS
  if (state < 0 || state >= (signed)states.size())
    {
      errorlog << "MarkovChain::getTran: state out of range "
	       << state << ", max=" << states.size() << ende;
      return states[0].end();
    }
  if (tran < 0 || tran >= (signed)states[state].size())
    {
      errorlog << "MarkovChain::getTran: tran out of range for state "
	       << state << ": "
	       << tran << ", max=" << states[state].size() << ende;
      return states[0].end();
    }
#endif
  MarkovChain::TranStorage::const_iterator iter = states[state].begin();
  std::advance(iter, tran);
  return iter;
}

// replaces any current for transition to this state
bool
MarkovChain::addTransition(int state, int nextstate, double weight)
{
#ifdef USE_SANITY_CHECKS
  if (state < 0 || state >= (signed)states.size())
    {
      errorlog << "MarkovChain::addTransition: state out of range "
	       << state << ", max=" << states.size() << ende;
      return false;
    }
  if (nextstate < 0 || nextstate >= (signed)states.size())
    {
      errorlog << "MarkovChain::addTransition: nextstate out of range "
	       << nextstate << ", max=" << states.size() << ende;
      return false;
    }
#endif
  states[state][nextstate] = weight;
  return true;
}

// adds the weight to any current transition
double
MarkovChain::incrementTransition(int state, int nextstate, double weight)
{
#ifdef USE_SANITY_CHECKS
  if (state < 0 || state >= (signed)states.size())
    {
      errorlog << "MarkovChain::incrementTransition: state out of range "
	       << state << ", max=" << states.size() << ende;
      return -1.0;
    }
  if (nextstate < 0 || nextstate >= (signed)states.size())
    {
      errorlog << "MarkovChain::incrementTransition: nextstate out of range "
	       << nextstate << ", max=" << states.size() << ende;
      return -1.0;
    }
#endif
  double ret = states[state][nextstate];
  states[state][nextstate] += weight;
  return ret;
}

// removes the transition. returns the weight that used to be there
double
MarkovChain::removeTransition(int state, int nextstate)
{
#ifdef USE_SANITY_CHECKS
  if (state < 0 || state >= (signed)states.size())
    {
      errorlog << "MarkovChain::removeTransition: state out of range "
	       << state << ", max=" << states.size() << ende;
      return false;
    }
  if (nextstate < 0 || nextstate >= (signed)states.size())
    {
      errorlog << "MarkovChain::removeTransition: nextstate out of range "
	       << nextstate << ", max=" << states.size() << ende;
      return false;
    }
#endif
  TranStorage::iterator iter = states[state].find(nextstate);
  if (iter == states[state].end())
    return 0.0;
  double retval = iter->second;
  states[state].erase(iter);
  return retval;
}


// normalizes all weights to probabilities
void
MarkovChain::normalize()
{
  for (MarkovChain::StateStorage::iterator state_iter = states.begin();
       state_iter != states.end();
       state_iter++)
    {
      double sum = 0;
      for (MarkovChain::TranStorage::iterator tran_iter = state_iter->begin();
	   tran_iter != state_iter->end();
	   tran_iter++)
	{
	  sum += tran_iter->second;
	}
      for (MarkovChain::TranStorage::iterator tran_iter = state_iter->begin();
	   tran_iter != state_iter->end();
	   tran_iter++)
	{
	  tran_iter->second /= sum;
	}
    }
}

/************************************************************************************/
class MCTransitionLearnerStrategy
  : public StateTraceAnalyzerStrategy
{
public:
  MCTransitionLearnerStrategy(MarkovChain& mc,
			      const char* new_state_count_fn,
			      const char* new_tran_count_fn)
    : mc(mc), new_state_count(0), new_tran_count(0)
  {
    if (new_state_count_fn && new_state_count_fn[0] != 0)
      {
	os_new_state_count.open(new_state_count_fn);
	if (!os_new_state_count)
	  {
	    errorlog << "Could not open new state count file '" << new_state_count_fn << "'" << ende;
	    return;
	  }
	os_new_state_count << "# This file records the number of new states seen for each file in a MarkovChain learning" << std::endl;
	os_new_state_count << "# Format: <filename>\t<count of states reached not seen before" << std::endl;
      }
    if (new_tran_count_fn && new_tran_count_fn[0] != 0)
      {
	os_new_tran_count.open(new_tran_count_fn);
	if (!os_new_tran_count)
	  {
	    errorlog << "Could not open new tran count file '" << new_tran_count_fn << "'" << ende;
	    return;
	  }
	os_new_tran_count << "# This file records the number of new transitions seen for each file in a MarkovChain learning" << std::endl;
	os_new_tran_count << "# Format: <filename>\t<count of trans seen for the first time" << std::endl;
      }
  }
  
  void startFileList(const char* fn) 
  { }
  
  //called at the beginning of every file of state transitions
  void startFile(const char* fn)
  {
    std::cout << "...Processing: " << fn << std::endl;
    new_state_count = 0;
    new_tran_count = 0;
  }

  // THe main callback: inside of each file, called for each state trace element
  // Also remembers previous times and states, just for conveinance
  // -1 means invalid for last_state, last_time
  void handleNextState(int lasttime, int laststate, int time, int state) 
  {
    if (laststate != -1)
      {
	if (mc.incrementTransition(laststate, state, 1.0) == 0.0)
	  new_tran_count++;
      }
    if (mc.getNumTransitionsForState(state) == 0)
      new_state_count++;
  }

  //called at the end of every file analysis
  void endFile(const char* fn)
  {
    if (os_new_state_count)
      os_new_state_count << fn << "\t" << new_state_count << std::endl;
    if (os_new_tran_count)
      os_new_tran_count << fn << "\t" << new_tran_count << std::endl;
  }

  //called at end of processing a list of files
  void endFileList(const char* fn)
  {
    std::cout << "...Normalizing" << std::endl;
    mc.normalize();
  }
  

private:
  MarkovChain& mc;
  int new_state_count;
  int new_tran_count;
  std::ofstream os_new_state_count;
  std::ofstream os_new_tran_count;
};


void
MarkovChain::learnFromTransitions(const char* infn,
				  const char* new_state_count_fn,
				  const char* new_tran_count_fn)
{
  MCTransitionLearnerStrategy strategy(*this, new_state_count_fn, new_tran_count_fn);
  StateTraceAnalyzer analyzer(&strategy);

  int count = analyzer.processListFile(infn);
  if (count < 0)
    errorlog << "MarkovChain::learnFromTransitions: error analyzing file '"
	     << infn << "', count=" << count << ende;
}

/************************************************************************************/
// Performs the fairly simple check of: verify that each state that has a transition in
//  has at least one transition out
bool
MarkovChain::checkMinimalConnectivity(bool output_bad_states)
{
  bool ret = true;
  
  for (MarkovChain::StateStorage::iterator state_iter = states.begin();
       state_iter != states.end();
       state_iter++)
    {
      for (MarkovChain::TranStorage::iterator tran_iter = state_iter->begin();
	   tran_iter != state_iter->end();
	   tran_iter++)
	{
	  if (states[tran_iter->first].empty())
	    {
	      // This is a bad state
	      ret = false;
	      if (output_bad_states)
		std::cout << "State " << tran_iter->first << " has no out transitions; state "
			  << state_iter-states.begin() << " transitions to it"
			  << " (" << state_iter->size() << " tran total)" 
			  << std::endl;
	    }
	}
    }

  return ret;
}

// prunes transitions into states that have not out transitions
// returns number of transitions removed
int
MarkovChain::pruneDeadTransitions()
{
  // we may have to go through several times in case removing a transition makes a new dead state
  int total_removed = 0;
  bool new_state_dead = true;

  while (new_state_dead)
    {
      new_state_dead = false;
      for (MarkovChain::StateStorage::iterator state_iter = states.begin();
	   state_iter != states.end();
	   state_iter++)
	{
	  MarkovChain::TranStorage::iterator next_tran_iter;
	  for (MarkovChain::TranStorage::iterator tran_iter = state_iter->begin();
	       tran_iter != state_iter->end();
	       tran_iter = next_tran_iter)
	    {
	      next_tran_iter = tran_iter;
	      ++next_tran_iter;
	      if (states[tran_iter->first].empty())
		{
		  actionlog(50) << "pruneDeadTransitions: tran from "
				<< state_iter-states.begin()
				<< " to "
				<< tran_iter->first
				<< " is dead"
				<< ende;
		  // This is a bad transition!
		  state_iter->erase(tran_iter);
		  total_removed++;
		  if (state_iter->empty())
		    {
		      new_state_dead = true;
		      actionlog(60) << "pruneDeadTransitions: this makes this state "
				    << state_iter-states.begin()
				    << " dead "
				    << ende;
		    }
		}
	    }
	}
    }

  return total_removed;
}

  
// removes all states which can not reach the given state
int
MarkovChain::pruneNonReverseReachable(int target_state)
{
  /* This is not the most efficient way to do this, but it will work
     The idea is to recursively mark the states that can reach the target_state
     We keep going until we don't mark any new states.
     ALl the unmarked states can't reach target_state, and we then prune those */
#ifdef USE_SANITY_CHECKS
  if (target_state < 0 || target_state >= (signed)states.size())
    {
      errorlog << "MarkovChain::printNonReverseReachable: state out of range "
	       << target_state << ", max=" << states.size() << ende;
      return false;
    }

#endif	

  typedef std::vector<bool> MarkedStorage;
  MarkedStorage marked(states.size(), false);
  marked[target_state] = true;

  // recursively mark states
  for (;;)
    {
      int states_marked_this_step = 0;
      for (int sidx = states.size() - 1;
	   sidx >= 0;
	   sidx--)
	{
	  if (marked[sidx])
	    continue;
	  for (MarkovChain::TranStorage::iterator tran_iter = states[sidx].begin();
	       tran_iter != states[sidx].end();
	       tran_iter++)
	    {
	      if (marked[tran_iter->first])
		{
		  marked[sidx] = true;
		  states_marked_this_step++;
		  break;
		}
	    }
	}

      actionlog(100) << "pruneNonReverseReachable: marked " << states_marked_this_step << ende;
      if (states_marked_this_step == 0)
	break;
    }

  //prune all non marked
  int states_pruned = 0;
  for (MarkedStorage::iterator iter = marked.begin();
       iter != marked.end();
       iter++)
    {
      if (*iter)
	continue;
      int sidx = iter-marked.begin();
      if (!states[sidx].empty())
	{
	  actionlog(100) << "pruneNonReverseReachable: pruning state " << iter-marked.begin() << ende;
	  states_pruned++;
	  states[iter-marked.begin()].clear();
	}
    }

  return states_pruned;
}

// removes a specific transition if it exists. Returns the transition
bool
MarkovChain::pruneSpecificTransition(int first_state, int second_state)
{
  if (first_state < 0 || first_state >= (signed)states.size())
    {
      errorlog << "MarkovChain::pruneSpecificTransition: state out of range "
	       << first_state << ", max=" << states.size() << ende;
      return false;
    }

  TranStorage::iterator iter = states[first_state].find(second_state);
  if (iter == states[first_state].end())
    return false;
  states[first_state].erase(iter);
  return true;
}

  

// Fills in a reindexing array which does not include any state with zero transisitions
void
MarkovChain::identifyLiveStates()
{
  int curr_idx = 0;
  for (int s = 0; s < (signed)states.size(); s++)
    {
      if (states[s].empty())
	reidx[s] = -1;
      else
	reidx[s] = curr_idx++;
    }
  reidx_num_states = curr_idx;
}


void
MarkovChain::clearReIdx()
{
  reidx.resize(states.size());
  fill(reidx.begin(), reidx.end(), -1);
}

// return -1 if this state is not in the reindexing
int
MarkovChain::getReIdxStateVal(int s)
{
#ifdef USE_SANITY_CHECKS
  if (reidx.empty())
    {
      errorlog << "MarkovChain::getReIdxStateVal: have not reindexed!" << ende;
      return -1;
    }
  if (s < 0 || s >= (signed)reidx.size())
    {
      errorlog << "MarkovChain::getReIdxStateVal: state out of range "
	       << s << ", max=" << reidx.size() << ende;
      return -1;
    }
#endif
  return reidx[s];
}

void
MarkovChain::writeTransitions(LibSeaGraphWriter& writer, bool use_reidx, int start_state, int end_state)
{
  for (int s = start_state;
       s < end_state;
       s++)
    {
      if (use_reidx && reidx[s] == -1)
	continue;
      for (MarkovChain::TranStorage::iterator tran_iter = states[s].begin();
	   tran_iter != states[s].end();
	   tran_iter++)
	{
	  //writer.addLink(reidx[s], reidx[tran_iter->first]);
	  // every link is reversed here
	  writer.addLink(reidx[tran_iter->first], reidx[s]);
	}
    }
}

  
int
MarkovChain::findMaxTransition(int state, TranScorer* scorer)
{
  int best_tran = -1;
  int tran_idx = 0;
  double best_score = -(1e11);
  for (TranStorage::iterator tran_iter = states[state].begin();
       tran_iter != states[state].end();
       tran_iter++, tran_idx++)
    {
      double score = scorer->score(*this, state, tran_iter->first, tran_iter->second);
      if (score > best_score)
	{
	  best_score = score;
	  best_tran = tran_idx;
	}
    }
  return best_tran;
}


std::ostream&
operator << (std::ostream &os, const MarkovChain& c)
{
  os << "# This is a MarkovChain description" << std::endl;
  os << "# First line is the number of states" << std::endl;
  os << "# Then, one line per state with <num_pairs> then <nextstate> <weight> pairs" << std::endl;
  os << c.states.size() << std::endl;
  for (MarkovChain::StateStorage::const_iterator state_iter = c.states.begin();
       state_iter != c.states.end();
       state_iter++)
    {
      os << state_iter->size();
      for (MarkovChain::TranStorage::const_iterator tran_iter = state_iter->begin();
	   tran_iter != state_iter->end();
	   tran_iter++)
	{
	  os << ' ' << tran_iter->first << ' ' << tran_iter->second;
	}
      os << std::endl;
    }
  return os;
}

std::istream&
operator >> (std::istream &is, MarkovChain& c)
{
  if (!skip_to_non_comment(is))
    {
      is.setstate(std::ios::failbit);
      return is;
    }
  int size;
  is >> size;
  if (is.fail())
    return is;
  c.clear();
  c.setNumStates(size);
  for (int state = 0; state < size; state++)
    {
      int numtran;
      is >> numtran;
      if (is.fail())
	return is;
      for (int tran = 0; tran < numtran; tran++)
	{
	  int nextstate;
	  double weight;
	  is >> nextstate >> weight;
	  if (is.fail())
	    return is;
	  if (!c.addTransition(state, nextstate, weight))
	    {
	      is.setstate(std::ios::failbit);
	      return is;
	    }
	}
    }
  return is;
}

/**************************************************************************************/
/**************************************************************************************/
/**************************************************************************************/
double
MarkovChainValue::NextValTranScorer::score(const MarkovChain& chain, int state,
					   int nextstate, double prob)
{
  actionlog(210) << "NextValTranScorer: " << state << " " << nextstate << " " << prob
		 << "\t" << value.getValue(state) << " " << value.getValue(nextstate)
		 << ende;
  return value.getValue(nextstate);
}

/**************************************************************************************/

// Does NOT take over the memory
MarkovChainValue::MarkovChainValue(MarkovChain* p, double discount_factor)
  : pchain(p),
    discount_factor(discount_factor),
    rewards(),
    values(p ? p->getNumStates() : 0, 0.0)
{
}

double
MarkovChainValue::getValue(int state)
{
#ifdef USE_SANITY_CHECKS
  if (state < 0 || state >= (signed)values.size())
    {
      errorlog << "MarkovChainValue::getValue: state out of range "
	       << state << ", max=" << values.size() << ende;
      return 0.0;
    }
#endif
  return values[state];
}

//overwrites any reward currently for this state
void
MarkovChainValue::addReward(int state, double reward)
{
  actionlog(50) << "MarkovChainValue: setting reward of state " << state << " to " << reward << ende;
  rewards[state] = reward;
}


void
MarkovChainValue::setChain(MarkovChain* p)
{
  pchain = p;
  if (values.empty())
    {
      setNumStates(p->getNumStates());
    }
  else
    {
      if ((signed)values.size() != p->getNumStates())
	errorlog << "MarkovChainValue::setChain: size not equal "
		 << values.size() << " " << p->getNumStates() << ende;
    }
}

void
MarkovChainValue::clear()
{
  pchain = NULL;
  rewards.clear();
  values.clear();
}

void
MarkovChainValue::setNumStates(int s)
{
  values.resize(s, 0.0);
}

// Sets the values to -distance to the given target state; probabilities are ignored
bool
MarkovChainValue::setToDistance(int target_state)
{
#ifdef USE_SANITY_CHECKS
  if (pchain == NULL)
    {
      errorlog << "setToDistance: null chain" << ende;
      return false;
    }
  if (pchain->states.size() != values.size())
    {
      errorlog << "setToDistance: size mismatch "
	       << pchain->states.size() << " " << values.size() << ende;
      return false;
    }
#endif	

  std::fill(values.begin(), values.end(), -1e10); // just a small value
  values[target_state] = 0;
  int num_states_changed = 0;
  do
    {
      num_states_changed = 0;
      for (int state = 0; state < (signed)pchain->states.size(); state++)
	{
	  for (MarkovChain::TranStorage::iterator tran_iter = pchain->states[state].begin();
	       tran_iter != pchain->states[state].end();
	       tran_iter++)
	    {
	      if (values[state] < values[tran_iter->first] - 1)
		{
		  num_states_changed++;
		  values[state] = values[tran_iter->first] - 1;
		}
	    }
	}
    }
  while (num_states_changed > 0);
  return true;
}

  

//Value update
bool
MarkovChainValue::valueIterateInPlace(int progress_interval,
				      double* ptotal_change, double* pper_capita_change)
{
  double delta = 0;

#ifdef USE_SANITY_CHECKS
  if (pchain == NULL)
    {
      errorlog << "valueIterateInPlace: null chain" << ende;
      return false;
    }
  if (pchain->states.size() != values.size())
    {
      errorlog << "valueIterateInPlace: size mismatch "
	       << pchain->states.size() << " " << values.size() << ende;
      return false;
    }
#endif	
  
  for (int state = values.size() - 1; state >= 0; state--)
    {
      double new_state_val = 0;
      for (MarkovChain::TranStorage::const_iterator tran_iter = pchain->states[state].begin();
	   tran_iter != pchain->states[state].end();
	   tran_iter++)
	{
	  // add in probability * value
	  new_state_val += tran_iter->second * values[tran_iter->first];
	}
      // decay the future rewards
      new_state_val *= discount_factor;
      // add in the reward for this state
      RewardStorage::const_iterator rew_iter = rewards.find(state);
      if (rew_iter != rewards.end())
	new_state_val += rew_iter->second;

      delta += fabs(new_state_val - values[state]);
      values[state] = new_state_val;

      if (progress_interval > 0 && state % progress_interval == 0)
	std::cout << '.' << std::flush;
    }

  if (ptotal_change)
    *ptotal_change = delta;
  if (pper_capita_change)
    *pper_capita_change = delta / values.size();

  return true;
}


// progress_interval is passed to valueIterateInPlace
// the other arguments are stopping criteria
// if the change becomes less than the value given, you stop
bool
MarkovChainValue::learnValues(int progress_interval,
			      double total_change_limit,
			      double per_capita_change_limit,
			      int iteration_limit,
			      bool* p_early_term)
{
  if (!pchain->checkMinimalConnectivity(true))
    {
      errorlog << "learnValues: chain is not minimally connected" << ende;
      return false;
    }
  
  double total_change;
  double per_capita_change;

  for (int iter = 0;
       iteration_limit < 0 || iter < iteration_limit;
       iter++)
    {
      if (progress_interval > 0)
	std::cout << "\tIteration " << iter << ": " << std::flush;
      if (!valueIterateInPlace(progress_interval, &total_change, &per_capita_change))
	{
	  errorlog << "learnValues: iteration " << iter << " failed" << ende;
	  return false;
	}
      if (progress_interval > 0)
	std::cout << "\t" << total_change << " " << per_capita_change << std::endl;

      actionlog(60) << "learnValues: iteration " << iter << " has changes "
		    << total_change << " " << per_capita_change << ende;
      
      if (total_change < total_change_limit)
	{
	  std::cout << "Stopping for total change limit: " << total_change_limit << std::endl;
	  actionlog(50) << "Stopping at iteration " << iter
			<< " because of total change limit " << total_change_limit
			<< ende;
	  break;
	}

      if (per_capita_change < per_capita_change_limit)
	{
	  std::cout << "Stopping for per capitat change limit: " << per_capita_change_limit << std::endl;
	  actionlog(50) << "Stopping at iteration " << iter
			<< " because of per capita change limit " << per_capita_change_limit
			<< ende;
	  break;
	}

      if (p_early_term && *p_early_term)
	{
	  std::cout << "Got an early termination request" << std::endl;
	  actionlog(50) << "Stopping at iteration " << iter
			<< " because got a term request " << ende;
	  break;
	}
      
    }
  return true;
}

void
MarkovChainValue::createGraph(LibSeaGraphWriter& writer, const char* name, int root_state,
			      MarkovChainValue* p_tree)
{
#ifdef USE_SANITY_CHECKS
  if (pchain == NULL)
    {
      errorlog << "createGraph: null chain" << ende;
      return;
    }
  if (pchain->states.size() != values.size())
    {
      errorlog << "createGraph: size mismatch "
	       << pchain->states.size() << " " << values.size() << ende;
      return;
    }
#endif	
  // Finds the state which actually have transitions
  pchain->clearReIdx();
  pchain->identifyLiveStates();

  int root_attr_idx;
  int max_tran_idx;
  
  writer.startGraph();

  writer.writeMetaData(name,
		       "Created by MarkovChainValue::createGraph",
		       pchain->getReIdxNumStates(),
		       pchain->getNumTransitions() - pchain->getNumTransitionsForState(root_state),
		       0, 0);
  //// Strutural Data
  writer.startStructuralData();
  
  writer.startLinks();
  // we do this to avoid writing transitions from the root state
  pchain->writeTransitions(writer, true, 0, root_state);
  pchain->writeTransitions(writer, true, root_state + 1, values.size());
  writer.endLinks();
  
  writer.startPaths();
  writer.endPaths();
  
  writer.endStructuralData();

  //// Attribute data
  writer.startAttributeData();
  writer.writeEnumerations();
  writer.startAttributeDefs();

  root_attr_idx = writer.startAttribute("$root", "bool", "|| false ||");
  writer.startNodeValues();
  writer.addAttrValue(root_state, "T"); 
  writer.endNodeValues();
  writer.startLinkValues();
  writer.endLinkValues();
  writer.startPathValues();
  writer.endPathValues();
  writer.endAttribute();

  max_tran_idx = writer.startAttribute("$max_tran", "bool", "|| false ||");
  writer.startNodeValues();
  writer.endNodeValues();
  writer.startLinkValues();
  NextValTranScorer max_scorer(*p_tree);
  p_tree->createAttrValsForMaxTran(writer, root_state, &max_scorer);
  writer.endLinkValues();
  writer.startPathValues();
  writer.endPathValues();
  writer.endAttribute();

  ValueColorAttributeCreator value_color_creator(this);
  value_color_creator.create(writer, root_state);
  ValueAttributeCreator value_creator(this);
  value_creator.create(writer, root_state);
  TrueStateNumAttributeCreator true_state_creator(this);
  true_state_creator.create(writer, root_state);
  
  writer.endAttributeDefs();

  writer.startQualifiers();
  writer.writeSpanningTree("short_dist_to_goal_spanning_tree",
			   "The links are the shortest distance to the goal from each node",
			   root_attr_idx, max_tran_idx);
  writer.endQualifiers();
  
  writer.endAttributeData();

  writer.writeVisualizationHints();
  writer.writeInterfaceHints();
  
  writer.endGraph();
}

void
MarkovChainValue::createAttrValsForMaxTran(LibSeaGraphWriter& writer,
					   int root_state,
					   MarkovChain::TranScorer* scorer)
{
  int tran_idx = 0;
  for (int state = 0; state < (signed)values.size(); state++)
    {
      if (state == root_state)
	continue;
      int best_rel_idx = pchain->findMaxTransition(state, scorer);
      if (best_rel_idx < 0)
	continue;
#define TEMP_TEST_CODE
#ifdef TEMP_TEST_CODE
      int nextstate = pchain->getTranNextState(state, best_rel_idx);
      if (values[state] >= values[nextstate])
	errorlog << "createAttrValsForMaxTran: did not increase! "
		 << state << " " << best_rel_idx << " " << nextstate
		 << "\t" << values[state] << " " << values[nextstate]
		 << ende;
#endif	
      writer.addAttrValue(tran_idx + best_rel_idx, "T");
      tran_idx += pchain->getNumTransitionsForState(state);
    }
}


std::ostream&
operator << (std::ostream &os, const MarkovChainValue& c)
{
  os << "# This represents a MarkovChainValue object (which is related to a MarkovChain object)" << std::endl;
  os << "# Format: " << std::endl;
  os << "# First line: <num states> <discount factor>" << std::endl;
  os << "# Second line (rewards): <num reward states> <state> <rewards> <state> <reward> ..." << std::endl;
  os << "# All other lines: <value>" << std::endl;

  os << c.values.size() << " " << c.discount_factor << std::endl;

  os << c.rewards.size();
  for (MarkovChainValue::RewardStorage::const_iterator iter = c.rewards.begin();
       iter != c.rewards.end();
       iter++)
    {
      os << " " << iter->first << " " << iter->second;
    }
  os << std::endl;

  std::copy(c.values.begin(), c.values.end(), std::ostream_iterator<double>(os, "\n"));
  return os;
}

std::istream&
operator >> (std::istream &is, MarkovChainValue& c)
{
  if (!skip_to_non_comment(is))
    {
      is.setstate(std::ios::failbit);
      return is;
    }
  int numstates;
  is >> numstates >> c.discount_factor;
  if (is.fail())
    return is;
  MarkovChain* p = c.getChain();
  c.clear();
  c.setChain(p);

  int numrew;
  is >> numrew;
  if (is.fail())
    return is;
  for (int i=0; i<numrew; i++)
    {
      int state;
      double reward;
      is >> state >> reward;
      if (is.fail())
	return is;
      c.addReward(state, reward);
    }
  
  c.setNumStates(numstates);
  for (int state = 0; state < numstates; state++)
    {
      double val;
      is >> val;
      if (is.fail())
	return is;
      c.values[state] = val;
    }
  return is;
}

/**************************************************************************/

int
MarkovChainValue::AttributeCreator::create(LibSeaGraphWriter& writer, int root_state)
{
  int idx = writer.startAttribute(name.c_str(), type.c_str(), def.c_str());
  writer.startNodeValues();
  for (int state = 0; state < pvalue->size(); state++)
    {
      if (pvalue->getChain()->getReIdxStateVal(state) < 0)
	continue;
      handleState(writer, state);
    }
  writer.endNodeValues();
  writer.startLinkValues();
  int tran_idx = 0;
  for (int state = 0; state < pvalue->size(); state++)
    {
      if (state == root_state)
	continue;
      handleLink(writer, state, tran_idx);
      tran_idx += pvalue->getChain()->getNumTransitionsForState(state);
    }
  writer.endLinkValues();
  writer.startPathValues();
  writer.endPathValues();
  writer.endAttribute();

  return idx;
}

void
MarkovChainValue::ValueColorAttributeCreator::handleState(LibSeaGraphWriter& writer,
							  int state)
{
  unsigned char r = 255;
  unsigned char g = MinMax(0,
			   (int)rint(255 * ((getValueObject()->getValue(state) + 100)/200)),
			   255); //THIS IS A BIG HACK SMURF
  unsigned char b = 255;
  actionlog(210) << "State " << state << " has value color "
		 << (unsigned)r << ' ' << (unsigned)g << ' ' << (unsigned)b
		 << ende;
      
  writer.addAttrValue(getValueObject()->getChain()->getReIdxStateVal(state),
		      LibSeaGraphWriter::encodeRGB(r, g, b));
}


void
MarkovChainValue::ValueAttributeCreator::handleState(LibSeaGraphWriter& writer,
						     int state)
{
  writer.addAttrValue(getValueObject()->getChain()->getReIdxStateVal(state),
		      (float)getValueObject()->getValue(state));
}

void
MarkovChainValue::TrueStateNumAttributeCreator::handleState(LibSeaGraphWriter& writer,
							    int state)
{
  writer.addAttrValue(getValueObject()->getChain()->getReIdxStateVal(state), state);
}

