/*******************************************************************\

Module: SMV Language Type Checking

Author: Daniel Kroening, kroening@cs.cmu.edu

\*******************************************************************/

#include <expr_util.h>
#include <trfalse.h>
#include <typecheck.h>

#include "smv_typecheck.h"
#include "expr2smv.h"

class smv_typecheckt:public typecheckt
 {
 public:
  smv_typecheckt(smv_parset &_smv_parse,
                 contextt &_context,
                 std::ostream &_err):
                 typecheckt(_err),
                 smv_parse(_smv_parse), context(_context)
   { }

  virtual ~smv_typecheckt() { }

  void convert(smv_parset::modulet &smv_module);
  void convert(expr_listt &list);
  void convert(const smv_parset::mc_vart &var, typet &dest);
  void convert(smv_parset::mc_varst &vars);
  void convert_define(const exprt &expr);
  virtual void convert(exprt &exprt);
  virtual void typecheck();

  // overload to use SMV syntax
  
  //virtual std::string to_string(const typet &type);
  virtual std::string to_string(const exprt &expr);

 protected:
  smv_parset &smv_parse;
  contextt &context;

  smv_parset::modulet *module;
 };

/*******************************************************************\

Function: smv_typecheckt::convert

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::convert(exprt &expr)
 {
  Forall_operands(it, expr)
    convert(*it);

  if(expr.id=="symbol" || 
     expr.id=="next_symbol")
   {
    const std::string &identifier=expr.get("identifier");

    smv_parset::mc_varst::const_iterator it=
      module->vars.find(identifier);

    if(it==module->vars.end())
     {
      err_location(expr);
      err << "Variable " << identifier << " not found" << std::endl;
      throw 0;
     }

    convert(it->second, expr.type());

    expr.set("identifier", "smv::vars::"+identifier);
   }
  else if(expr.id=="and" || expr.id=="or" || expr.id=="not")
   {
    if(expr.operands().size()==0)
     {
      err_location(expr);
      err << "Expected operands for " << expr << std::endl;
      throw 0;
     }

    expr.type()=expr.operands()[0].type();

    forall_operands(it, expr)
     {
      if(expr.type()!=it->type())
       {
        err_location(expr);
        err << "Expected operands of same type in " 
            << to_string(expr) << std::endl;
        throw 0;
       }
     }

    if(expr.type().id!="bool")
     {
      if(expr.id=="and")
        expr.id="bitand";
      else if(expr.id=="or")
        expr.id="bitand";
      else if(expr.id=="not")
        expr.id="bitnot";
     }
   }
  else if(expr.id=="=>")
   {
    if(expr.operands().size()!=2)
     {
      err_location(expr);
      err << "Expected two operands for " << expr << std::endl;
      throw 0;
     }

    expr.type()=typet("bool");

    forall_operands(it, expr)
     {
      if(expr.type()!=it->type())
       {
        err_location(expr);
        err << "Expected operands of boolean type in "
            << to_string(expr) << std::endl;
        throw 0;
       }
     }
   }
  else if(expr.id=="=" || expr.id=="notequal")
   {
    if(expr.operands().size()!=2)
     {
      err_location(expr);
      err << "Expected two operands for " << expr << std::endl;
      throw 0;
     }

    expr.type()=typet("bool");

    if(expr.operands()[0].type()!=expr.operands()[1].type())
     {
      err_location(expr);
      err << "Expected operands of same type in "
          << to_string(expr) << std::endl;
      throw 0;
     }
   }
  else if(expr.id=="extractbit")
   {
    if(expr.operands().size()!=1)
     {
      err_location(expr);
      err << "Expected one operand for " << expr << std::endl;
      throw 0;
     }

    expr.type()=typet("bool");

    const std::string &optypeid=expr.operands()[0].type().id;

    if(optypeid!="unsignedbv" &&
       optypeid!="signedbv" &&
       optypeid!="bv")
     {
      err_location(expr);
      err << "Expected bit vector operand in "
          << to_string(expr) << std::endl;
      throw 0;
     }

    if(atoi(expr.get("index").c_str())>=
       atoi(expr.operands()[0].type().get("width").c_str()))
     {
      err_location(expr);
      err << "Index out of bounds in "
          << to_string(expr) << std::endl;
      throw 0;
     }    
   }
  else if(expr.id=="constant")
   {
    const std::string &value=expr.get("value");

    if(value=="0")
      expr.make_false();
    else if(value=="1")
      expr.make_true();
    else
     {
      err_location(expr);
      err << "Unexpected constant: " << value << std::endl;
      throw 0;
     }    
   }
  else if(expr.id=="cond") // cases
   {
    if(expr.operands().size()<2)
     {
      err_location(expr);
      err << "Expected at least two operands for " << expr.id
          << " expression" << std::endl;
      throw 0;
     }

    if(expr.operands().size()%2)
     {
      err_location(expr);
      err << "Expected even number of operands for " << expr.id
          << " expression" << std::endl;
      throw 0;
     }

    expr.type()=expr.operands()[1].type();

    bool condition=TRUE;

    forall_operands(it, expr)
     {
      if(condition)
       {
        if(it->type()!=typet("bool"))
         {
          err_location(expr);
          err << "Condition must be of boolean type in "
              << to_string(expr) << std::endl;
          throw 0;
         }
       }
      else
       {
        if(it->type()!=expr.type())
         {
          err_location(expr);
          err << "Expressions must be of same type in "
              << to_string(expr) << std::endl;
          throw 0;
         }
       }

      condition=!condition;
     }
   } 
  else if(expr.id=="AG")
   {
    if(expr.operands().size()!=1)
     {
      err_location(expr);
      err << "Expected two operands for " << expr.id
          << " operand" << std::endl;
      throw 0;
     }

    expr.type()=typet("bool");

    if(expr.type()!=expr.operands()[0].type())
     {
      err_location(expr);
      err << "Expected operand of boolean type in "
          << to_string(expr) << std::endl;
      throw 0;
     }
   }
  else
   {
    err << "No type checking for " << expr << std::endl;
    throw 0;
   }
 }

/*******************************************************************\

Function: smv_typecheckt::to_string

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

std::string smv_typecheckt::to_string(const exprt &expr)
 {
  std::string result;
  expr2smv(expr, result);
  return result;
 }

/*******************************************************************\

Function: smv_typecheckt::convert

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::convert(expr_listt &list)
 {
  Forall_expr_list(it, list)
   {
    convert(*it);

    if(it->type().id!="bool")
     {
      err_location(*it);
      err << "boolean expression expected" << std::endl;
      throw 0;
     }
   }
 }

/*******************************************************************\

Function: smv_typecheckt::convert

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::convert(const smv_parset::mc_vart &var, typet &dest)
 {
  switch(var.type)
   {
   case smv_parset::mc_vart::ARRAY:
    dest.id="unsignedbv";
    dest.set("width", var.size);
    break;

   case smv_parset::mc_vart::BOOL:
    dest=typet("bool");
    break;

   default:
    err << "unexpected type" << std::endl;
    throw 0;
   }
 }

/*******************************************************************\

Function: smv_typecheckt::convert_vars

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::convert(smv_parset::mc_varst &vars)
 {
  symbolt symbol;

  symbol.mode="smv";
  symbol.module=smv_module_symbol(module->name);

  for(smv_parset::mc_varst::const_iterator it=vars.begin();
      it!=vars.end(); it++)
   {
    const smv_parset::mc_vart &var=it->second;

    symbol.base_name=it->first;
    symbol.name="smv::vars::"+symbol.base_name;
    symbol.value.make_nil();
    symbol.is_input=!var.used_with_next;

    convert(var, symbol.type);

    context.add(symbol);
   }
 }

/*******************************************************************\

Function: smv_typecheckt::convert_define

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::convert_define(const exprt &expr)
 {
  if(expr.id!="=" || expr.operands().size()!=2)
    throw "convert_define expects equality";

  const exprt &op0=expr.operands()[0];
  const exprt &op1=expr.operands()[1];

  if(op0.id!="symbol")
    throw "convert_define expects symbol on left hand side";

  const std::string &identifier=op0.get("identifier");

  symbolst::iterator it=context.symbols.find(identifier);

  if(it==context.symbols.end())
    throw "convert_define failed to find symbol "+identifier;

  if(!it->second.value.is_nil())
   {
    err_location(expr);
    err << "symbol " << it->second.base_name 
        << " defined twice" << std::endl;
    err << "original definition: " << it->second.value << std::endl;
    throw 0;
   }

  it->second.value=op1;
  it->second.is_input=FALSE;
  it->second.is_macro=TRUE;
 }

/*******************************************************************\

Function: smv_typecheckt::convert

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::convert(smv_parset::modulet &smv_module)
 {
  module=&smv_module;

  convert(smv_module.vars);

  symbolt module_symbol;

  module_symbol.base_name=smv_module.name;
  module_symbol.name=smv_module_symbol(smv_module.name);
  module_symbol.module=module_symbol.name;
  module_symbol.type.id="module";
  module_symbol.mode="smv";
  module_symbol.value.id="trans";
  module_symbol.value.operands().resize(3);

  convert(smv_module.spec);
  convert(smv_module.init);
  convert(smv_module.trans);
  convert(smv_module.define);
  convert(smv_module.invar);
  convert(smv_module.fairness);

  Forall_expr_list(it, smv_module.define)
    convert_define(*it);

  Forall_expr_list(it, smv_module.init)
    module_symbol.value.operands()[1].move_to_operands(*it);

  Forall_expr_list(it, smv_module.trans)
    module_symbol.value.operands()[2].move_to_operands(*it);

  Forall_operands(it, module_symbol.value)
    gen_and(*it);

  context.add(module_symbol);
 }

/*******************************************************************\

Function: smv_typecheckt::typecheck

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

void smv_typecheckt::typecheck()
 {
  for(smv_parset::modulest::iterator it=smv_parse.modules.begin();
      it!=smv_parse.modules.end(); it++)
    convert(*it);
 }

/*******************************************************************\

Function: smv_typecheck

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

bool smv_typecheck(smv_parset &smv_parse,
                   contextt &context,
                   std::ostream &err)
 {
  smv_typecheckt smv_typecheck(smv_parse, context, err);
  return smv_typecheck.typecheck_main();
 }

/*******************************************************************\

Function: smv_module_symbol

  Inputs:

 Outputs:

 Purpose:

\*******************************************************************/

std::string smv_module_symbol(const std::string &module)
 {
  return "smv::module::"+module;
 }
