/*******************************************************************\

Module:

Author: Daniel Kroening, kroening@cs.cmu.edu

\*******************************************************************/

#include <algorithm>

#include "simplify.h"
#include "mp_arith.h"
#include "arith_tools.h"
#include "replace_expr.h"
#include "trfalse.h"
#include "bitvector.h"

/*******************************************************************\

Function: simplify_implt::sort_operands

  Inputs: operand list

 Outputs: modifies operand list
          returns true iff nothing was changed

 Purpose: sort operands of an expression according to ordering
          defined by operator<

\*******************************************************************/

bool simplify_implt::sort_operands(exprt::operandst &operands)
 {
  bool do_sort=FALSE;

  forall_expr(it, operands)
    if(it+1!=operands.end() && ordering(*(it+1), *it))
     {
      do_sort=TRUE;
      break;
     }

  if(!do_sort) return TRUE;

  sort(operands.begin(), operands.end(), ordering);

  return FALSE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_typecast

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_typecast(exprt &expr)
 {
  if(expr.operands().size()!=1) return TRUE;

  const std::string &expr_type_id=expr.type().id;
  const exprt &operand=expr.operands()[0];
  const std::string &op_type_id=operand.type().id;

  unsigned expr_width=bv_width(expr.type());
  unsigned op_width=bv_width(operand.type());

  if(operand.is_constant())
   {
    const std::string &value=operand.get("value");

    exprt new_expr;
    new_expr.type()=expr.type();
    new_expr.id="constant";

    if(op_type_id=="integer" || op_type_id=="natural")
     {
      mp_integer int_value=string2integer(value);

      if(expr_type_id=="bool")
       {
        new_expr.set("value", (int_value!=0)?"true":"false");
        expr.swap(new_expr);
        return FALSE;
       }

      if(expr_type_id=="unsignedbv" || expr_type_id=="signedbv")
       {
        new_expr.set("value", integer2binary(int_value, expr_width));
        expr.swap(new_expr);
        return FALSE;
       }

      if(expr_type_id=="integer")
       {
        new_expr.set("value", value);
        expr.swap(new_expr);
        return FALSE;
       }

      /*
      if(expr_type_id=="real") return FALSE;
      if(expr_type_id=="complex") return FALSE;
      if(expr_type_id=="floatbv") return FALSE;
      */
     }
    else if(op_type_id=="rational")
     {
     }
    else if(op_type_id=="real")
     {
      /*
      if(expr_type_id=="bool") return FALSE;
      if(expr_type_id=="real") return FALSE;
      if(expr_type_id=="complex") return FALSE;
      if(expr_type_id=="unsignedbv") return FALSE;
      if(expr_type_id=="signedbv") return FALSE;
      if(expr_type_id=="floatbv") return FALSE;
      */
     }
    else if(op_type_id=="bool")
     {
      if(expr_type_id=="unsignedbv" ||
         expr_type_id=="signedbv")
       {
        if(operand.is_true())
         {
          from_integer(0, new_expr);
          return FALSE;
         }
        else if(operand.is_false())
         {
          new_expr.make_zero();
          return FALSE;
         }
       }
     }
    else if(op_type_id=="unsignedbv" ||
            op_type_id=="signedbv")
     {
      mp_integer int_value=binary2integer(value, op_type_id=="signedbv");

      if(expr_type_id=="bool")
       {
        new_expr.make_bool(int_value!=0);
        expr.swap(new_expr);
        return FALSE;
       }

      if(expr_type_id=="unsignedbv" || expr_type_id=="signedbv")
       {
        new_expr.set("value", integer2binary(int_value, expr_width));
        expr.swap(new_expr);
        return FALSE;
       }
     }
   }
  else if(operand.id=="typecast") // typecast of typecast
   {
    if(operand.operands().size()==1 &&
       op_type_id==expr_type_id &&
       (expr_type_id=="unsignedbv" || expr_type_id=="signedbv") &&
       expr_width<=op_width)
     {
      exprt tmp;
      tmp.swap((irept &)expr.operands()[0].operands()[0]);
      expr.operands()[0].swap(tmp);
      return FALSE;
     }
   }

  // propagate type casts into arithmetic operators

  if((op_type_id=="unsignedbv" || op_type_id=="signedbv") &&
     (expr_type_id=="unsignedbv" || expr_type_id=="signedbv") &&
     (operand.id=="+" || operand.id=="-" || operand.id=="unary-" || operand.id=="*") &&
     expr_width<=op_width)
   {
    exprt new_expr;
    new_expr.swap(expr.operands()[0]);
    new_expr.type()=expr.type();

    Forall_operands(it, new_expr)
      it->make_typecast(expr.type());

    expr.swap(new_expr);

    return FALSE;
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_multiplication

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_multiplication(exprt &expr)
 {
  // check to see if it is a number type
  if(!is_number(expr.type()))
    return TRUE;

  // vector of operands
  std::vector<exprt> &operands=expr.operands();

  // result of the simplification
  bool result = TRUE;

  // position of the constant
  exprt::operandst::iterator constant;

  // true if we have found a constant
  bool found = FALSE;

  // scan all the operands
  for(exprt::operandst::iterator it=operands.begin();
      it!=operands.end();)
   {
    // if one of the operands is not a number return
    if(!is_number(it->type())) return TRUE;

    // if one of the operands is zero the result is zero
    if(it->is_zero())
     {
      expr.make_zero();
      return FALSE;
     }

    // true if the given operand has to be erased
    bool do_erase = FALSE;

    // if this is a constant of the same time as the result
    if(it->is_constant() && it->type() == expr.type())
     {
      if(found)
       {
	// update the constant factor
	if(!constant->mul(*it)) do_erase=TRUE;
       }
      else
       {
	// set it as the constant factor if this is the first
	constant = it;
	found = TRUE;
       }
     }

    // erase the factor if necessary
    if(do_erase)
     {
      it = operands.erase(it);
      result = FALSE;
     }
    else
     // move to the next operand
     it++;
   }

  if(operands.size() == 1)
   {
    exprt product(operands[0]);
    expr.swap(product);

    result = FALSE;
   }
  else
   {
    // if the constant is a one and there are other factors
    if(found && constant->is_one())
     {
      // just delete it
      operands.erase(constant);
     }
   }

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_division

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_division(exprt &expr)
 {
  if(!is_number(expr.type()))
    return TRUE;

  if(expr.operands().size()!=2)
    return TRUE;

  if(expr.type().id=="signedbv" ||
     expr.type().id=="unsignedbv" ||
     expr.type().id=="natural" ||
     expr.type().id=="integer")
   {
    if(expr.type()==expr.operands()[0].type() &&
       expr.type()==expr.operands()[1].type())
     {
      mp_integer int_value0, int_value1;
      bool ok0, ok1;

      ok0=!to_integer(expr.operands()[0], int_value0);
      ok1=!to_integer(expr.operands()[1], int_value1);

      if(ok1 && int_value1==0)
        return TRUE;

      if((ok1 && int_value1==1) ||
         (ok0 && int_value0==0))
       {
        exprt tmp;
        tmp.swap(expr.operands()[0]);
        expr.swap(tmp);
        return FALSE;
       }

      if(ok0 && ok1)
       {
        #if 0
        if(int_value0==int_value1)
         {
          expr.make_one();
          return FALSE;
         }
        else if(int_value0>=0 && int_value1>=0 &&
                int_value0<int_value1)
         {
          expr.make_zero();
          return FALSE;
         }
        #endif

        exprt tmp;

        tmp.type()=expr.type();
        mp_integer result=int_value0/int_value1;

        if(!from_integer(result, tmp))
         {
          expr.swap(tmp);
          return FALSE;
         }
       }
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_modulo

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_modulo(exprt &expr)
 {
  if(!is_number(expr.type()))
    return TRUE;

  if(expr.operands().size()!=2)
    return TRUE;

  if(expr.type().id=="signedbv" ||
     expr.type().id=="unsignedbv" ||
     expr.type().id=="natural" ||
     expr.type().id=="integer")
   {
    if(expr.type()==expr.operands()[0].type() &&
       expr.type()==expr.operands()[1].type())
     {
      mp_integer int_value0, int_value1;
      bool ok0, ok1;

      ok0=!to_integer(expr.operands()[0], int_value0);
      ok1=!to_integer(expr.operands()[1], int_value1);

      if(ok1 && int_value1==0)
        return TRUE; // division by zero

      if((ok1 && int_value1==1) ||
         (ok0 && int_value0==0))
       {
        expr.make_zero();
        return FALSE;
       }

      if(ok0 && ok1)
       {
        exprt tmp;

        tmp.type()=expr.type();
        mp_integer result=int_value0%int_value1;

        if(!from_integer(result, tmp))
         {
          expr.swap(tmp);
          return FALSE;
         }
       }
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_addition_substraction

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_addition_substraction(exprt &expr)
 {
  if(!is_number(expr.type()))
    return TRUE;

  bool result=TRUE;

  exprt::operandst &operands=expr.operands();

  if(expr.id=="+")
   {
    exprt::operandst::iterator const_sum;
    bool const_sum_set=FALSE;

    for(exprt::operandst::iterator it=operands.begin();
        it!=operands.end();)
     {
      if(!is_number(it->type())) return TRUE;

      bool do_erase=FALSE;

      if(it->is_zero())
        do_erase=TRUE;
      else if(it->is_constant() &&
              it->type()==expr.type())
       {
        if(!const_sum_set)
         {
          const_sum=it;
          const_sum_set=TRUE;
         }
        else
         {
          if(!const_sum->sum(*it)) do_erase=TRUE;
         }
       }

      if(do_erase)
       {
        it=operands.erase(it);
        result=FALSE;
       }
      else
        it++;
     }

    if(operands.size()==0)
     {
      expr.make_zero();
      return FALSE;
     }
    else if(operands.size()==1)
     {
      exprt tmp(operands[0]);
      expr.swap(tmp);
      return FALSE;
     }
   }
  else if(expr.id=="-")
   {
    if(operands.size()==2 &&
       is_number(expr.type()) &&
       is_number(operands[0].type()) &&
       is_number(operands[1].type()))
     {
      exprt tmp2;
      tmp2.id="unary-";
      tmp2.type()=expr.type();
      tmp2.move_to_operands(operands[1]);

      exprt tmp;
      tmp.id="+";
      tmp.type()=expr.type();
      tmp.move_to_operands(operands[0]);
      tmp.move_to_operands(tmp2);

      expr.swap(tmp);
      return FALSE;
     }
   }

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if_implies

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_implies(exprt &expr, const exprt &cond,
     bool truth, bool &new_truth)
 {
  if(expr == cond) {
   new_truth = truth;
   return FALSE;
  }

  if(truth && cond.id == "<" && expr.id == "<")
   {
    if(cond.operands()[0] == expr.operands()[0] &&
	cond.operands()[1].is_constant() &&
	expr.operands()[1].is_constant() &&
	cond.operands()[1].type() == expr.operands()[1].type())
     {
      const irep_string &type_id = cond.operands()[1].type().id;
      if(type_id=="integer" || type_id=="natural")
       {
	if(string2integer(cond.operands()[1].get("value")) >=
	  string2integer(expr.operands()[1].get("value")))
	 {
	  new_truth = TRUE;
	  return FALSE;
	 }
       }
      else if(type_id=="unsignedbv")
       {
	const mp_integer i1, i2;
	if(binary2integer(cond.operands()[1].get("value"), FALSE) >=
	  binary2integer(expr.operands()[1].get("value"), FALSE))
	 {
	  new_truth = TRUE;
	  return FALSE;
	 }
       }
      else if(type_id=="signedbv")
       {
	const mp_integer i1, i2;
	if(binary2integer(cond.operands()[1].get("value"), TRUE) >=
	  binary2integer(expr.operands()[1].get("value"), TRUE))
	 {
	  new_truth = TRUE;
	  return FALSE;
	 }
       }
     }
    if(cond.operands()[1] == expr.operands()[1] &&
	cond.operands()[0].is_constant() &&
	expr.operands()[0].is_constant() &&
	cond.operands()[0].type() == expr.operands()[0].type())
     {
      const irep_string &type_id = cond.operands()[1].type().id;
      if(type_id=="integer" || type_id=="natural")
       {
	if(string2integer(cond.operands()[1].get("value")) <=
	  string2integer(expr.operands()[1].get("value")))
	 {
	  new_truth = TRUE;
	  return FALSE;
	 }
       }
      else if(type_id=="unsignedbv")
       {
	const mp_integer i1, i2;
	if(binary2integer(cond.operands()[1].get("value"), FALSE) <=
	  binary2integer(expr.operands()[1].get("value"), FALSE))
	 {
	  new_truth = TRUE;
	  return FALSE;
	 }
       }
      else if(type_id=="signedbv")
       {
	const mp_integer i1, i2;
	if(binary2integer(cond.operands()[1].get("value"), TRUE) <=
	  binary2integer(expr.operands()[1].get("value"), TRUE))
	 {
	  new_truth = TRUE;
	  return FALSE;
	 }
       }
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if_recursive

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_recursive(exprt &expr, const exprt &cond, bool truth)
 {
  if(expr.type().id == "bool")
   {
    bool new_truth;

    if(!simplify_if_implies(expr, cond, truth, new_truth))
     {
      if(new_truth)
       {
	expr.make_true();
	return FALSE;
       }
      else
       {
	expr.make_false();
	return FALSE;
       }
     }
   }

  bool result = TRUE;

  Forall_operands(it, expr)
    result = simplify_if_recursive(*it, cond, truth) && result;

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if_conj

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_conj(exprt &expr, const exprt &cond)
 {
  forall_operands(it, cond)
   {
    if(expr == *it)
     {
      expr.make_true();
      return FALSE;
     }
   }

  bool result = TRUE;

  Forall_operands(it, expr)
    result = simplify_if_conj(*it, cond) && result;

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if_disj

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_disj(exprt &expr, const exprt &cond)
 {
  forall_operands(it, cond)
   {
    if(expr == *it)
     {
      expr.make_false();
      return FALSE;
     }
   }

  bool result = TRUE;

  Forall_operands(it, expr)
    result = simplify_if_disj(*it, cond) && result;

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if_branch

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_branch
 (exprt &trueexpr, exprt &falseexpr, const exprt &cond)
 {
  bool tresult = TRUE;
  bool fresult = TRUE;

  if(cond.id == "and")
   {
    tresult = simplify_if_conj(trueexpr, cond) && tresult;
    fresult = simplify_if_recursive(falseexpr, cond, false) && fresult;
   }
  else if(cond.id == "or")
   {
    tresult = simplify_if_recursive(trueexpr, cond, true) && tresult;
    fresult = simplify_if_disj(falseexpr, cond) && fresult;
   }
  else
   {
    tresult = simplify_if_recursive(trueexpr, cond, true) && tresult;
    fresult = simplify_if_recursive(falseexpr, cond, false) && fresult;
   }

  if(!tresult) simplify(trueexpr);
  if(!fresult) simplify(falseexpr);

  return tresult && fresult;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if_cond

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_cond(exprt &expr)
 {
  bool result = TRUE;
  bool tmp = FALSE;

  while(!tmp)
   {
    tmp = TRUE;

    if(expr.id == "and")
     {
      if(!expr.find("operands").is_nil())
       {
	exprt::operandst &operands = expr.operands();
	for(exprt::operandst::iterator it1 = operands.begin();
	    it1 != operands.end(); it1++)
	 {
	  for(exprt::operandst::iterator it2 = operands.begin();
	      it2 != operands.end(); it2++)
	   {
	    if(it1 != it2)
	      tmp = simplify_if_recursive(*it1, *it2, true) && tmp;
	   }
	 }
       }
     }
    if(!tmp) simplify(expr);

    result = tmp && result;
   }

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_if

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_if_main(exprt &expr)
 {
  std::vector<exprt> &operands=expr.operands();
  bool result = TRUE;

  if(operands.size()==3)
   {
    exprt &cond = operands[0];
    exprt &truevalue = operands[1];
    exprt &falsevalue = operands[2];

    if(truevalue == falsevalue)
     {
      exprt tmp;
      tmp.swap(truevalue);
      expr.swap(tmp);
      return FALSE;
     }

    if(simplify_if)
     {
      if(cond.id == "not")
       {
        exprt tmp;
        tmp.swap(cond.operands()[0]);
        cond.swap(tmp);
        truevalue.swap(falsevalue);
       }

      result = simplify_if_cond(cond) && result;
      result = simplify_if_branch(truevalue, falsevalue, cond) && result;
     }

    if(cond.is_true())
     {
      exprt tmp;
      tmp.swap(truevalue);
      expr.swap(tmp);
      return FALSE;
     }

    if(cond.is_false())
     {
      exprt tmp;
      tmp.swap(falsevalue);
      expr.swap(tmp);
      return FALSE;
     }
   }

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_switch

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_switch(exprt &expr)
 {
  // std::vector<exprt> &operands=expr.operands();

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_boolean

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_boolean(exprt &expr)
 {
  if(expr.find("operands").is_nil()) return TRUE;

  std::vector<exprt> &operands=expr.operands();

  if(expr.type().id!="bool") return TRUE;

  if(expr.id=="not")
   {
    if(operands.size()!=1 ||
       operands[0].type().id!="bool")
      return TRUE;

    if(operands[0].id=="not") // (not not a) == a
     {
      if(operands[0].operands().size()==1)
       {
        exprt tmp(operands[0].operands()[0]);
        expr.swap(tmp);
        return FALSE;
       }
     }
    else if(operands[0].is_false())
     {
      expr.make_true();
      return FALSE;
     }
    else if(operands[0].is_true())
     {
      expr.make_false();
      return FALSE;
     }

    if(operands[0].id=="=" || operands[0].id=="notequal")
     {
      exprt tmp(operands[0]);
      tmp.id=(operands[0].id=="=")?"notequal":"=";
      expr.swap(tmp);
      return FALSE;
     }
   }
  else if(expr.id=="=>")
   {
    if(operands.size()!=2 ||
       operands[0].type().id!="bool" ||
       operands[1].type().id!="bool")
      return TRUE;

    if(operands[0].is_false())
     {
      expr.make_true();
      return FALSE;
     }
    else if(operands[0].is_true())
     {
      exprt tmp(operands[1]);
      expr.swap(tmp);
      return FALSE;
     }
    else if(operands[1].is_false())
     {
      expr.id="not";
      operands.erase(operands.begin()+1);
      return FALSE;
     }
    else if(operands[1].is_true())
     {
      expr.make_true();
      return FALSE;
     }    
   }
  else if(expr.id=="<=>")
   {
    if(operands.size()!=2 ||
       operands[0].type().id!="bool" ||
       operands[1].type().id!="bool")
      return TRUE;

    if(operands[0].is_false())
     {
      expr.id="not";
      operands.erase(operands.begin());
      return FALSE;
     }
    else if(operands[0].is_true())
     {
      exprt tmp(operands[1]);
      expr.swap(tmp);
      return FALSE;
     }
    else if(operands[1].is_false())
     {
      expr.id="not";
      operands.erase(operands.begin()+1);
      return FALSE;
     }
    else if(operands[1].is_true())
     {
      exprt tmp(operands[0]);
      expr.swap(tmp);
      return FALSE;
     }    
   }
  else if(expr.id=="or" || expr.id=="and" || expr.id=="xor")
   {
    if(operands.size()==0) return TRUE;

    bool result=TRUE;

    exprt::operandst::const_iterator last=operands.end();

    for(exprt::operandst::iterator it=operands.begin();
        it!=operands.end();)
     {
      if(it->type().id!="bool") return TRUE;

      bool erase=FALSE;
      bool is_true=it->is_true();
      bool is_false=it->is_false();

      if(expr.id=="and" && is_false)
       {
        expr.make_false();
        return FALSE;
       }
      else if(expr.id=="or" && is_true)
       {
        expr.make_true();
        return FALSE;
       }

      if(expr.id=="and")
        erase=is_true;
      else
        erase=is_false;

      if(last!=operands.end() &&
         *it==*last &&
         (expr.id=="or" || expr.id=="and"))
        erase=TRUE; // erase duplicate operands

      if(erase)
       {
        it=operands.erase(it);
        result=FALSE;
       }
      else
       {
        last=it;
        it++;
       }
     }

    if(operands.size()==0)
     {
      if(expr.id=="and")
        expr.make_true();
      else
        expr.make_false();

      return FALSE;
     }
    else if(operands.size()==1)
     {
      exprt tmp(operands[0]);
      expr.swap(tmp);
      return FALSE;
     }

    return result;
   }
  else if(expr.id=="=" || expr.id=="notequal")
   {
    if(operands.size()==2 && operands[0]==operands[1])
     {
      if(expr.id=="=")
       {
        expr.make_true();
        return FALSE;
       }
      else
       {
        expr.make_false();
        return FALSE;
       }
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_inequality

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_inequality(exprt &expr)
 {
  std::vector<exprt> &operands=expr.operands();

  if(expr.type().id!="bool") return TRUE;

  if(operands.size()!=2) return TRUE;

  // types must match
  if(expr.operands()[0].type()!=expr.operands()[1].type())
    return TRUE;

  if(operands[0]==operands[1])
   {
    if(expr.id=="<=" || expr.id==">=" || expr.id=="=")
     {
      expr.make_true();
      return FALSE;
     }
    else if(expr.id=="<" || expr.id==">" || expr.id=="notequal")
     {
      expr.make_false();
      return FALSE;
     }
   }

  mp_integer int_value0, int_value1;
  bool ok0, ok1;

  ok0=!to_integer(expr.operands()[0], int_value0);
  ok1=!to_integer(expr.operands()[1], int_value1);

  if(ok0 && ok1)
   {
    if(expr.id=="<=")
     {
      expr.make_bool(int_value0 <= int_value1);
      return FALSE;
     }
    else if(expr.id=="<")
     {
      expr.make_bool(int_value0 <  int_value1);
      return FALSE;
     }
    else if(expr.id==">=")
     {
      expr.make_bool(int_value0 >= int_value1);
      return FALSE;
     }
    else if(expr.id==">")
     {
      expr.make_bool(int_value0 >  int_value1);
      return FALSE;
     }
    else if(expr.id=="=")
     {
      expr.make_bool(int_value0 == int_value1);
      return FALSE;
     }
    else if(expr.id=="notequal")
     {
      expr.make_bool(int_value0 != int_value1);
      return FALSE;
     }
   }

  // is one zero?

  if(ok0 && int_value0==0)
   {
    if(expr.id=="<=" &&
       expr.operands()[1].type().id=="unsignedbv")
     {
      // zero is always smaller or equal something unsigned
      expr.make_true();
      return FALSE;
     }

    return zero_compare(expr, 0);
   }
  else if(ok1 && int_value1==0)
   {
    if(expr.id==">=" &&
       expr.operands()[0].type().id=="unsignedbv")
     {
      // something unsigned is always greater or equal zero
      expr.make_true();
      return FALSE;
     }

    return zero_compare(expr, 1);
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::zero_compare

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::zero_compare(exprt &expr, unsigned side)
 {
  unsigned other_side=1-side;

  exprt &operand=expr.operands()[other_side];

  if(operand.id=="unary-")
   {
    if(operand.operands().size()!=1) return TRUE;
    exprt tmp;
    tmp.swap(operand.operands()[0]);
    operand.swap(tmp);
    return FALSE;
   }
  else if(operand.id=="+")
   {
    // simplify a+-b=0 to a=b

    if(operand.operands().size()==2)
     {
      // if we have -b+a=0, make that a-b=0

      if(operand.operands()[0].id=="unary-")
        operand.operands()[0].swap(operand.operands()[1]);

      if(operand.operands()[1].id=="unary-" &&
         operand.operands()[1].operands().size()==1)
       {
        exprt tmp;
        tmp.id=expr.id;
        tmp.type()=expr.type();
        tmp.operands().resize(2);
        tmp.operands()[side].swap(operand.operands()[1].operands()[0]);
        tmp.operands()[other_side].swap(operand.operands()[0]);
        expr.swap(tmp);
        return FALSE;
       }
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_relation

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_relation(exprt &expr)
 {
  bool result=TRUE;

  if(expr.id=="=" || expr.id=="notequal" ||
     expr.id==">=" || expr.id=="<=" ||
     expr.id==">"  || expr.id=="<")
    result=simplify_inequality(expr) && result;

  //std::vector<exprt> &operands=expr.operands();

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_lambda

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_lambda(exprt &expr)
 {
  bool result=TRUE;

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify_index

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_index(exprt &expr)
 {
  if(expr.operands().size()!=2) return TRUE;

  if(expr.operands()[0].id=="lambda")
   {
    // simplify (lambda i: e)(x) to e[i/x]

    exprt &lambda_expr=expr.operands()[0];

    if(lambda_expr.operands().size()!=2) return TRUE;

    if(expr.operands()[1].type()==lambda_expr.operands()[0].type())
     {
      exprt tmp;

      tmp.swap(lambda_expr.operands()[1]);

      replace_expr(lambda_expr.operands()[0], expr.operands()[1], tmp);

      expr.swap(tmp);
      return FALSE;
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: sort_and_join

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

struct saj_tablet
 {
  char *id;
  char *type_id;
 } const saj_table[]=
 {
   { "+",      "integer"    },
   { "+",      "natural"    },
   { "+",      "real"       },
   { "+",      "complex"    },
   { "+",      "rational"   },
   { "+",      "unsignedbv" },
   { "+",      "signedbv"   },
   { "+",      "floatbv"    },
   { "*",      "integer"    },
   { "*",      "natural"    },
   { "*",      "real"       },
   { "*",      "complex"    },
   { "*",      "rational"   },
   { "*",      "unsignedbv" },
   { "*",      "signedbv"   },
   { "*",      "floatbv"    },
   { "and",    "bool"       },
   { "or",     "bool"       },
   { "xor",    "bool"       },
   { "bitand", "unsignedbv" },
   { "bitand", "signedbv"   },
   { "bitand", "floatbv"    },
   { "bitor",  "unsignedbv" },
   { "bitor",  "signedbv"   },
   { "bitor",  "floatbv"    },
   { "bitxor", "unsignedbv" },
   { "bitxor", "signedbv"   },
   { "bitxor", "floatbv"    },
   { NULL,     NULL         }
 };

bool sort_and_join(const std::string &id, const std::string &type_id)
 {
  for(unsigned i=0; saj_table[i].id!=NULL; i++)
    if(saj_table[i].id==id && saj_table[i].type_id==type_id)
      return TRUE;

  return FALSE;
 }

/*******************************************************************\

Function: sort_and_join

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::sort_and_join(exprt &expr)
 {
  bool result=TRUE;

  if(expr.find("operands").is_nil()) return TRUE;

  if(!::sort_and_join(expr.id, expr.type().id)) return TRUE;

  // check operand types

  forall_operands(it, expr)
    if(it->type()!=expr.type()) return TRUE;

  // join expressions

  for(unsigned i=0; i<expr.operands().size();)
   {
    bool join=TRUE;

    if(expr.operands()[i].id==expr.id && 
       expr.operands()[i].type()==expr.type())
     {
      forall_operands(it, expr.operands()[i])
        if(it->type()!=expr.type()) join=FALSE;
     }
    else
      join=FALSE;

    if(join)
     {
      unsigned no_joined=expr.operands()[i].operands().size();

      expr.operands().insert(expr.operands().begin()+i+1,
        expr.operands()[i].operands().begin(), 
        expr.operands()[i].operands().end());

      expr.operands().erase(expr.operands().begin()+i);

      i+=no_joined;

      result=FALSE;
     }
    else
      i++;
   }

  // sort it

  result=sort_operands(expr.operands()) && result;

  return result;
 }

/*******************************************************************\

Function:

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_unary_minus(exprt &expr)
 {
  if(expr.operands().size()!=1)
    return TRUE;

  if(!is_number(expr.type()))
    return TRUE;

  if(expr.type()!=expr.operands()[0].type())
    return TRUE;

  exprt &operand=expr.operands()[0];

  if(operand.id=="unary-")
   {
    if(!operand.operands().size()!=1)
      return TRUE;

    if(!is_number(operand.operands()[0].type()))
      return TRUE;

    exprt tmp;
    tmp.swap(expr.operands()[0].operands()[0]);
    expr.swap(tmp);
    return FALSE;
   }
  else if(operand.id=="constant")
   {
    const std::string &type_id=expr.type().id;

    if(type_id=="integer" ||
       type_id=="signedbv" ||
       type_id=="unsignedbv")
     {
      mp_integer int_value;

      if(to_integer(expr.operands()[0], int_value))
        return TRUE;

      exprt tmp;
      tmp.type()=expr.type();

      if(from_integer(-int_value, tmp))
        return TRUE;

      expr.swap(tmp);

      return FALSE;
     }
   }

  return TRUE;
 }

/*******************************************************************\

Function: simplify_implt::simplify_node

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify_node(exprt &expr)
 {
  if(expr.find("operands").is_nil()) return TRUE;

  bool result=TRUE;

  result=sort_and_join(expr) && result;

  if(expr.id=="typecast")
    result=simplify_typecast(expr) && result;
  else if(expr.id=="=" || expr.id=="notequal" ||
          expr.id==">" || expr.id=="<" ||
          expr.id==">=" || expr.id=="<=")
    result=simplify_relation(expr) && result;
  else if(expr.id=="if")
    result=simplify_if_main(expr) && result;
  else if(expr.id=="lambda")
    result=simplify_lambda(expr) && result;
  else if(expr.id=="index")
    result=simplify_index(expr) && result;
  else if(expr.id=="switch")
    result=simplify_switch(expr) && result;
  else if(expr.id=="/")
    result=simplify_division(expr) && result;
  else if(expr.id=="mod")
    result=simplify_modulo(expr) && result;
  else if(expr.id=="mod" ||
          expr.id=="ashr" || expr.id=="lshr" || expr.id=="shl" ||
          expr.id=="bitnot" || expr.id=="bitand" || expr.id=="bitor" ||
          expr.id=="bitxor")
   {
   }
  else if(expr.id=="+" || expr.id=="-")
    result=simplify_addition_substraction(expr) && result;
  else if(expr.id=="*")
    result=simplify_multiplication(expr) && result;
  else if(expr.id=="unary-")
    result=simplify_unary_minus(expr) && result;
  else if(expr.id=="=>"  || expr.id=="<=>" ||
          expr.id=="not" || expr.id=="or" ||
          expr.id=="xor" || expr.id=="and")
    result=simplify_boolean(expr) && result;
  else if(expr.id=="comma")
   {
    if(expr.operands().size()!=0)
     {
      exprt tmp;
      tmp.swap(expr.operands()[expr.operands().size()-1]);
      expr.swap(tmp);
      result=FALSE;
     }
   }

  return result;
 }

/*******************************************************************\

Function: simplify_implt::simplify

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify(exprt &expr)
 {
  bool result=TRUE;

  if(!expr.find("operands").is_nil())
    Forall_operands(it, expr)
      if(!simplify(*it)) // recursive call
        result=FALSE;

  if(!simplify_node(expr)) result=FALSE;

  return result;
 }

/*******************************************************************\

Function: constant_propagationt::collect_node

  Inputs:

 Outputs: TRUE if node is to be collected
          FALSE otherwise

 Purpose:

\*******************************************************************/

bool constant_propagationt::collect_node(const exprt &expr)
 {
  if(expr.type().id!="bool") return FALSE;

  if((expr.id=="=" || expr.id=="notequal") && 
     expr.operands().size()==2 &&
     (expr.operands()[0].is_constant() || 
      expr.operands()[1].is_constant()))
    return TRUE;

  return FALSE;
 }

/*******************************************************************\

Function: constant_propagationt::collect

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void constant_propagationt::collect(exprt &expr, bool premise)
 {
  if(expr.type().id!="bool") return;

  if(!expr.is_true() && !expr.is_false())
   {
    if(premise)
      bool_is_true.insert(expr);
    else
      bool_is_false.insert(expr);
   }

  if(collect_node(expr))
   {
    bool equal=(expr.id=="=");
    if(!premise) equal=!equal;

    if(equal)
     {
      if(expr.operands()[0].is_constant())
        constants.insert(std::pair<exprt, exprt>
          (expr.operands()[1], expr.operands()[0]));
      else
        constants.insert(std::pair<exprt, exprt>
          (expr.operands()[0], expr.operands()[1]));

      expr.set("#protected", TRUE);

      return;
     }
   }
  else if(expr.id=="and")
   {
    if(premise)
     {
      Forall_operands(it, expr) collect(*it, premise);
      return;
     }
   }
  else if(expr.id=="or")
   {
    if(!premise)
     {
      Forall_operands(it, expr) collect(*it, premise);
      return;
     }
   }
  else if(expr.id=="not")
   {
    if(expr.operands().size()==1)
     {
      collect(expr.operands()[0], !premise);
      return;
     }
   }
  else if(expr.id=="=>")
   {
    if(expr.operands().size()==2 && !premise)
     {
      collect(expr.operands()[0], TRUE);
      collect(expr.operands()[1], FALSE);
      return;
     }
   }
 }

/*******************************************************************\

Function: constant_propagation::apply_constants

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool constant_propagationt::apply_constants(exprt &expr)
 {
  bool result=TRUE;

  if(expr.get_bool("#protected"))
    return TRUE;

  if(expr.id=="address_of" || expr.id=="dereference")
    return TRUE;

  Forall_operands(it, expr)
    result=apply_constants(*it) && result; // recursive call

  mapt::iterator it=constants.find(expr);

  if(it!=constants.end())
   {
    expr=it->second;
    result=FALSE;
   }
    
  return result;
 }

/*******************************************************************\

Function: constant_propagation::clean_protected

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void constant_propagationt::clean_protected(exprt &expr)
 {
  Forall_operands(it, expr)
    clean_protected(*it); // recursive call

  if(expr.get_bool("#protected"))
    expr.remove("#protected");
 }

/*******************************************************************\

Function: constant_propagation::apply_bool

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool constant_propagationt::apply_bool(exprt &expr,
                                       bool premise,
                                       bool claim)
 {
  if(expr.type().id!="bool")
   {
    bool result=TRUE;

    Forall_operands(it, expr)
      result=apply_bool(*it, TRUE, TRUE) && result;

    return result;
   }

  // lookup expression

  if(premise)
    if(!apply_bool(expr, TRUE)) return FALSE;

  if(claim)
    if(!apply_bool(expr, FALSE)) return FALSE;

  // not found
  // check operands
  
  if(expr.id=="or" || expr.id=="and")
   {
    if((expr.id=="or" && premise) ||
       (expr.id=="and" && claim))
     {
      bool result=TRUE;

      Forall_operands(it, expr)
        result=apply_bool(*it, premise, claim) && result;

      return result;
     }
   }
  else if(expr.id=="not")
   {
    if(expr.operands().size()==1)
      return apply_bool(expr.operands()[0], claim, premise);
   }
  else if(expr.id=="=>")
   {
    if(expr.operands().size()==2 && premise)
     {
      bool result=TRUE;
      result=apply_bool(expr.operands()[0], claim, premise) && result;
      result=apply_bool(expr.operands()[1], premise, claim) && result;
      return result;
     }
   }
  else // anything else was not used for learning
   {
    bool result=TRUE;

    Forall_operands(it, expr)
      result=apply_bool(*it, TRUE, TRUE) && result;

    return result;
   }
    
  return TRUE;
 }

/*******************************************************************\

Function: constant_propagation::apply_bool

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool constant_propagationt::apply_bool(exprt &expr,
                                       bool premise)
 {
  sett &mapping=premise?bool_is_false:bool_is_true;

  sett::iterator it=mapping.find(expr);

  if(it==mapping.end())
    return TRUE;

  if(premise)
    expr.make_false();
  else
    expr.make_true();
    
  return FALSE;
 }

/*******************************************************************\

Function: constant_propagationt::constant_propagation

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool constant_propagationt::constant_propagation(h_sequentt &sequent)
 {
  Forall_formulae(it, sequent.premise)
    collect(*it, TRUE);

  if(sequent.code.is_nil())
    Forall_formulae(it, sequent.claim)
      collect(*it, FALSE);

  bool result=TRUE;

  Forall_formulae(it, sequent.premise)
    if(it->match)
     {
      result=apply_bool(*it, TRUE, FALSE) && result;
      result=apply_constants(*it) && result; 
      clean_protected(*it);
     }

  if(sequent.code.is_nil())
   {
    Forall_formulae(it, sequent.claim)
      if(it->match)
       {
        result=apply_bool(*it, FALSE, TRUE) && result;
        result=apply_constants(*it) && result;
        clean_protected(*it);
       }
   }

  return result;
 }

/*******************************************************************\

Function: do_constant_propagation

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::do_constant_propagation(h_sequentt &sequent)
 {
  constant_propagationt constant_propagation;

  return constant_propagation.constant_propagation(sequent);
 }

/*******************************************************************\

Function: simplify

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify_implt::simplify(h_sequentt &sequent)
 {
  bool result=TRUE, did_something;

  do
   {
    did_something=FALSE;

    Forall_formulae(it, sequent.premise)
      if(it->match)
        if(!simplify(*it))
         {
          did_something=TRUE;
          it->changed=TRUE;
         }

    Forall_formulae(it, sequent.claim)
      if(it->match)
        if(!simplify(*it))
         {
          did_something=TRUE;
          it->changed=TRUE;
         }

    if(constant_propagation)
      if(!do_constant_propagation(sequent))
	did_something=TRUE;

    if(did_something) result=FALSE;
   }
  while(did_something);
  
  return result;
 }

/*******************************************************************\

Function:

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify(exprt &expr)
 {
  simplify_implt simplify_impl;

  return simplify_impl.simplify(expr);
 }

/*******************************************************************\

Function:

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify(h_sequentt &sequent, bool constant_propagation,
              bool simplify_if)
 {
  simplify_implt simplify_impl;

  simplify_impl.constant_propagation=constant_propagation;
  simplify_impl.simplify_if=simplify_if;

  return simplify_impl.simplify(sequent);
 }

/*******************************************************************\

Function: simplify

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool simplify(h_sequentt &sequent, bool constant_propagation,
              bool simplify_if,
              const fnumst &fnums)
 {
  sequent.match_fnum(fnums);
  return simplify(sequent, constant_propagation, simplify_if);
 }
