/******************************** CPPFile *****************************

* FileName [SatSimChecker.cpp]

* PackageName [main]

* Synopsis [Method definitions of SatSimChecker class.]

* SeeAlso [SatSimChecker.h]

* Author [Sagar Chaki]

* Copyright [ Copyright (c) 2002 by Carnegie Mellon University. All
* Rights Reserved. This software is for educational purposes only.
* Permission is given to academic institutions to use, copy, and
* modify this software and its documentation provided that this
* introductory message is not removed, that this software and its
* documentation is used for the institutions' internal research and
* educational purposes, and that no monies are exchanged. No guarantee
* is expressed or implied by the distribution of this code. Send
* bug-reports and/or questions to: chaki+@cs.cmu.edu. ]

**********************************************************************/

#include <cstdio>
#include <cassert>
#include <unistd.h>
#include <gmp.h>
#include <string>
#include <list>
#include <set>
#include <map>
#include <vector>
using namespace std;

#include "BigInt.h"
#include "Util.h"
#include "Statistics.h"
#include "Node.h"
#include "Action.h"
#include "ProcAbs.h"
#include "Database.h"
#include "Predicate.h"
#include "PredSet.h"
#include "ImplState.h"
#include "ContLoc.h"
#include "ProcInfo.h"
#include "LtsInfo.h"
#include "HornCheckerAI.h"
#ifdef MAGIC_FULL
#include "SAT.h"
#endif //MAGIC_FULL
#include "SatSimChecker.h"
#include "SimChecker.h"
#include "GlobalAbsLts.h"
using namespace magic;

/*********************************************************************/
//static members
/*********************************************************************/
#ifdef MAGIC_FULL
SAT_Manager SatSimChecker::manager;
const char SatSimChecker::DIMACS_FILE[] = "sat.cnf";
const char SatSimChecker::OUTPUT_FILE[] = "sat.out";
#endif //MAGIC_FULL
map<BigInt,int> SatSimChecker::varMap;
vector<BigInt> SatSimChecker::varVec;
set< set<int> > SatSimChecker::clauses;
set< set<int> > SatSimChecker::addClauses;
set< set<int> > SatSimChecker::remClauses;
const int SatSimChecker::ASSIGNED_TRUE = 1500;
const int SatSimChecker::ASSIGNED_FALSE = 1510;
const int SatSimChecker::ASSIGNED_UNKNOWN = 1520;

/*********************************************************************/
//check simulation by SAT method. if the simulation relation exists
//return true and also compute the simulation relation from a
//satisfying assignment.
/*********************************************************************/
bool SatSimChecker::CheckSimulation()
{
  //we can construct the simulation relation only if we use HORNSAT
  assert(!Database::SIM_REL_FROM_SAT || (Database::SAT_SOLVER_USED == Database::SAT_SOLVER_HORN_AI));
  //create the sat manager
  CreateSatManager();
  //add the constraints
  assert(AddInitialConstraints());
  assert(AddNonInitialConstraints());
  Util::Message(2,"number of HORNSAT variables = %d ...\n",varVec.size());
  Statistics::stateNum = (Statistics::stateNum < varVec.size()) ? varVec.size() : Statistics::stateNum;
  bool satResult = SolveSat();
  //destroy the sat manager
  DestroySatManager();
  //return the result
  return satResult;
}

/*********************************************************************/
//create the SAT manager
/*********************************************************************/
void SatSimChecker::CreateSatManager()
{
  //these must always be cleared
  addClauses.clear();
  remClauses.clear();

  //if we are not doing incremental verification or if the last
  //refinement was not an LTS refinement
  if((!Database::INC_VERIFY) || (Database::lastRefType != Database::REFINE_LTS)) {
    //initialize data structures
    clauses.clear();
    varMap.clear();
    varVec.clear();
    //clear the global set of reachable states
    GlobalAbsLts::reach.clear();

    //if using horn solver
    if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_HORN_AI) {
      HornCheckerAI::Initialise();
    }
#ifdef MAGIC_FULL
    //if using chaff
    else if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_CHAFF) {
      //create the sat manager using native code
      manager = SAT_InitManager();  
      //add the hook function for the variable value decision
      if(Database::SIM_REL_FROM_SAT) {
	SAT_AddHookFun(manager,InitTrueDecision,1);
      }
    }
#endif //MAGIC_FULL
    //illegal
    else assert(false);
  }
  Util::Message(2,"SAT manager created ...\n");
}

#ifdef MAGIC_FULL
/*********************************************************************/
//this is the customized decision making procedure. this ensures that
//"true" is always tried as the first value of every variable.
/*********************************************************************/
void SatSimChecker::InitTrueDecision(SAT_Manager mng)
{
  int n_var = SAT_NumVariables(mng);
  int i;
  for (i = 1; i < n_var;++i) {
    if (SAT_GetVarAsgnment(mng,i) == UNKNOWN){
      //make decision with "true"
      SAT_MakeDecision(mng,i,0); 
      break;
    }
  }
  //if every var got an assignment, we are done
  if (i >= n_var) SAT_MakeDecision(mng,0,0);
}

/*********************************************************************/
//create the SAT file in DIMACS format
/*********************************************************************/
void SatSimChecker::CreateDimacsFile()
{
  FILE *satf = fopen(DIMACS_FILE,"w");
  fprintf(satf,"p cnf %d %d\n",varVec.size(),clauses.size());
  for(set< set<int> >::const_iterator i = clauses.begin();i != clauses.end();++i) {
    const set<int> &a = *i;
    for(set<int>::const_iterator j = a.begin();j != a.end();++j) {
      //convert the variable index from chaff to dimacs format
      int dimacsId = (((*j) % 2) == 1) ? (-(((*j) - 1) / 2)) : ((*j) / 2);
      fprintf(satf,"%d ",dimacsId);
    }
    fprintf(satf,"0\n");
  }
  fclose(satf);
}

/*********************************************************************/
//solver SAT using SATO
/*********************************************************************/
int SatSimChecker::SolveSatSato()
{
  char *envptr = getenv("MAGICROOT");
  if(envptr == NULL) {
    return -1;
  }
  string command = string(envptr) + "/sato3.2.1/sato -f -g " + DIMACS_FILE + " > " + OUTPUT_FILE;
  system(command.c_str());
  command = string("grep Model ") + OUTPUT_FILE;
  if(system(command.c_str()) == 0) {
    return 1;
  } else {
    command = string("grep unsatisfiable ") + OUTPUT_FILE;
    if(system(command.c_str()) == 0) {
      return 0;
    } else {
      return -1;
    }
  }
}

/*********************************************************************/
//solve SAT using grasp
/*********************************************************************/
int SatSimChecker::SolveSatGrasp()
{
  char *envptr = getenv("MAGICROOT");
  if(envptr == NULL) {
    return -1;
  }
  string command = string(envptr) + "/fgrasp/sat-grasp.st.linux " + DIMACS_FILE + " > " + OUTPUT_FILE;
  system(command.c_str());
  command = string("grep Satisfiable ") + OUTPUT_FILE;
  if(system(command.c_str()) == 0) {
    return 1;
  } else {
    command = string("grep Unsatisfiable ") + OUTPUT_FILE;
    if(system(command.c_str()) == 0) {
      return 0;
    } else {
      return -1;
    }
  }
}
#endif //MAGIC_FULL

/*********************************************************************/
//add the non-initial constraint clauses to the sat manager
/*********************************************************************/
bool SatSimChecker::AddNonInitialConstraints()
{
  //the spec lts
  const LtsInfo &specInfo = Database::GetLtsInfo(Database::specLtsName);
  //the states of the specification
  const set<string> &specStates = specInfo.GetStates();
  //number of spec lts states
  int ssn = specStates.size();
  //the set of reachable and frontier states
  set<BigInt> reach,front;
  GlobalAbsLts::GetInitStates(front);
  //generate sat constraints for each global abstract state
  bool res = true;
  while(!front.empty()) {
    reach.insert(front.begin(),front.end());
    set<BigInt> newFront = front;
    front.clear();
    for(set<BigInt>::const_iterator i = newFront.begin();i != newFront.end();++i) {
      map< Action,set<BigInt> > saa;
      GlobalAbsLts::GetSuccsAndActions(*i,saa);
      for(map< Action,set<BigInt> >::const_iterator j = saa.begin();j != saa.end();++j) {
	for(set<BigInt>::const_iterator k = j->second.begin();k != j->second.end();++k) {
	  if(reach.count(*k) == 0) front.insert(*k);
	  for(set<string>::const_iterator l = specStates.begin();l != specStates.end();++l) {
	    set<BigInt> clause;
	    clause.insert(((*i) * ssn + specInfo.GetStateId(*l) + 1) * 2 + 1);
	    set<string> specSuccs; specInfo.GetSuccsOnActionSpec(*l,j->first,specSuccs);
	    for(set<string>::const_iterator m = specSuccs.begin();m != specSuccs.end();++m) {
	      if(!AddLiteral(clause,((*k) * ssn + specInfo.GetStateId(*m) + 1) * 2)) break;
	    }
	    if(!clause.empty()) AddClause(clause);
	  }
	}
      }
    }
  }
  //all done
  Util::Message(2,"number of reachable implementation states : %d\n",reach.size());
  return res;
}

/*********************************************************************/
//add a clause to the boolean formula to be solved
/*********************************************************************/
void SatSimChecker::AddClause(const set<BigInt> &clause)
{
  set<int> newClause;
  for(set<BigInt>::const_iterator i = clause.begin();i != clause.end();++i) {
    BigInt x,y; BigInt::Div(x,y,*i,2);
    bool pos = (y.ToSL() == 0);
    BigInt varId = pos ? (x - 1) : ((((*i) - 1) / 2) - 1);
    int newId = -1;
    if(varMap.count(varId) == 0) {
      newId = varMap.size();
      varMap[varId] = newId;
      varVec.push_back(varId);
    } else newId = varMap[varId];
    newClause.insert(pos ? ((newId + 1) * 2) : (((newId + 1) * 2) + 1));
  }
  addClauses.insert(newClause);
}

/*********************************************************************/
//remove a clause from the boolean formula to be solved
/*********************************************************************/
void SatSimChecker::RemoveClause(const set<BigInt> &clause)
{
  set<int> newClause;
  for(set<BigInt>::const_iterator i = clause.begin();i != clause.end();++i) {
    BigInt x,y; BigInt::Div(x,y,*i,2);
    bool pos = (y.ToSL() == 0);
    BigInt varId = pos ? (x - 1) : ((((*i) - 1) / 2) - 1);
    int newId = -1;
    if(varMap.count(varId) == 0) {
      newId = varMap.size();
      varMap[varId] = newId;
      varVec.push_back(varId);
    } else newId = varMap[varId];
    newClause.insert(pos ? ((newId + 1) * 2) : (((newId + 1) * 2) + 1));
  }
  remClauses.insert(newClause);
}

/*********************************************************************/
//add constraint clauses due to the initial state
/*********************************************************************/
bool SatSimChecker::AddInitialConstraints()
{
  //the spec lts
  const LtsInfo &specInfo = Database::GetLtsInfo(Database::specLtsName);
  //the states of the specification
  const set<string> &specStates = specInfo.GetStates();
  //number of spec lts states
  int ssn = specStates.size();
  //the id of the initial state of the spec lts
  int initSpecId = specInfo.GetStateId(Database::specLtsName);

  set<BigInt> initImpls; GlobalAbsLts::GetInitStates(initImpls);
  set<BigInt> prevInitImpls; 
  if(Database::INC_VERIFY && (Database::lastRefType == Database::REFINE_LTS)) {
    GlobalAbsLts::GetPrevInitStates(prevInitImpls);
  }
  for(set<BigInt>::const_iterator i = prevInitImpls.begin();i != prevInitImpls.end();++i) {
    if(initImpls.erase(*i) == 0) {
      set<BigInt> clause;
      clause.insert(((*i) * ssn + initSpecId + 1) * 2);
      RemoveClause(clause);
    }
  }

  for(set<BigInt>::const_iterator i = initImpls.begin();i != initImpls.end();++i) {
    set<BigInt> clause;
    clause.insert(((*i) * ssn + initSpecId + 1) * 2);
    AddClause(clause);
  }
  return true;
}

/*********************************************************************/
//solve for satisfiability
/*********************************************************************/
bool SatSimChecker::SolveSat()
{
  //if using horn solver 
  if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_HORN_AI) {
    //set variable number
    HornCheckerAI::SetVarNum(varVec.size());
    //remove clauses
    int count = 0;
    for(set< set<int> >::const_iterator i = remClauses.begin();i != remClauses.end();++i) {
      if(addClauses.count(*i) == 0) {      
	if(clauses.erase(*i) != 0) {
	  HornCheckerAI::RemoveClause(*i);
	  ++count;
	}
      }
    }
    Util::Message(2,"%d constraint clauses removed ...\n",count);
    //add clauses
    count = 0;
    for(set< set<int> >::const_iterator i = addClauses.begin();i != addClauses.end();++i) {
      if(clauses.insert(*i).second) {
	HornCheckerAI::AddClause(*i);
	++count;
      }
    }
    Util::Message(2,"%d constraint clauses added ...\n",count);
    //solve sat
    vector<int> elim;
    bool res = HornCheckerAI::SolveSat(elim);
    //create the elimination sequence only if the formula is unsatisfiable
    if(!res) {
      for(vector<int>::const_iterator j = elim.begin();j != elim.end();++j) {
	BigInt varId = varVec[*j];
	size_t size = SimChecker::exclMap.size();
	SimChecker::exclMap[varId] = size;
      }
    }
    return res;
  }
#ifdef MAGIC_FULL
  //if using chaff
  else if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_CHAFF) {
    //update the set of clauses
    for(set< set<int> >::const_iterator i = remClauses.begin();i != remClauses.end();++i) {
      clauses.erase(*i);
    }
    for(set< set<int> >::const_iterator i = addClauses.begin();i != addClauses.end();++i) {
      clauses.insert(*i);
    }
    //add the clauses
    SAT_SetNumVariables(manager,varVec.size());
    for(set< set<int> >::const_iterator i = clauses.begin();i != clauses.end();++i) {
      int *copy = new int[i->size()];
      int pos = 0;
      for(set<int>::const_iterator j = i->begin();j != i->end();++j) {
	copy[pos++] = *j;
      }      
      SAT_AddClause(manager,copy,i->size());
      delete [] copy;
    }
    Util::Message(2,"%d constraint clauses added ...\n",clauses.size());
    return (SAT_Solve(manager) == SATISFIABLE);
  } 
  //if using SATO or GRASP
  else if((Database::SAT_SOLVER_USED == Database::SAT_SOLVER_SATO) ||
	  (Database::SAT_SOLVER_USED == Database::SAT_SOLVER_GRASP)) {
    //update the set of clauses
    for(set< set<int> >::const_iterator i = remClauses.begin();i != remClauses.end();++i) {
      clauses.erase(*i);
    }
    for(set< set<int> >::const_iterator i = addClauses.begin();i != addClauses.end();++i) {
      clauses.insert(*i);
    }
    //cleanup and create new file
    CleanupFiles();
    CreateDimacsFile();
    Util::Message(2,"%d constraint clauses added ...\n",clauses.size());
    int x = 0;
    if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_SATO) {
      x = SolveSatSato();
    }
    else if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_GRASP) {
      x = SolveSatGrasp();
    }
    else { assert(false); }
    CleanupFiles();
    if(x == -1) {
      Util::Error("could not handle SAT instance ...\n");
    } else return (x == 1);
  }
#endif //MAGIC_FULL
  //illegal
  else assert(false);
}

/*********************************************************************/
//remove temporary files
/*********************************************************************/
void SatSimChecker::CleanupFiles()
{
#ifdef MAGIC_FULL
  remove(DIMACS_FILE);
  remove(OUTPUT_FILE);
#endif //MAGIC_FULL
  remove("core");
}

/*********************************************************************/
//destroy the SAT manager
/*********************************************************************/
void SatSimChecker::DestroySatManager()
{
  //these mus always be cleared
  addClauses.clear();
  remClauses.clear();

  //if we are not doing incremental verification
  if(!Database::INC_VERIFY) {
    //reset data structures
    clauses.clear();
    //clear the global set of reachable states
    GlobalAbsLts::reach.clear();

    if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_HORN_AI) {
      HornCheckerAI::Cleanup();
    }
#ifdef MAGIC_FULL
    else if(Database::SAT_SOLVER_USED == Database::SAT_SOLVER_CHAFF) {
      SAT_ReleaseManager(manager);
      Util::Message(2,"SAT manager destroyed ...\n");
    }
#endif //MAGIC_FULL
  }
}

/*********************************************************************/
//check reachability of ERROR state
/*********************************************************************/
bool SatSimChecker::CheckReachability()
{
  //we can construct the simulation relation only if we use HORNSAT
  assert(!Database::SIM_REL_FROM_SAT || (Database::SAT_SOLVER_USED == Database::SAT_SOLVER_HORN_AI));
  //create the sat manager
  CreateSatManager();
  //add the constraints
  assert(AddInitialReachConstraints());
  assert(AddTransReachConstraints());
  assert(AddFinalReachConstraints());
  Util::Message(2,"number of reachable implementation states : %d\n",GlobalAbsLts::reach.size());
  Util::Message(2,"number of HORNSAT variables = %d ...\n",varVec.size());
  Statistics::stateNum = (Statistics::stateNum < varVec.size()) ? varVec.size() : Statistics::stateNum;
  bool satResult = SolveSat();
  DestroySatManager();
  return satResult;
}

/*********************************************************************/
//add the initial reachability constraints that make the initial pair
//of states reachable
/*********************************************************************/
bool SatSimChecker::AddInitialReachConstraints()
{
  //the initial constraints for reachability are the same as that for
  //simulation since we want each initial state pair to be reachable
  return AddInitialConstraints();
}

/*********************************************************************/
//add the final reachability constraints that make the ERROR states
//unreachable
/*********************************************************************/
bool SatSimChecker::AddFinalReachConstraints()
{
  //the spec lts
  const LtsInfo &specInfo = Database::GetLtsInfo(Database::specLtsName);
  //number of spec lts states
  int ssn = specInfo.GetStates().size();
  //id of the ERROR state
  int errId = specInfo.GetStateId(Database::ERROR_STATE);
  
  //compute the set of states
  for(set<BigInt>::const_iterator i = GlobalAbsLts::reach.begin();i != GlobalAbsLts::reach.end();++i) {
    BigInt varId = (*i) * ssn + errId;
    if(varMap.count(varId) != 0) {
      set<BigInt> clause;
      clause.insert((varId + 1) * 2 + 1);
      AddClause(clause);
    } 
  }
  return true;    
}

/*********************************************************************/
//add the transition reachability constraints that make additional
//states reachable
/*********************************************************************/
bool SatSimChecker::AddTransReachConstraints()
{
  //the spec lts
  const LtsInfo &specInfo = Database::GetLtsInfo(Database::specLtsName);
  //the states of the specification
  const set<string> &specStates = specInfo.GetStates();
  //number of spec lts states
  int ssn = specStates.size();
  //the actions of the implementation
  set<Action> implActs; GlobalAbsLts::GetActions(implActs);
  //the set of reachable states and frontier states
  set<BigInt> &reach = GlobalAbsLts::reach;
  set<BigInt> front;
  if((!Database::INC_VERIFY) || (Database::lastRefType != Database::REFINE_LTS)) {
    reach.clear();
    GlobalAbsLts::GetInitStates(front);
  } else {
    GlobalAbsLts::GetSuccsChanged(front);
    for(set<BigInt>::const_iterator i = front.begin();i != front.end();++i) reach.erase(*i);
    //generate the constraints to be removed
    for(set<BigInt>::const_iterator i = front.begin();i != front.end();++i) {
      for(set<Action>::const_iterator j = implActs.begin();j != implActs.end();++j) {
	set<BigInt> implSuccs; GlobalAbsLts::GetPrevSuccsOnAction(*i,*j,implSuccs);
	for(set<BigInt>::const_iterator k = implSuccs.begin();k != implSuccs.end();++k) {
	  for(set<string>::const_iterator l = specStates.begin();l != specStates.end();++l) {
	    set<string> specSuccs; specInfo.GetSuccsOnActionSpec(*l,*j,specSuccs);
	    for(set<string>::const_iterator m = specSuccs.begin();m != specSuccs.end();++m) {
	      set<BigInt> clause;
	      clause.insert(((*i) * ssn + specInfo.GetStateId(*l) + 1) * 2 + 1);
	      AddLiteral(clause,((*k) * ssn + specInfo.GetStateId(*m) + 1) * 2);
	      if(!clause.empty()) RemoveClause(clause);
	    }
	  }
	}
      }
    }
  }
  //generate sat constraints for each global abstract state
  bool res = true;
  while(!front.empty()) {
    reach.insert(front.begin(),front.end());
    set<BigInt> newFront = front;
    front.clear();
    for(set<BigInt>::const_iterator i = newFront.begin();i != newFront.end();++i) {
      map< Action,set<BigInt> > saa;
      GlobalAbsLts::GetSuccsAndActions(*i,saa);
      for(map< Action,set<BigInt> >::const_iterator j = saa.begin();j != saa.end();++j) {
	for(set<BigInt>::const_iterator k = j->second.begin();k != j->second.end();++k) {
	  if(reach.count(*k) == 0) front.insert(*k);
	  for(set<string>::const_iterator l = specStates.begin();l != specStates.end();++l) {
	    set<string> specSuccs; specInfo.GetSuccsOnActionSpec(*l,j->first,specSuccs);
	    for(set<string>::const_iterator m = specSuccs.begin();m != specSuccs.end();++m) {
	      set<BigInt> clause;
	      clause.insert(((*i) * ssn + specInfo.GetStateId(*l) + 1) * 2 + 1);
	      AddLiteral(clause,((*k) * ssn + specInfo.GetStateId(*m) + 1) * 2);
	      if(!clause.empty()) AddClause(clause);
	    }
	  }
	}
      }
    }
  }
  //all done
  return res;
}

/*********************************************************************/
//add a literal to a clause. if the resulting clause is trivially
//true i.e. it contains both a variable and its negation then
//return false and make the clause empty. otherwise return true.
/*********************************************************************/
bool SatSimChecker::AddLiteral(set<BigInt> &clause,const BigInt &lit)
{
  BigInt quo,rem; BigInt::Div(quo,rem,lit,2);
  if(rem.ToSL() == 0) {
    if(clause.count(lit + 1) != 0) {
      clause.clear();
      return false;
    } else {
      clause.insert(lit);
      return true;
    }
  } else {
    if(clause.count(lit - 1) != 0) {
      clause.clear();
      return false;
    } else {
      clause.insert(lit);
      return true;
    }
  }
}

/*********************************************************************/
//end of SatSimChecker.cpp
/*********************************************************************/
