/* -*- Mode: C++ -*- */

#include "QTableFlex.h"
#include "MDP.h"
#include "Logger.h"

using namespace spades;

//static
const char* QTableFlex::BIN_FILE_MAGIC = "PFRQT";
//static
const int QTableFlex::BIN_FILE_VERSION = 0;

QTableFlex::QTableFlex(int num_states, float gamma)
  : QTable(num_states, gamma), qtable(num_states)
{
}

QTableFlex::QTableFlex(const MDP& mdp, float gamma)
  : QTable(mdp.getNumStates(), gamma)
{
  setSizeFrom(mdp);
}

QTableFlex::~QTableFlex()
{
}

void
QTableFlex::setSizeFrom(const MDP& mdp)
{
  num_states = mdp.getNumStates();
  qtable.resize(num_states);
  for (StateStorage::iterator siter = qtable.begin();
       siter != qtable.end();
       siter++)
    {
      siter->resize(mdp.getNumActionsInState(siter-qtable.begin()));
    }
  zero();
}


//zeros all ENABLED actions
void
QTableFlex::zero()
{
  for (StateStorage::iterator siter = qtable.begin();
       siter != qtable.end();
       siter++)
    {
      for (ActionStorage::iterator aiter = siter->begin();
	   aiter != siter->end();
	   aiter++)
	{
	  if (!aiter->isEnabled())
	    continue;
	  aiter->zero();
	}
    }
}

  
int
QTableFlex::getNumActions(int state) const
{
  if (!checkValid(state, -1))
    return -1;
  return qtable[state].size();
}


QTable::StateActionEntry*
QTableFlex::getSA(int state, int action)
{
  if (!checkValid(state, action))
    return NULL;
  return &(qtable[state][action]);
}

const QTable::StateActionEntry*
QTableFlex::getSA(int state, int action) const
{
  if (!checkValid(state, action))
    return NULL;
  return &(qtable[state][action]);
}

  
bool
QTableFlex::checkValid(int state, int action) const
{
  if (state < 0 || state >= num_states)
    return false;
  if (action != -1 &&
      (action < 0 || action >= (signed)qtable[state].size()))
    return false;
  return true;
}

std::ostream&
operator<<(std::ostream& os, const QTableFlex& qt)
{
  os << qt.num_states << " " << qt.gamma << std::endl;
  
  for (QTableFlex::StateStorage::const_iterator siter = qt.qtable.begin();
       siter != qt.qtable.end();
       siter++)
    {
      os << siter->size() << ' ';
      std::copy(siter->begin(), siter->end(),
		std::ostream_iterator<QTableFlex::StateActionEntry>(os, " "));
      os << std::endl;
    }
  return os;
}

std::istream&
operator>>(std::istream& is, QTableFlex& qt)
{
  qt.clear();

  is >> qt.num_states >> qt.gamma;
  if (is.fail())
    return is;

  qt.qtable.resize(qt.num_states);
  
  for (int s = 0; s < qt.num_states; s++)
    {
      int num_act;
      is >> num_act;
      if (is.fail())
	return is;
      for (int a = 0; a < num_act; a++)
	{
	  QTableFlex::StateActionEntry entry;
	  is >> entry;
	  if (is.fail())
	    return is;
	  qt.qtable[s].push_back(entry);
	}
    }

  return is;
}

bool
QTableFlex::writeTo(BinaryFileWriter& writer) const
{
  if (!writer.writeMagicHeader(BIN_FILE_MAGIC)) return false;

  if (!writer.writeChar(BIN_FILE_VERSION)) return false;
  
  if (!writer.writeInt(num_states)) return false;
  if (!writer.writeFloat(gamma)) return false;
  
  for (StateStorage::const_iterator siter = qtable.begin();
       siter != qtable.end();
       ++siter)
    {
      if (!writer.writeShort(siter->size())) return false;
      for (ActionStorage::const_iterator aiter = siter->begin();
	   aiter != siter->end();
	   ++aiter)
	{
	  if (!aiter->writeTo(writer)) return false;
	}
    }
  return true;
}

bool
QTableFlex::readFrom(BinaryFileReader& reader)
{
  clear();

  if (!reader.checkMagicHeader(BIN_FILE_MAGIC)) return false;

  char version;
  if (!reader.readChar(&version)) return false;
  if (version != BIN_FILE_VERSION)
    {
      warninglog(10) << "Trying to read different version QTableFlex: saw="
		     << version << ", exp=" << BIN_FILE_VERSION
		     << ende;
      return false;
    }
  

  if (!reader.readInt(&num_states)) return false;
  if (!reader.readFloat(&gamma)) return false;

  qtable.resize(num_states);
  
  for (int s = 0; s < num_states; s++)
    {
      short num_act;
      if (!reader.readShort(&num_act)) return false;
      for (int a = 0; a < num_act; a++)
	{
	  StateActionEntry entry;
	  if (!entry.readFrom(reader)) return false;
	  qtable[s].push_back(entry);
	}
    }
  
  return true;
}

bool
QTableFlex::readTextOrBinary(std::istream& is)
{
  // we'll try binary first, and if that fails, we'll reset and try ascii
  std::streampos init_pos = is.tellg();
  
  {
    BinaryFileReader reader(is);

    if (readFrom(reader))
      return true;
  }

  is.seekg(init_pos);

  is >> *this;

  return !is.fail();
}

