/******************************** CPPFile *****************************

* FileName [HornCheckerAI.cpp]

* PackageName [main]

* Synopsis [Method definitions of HornCheckerAI class.]

* SeeAlso [HornCheckerAI.h]

* Author [Sagar Chaki]

* Copyright [ Copyright (c) 2002 by Carnegie Mellon University. All
* Rights Reserved. This software is for educational purposes only.
* Permission is given to academic institutions to use, copy, and
* modify this software and its documentation provided that this
* introductory message is not removed, that this software and its
* documentation is used for the institutions' internal research and
* educational purposes, and that no monies are exchanged. No guarantee
* is expressed or implied by the distribution of this code. Send
* bug-reports and/or questions to: chaki+@cs.cmu.edu. ]

**********************************************************************/

#include <cassert>
#include <list>
#include <set>
#include <map>
#include <vector>
using namespace std;

#include "HornCheckerAI.h"
#ifdef MAGIC_FULL
#include "SAT.h"
#endif //MAGIC_FULL

/*********************************************************************/
//methods for SimpleNode
/*********************************************************************/
const SimpleNode &SimpleNode::operator = (const SimpleNode &rhs)
{
  status = rhs.status;
  sNext = rhs.sNext;
  cNext = rhs.cNext;
  return *this;
}

/*********************************************************************/
//methods for CompoundNode
/*********************************************************************/
const CompoundNode &CompoundNode::operator = (const CompoundNode &rhs)
{
  status = rhs.status;
  sNext = rhs.sNext;
  return *this;
}

/*********************************************************************/
//static data
/*********************************************************************/
map<int,SimpleNode> HornCheckerAI::trueFalse;
map<int,SimpleNode> HornCheckerAI::simples;
map< set<int>,list<CompoundNode> > HornCheckerAI::compounds;

/*********************************************************************/
//initialise the solver
/*********************************************************************/
void HornCheckerAI::Initialise()
{
  trueFalse.clear();
  trueFalse[0].status = 0;
  trueFalse[1].status = 1;
  simples.clear();
  compounds.clear();
}

/*********************************************************************/
//set the number of variables
/*********************************************************************/
void HornCheckerAI::SetVarNum(int num)
{
  assert(static_cast<size_t>(num) >= simples.size());
  for(size_t i = simples.size(); i < static_cast<size_t>(num);++i) simples[i] = SimpleNode();
}

/*********************************************************************/
//remove a clause.
/*********************************************************************/
void HornCheckerAI::RemoveClause(const set<int> &clause)
{
  pair< set<int>,set<int> > comps;
  Split(clause,comps);
  if(comps.second.empty()) RemoveArcToTrue(comps.first);
  else if(comps.first.empty()) RemoveArcFromFalse(*(comps.second.begin()));
  else RemoveHyperArc(comps.first,*(comps.second.begin()));
}

/*********************************************************************/
//add a clause.
/*********************************************************************/
void HornCheckerAI::AddClause(const set<int> &clause)
{
  pair< set<int>,set<int> > comps;
  Split(clause,comps);
  if(comps.second.empty()) AddArcToTrue(comps.first);
  else if(comps.first.empty()) AddArcFromFalse(*(comps.second.begin()));
  else AddHyperArc(comps.first,*(comps.second.begin()));
}

/*********************************************************************/
//split the clause into a list of positive literals and a list of
//negative literals. return the two lists in the second argument.
/*********************************************************************/
void HornCheckerAI::Split(const set<int> &clause,pair< set<int>,set<int> > &dest)
{
  for(set<int>::const_iterator i = clause.begin();i != clause.end();++i) {
    int x = *i;
    if(x % 2) {
      dest.second.insert(((x - 1) / 2) - 1);
    } else {
      dest.first.insert((x / 2) - 1);
    }
  }
  assert((dest.second.size() < 2));
}

/*********************************************************************/
//add a hyperarc from a source set to the true node
/*********************************************************************/
void HornCheckerAI::AddArcToTrue(const set<int> &source)
{
  list<CompoundNode> &clist = compounds[source];
  list<CompoundNode>::iterator it = clist.insert(clist.end(),CompoundNode());
  for(set<int>::const_iterator i = source.begin();i != source.end();++i) {
    simples[*i].cNext.push_back(it);
  }
  it->sNext.push_back(trueFalse.find(1));
}

/*********************************************************************/
//remove a hyperarc from a source set to the true node
/*********************************************************************/
void HornCheckerAI::RemoveArcToTrue(const set<int> &source)
{
  list<CompoundNode> &clist = compounds[source];
  for(list<CompoundNode>::iterator i = clist.begin();i != clist.end();++i) {
    if(i->sNext.front() == trueFalse.find(1)) {
      clist.erase(i);
      for(set<int>::const_iterator j = source.begin();j != source.end();++j) {
	bool found = false;
	for(list<list<CompoundNode>::iterator>::iterator k = simples[*j].cNext.begin();k != simples[*j].cNext.end();++k) {
	  if((*k) == i) {
	    simples[*j].cNext.erase(k);
	    found = true;
	    break;
	  }
	}
	assert(found);
      }
      return;
    }
  }
  assert(false);
}

/*********************************************************************/
//add a hyperarc from the false node a destination node
/*********************************************************************/
void HornCheckerAI::AddArcFromFalse(const int dest)
{
  trueFalse[0].sNext.push_back(simples.find(dest));
}

/*********************************************************************/
//remove a hyperarc from the false node a destination node
/*********************************************************************/
void HornCheckerAI::RemoveArcFromFalse(const int dest)
{
  bool found = false;
  map<int,SimpleNode>::iterator it = simples.find(dest);
  for(list<map<int,SimpleNode>::iterator>::iterator i = trueFalse[0].sNext.begin();i != trueFalse[0].sNext.end();++i) {
    if((*i) == it) {
      trueFalse[0].sNext.erase(i);
      found = true;
      break;
    }
  }
  assert(found);
}

/*********************************************************************/
//add a hyperarc from a source set to a destination node
/*********************************************************************/
void HornCheckerAI::AddHyperArc(const set<int> &source,const int dest)
{
  if(source.size() == 1) {
    int x = *(source.begin());
    simples[x].sNext.push_back(simples.find(dest));
  } else {
    list<CompoundNode> &clist = compounds[source];
    list<CompoundNode>::iterator it = clist.insert(clist.end(),CompoundNode());
    for(set<int>::const_iterator i = source.begin();i != source.end();++i) {
      simples[*i].cNext.push_back(it);
    }    
    it->sNext.push_back(simples.find(dest));
  }
}

/*********************************************************************/
//remove a hyperarc from a source set to a destination node
/*********************************************************************/
void HornCheckerAI::RemoveHyperArc(const set<int> &source,const int dest)
{
  if(source.size() == 1) {
    int x = *(source.begin());
    bool found = false;
    map<int,SimpleNode>::iterator it = simples.find(dest);
    for(list<map<int,SimpleNode>::iterator>::iterator i = simples[x].sNext.begin();i != simples[x].sNext.end();++i) {
      if((*i) == it) {
	simples[x].sNext.erase(i);
	found = true;
	break;
      }
    }
    assert(found);
  } else {
    list<CompoundNode> &clist = compounds[source];
    for(list<CompoundNode>::iterator i = clist.begin();i != clist.end();++i) {
      if(i->sNext.front() == simples.find(dest)) {
	clist.erase(i);
	for(set<int>::const_iterator j = source.begin();j != source.end();++j) {
	  bool found = false;
	  for(list<list<CompoundNode>::iterator>::iterator k = simples[*j].cNext.begin();k != simples[*j].cNext.end();++k) {
	    if((*k) == i) {
	      simples[*j].cNext.erase(k);
	      found = true;
	      break;
	    }
	  }
	  assert(found);
	}
	return;
      }
    }
    assert(false);
  }
}

/*********************************************************************/
//compute closure of a simple node to propagate reachability
//information
/*********************************************************************/
void HornCheckerAI::Closure(map<int,SimpleNode>::iterator &node,vector<int> &elim)
{
  if(node->second.status != 0) {
    --node->second.status;
    if(node->second.status == 0) {
      if(node != trueFalse.find(1)) {
	int index = node->first;
	elim.push_back(index);
      }

      for(list<map<int,SimpleNode>::iterator>::iterator i = node->second.sNext.begin();i != node->second.sNext.end();++i) {
	Closure(*i,elim);
      }
      for(list<list<CompoundNode>::iterator>::iterator i = node->second.cNext.begin();i != node->second.cNext.end();++i) {
	Closure(**i,elim);
      }
    }
  }
}

/*********************************************************************/
//compute closure of a compound node to propagate reachability
//information
/*********************************************************************/
void HornCheckerAI::Closure(CompoundNode &node,vector<int> &elim)
{
  if(node.status != 0) {
    --node.status;
    if(node.status == 0) {
      for(list<map<int,SimpleNode>::iterator>::iterator i = node.sNext.begin();i != node.sNext.end();++i) {
	Closure(*i,elim);
      }
    }
  }
}

/*********************************************************************/
//solve for satisfiability. return true if the formula is satisfiable
//and false otherwise. the second argument is used to store the
//sequence of variables which have been thrown out of the maximal
//simulation relation during the check for satisfiability.
/*********************************************************************/
bool HornCheckerAI::SolveSat(vector<int> &elim)
{
  //set the status values
  for(map< set<int>,list<CompoundNode> >::iterator i = compounds.begin();i != compounds.end();++i) {
    for(list<CompoundNode>::iterator j = i->second.begin();j != i->second.end();++j) {
      j->status = 0;
    }
  }
  for(map<int,SimpleNode>::iterator i = simples.begin();i != simples.end();++i) {
    i->second.status = 1;
    for(list<list<CompoundNode>::iterator>::iterator j = i->second.cNext.begin();j != i->second.cNext.end();++j) {
      ++(*j)->status;
    }
  }
  trueFalse[1].status = 1;

  //do reachability from false node
  for(list<map<int,SimpleNode>::iterator>::iterator i = trueFalse[0].sNext.begin();i != trueFalse[0].sNext.end();++i) {
    Closure(*i,elim);
  }
  for(list<list<CompoundNode>::iterator>::iterator i = trueFalse[0].cNext.begin();i != trueFalse[0].cNext.end();++i) {
    Closure(**i,elim);
  }
  return trueFalse[1].status;
}

/*********************************************************************/
//cleanup
/*********************************************************************/
void HornCheckerAI::Cleanup()
{
  simples.clear();
  compounds.clear();
}

/*********************************************************************/
//end of HornCheckerAI.cpp
/*********************************************************************/
