(* checker.ml *)
(* 15-411 *)
(* by Roland Flury *)
(* @version $Id: checker.ml,v 1.7 2003/09/17 18:54:07 rflury Exp $ *)

module HA=Hashtbl
module E=Errormsg
module TP = Temp
open Absyn

exception CheckerError of string
exception EXIT

(* Prototype of a function *)
type funType = ctype * (string * ctype) list

(* Description of the fields of a struct *)
type structType = (string * ctype) list


(*** Conext ***)

(* Function declarations *)
let sigma : (string, funType) HA.t = HA.create 57

(* Struct-type context *)
let delta : (string, structType) HA.t = HA.create 57

(* Variable context *)
let gamma : (string, ctype) HA.t  = HA.create 57 

(* Context Xi *)
let returnType = ref Void

(* set to true if there is a return on ALL execution paths *)
let returnSeen = ref false
    
(*****************************************************************************)
(* HELPERS *)
(*****************************************************************************)

(* Translation from  typ to ctype *)
let rec typ2ctype = function
  | Int -> TyInt
  | Bool -> TyBool
  | Pointer(t) -> TyPointer(typ2ctype t)
  | User(Ident(id,_)) -> TyRecord id
  | VOID -> Void

(* Returns the type of an expression *)
let rec ctypeOfExp = function
  | ConstExp(_,t) -> t
  | OpExp(_,_,_,_,t) -> t
  | Alloc(_,_,_,_,t) -> t
  | StackAlloc(_,_,_,_,t) -> t
  | Offset(_,_,t) -> t
  | Size(_,_,t) -> t
  | Reference(_,_,t) -> t
  | Call(_,_,_,t) -> t
  | LVal(l) -> ctypeOfLval l

(* Returns the type of an L-Value *)
and ctypeOfLval = function
  | Var(_,t) -> t
  | Deref(_,_,t) -> t
  | Field(_,_,_,t) -> t

(* Returns a string representation of a ctype *)
let ctype2string t = 
  let rec iter t expand = 
    match t with
    | TyInt -> "int"
    | TyBool -> "bool"
    | TyRecord(id) -> 
	if(expand) then (
	  (try
	    let fields = HA.find delta id in
	    ("struct " ^ id ^ "{" ^ 
	     (List.fold_left (fun res (s,t) -> 
	       if(0 = Pervasives.compare res "") then
		 iter t false
	       else
		 (res ^ ", " ^ (iter t false))
			     ) "" fields) ^ "}")
	  with Not_found -> 
	    E.error (0,0) "Checker failed in ctyp2string: lost struct-info";
	    raise EXIT
	  )) else
	  id	    
    | TyPointer(t) -> "(" ^ (iter t expand) ^ ")*"
    | NS -> "NS"
    | Void -> "void"
  in
  iter t true

(* Print a 'Type Mismatch <type> <type> ' errormsg *)
let typeMismatch pos typrecv typexp = 
  E.error pos ("Type mismatch: This Expression has type \n  " ^
		      ctype2string typrecv ^ "\nbut expected type is\n  " ^
		      ctype2string typexp)

(* Print a 'Type Mismatch <type> <text>' errormsg *)
let typeMismatchText pos typerecv textexp = 
  E.error pos ("Type mismatch: This Expression has type \n  " ^
		      ctype2string typerecv ^ "\n" ^ textexp)


(* Returns the struct declaration given a struct Name *)
let getStructType sName = 
  (try
    HA.find delta sName
  with Not_found -> 
    raise (CheckerError ("Type declaration for struct-type " ^ 
			 sName ^ " not found in delta"))
  )

(* Returns the function declaration given a function Name *)
let getFunType fName = 
  (try
    HA.find sigma fName
  with Not_found -> 
    raise (CheckerError ("Function declaration for function " ^ 
			 fName ^ " not found in sigma"))
  )

let setFunType fName typ = 
  HA.add sigma fName typ

(*****************************************************************************)
(* Types *)
(*****************************************************************************)

(* Valid types *)

let rec tValid = function
  | TyInt -> true
  | TyBool -> true
  | TyPointer(t) -> tValid t
  | TyRecord(id) -> HA.mem delta id
  | _ -> false


(* Compatibility of types *)
let rec typCmp a b = 
  match (a,b) with
  | (TyInt,TyInt) -> true
  | (TyBool,TyBool) -> true
  | (NS, NS) -> true
  | (TyPointer(NS), TyPointer(NS)) -> true
  | (TyPointer(pa), TyPointer(NS)) -> tValid pa
  | (TyPointer(NS), TyPointer(pb)) -> tValid pb
  | (TyPointer(pa), TyPointer(pb)) -> typCmp pa pb
  | (TyRecord(a),TyRecord(b)) -> 
      ((HA.mem delta a) && (Pervasives.compare a b) = 0)
  | _ -> false


(*****************************************************************************)
(* Extension of a context *)
(*****************************************************************************)

let addDecl = function
  | FunDecl(Ident(fName,p1), ty, p2, tylist, _) -> 
      (* Check for redeclaration *)
      if(HA.mem sigma fName) then begin
	E.error p1 ("Redeclaration of function '" ^ fName ^ "'");
	raise EXIT
      end;
      (* Check Return type *)
      let cty = typ2ctype ty in
      (match (cty, tValid cty) with
      | (Void,_) -> ()
      | (TyRecord(_),_) -> 
	  E.error p2 ("Invalid return type: only int, bool" ^ 
		      " and pointer types are allowed");
	  raise EXIT
      | (_, true) -> 
	  (* Only allow void 'main' functions *)
	  if(0=Pervasives.compare "main" fName) then (
	    E.error p2 ("Main must be a void-function");
	    raise EXIT
	   ) else
	    ()
      | _ -> E.error p2 "Invalid return type"
      );
      (* Only allow 'main' without arguments *)
      if(0=Pervasives.compare "main" fName) then (
	if(List.length tylist != 0) then (
	  E.error p2 ("Main is not allowed to have parameters");
	  raise EXIT
	 )
       );

      (* Check that all parameters differ and have propper type *)
      let tmp = ref [] in (* catch parameter name reuses *)
      HA.add sigma fName 
	(cty, 
	 List.map (fun (Ident(id,pi),ty,pty) -> 
	   if( List.exists (fun s -> s=id) !tmp ) then begin
	     E.error pi ("Parameter name '" ^ id ^ "' reused");
	     raise EXIT
	   end;
	   tmp := id :: !tmp;
	   let cty = typ2ctype ty in
	   if(not (tValid cty)) then begin
	     E.error pty "Invalid type";
	     raise EXIT
	   end;
	   (id, cty)
		  ) tylist
	)
  | StructDecl(Ident(sName,p1), fields) -> 
      (* Check for redeclaration *)
      if(HA.mem delta sName) then begin
	E.error p1 ("Redeclaration of struct '" ^ sName ^ "'");
	raise EXIT
      end;
      (* Check that all fields differ and have propper type *)
      let tmp = ref [] in (* catch field name reuses *)
      HA.add delta sName 
	(List.map (fun (Ident(id,pi),ty,pty) -> 
	  if( List.exists (fun s -> s=id) !tmp ) then begin
	    E.error pi ("Field name '" ^ id ^ "' reused");
	    raise EXIT
	  end;
	  tmp := id :: !tmp;
	  let cty = typ2ctype ty in
	  (* Check whether valid recrusive struct decl *)
	  let rec isRecStruct name = function
	    | TyPointer(TyRecord(n)) -> 0 = (Pervasives.compare n name)
	    | TyPointer(t) -> isRecStruct name t
	    | _ -> false
	  in
	  if(not (tValid cty) && not (isRecStruct sName cty)) then begin
	    match cty with
	    | TyRecord(id2) -> 
		if(0 = Pervasives.compare id2 sName) then (
		  E.error pty "Incomplete type ";
		  raise EXIT;
		 ) else (
		  E.error pty "Invalid type";
		  raise EXIT
		 )
	    | _ -> 
		E.error pty "Invalid type";
		raise EXIT
	  end;
	  (id, cty)
		  ) fields
	)
  | VarDecl(Ident(id,p1),ty,p) -> 
      if(HA.mem gamma id) then (
	E.error p1 ("Redeclaration of variable '" ^ id ^ "'");
	raise EXIT
      );
      let cty = typ2ctype ty in
      if(not (tValid cty)) then (
	E.error p1 ("Invalid type");
	raise EXIT
      );
      HA.add gamma id cty


(*****************************************************************************)
(* Expressions *)
(*****************************************************************************)

let rec exp = function
(* Constants *)
  | ConstExp(IntConst(i,p),_) -> ConstExp(IntConst(i,p),TyInt)
  | ConstExp(BoolConst(b,p),_) -> ConstExp(BoolConst(b,p),TyBool)
  | ConstExp(NULL(p), _) -> ConstExp(NULL(p), TyPointer(NS))
(* Binary Operators *)
  | OpExp(e1, op, Some(e2), p, _) -> 
      let exp1 = exp e1 in
      let exp2 = exp e2 in
      let ty1 = ctypeOfExp exp1 in
      let ty2 = ctypeOfExp exp2 in
      (match (ty1, op, ty2) with
      | (TyInt, (PLUS|MINUS|TIMES|DIVIDE|MOD|BITAND|
	         BITOR|BITXOR|SHIFTLEFT|SHIFTRIGHT), TyInt) -> 
	  OpExp(exp1, op, Some(exp2), p, TyInt)
      | (TyInt, (EQ|NEQ|LT|LTE|GT|GTE), TyInt) -> 
	  OpExp(exp1, op, Some(exp2), p, TyBool)
      | (TyBool, (LOGICAND|LOGICOR|LOGICXOR), TyBool) -> 
	  OpExp(exp1, op, Some(exp2), p, TyBool)
      | (TyPointer(t), (PPLUS|PMINUS), TyInt) -> 
	  OpExp(exp1, op, Some(exp2), p, TyPointer(t))
      | (TyPointer(t1), (PEQ|PNEQ), TyPointer(t2)) -> 
	  if(not (typCmp ty1 ty2)) then (
	    E.warning p ("The types of the operands differ - " ^
			 "This expression will always evaluate to false");
	    ConstExp(BoolConst(false,p), TyBool)
	   ) else
	    OpExp(exp1, op, Some(exp2), p, TyBool)
      | _ -> E.error p "Type mismatch"; raise EXIT
      )    
(* Unary Operators *)
  | OpExp(e1, op, None, p, _) -> 
      let exp1 = exp e1 in
      let ty1 = ctypeOfExp exp1 in
      (match (ty1, op) with
      | (TyInt, (BITNOT|UMINUS)) -> 
	  OpExp(exp1, op, None, p, TyInt)
      | (TyBool, LOGICNOT) -> 
	  OpExp(exp1, op, None, p, TyBool)
      | _ -> E.error p "Type mismatch"; raise EXIT
      )
  | LVal(l) -> LVal(lval l)
  | Reference(l, p, _) -> 
      let cl = lval l in
      (* Detect escaping variables (conservative approximation) *) 
      (match cl with
      | Var(Ident(name,_),_) -> 
	  TP.addEscape name
      | _ -> ()
      );
      Reference(cl, p, TyPointer(ctypeOfLval cl))
  | Alloc(e, p1, ty, p2, _) -> 
      let cty = typ2ctype ty in
      let exp1 = exp e in
      let te = ctypeOfExp exp1 in
      if(not (typCmp te TyInt)) then (
	typeMismatch p1 te TyInt;
	raise EXIT
       );
      if(not (tValid cty)) then (
	E.error p2 "Invalid type";
	raise EXIT
       );
      Alloc(exp1, p1, ty, p2, TyPointer(cty))
  | StackAlloc(e, p1, ty, p2, _) -> assert false
  | Offset(e,p,_) -> 
      let exp1 = exp e in
      let cty = ctypeOfExp exp1 in
      (match cty with
      | TyPointer(_) -> Offset(exp1,p,TyInt)
      | _ -> 
	  typeMismatchText p cty "but a pointer-type is expected";
	  raise EXIT
      )
  | Size(e,p,_) -> 
      let exp1 = exp e in
      let cty = ctypeOfExp exp1 in
      (match cty with
      | TyPointer(_) -> Size(exp1,p,TyInt)
      | _ -> 
	  typeMismatchText p cty "but a pointer-type is expected";
	  raise EXIT
      )
  | Call(Ident(id,p1), param, p2, _) -> 
      let (rettype, reqParam) = 
	(try
	  HA.find sigma id 
	with Not_found -> 
	  E.error p1 ("Unknown function '" ^ id ^ "'");
	  raise EXIT
	) in
      let li = 
	(try
	  List.map2 (fun (e,p1) (_,prototy) -> 
	    let exp1 = exp e in
	    let cty = ctypeOfExp exp1 in
	    if(not (typCmp prototy cty)) then (
	      typeMismatch p1 cty prototy;
	      raise EXIT
	     );
	    (exp1,p1)
		    ) param reqParam
	with Invalid_argument(s) -> 
	  E.error p2 "Wrong number of arguments";
	  raise EXIT;
	) in
      Call(Ident(id,p1), li, p2, rettype)

(* L-Values *)
and lval = function
  | Var(Ident(i,p),_) -> 
      (try
	let cty = HA.find gamma i in
	Var(Ident(i,p), cty)
      with Not_found -> 
	E.error p ("Unknown variable '" ^ i ^ "'");
	raise EXIT
      )	
  | Deref(e, p, _) -> 
      let exp1 = exp e in
      let cty = ctypeOfExp exp1 in
      (match cty with
      | TyPointer(t) -> Deref(exp1, p, t)
      | _ -> 
	  typeMismatchText p cty "but a pointer-type is expected";
	  raise EXIT
      )
  | Field(l, Ident(i,p1), p2, _) -> 
      let cl = lval l in
      let cty = ctypeOfLval cl in
      (match cty with
      | TyRecord(n) -> 
	  let proto = 
	    (try
	      HA.find delta n
	    with Not_found -> 
	      E.error p2 ("Unknown struct '" ^ n ^ "'");
	      raise EXIT
	    ) in
	  let (_,ty) = 
	    (try
	      List.find 
		(fun (s,_) -> (Pervasives.compare s i) = 0) proto
	    with Not_found -> 
	      E.error p1 ("Unknown field '" ^ i ^ "'");
	      raise EXIT
	    ) in
	  Field(cl, Ident(i,p1), p2, ty)
      | _ -> 
	  typeMismatchText p2 cty "but a struct-type is expected";
	  raise EXIT
      )


(*****************************************************************************)
(* Statements *)
(*****************************************************************************)

(* loop is true if the stmt is in a loop, else false *)
let rec stmt loop = function
  | Assign(l, p1, e, p2) -> 
      let cl = lval l in
      let ce = exp e in
      let tl = ctypeOfLval cl in
      let te = ctypeOfExp ce in
      if(not (typCmp tl te)) then (
	typeMismatch p2 te tl;
	raise EXIT
       );
      (match te with
      | TyRecord(name) -> 
	  E.error p2 "Deep copy of structs is not allowed";
	  raise EXIT
      | _ -> Assign(lval l, p1, exp e, p2)
      )
  | StmtExp(e,p) -> StmtExp(exp e, p)
  | Return(Some(e),p) -> 
      let ce = exp e in
      let te = ctypeOfExp ce in
      if(not (typCmp te !returnType)) then (
	typeMismatch p te !returnType;
	raise EXIT
       );
      returnSeen := true;
      Return(Some(ce),p)
  | Return(None, p) -> 
      if(not (!returnType = Void)) then (
	typeMismatch p Void !returnType;
	raise EXIT
       );
      returnSeen := true;
      Return(None, p)
  | Continue(p) -> 
      if(loop) then
	Continue(p)
      else (
	E.error p "Continue only allowed in loops";
	raise EXIT
       )
  | Break(p) -> 
      if(loop) then
	Break(p)
      else (
	E.error p "Break only allowed in loops";
	raise EXIT
       )
  | Blank -> Blank
(* Control Statements *)
  | If(e, p1, li, p2) -> 
      let tmp = !returnSeen in (* Return in if-block does not count *)
      let ce = exp e in
      let te = ctypeOfExp ce in
      if(not (typCmp te TyBool)) then (
	typeMismatch p1 te TyBool;
	raise EXIT
       );
      let cli = List.map (fun s -> stmt loop s) li in
      returnSeen := tmp;
      If(ce, p1, cli, p2)
  | IfElse(e, p1, li1, li2, p2) -> 
      let tmp = !returnSeen in 
      let ce = exp e in
      let te = ctypeOfExp ce in
      if(not (typCmp te TyBool)) then (
	typeMismatch p1 te TyBool;
	raise EXIT
       );
      let cli1 = List.map (fun s -> stmt loop s) li1 in
      let ret1 = !returnSeen in
      returnSeen := false;
      let cli2 = List.map (fun s -> stmt loop s) li2 in
      returnSeen := (tmp || (!returnSeen && ret1));
      IfElse(ce, p1, cli1, cli2, p2)
  | For(s1, p1, e, p2, s2, p3, li) -> 
      let tmp = !returnSeen in (* Return in loop does not count *)
      let ce = exp e in
      let te = ctypeOfExp ce in
      if(not (typCmp te TyBool)) then (
	typeMismatch p1 te TyBool;
	raise EXIT
       );
      let cs1 = stmt loop s1 in
      let cs2 = stmt true s2 in
      let cli = List.map (fun s -> stmt true s) li in
      returnSeen := tmp;
      For(cs1, p1, ce, p2, cs2, p3, cli)
  | While(e, p, li) -> 
      let tmp = !returnSeen in (* Return in loop does not count *)
      let ce = exp e in
      let te = ctypeOfExp ce in
      if(not (typCmp te TyBool)) then (
	typeMismatch p te TyBool;
	raise EXIT
       );
      let cli = List.map (fun s -> stmt true s) li in      
      returnSeen := tmp;
      While(ce, p, cli)


(*****************************************************************************)
(* Type-check Functions *)
(*****************************************************************************)

let validate d =
  match d with
  | StructDecl(_) -> d
  | VarDecl(_) -> d
  | FunDecl(_, _, _, _, Foreign(_)) -> d
  | FunDecl(Ident(fName, p1), tr, p2, paramlist, Body(dli, sli)) -> 
      HA.clear gamma;
      (* Enter new scope for escaping analysis *)
      TP.checkNewFun fName; 
      returnSeen := false;
      returnType := typ2ctype tr;
      (* Add parameters to gamma *)
      List.iter (fun (i,t,p) -> addDecl (VarDecl(i,t,p))) paramlist;
      (* Add variables to gamma *)
      List.iter (fun d -> addDecl d) dli;
      (* Check the statements of the body *)
      let csli = List.map (fun s -> stmt false s) sli in
      (* Store escaping analysis for this fun *)
      TP.checkedFun fName; 
      (match (tr,!returnSeen) with
      | (VOID,_) 
      | (_,true) -> FunDecl(Ident(fName,p1), tr, p2, paramlist, Body(dli, csli))
      | _ -> 
	  E.error p1 ("Probably not all paths of execution " ^
		      "contain a return statement"); 
	  raise EXIT
      )    


(*****************************************************************************)
(* Type-check a Program *)
(*****************************************************************************)

(* Type-check a program *)
let check (Program(li)) = 
  HA.clear delta;
  HA.clear sigma;
  List.iter (fun d -> addDecl d) li;
  Program(List.map (fun d -> validate d) li)
