(* Type checking MinML *)
(* Operates on de Bruijn form of expressions *)

signature TYPING =
sig

  exception Error of string

  (* typeOf (e) = t for the t such that |- e : t *)
  (* raises Error if no such t exists *)
  val typeOf : DBMinML.exp -> T.typ

  (* typeOpt (e) = SOME(t) for the t such that |- e : t *)
  (* typeOpt (e) = NONE if no such t exists *)
  val typeOpt : DBMinML.exp -> T.typ option

  val typeModule : DBMinML.module -> unit   (* raises Error if ill-typed *)
  val typeModuleBool : DBMinML.module -> bool

end;  (* signature TYPING *)

structure Typing :> TYPING =
struct

  open T
  open P
  open DBMinML

  exception Error of string

  (* context for typing *)
  infix 1 ++
  fun (ctx ++ x') = x'::ctx
  (* lookup : var -> typ list -> typ *)
  fun lookup x (t::ctx) = if x = 1 then t else lookup (x - 1) ctx
    | lookup x nil = raise Error "Unbound variable"

  (* type checking *)
  (* typeOf' : Bindings.bindings -> typ list -> exp -> typ *)
  fun typeOf' bind _ (Int _) = INT
    | typeOf' bind _ (Bool _) = BOOL
    | typeOf' bind ctx (If (e, e1, e2)) =
      let val t = typeOf' bind ctx e
      in case t 
          of BOOL => let val t1 = typeOf' bind ctx e1
                         val t2 = typeOf' bind ctx e2 
                     in
                       if t1 = t2 then t1 else 
                       raise Error "IF branches differ in type"
                     end
           | _ => raise Error "Expected BOOL for IF conditional"
      end
    | typeOf' bind ctx (Primop (primop, elist)) =
      let val (domain, codomain) = typeOfPrimop primop 
          fun check_type (t::tlist) (e::elist) = 
              if t = typeOf' bind ctx e then 
                check_type tlist elist 
              else 
                raise Error "Unexpected type of primop argument"
            | check_type nil nil = codomain
            | check_type _ _ = raise Error "Impossible"
      in
        check_type domain elist
      end
    | typeOf' bind ctx (Fn (t1, (x, e))) =
      let val t2 = typeOf' bind (ctx++t1) e
      in 
        T.ARROW(t1, t2)
      end
    | typeOf' bind ctx (Rec (t, (x, e))) =
      let val t' = typeOf' bind (ctx++t) e
      in 
        if t = t' then t else raise Error "Unexpected type for REC"
      end
    | typeOf' bind ctx (Let (e1, (x, e2))) =
      let val t1 = typeOf' bind ctx e1 
      in 
        typeOf' bind (ctx++t1) e2
      end
    | typeOf' bind ctx (Apply (e1, e2)) =
      let val t1 = typeOf' bind ctx e1
      in case t1
          of ARROW (t11, t12)
             => let val t2 = typeOf' bind ctx e2
                in
                  if t11 = t2 then t12 else 
                  raise Error "Actual type does not match formal type"
                end
           | _ => raise Error "Expected to apply value of ARROW type"
      end 
    | typeOf' bind ctx (Var x) = lookup x ctx
    | typeOf' bind ctx (Object (c, args)) = 
      let
	  fun checkField' (xtra, []) (lab, t) = raise Error ("Missing field " ^ lab)
	    | checkField' (xtra, (lab',e)::tl) (lab, t) = 
	      if lab = lab' 
	      then if t = typeOf' bind ctx e
		   then xtra @ tl 
		   else raise Error ("Type mismatch in field " ^ lab)
	      else checkField' ((lab',e)::xtra, tl) (lab,t)
	  fun checkField args (lab,t) = checkField' ([],args) (lab,t) (* removes the arg if type-correct *)
	  fun checkObject NONE [] ([]:(label * T.typ) list) = ()
	    | checkObject NONE ((l,_)::_) [] = raise Error ("Superfluous fields "^l^",...")
	    | checkObject (SOME c) args [] = 
	      let val (_, superclass, repn) = Module.lookupClassIn bind c
	      in checkObject superclass args repn
	      end
	    | checkObject sc args (field::tl) = checkObject sc (checkField args field) tl

	  val (abstract, superclass, repn) = Module.lookupClassIn bind c
	  val _ = if abstract 
		  then raise Error "Instantiation of abstract classes not allowed"
		  else ()
	  val _ = checkObject superclass args repn
      in
	  T.CLASS c
      end
    | typeOf' bind ctx (Proj (f, e)) = 
      let
	  val t = typeOf' bind ctx e
	  val class = case t of T.CLASS c => c 
			      | _ => raise Error ("Projecting a non-object")
      in
	  case Module.fieldTypeIn bind class f of
	      NONE => (print class; print f; raise Error "Projecting a non-existent field")
	    | SOME t => t
      end
    | typeOf' bind ctx (Call (m, args)) = 
      let
	  val (argst, ret) = case Module.methodTypeIn bind m of
				 SOME x => x
			       | NONE => raise Error "No such method"
	  fun checkArgs [] [] = ()
	    | checkArgs (arg::tl) (argt::tl') = 
	      (case typeOf' bind ctx arg of
		   T.CLASS c => if Module.isSubclassIn bind (c, argt) 
				then checkArgs tl tl' 
				else raise Error ("Method argument type mismatch: " ^ c ^ " not a subclass of " ^ argt)
		 | _ => raise Error "Method argument not an object")
	    | checkArgs _ _ = raise Error "Wrong number of method args"
	  val _ = checkArgs args argst
      in
	  ret
      end

  fun typeOfIn bind e = typeOf' bind nil e
  fun typeOf e = typeOfIn (Module.curBindings()) e
  fun typeOpt e = SOME (typeOf e) handle Error s => NONE

  fun typeModule m = let
      val bind = Module.imagine m
      fun insureValidClass om x =
	  if Module.isClassNamed x om 
	  then ()
	  else raise Error ("No such class "^x)
      fun typeModule' om [] = ()
	| typeModule' om (h::tl) =
	  ((case h of
		Class (name, abs, NONE, repn) => ()
	      | Class (name, abs, SOME x, repn) => insureValidClass om x
	      | Method (name, args, ret) => app (insureValidClass om) args
	      | Extend (m, args, impl) => 
		let 
		    val (targs, ret) = case Module.methodSpec m om of
					   NONE => raise Error ("No such method "^m)
					 | SOME ret => ret
		in
		    app (insureValidClass om) (map #2 args); 
		    if Module.listSubclassIn bind (map #2 args, targs) then () 
		    else raise Error "Type mismatch in args of method extention.";
		    if typeOf' bind (rev (map (T.CLASS o #2) args)) impl = ret then () 
		    else raise Error "Type mismatch in body of method extention."
		end);
	   typeModule' (h::om) tl)
  in
      typeModule' [] m;
      Module.check m
  end

  fun typeModuleBool m = ((typeModule m; true) handle Error _ => false)

end;  (* structure Typing *)
