/* -*- Mode: C++ -*- */

/* This class represents a Markov chain */

#ifndef _MARKOVCHAIN_H_
#define _MARKOVCHAIN_H_

#include <iostream>
#include <map>
#include <vector>
#include "LibSeaGraphWriter.h"

class MarkovChain;
class MarkovChainValue;

class MarkovChain 
{
public:
  MarkovChain(int num_states = 0);
  ~MarkovChain();

  void clear() { states.clear(); }
  
  int getNumStates() const { return states.size(); }
  void setNumStates(int new_size);

  int getNumTransitions() const;
  int getNumTransitionsForState(int state) const;

  int getTranNextState(int state, int tran) const { return getTran(state, tran)->first; }
  double getTranProb(int state, int tran) const { return getTran(state, tran)->second; }

  // replaces any current for transition to this state
  bool addTransition(int state, int nextstate, double weight);
  // adds the weight to any current transition
  // returns previous weight, number < 0 on error
  double incrementTransition(int state, int nextstate, double weight);
  // removes the transition. returns the weight that used to be there
  double removeTransition(int state, int nextstate);
  
  // normalizes all weights to probabilities
  void normalize();

  void learnFromTransitions(const char* infn,
			    const char* new_state_count_fn,
			    const char* new_tran_count_fn);

  // Performs the fairly simple check of: verify that each state that has a transition in
  //  has at least one transition out
  bool checkMinimalConnectivity(bool output_bad_states);

  // prunes transitions into states that have not out transitions
  // returns number of transitions removed
  int pruneDeadTransitions();

  // removes all states which can not reach the given state
  int pruneNonReverseReachable(int target_state);

  // removes a specific transition if it exists. Returns the transition
  bool pruneSpecificTransition(int first_state, int second_state);
  
  // Fills in a reindexing array which does not include any state which zero transisitions
  void identifyLiveStates();

  void clearReIdx();
  // return -1 if this state is not in the reindexing
  int getReIdxNumStates() const { return reidx_num_states; }
  int getReIdxStateVal(int s);

  void writeTransitions(LibSeaGraphWriter& writer, bool use_reidx, int start, int end);
  
  class TranScorer
  {
  public:
    TranScorer() {}
    virtual ~TranScorer() {}

    virtual double score(const MarkovChain& chain, int state,
			 int nextstate, double prob) = 0;
  };

  int findMaxTransition(int state, TranScorer* scorer);

  friend std::ostream& operator << (std::ostream &os, const MarkovChain& c);
  friend std::istream& operator >> (std::istream &is, MarkovChain& c);

protected:
  friend class MarkovChainValue;

  //maps a next state to a weight (which once normalized) represents a probability
  typedef std::map<int, double> TranStorage;
  typedef std::vector<TranStorage> StateStorage;
  typedef std::vector<int> ReIdxStorage;

  TranStorage::const_iterator getTran(int state, int tran) const;
  
  StateStorage states;
  ReIdxStorage reidx;
  int reidx_num_states;
};

/*********************************************************************************/

class MarkovChainValue
{
public:
  class NextValTranScorer
    : public MarkovChain::TranScorer
  {
  public:
    NextValTranScorer(MarkovChainValue& value) : value(value) {}

    double score(const MarkovChain& chain, int state,
		 int nextstate, double prob);
  private:
    MarkovChainValue& value;
  };

public:
  // Does NOT take over the memory
  MarkovChainValue(MarkovChain* p, double discount_factor);
  ~MarkovChainValue() {}

  MarkovChain* getChain() { return pchain; }
  void setChain(MarkovChain* p);
  
  double getDiscountFactor() const { return discount_factor; }
  void setDiscountFactor(double d) { discount_factor = d; }

  double getValue(int state);
  int size() const { return values.size(); }

  //overwrites any reward currently for this state
  void addReward(int state, double reward);

  void clear();

  // Sets the values to -distance to the given target state; probabilities are ignored
  bool setToDistance(int target_state);
  
  //Value update
  // progress_interval <= 0 turns it off, otherwise, show a dot every that many states
  bool valueIterateInPlace(int progress_interval, double* ptotal_change, double* pper_capita_change);

  // progress_interval is passed to valueIterateInPlace
  // the other arguments are stopping criteria
  // if the change becomes less than the value given, you stop
  // The value in p_early_term is checked after every value iteration; if true, we exit
  bool learnValues(int progress_interval,
		   double total_change_limit, double per_capita_change_limit,
		   int iteration_limit,
		   bool* p_early_term = NULL);

  // p_tree is used (by following the max next possible transition)
  // to create the spanning tree
  void createGraph(LibSeaGraphWriter& writer,
		   const char* name, int root_state,
		   MarkovChainValue* p_tree);
  
  friend std::ostream& operator << (std::ostream &os, const MarkovChainValue& c);
  friend std::istream& operator >> (std::istream &is, MarkovChainValue& c);

private:
  void setNumStates(int s);

  void createAttrValsForMaxTran(LibSeaGraphWriter& writer, int root_state, MarkovChain::TranScorer* scorer);

  // We'll associate rewards with states
  typedef std::map<int, double> RewardStorage; // state to reward;
  typedef std::vector<double> ValueStorage;
  
  MarkovChain* pchain;
  double discount_factor;
  RewardStorage rewards;
  ValueStorage values;

private:
  class AttributeCreator
  {
  public:
    AttributeCreator(MarkovChainValue* pvalue,
		     const char* name, const char* type, const char* def)
      : pvalue(pvalue), name(name), type(type), def(def) {}
    virtual ~AttributeCreator() {}
    
    // returns the attr id
    int create(LibSeaGraphWriter& writer, int root_state);
    virtual void handleState(LibSeaGraphWriter& writer, int state) {}
    // note this is a state index! 
    virtual void handleLink(LibSeaGraphWriter& writer, int state, int curr_global_tran_count) {}

    MarkovChainValue* getValueObject() { return pvalue; }
    
  private:
    MarkovChainValue *pvalue;
    std::string name;
    std::string type;
    std::string def;
  };
  

  class ValueColorAttributeCreator
    : public AttributeCreator
  {
  public:
    ValueColorAttributeCreator(MarkovChainValue* pvalue)
      : AttributeCreator(pvalue, "$value_color", "int", "") {}

    void handleState(LibSeaGraphWriter& writer, int state);
  };

  class ValueAttributeCreator
    : public AttributeCreator
  {
  public:
    ValueAttributeCreator(MarkovChainValue* pvalue)
      : AttributeCreator(pvalue, "$value", "float", "") {}

    void handleState(LibSeaGraphWriter& writer, int state);
  };

  class TrueStateNumAttributeCreator
    : public AttributeCreator
  {
  public:
    TrueStateNumAttributeCreator(MarkovChainValue* pvalue)
      : AttributeCreator(pvalue, "$orig_state_num", "int", "") {}

    void handleState(LibSeaGraphWriter& writer, int state);
  };
};

#endif
