/******************************** CPPFile *****************************

* FileName [Predicate.cpp]

* PackageName [main]

* Synopsis [Method definitions of Predicate class.]

* SeeAlso [Predicate.h]

* Author [Sagar Chaki]

* Copyright [ Copyright (c) 2002 by Carnegie Mellon University. All
* Rights Reserved. This software is for educational purposes only.
* Permission is given to academic institutions to use, copy, and
* modify this software and its documentation provided that this
* introductory message is not removed, that this software and its
* documentation is used for the institutions' internal research and
* educational purposes, and that no monies are exchanged. No guarantee
* is expressed or implied by the distribution of this code. Send
* bug-reports and/or questions to: chaki+@cs.cmu.edu. ]

**********************************************************************/

#include <cstdio>
#include <cassert>
#include <gmp.h>
#include <string>
#include <list>
#include <vector>
#include <set>
#include <map>
#include <typeinfo>
using namespace std;

#include "Util.h"
#include "Node.h"
#include "Predicate.h"
#include "Database.h"
#define YYSTYPE int
#include "StdcParser.h"
using namespace magic;

/*********************************************************************/
//static members
/*********************************************************************/
const int Predicate::NONE = -1;
const int Predicate::REQUIRED = 0;
const int Predicate::TENTATIVE = 1;

/*********************************************************************/
//constructors
/*********************************************************************/
Predicate::Predicate(const Expr &e,int t,const pair<const ContLoc*,int> &h)
{
  assert((t == REQUIRED) || (t == TENTATIVE));
  valToExpr.push_back(e);
  type = t;
  complete = false;
  valNum = 2;

  //create the history map
  map<const ContLoc*,int> m;
  m[h.first] = h.second;
  history.push_back(m);
}

Predicate::Predicate(const vector<Expr> &v,int t,const vector< map<const ContLoc*,int> > &h)
{
  assert(v.size() == h.size());
  for(vector<Expr>::const_iterator i = v.begin();i != v.end();++i) {
    valToExpr.push_back(*i);
  }
  type = t;
  complete = false;
  valNum = valToExpr.size() + 1;
  history = h;
}

Predicate::Predicate(const Predicate &rhs)
{
  *this = rhs;
}

/*********************************************************************/
//operators
/*********************************************************************/
const Predicate &Predicate::operator = (const Predicate &rhs)
{
  valToExpr = rhs.valToExpr;
  type = rhs.type;
  complete = rhs.complete;
  valNum = rhs.valNum;
  history = rhs.history;
  return *this;
}

bool Predicate::operator == (const Predicate &rhs) const
{
  return (ToString() == rhs.ToString());
}

/*********************************************************************/
//string representation
/*********************************************************************/
string Predicate::ToString() const
{
  string res = "<";
  for(vector<Expr>::const_iterator i = valToExpr.begin();i != valToExpr.end();++i) {
    res += "( " + i->ToString() + ")";
  }
  return res + ">";
}

/*********************************************************************/
//return true if the predicate can be proved to be true i.e. if true
//implies it.
/*********************************************************************/
bool Predicate::IsTrue() const
{
  return ((valToExpr.size() == 1) && ExprManager::IsTrue(valToExpr[0]));
}

/*********************************************************************/
//return true if the predicate can be proved to be false i.e. if it
//implies false.
/*********************************************************************/
bool Predicate::IsFalse() const
{
  return ((valToExpr.size() == 1) && ExprManager::IsFalse(valToExpr[0]));
}

/*********************************************************************/
//merge a general predicate with another supplied as argument. any
//expression that is equivalent with or negation of some expression in
//this predicate will be discarded. any expression that is disjoint
//with all expressions in this predicate will be incorporated. remove
//from the argument any expression that was incorporated or
//discarded. return true iff at least one expression was incorporated.
/*********************************************************************/
bool Predicate::Merge(Predicate &rhs)
{  
  //the vector of expressions and history to be retained in rhs
  vector<Expr> retVTE;
  vector< map<const ContLoc*,int> > retHist;
  
  bool res = false;
  Expr one = ExprManager::GetIntConstExpr(1);
  for(size_t i = 0;i < rhs.valToExpr.size();++i) {
    Expr e1 = rhs.valToExpr[i];
    bool iterFound = false;
    const BasicExpr *be1 = e1.GetExpr();
    if(typeid(*be1) == typeid(BinaryExpr)) {
      const BinaryExpr *be2 = static_cast<const BinaryExpr*>(be1);
      short int op1 = be2->op;
      /***************************************************************/
      //check if one of the following conditions hold
      /***************************************************************/
      //condition 1: this predicate is of the form (X + 1 > Y) where
      //both (X = Y) and (X > Y) are existing predicates
      /***************************************************************/
      //condition 2: this predicate is of the form (X - 1 < Y) where
      //both (X = Y) and (X < Y) are existing predicates
      /***************************************************************/
      if((op1 == '>') || (op1 == '<')) {
	short sign = (op1 == '>') ? '+' : '-';
	short neg = (op1 == '>') ? '<' : '>';
	bool eqFlag = false,opFlag = false;
	for(size_t j = 0;j < valToExpr.size();++j) {
	  const BasicExpr *be5 = valToExpr[j].GetExpr();
	  if(typeid(*be5) == typeid(BinaryExpr)) {
	    const BinaryExpr *XY = static_cast<const BinaryExpr*>(be5);
	    if(XY->op == MAGIC_EQ_OP) {
	      Expr XS1 = ExprManager::GetBinaryExpr(XY->lhs,one,sign);
	      Expr XS1OPY = ExprManager::GetBinaryExpr(XS1,XY->rhs,op1);
	      eqFlag = (XS1OPY == e1) ? true : eqFlag;
	    } else if(XY->op == op1) {
	      Expr XS1 = ExprManager::GetBinaryExpr(XY->lhs,one,sign);
	      Expr XS1OPY = ExprManager::GetBinaryExpr(XS1,XY->rhs,op1);
	      opFlag = (XS1OPY == e1) ? true : opFlag;
	    } else if(XY->op == neg) {
	      Expr YS1 = ExprManager::GetBinaryExpr(XY->rhs,one,sign);
	      Expr YS1OPX = ExprManager::GetBinaryExpr(YS1,XY->lhs,op1);
	      opFlag = (YS1OPX == e1) ? true : opFlag;
	    }
	    if(eqFlag && opFlag) { iterFound = true; break; }
	  }
	}
      }
      if(iterFound) continue;
      /***************************************************************/
      //check if the one of the following two conditions hold:
      /***************************************************************/
      //condition 1: this predicate is of the form (X + 1 OP Y) such
      //that (X OP Y) is an existing predicate where OP is "<" or "<="
      /***************************************************************/
      //condition 2: this predicate is of the form (X - 1 OP Y) such
      //that (X OP Y) is an existing predicate where OP is ">" or ">="
      /***************************************************************/
      if((op1 == '<') || (op1 == MAGIC_LE_OP) || (op1 == '>') || (op1 == MAGIC_GE_OP)) {
	short sign = 0,neg = 0;
	if(op1 == '<') { sign = '+'; neg = '>'; }
	else if(op1 == MAGIC_LE_OP) { sign = '+'; neg = MAGIC_GE_OP; }
	else if(op1 == '>') { sign = '-'; neg = '<'; }
	else { sign = '-'; neg = MAGIC_LE_OP; }
	bool flag = false; Expr X,Y;
	for(size_t j = 0;j < valToExpr.size();++j) {
	  const BasicExpr *be5 = valToExpr[j].GetExpr();
	  if(typeid(*be5) == typeid(BinaryExpr)) {
	    const BinaryExpr *XY = static_cast<const BinaryExpr*>(be5);
	    if(XY->op == op1) {
	      Expr XS1 = ExprManager::GetBinaryExpr(XY->lhs,one,sign);
	      Expr XS1OPY = ExprManager::GetBinaryExpr(XS1,XY->rhs,op1);
	      if(XS1OPY == e1) { X = XY->lhs; Y = XY->rhs; flag = true; }
	    } else if(XY->op == neg) {
	      Expr YS1 = ExprManager::GetBinaryExpr(XY->rhs,one,sign);
	      Expr YS1OPX = ExprManager::GetBinaryExpr(YS1,XY->lhs,op1);
	      if(YS1OPX == e1) { X = XY->rhs; Y = XY->lhs; flag = true; }
	    }
	    if(flag) {
	      if((op1 == '<') || (op1 == '>')) {
		Expr YNS1 = ExprManager::GetBinaryExpr(Y,one,(sign == '+') ? '-' : '+');
		valToExpr[j] = ExprManager::GetBinaryExpr(X,YNS1,MAGIC_EQ_OP);
	      } else valToExpr[j] = ExprManager::GetBinaryExpr(X,Y,MAGIC_EQ_OP);
	      valToExpr.push_back(e1);
	      history.push_back(rhs.history[i]);
	      res = true;	      
	      iterFound = true;
	      break;
	    }
	  }
	}
      }
      if(iterFound) continue;
    }

    //check if this expression is identical to, negation of, disjoint
    //with, implies or implied by some already existing flag
    Expr ne1 = ExprManager::NegateExpr(e1);
    bool sameFlag = false,oppFlag = false,disjFlag = true;
    bool impliesFlag = true,impliedByFlag = true,notImpliesFlag = true;
    for(size_t j = 0;j < valToExpr.size();++j) {
      Expr e2 = valToExpr[j];
      if(ExprManager::EquivalentTo(e1,e2)) {
	sameFlag = true;
	break;
      }      
      if(ExprManager::NegationOf(e1,e2)) {      
	oppFlag = true;
	break;
      }
      if(!ExprManager::DisjointWith(e1,e2)) {      
	disjFlag = false;
      }
      if(!ExprManager::Implies(e1,e2)) {      
	impliesFlag = false;
      }
      if(!ExprManager::Implies(e2,e1)) {      
	impliedByFlag = false;
      }
      if(!ExprManager::Implies(ne1,e2)) {      
	notImpliesFlag = false;
      }
    }
    //if it is identical then drop it
    if(sameFlag || oppFlag) {
      continue;
    } else if(disjFlag) {
      valToExpr.push_back(e1);
      history.push_back(rhs.history[i]);
      res = true;
    } else if(impliesFlag) {
      assert(valToExpr.size() == 1);
      Expr neg = ExprManager::NegateExpr(valToExpr[0]);
      valToExpr.clear();
      valToExpr.push_back(neg);
      valToExpr.push_back(e1);
      history.push_back(rhs.history[i]);
      res = true;
    } else if(notImpliesFlag) {
      assert(valToExpr.size() == 1);
      Expr neg = ExprManager::NegateExpr(valToExpr[0]);
      valToExpr.clear();
      valToExpr.push_back(neg);
      valToExpr.push_back(ne1);
      history.push_back(rhs.history[i]);
      res = true;
    } else if(impliedByFlag) {
      valToExpr.push_back(ne1);
      history.push_back(rhs.history[i]);
      res = true;
    } else {
      retVTE.push_back(e1);
      retHist.push_back(rhs.history[i]);
    }
  }
  
  //modify the set of expressions in rhs accordingly
  assert(retVTE.size() == retHist.size());
  rhs.valToExpr = retVTE;
  rhs.history = retHist;
  
  //recompute completeness
  ComputeCompletenessAndValNum();
  return res;
}

/*********************************************************************/
//compute the completeness of this predicate
/*********************************************************************/
void Predicate::ComputeCompletenessAndValNum()
{
  //create the disjunction of the expressions and check if it is true.
  set<Expr> a;
  a.insert(valToExpr.begin(),valToExpr.end());
  set<Expr> d;
  complete = ExprManager::ProveImplies(d,a);

  valNum = valToExpr.size();
  valNum = complete ? valNum : (valNum + 1);
}

/*********************************************************************/
//given a set of variables return true if some expression contains at
//least one variable from the set.
/*********************************************************************/
bool Predicate::ContainsLvalue(const set<string> &arg) const
{
  for(vector<Expr>::const_iterator i = valToExpr.begin();i != valToExpr.end();++i) {
    if(i->ContainsLvalue(arg)) return true;
  }
  return false;
}

/*********************************************************************/
//given the list of lhs and rhs of an assignment return the WP of this
//predicate and whether the WP is different from this predicate
/*********************************************************************/
bool Predicate::ComputeWP(const list<Expr> &lhsList,const list<Expr> &rhsList,Predicate &res) const
{
  bool modified = false;
  size_t size = valToExpr.size();
  vector<Expr> vte;
  vector< map<const ContLoc*,int> > hist;
  for(size_t i = 0;i < size;++i) {
    Expr b = ExprManager::ComputeWP(lhsList,rhsList,valToExpr[i]);
    if((!ExprManager::IsTrue(b)) && (!ExprManager::IsFalse(b))) {
      vte.push_back(b);
      hist.push_back(history[i]);
    }
    if(!(b == valToExpr[i])) modified = true;
  }
  
  res = Predicate(vte,type,hist);
  if(!vte.empty()) {
    res.ComputeCompletenessAndValNum();
  }
  return modified;
}

/*********************************************************************/
//compute the conjunction of this predicate with the supplied
//expression. if the supplied expression is totally disjoint with
//this predicate then return an empty predicate.
/*********************************************************************/
void Predicate::Conjunct(const Expr &expr,Predicate &res) const
{
  Expr neg = ExprManager::NegateExpr(expr);
  set<Expr>negList; negList.insert(neg);
  
  size_t size = valToExpr.size();
  vector<Expr> vte;
  vector< map<const ContLoc*,int> > hist;
  for(size_t i = 0;i < size;++i) {
    const Expr &b = valToExpr[i];
    set<Expr> bList; bList.insert(b);
    
    //add this expression only if this expression is not
    //disjoint with the argument expression
    if(!ExprManager::ProveImplies(bList,negList)) {
      vte.push_back(b);
      hist.push_back(history[i]);
    }
  }
  
  res = Predicate(vte,type,hist);
  if(!vte.empty()) {
    res.ComputeCompletenessAndValNum();
  }
}

/*********************************************************************/
//return the initial valuation
/*********************************************************************/
short Predicate::GetInitValuation() const
{
  return 0;
}

/*********************************************************************/
//given a valuation and a base index returns true if the valuation
//represents the final valuation for this predicate.
/*********************************************************************/
bool Predicate::IsFinalValuation(short val) const
{
  return (val == (valNum - 1));
}

/*********************************************************************/
//given a valuation and a base index returns the next valuation for
//this predicate.
/*********************************************************************/
short Predicate::GetNextValuation(short val) const
{
  return ((val + 1) % valNum);
}

/*********************************************************************/
//given a valuation and a base index for the expressions return a
//concrete expression for the predicate. the valuation is given as a
//set of expressions implicitly conjuncted.
/*********************************************************************/
void Predicate::ToExpr(short val,set<Expr> &res) const
{
  short index = complete ? val : (val - 1);
  if(index == -1) {
    for(size_t i = 0;i < valToExpr.size();++i) {
      res.insert(ExprManager::NegateExpr(valToExpr[i]));
    }
  } else if((index >= 0) && (index < (short)(valToExpr.size()))) {
    res.insert(valToExpr[index]);
  } else {
    Util::Error("illegal predicate valuation in ToExpr ...\n");
  }
}

/*********************************************************************/
//given a valuation and a base index for the expressions return a set
//of expressions which have been assigned true and a set of
//expressions which have been assigned false.
/*********************************************************************/
pair< set<Expr>,set<Expr> > Predicate::ToExprSets(short val) const
{
  pair< set<Expr>,set<Expr> > res;
  val = complete ? val : (val - 1);
  int count = 0;
  for(size_t i = 0;i < valToExpr.size();++i) {
    if(val == (short)(i)) {
      res.second.insert(valToExpr[i]);
      ++count;
    } else {
      res.first.insert(valToExpr[i]);
    }
  }
  if((count > 1) || (complete && (count == 0))) {
    Util::Error("illegal predicate valuation in ToExprSets ...\n");
  }
  return res;
}

/*********************************************************************/
//return the vector of valuations. each valuation is a bit vector
//where only the bits for the true expressions is set. the first
//argument is a pair of lists of string values of expressions that are
//known to be false and true respectively before the assignment. the
//other arguments are the lhs and rhs of the assignment, and the index
//from which the bits for the expressions of this predicate start.
/*********************************************************************/
void Predicate::GetPossibleValuations(const pair< set<Expr>,set<Expr> > &before,
				      const list<Expr> &lhsList,const list<Expr> &rhsList,vector<short> &res) const
{
  bool empty = lhsList.empty(); 
  //for each expression
  for(size_t i = 0;i < valToExpr.size();++i) {
    //compute the WP of the expression
    Expr m = empty ? valToExpr[i] : ExprManager::ComputeWP(lhsList,rhsList,valToExpr[i]);    
    //if the WP belongs to the set of expression which have been
    //assigned true then this is the only possible valuation
    if(before.second.count(m) != 0) {
      res.push_back(complete ? i : (i + 1));
      return;
    }    
    //otherwise this predicate can be dropped if its WP is assigned
    //false
    if(before.first.count(m) == 0) {
      res.push_back(complete ? i : (i + 1));
    }
  }  
  //finally check if all the predicates can be false
  if(!complete) {
    res.push_back(0);
  }
}

/*********************************************************************/
//get the set of valuations consistent with an expression
/*********************************************************************/
void Predicate::GetConsistentValuations(const Expr &arg,vector<short> &res) const
{
  //for each expression
  for(size_t i = 0;i < valToExpr.size();++i) {
    //if this expression is implied the argument then this is the only
    //possible valuation
    if(ExprManager::Implies(arg,valToExpr[i])) {
      res.push_back(complete ? i : (i + 1));
      return;
    }    
    //otherwise this predicate can be dropped if it is disjoint with
    //the argument
    if(!ExprManager::DisjointWith(arg,valToExpr[i])) {
      res.push_back(complete ? i : (i + 1));
    }
  }  
  //finally check if all the predicates can be false
  if(!complete) {
    res.push_back(0);
  }
}

/*********************************************************************/
//sanitize the predicate w.r.t. the supplied cont loc. remove all
//expressions whose history has the same location at the maximum
//predicate inference loop count.
/*********************************************************************/
void Predicate::Sanitize(const ContLoc *loc,Predicate &res) const
{
  vector<Expr> vte;
  vector< map<const ContLoc*,int> > hist;
  for(size_t i = 0;i < valToExpr.size();++i) {
    map<const ContLoc*,int>::const_iterator j = history[i].find(loc);
    if(j != history[i].end()) {
      if(j->second <= Database::MAX_PRED_INFER_LOOP) {
	vte.push_back(valToExpr[i]);
	map<const ContLoc*,int> hmap = history[i];
	hmap[loc] += 1;
	hist.push_back(hmap);
      } 
    } else {
      vte.push_back(valToExpr[i]);
      map<const ContLoc*,int> hmap = history[i];
      hmap[loc] = 1;
      hist.push_back(hmap);
    }
  }  

  res = Predicate(vte,type,hist);
  if(!vte.empty()) {
    res.ComputeCompletenessAndValNum();
  }
}

/*********************************************************************/
//end of Predicate.cpp
/*********************************************************************/
