(* typecheck.sml
   Implements the type inference routines. *)


signature TYPE_CHECKER =
    sig
    exception InferError of string
    exception UnificationFailure
    exception OccursCheck
    val unify : CDSInternal.typeExp -> CDSInternal.typeExp -> CDSInternal.subst
    val genCurryType : CDSInternal.typeExp list -> CDSInternal.typeExp
    val isPoly : (CDSBasic.cell * CDSBasic.value list) list ->
		  (CDSBasic.cell * CDSBasic.value list) list -> bool
    val isPolyC : (CDSBasic.cell * CDSBasic.value list) list list -> 
		  (CDSBasic.cell * CDSBasic.value list) list -> bool
    val isOverloaded : (CDSBasic.cell * CDSBasic.value list) list ->
	               (CDSBasic.cell * CDSBasic.value list) list -> bool
    val isOverloadedC : (CDSBasic.cell * CDSBasic.value list) list list -> 
		        (CDSBasic.cell * CDSBasic.value list) list -> bool
    val genPolyType : (CDSBasic.cell * CDSBasic.value list) list -> 
	(CDSBasic.cell * CDSBasic.value list) list -> CDSInternal.typeExp
    val genPolyCType : (CDSBasic.cell * CDSBasic.value list) list list -> 
	(CDSBasic.cell * CDSBasic.value list) list -> CDSInternal.typeExp
    val separate : int -> (CDSBasic.cell * 'a) list -> (CDSBasic.cell * 'a) list list -> 
	(CDSBasic.cell * 'a) list list
    val genDummy : CDSInternal.typeExp -> (CDSInternal.typeExp * CDSInternal.typeExp) list ->
	(CDSInternal.typeExp * (CDSInternal.typeExp * CDSInternal.typeExp) list)
    val match : (CDSInternal.typeExp * CDSInternal.typeExp) list *
	        ((CDSBasic.cell * CDSBasic.value list) list * CDSInternal.typeExp) list *
	        (CDSInternal.typeExp * CDSInternal.typeExp) list *
		((CDSBasic.cell * CDSBasic.value list) list * CDSInternal.typeExp) list ->
		CDSInternal.subst
    val genMeet : CDSInternal.typeExp -> (CDSInternal.typeExp * CDSInternal.typeExp) list ->
	CDSInternal.typeExp
    val genEmbeddedMeet : CDSInternal.typeExp -> (CDSInternal.typeExp * CDSInternal.typeExp) list ->
	CDSInternal.typeExp
    val typeState : bool -> ((CDSBasic.cell * CDSBasic.value) list) -> 
	CDSInternal.typeExp
    val typeCVlist : bool -> (CDSBasic.cell * CDSBasic.value list) list -> CDSInternal.typeExp
    val inferType : CDSBasic.expr -> CDSInternal.TYPE
    end;


functor TypeCheckerFUN (structure Internal : INTERNAL
			structure Type : TYPE
			structure Subtype : SUBTYPE
			structure Printer : PRINTER) : TYPE_CHECKER =
  struct
  local open CDSBasic
             CDSInternal
	     CDSEnv
  in
  exception InferError of string


      (* *** UNIFICATION *** *)

  exception UnificationFailure
  exception OccursCheck

  fun unify (Dcds s1) (Dcds s2) =
        if s1=s2 then emptySub else raise UnificationFailure
    | unify (Alpha i1) (Alpha i2) = 
        if i1 = i2 then emptySub else newSub i1 (Alpha i2)
    | unify (Alpha i) t2 =
        if occursIn i t2 then raise OccursCheck else newSub i t2
    | unify t1 (Alpha i) =
        unify (Alpha i) t1
    | unify (Arrow(t1,t2)) (Arrow(u1,u2)) =
        let val sub1 = unify t1 u1
	    val (t2',u2') = (apply sub1 t2, apply sub1 u2)
	    val sub2 = unify t2' u2'
	in 
	    compose sub2 sub1
	end
    | unify (And tlist1) (And tlist2) =
        unifyList tlist1 tlist2
    | unify (Prod tlist1) (Prod tlist2) =
        unifyList tlist1 tlist2
    | unify _ _ = raise UnificationFailure

  and unifyList [] [] = emptySub
    | unifyList (t1::tlist1) (t2::tlist2) =
        let val sub1 = unify t1 t2
	    val newlist1 = map (apply sub1) tlist1
	    val newlist2 = map (apply sub1) tlist2
	    val sub2 = unifyList newlist1 newlist2
	in
	    compose sub2 sub1
	end
    | unifyList _ _ = raise UnificationFailure

      (* Performs unification for the application step. *)
      (* Since we can have /\ types, it works on 2 lists of types *)
      (* trying out all possibilities. *)
  fun unifyApp [] t2list result = result
    | unifyApp (t1::t1list) t2list result =
      let fun unifyAppList t1 [] = []
	    | unifyAppList t1 (t2::t2list) =
	      (let val result = Alpha(newVar())
		   val s2 = Alpha(newVar())
		   val S' = unify t1 (Arrow(s2,result))
		   val S'' = Subtype.subtypeV(t2,apply S' s2)
	       in (apply S'' (apply S' result))::(unifyAppList t1 t2list)
	       end
	   handle UnificationFailure => unifyAppList t1 t2list
		| Subtype.SubtypingFailure => unifyAppList t1 t2list)
      in unifyApp t1list t2list (result@(unifyAppList t1 t2list))
      end

  fun unifyComp [] t2list result = result
    | unifyComp (t1::t1list) t2list result =
      let fun unifyCompList t1 [] = []
	    | unifyCompList t1 (t2::t2list) =
	      (let val new1 = Alpha(newVar())
		   val new2 = Alpha(newVar())
		   val new3 = Alpha(newVar())
		   val new4 = Alpha(newVar())
		   val S1 = unify t2 (Arrow(new1,new2))
		   val S2 = unify t1 (Arrow(new3,new4))
		   val S = Subtype.subtypeV(apply S1 new2, apply S2 new3)
		   val t = apply (compose S (compose S2 S1)) (Arrow(new1,new4))
	       in t::(unifyCompList t1 t2list)
	       end 
	   handle UnificationFailure => unifyCompList t1 t2list
		| Subtype.SubtypingFailure => unifyCompList t1 t2list)
      in unifyComp t1list t2list (result@(unifyCompList t1 t2list))
      end

  fun unifyFix [] result = result
    | unifyFix (t::tlist) result =
      (let val new1 = Alpha(newVar())
	   val new2 = Alpha(newVar())
	   val S1 = unify t (Arrow(new1,new2))
	   val S = Subtype.subtypeV(apply S1 new2, apply S1 new1)
	   val t' = apply S (apply S1 new2)
       in unifyFix tlist (t'::result)
       end) handle UnificationFailure => unifyFix tlist result
                 | Subtype.SubtypingFailure => unifyFix tlist result

  fun unifyCurry [] result = result
    | unifyCurry (t::tlist) result =
      (let val new1 = Alpha(newVar())
	   val new2 = Alpha(newVar())
	   val new3 = Alpha(newVar())
	   val S1 = unify t (Arrow(Prod [new1, new2], new3))
	   val t' = apply S1 (Arrow(new1,Arrow(new2,new3)))
       in unifyCurry tlist (t'::result)
       end) handle UnificationFailure => unifyCurry tlist result

  fun unifyUncurry [] result = result
    | unifyUncurry (t::tlist) result =
      (let val new1 = Alpha(newVar())
	   val new2 = Alpha(newVar())
	   val new3 = Alpha(newVar())
	   val S1 = unify t (Arrow(new1,Arrow(new2,new3)))
	   val t' = apply S1 (Arrow(Prod [new1, new2], new3))
       in unifyUncurry tlist (t'::result)
       end) handle UnificationFailure => unifyUncurry tlist result

  fun unifyPair [] t2list result = result
    | unifyPair (t1::t1list) t2list result =
      let fun unifyPairList t1 [] = []
	    | unifyPairList t1 (t2::t2list) =
	      (let val new1 = Alpha(newVar())
		   val new2 = Alpha(newVar())
		   val new3 = Alpha(newVar())
		   val new4 = Alpha(newVar())
		   val S1 = unify t1 (Arrow(new1,new2))
		   val S2 = unify t2 (Arrow(new3,new4))
	       in 
		   let val S = Subtype.subtypeV(apply S1 new1, apply S2 new3)
		       val t = apply (compose S (compose S2 S1)) 
			       (Arrow(new3,Prod [new2,new4]))
		   in t::(unifyPairList t1 t2list)
		   end handle Subtype.SubtypingFailure =>
		       let val S = Subtype.subtypeV(apply S2 new3, 
						    apply S1 new1)
			   val t = apply (compose S (compose S2 S1)) 
			           (Arrow(new1,Prod [new2,new4]))
		       in t::(unifyPairList t1 t2list)
		       end
	       end) 
		   handle UnificationFailure => unifyPairList t1 t2list
			| Subtype.SubtypingFailure => unifyPairList t1 t2list
      in (unifyPair t1list t2list (result@(unifyPairList t1 t2list)))
      end


      (* *** CDS0 STATE MANIPULATION *** *)
      (* Will maintain information on input and output cell and value  *)
      (* usage in an algorithm in lists of the form [ (cell, values) ].*)
      (* The following functions update and access these lists. *)
 fun updateCell (c,v) [] = [(c, [v])]
    | updateCell (c,v) ((c',vlist)::result) = 
        if c = c' then (c',v::vlist)::result
	else (c',vlist)::(updateCell (c,v) result)

  fun updateNoVal c [] = [(c,[])]
    | updateNoVal c ((c',vlist)::result) = 
        if c = c' then (c',vlist)::result
	else (c',vlist)::(updateNoVal c result)

  fun update [] result = result
    | update ((c,v)::x) result =
        let val newResult = updateCell (c,v) result
	in update x newResult
	end

  fun findState (Cell_fun(x,c)) = (true, x)
(*    | findState (Cell_graft(c,t)) = findState c  *)
    | findState (Cell_with(c,b)) = findState c
    | findState c = (false, Expr_state [])

  fun findName (Cell_fun(x,c)) = c
    | findName (Cell_graft(c,s)) = Cell_graft(findName c,s)
    | findName (Cell_with(c,b)) = Cell_with(findName c,b)
    | findName c = c

  fun stripProductTags (Cell_graft(c,t)) =
      (case t of
	   Tag_str "1" => stripProductTags c
	 | Tag_arexpr(Arexpr_int 1) => stripProductTags c
	 | Tag_str "2" => stripProductTags c
	 | Tag_arexpr(Arexpr_int 2) => stripProductTags c
	 | _ => Cell_graft(c,t))
    | stripProductTags c = c

      (* Given a state, it extracts a (cell,value list) list *)
      (* for the input cells and values occuring in it.      *)
  fun getInputCV [] result = result
    | getInputCV ((c,v)::x) result = 
        let val (found, x') = findState c
	    val state = case x' of
		(Expr_state st) => st
	      | _ => raise InferError "functional cell not a state"
	in if found 
	       then let val result' = update state result
			val newResult = case v of
			    (Val_valof cin) => updateNoVal cin result'
			  | _ => result'
		    in getInputCV x newResult
		    end
	   else getInputCV x result
	end

      (* Given a state, it extracts a (cell,value list) list    *)
      (* for the output cells and values occuring in it.        *)
      (* Uses the order information to figure out if it has a   *)
      (* functional state or not (i.e. look at "output" or not).*)
  fun getOutputCV order [] result = result
    | getOutputCV order ((c,v)::x) result = 
        if order = 1 then getOutputCV order x (updateCell (c,v) result)
	else let val cname = findName c
	     in case v of
		 (Val_output v') => 
		     let val newResult = updateCell (cname, v') result
		     in getOutputCV order x newResult
		     end
	       | _ => getOutputCV order x (updateNoVal cname result)
	     end

      (* Same as getInputCV, except it works on a cvlist instead *)
      (* of a state.  Called when we have a higher-order state.  *)
  fun getInCVfromCVList [] result = result
    | getInCVfromCVList ((c,vlist)::rest) result =
        let val (found, x') = findState c
	    val state = case x' of
		(Expr_state st) => st
	      | _ => raise InferError "functional cell not a state"
	in if found 
	       then let val result' = update state result
			fun updateValList [] answer = answer
			  | updateValList (v::vlist) answer =
			    case v of
				(Val_valof cin) => updateValList vlist 
				    (updateNoVal cin answer)
			      | _ => updateValList vlist answer
			val newResult = updateValList vlist result'
		    in getInCVfromCVList rest newResult
		    end
	   else getInCVfromCVList rest result
	end

      (* Like getOutputCV except it works on a cvlist and we don't *)
      (* care about order anymore-we know it's a functional state. *)
  fun getOutCVfromCVList [] result = result
    | getOutCVfromCVList ((c,vlist)::rest) result =
        let val cname = findName c
	    fun updateValList [] answer = answer
	      | updateValList (v::vlist) answer =
		case v of
		    (Val_output v') => updateValList vlist 
			(updateCell (cname,v') answer)
		  | _ => updateValList vlist (updateNoVal cname answer)
	in getOutCVfromCVList rest (updateValList vlist result)
	end


      (* *** TYPE UTILITY FUNCTIONS *** *)

      (* For debugging. *)
  fun printCVlist [] = []
    | printCVlist ((c,vlist)::CV) = ("("^(Printer.unparseCell c)^
		   ", ["^(Printer.unparseValList vlist)^"])")::(printCVlist CV)

  fun printCVTlist [] = "}\n"
    | printCVTlist ((CVlist, t)::CVT) =
      let val s = "{["^(implode(printCVlist CVlist))^"] : "
	  ^(Printer.printType t)^", "^(printCVTlist CVT)
      in s
      end

      (* Can a certain dcds (cva list form) be the type of *)
      (* a certain list of cell, value list pairs. *)
  fun getSingleMatch [] cvas = true
    | getSingleMatch ((c,vlist)::rest) cvas =
      (* Check if a cv element occurs in a cva list. *)
      let fun checkCv (c1, v1, []) = false
	    | checkCv (c1, v1, (c2,v2,a2)::cva2) =
	      if Type.includedCell(c1,c2)
		  then if (Type.includedValues(v1,v2)) then true
		       else checkCv(c1,v1,cva2)
	      else checkCv(c1,v1,cva2)
      in
	  if (checkCv(c,vlist,cvas)) 
	      then if rest = nil then true
		   else getSingleMatch rest cvas
	  else false
      end


      (* Given a list of cell, value list pairs, assembles matching dcds. *)
  fun getMatches cvlist =
        let val maxLength = fold max (map Type.cellLength (map (#1) cvlist)) 0
	    val maxDepth = fold max 
		(map Type.countDepth (map (#2) (!typeList))) 0
		(* Using maxLength + maxDepth, because if the recursive part of a *)
		(* dcds follows some plain part, it will not get to apply *)
		(* enough tags when we look at it with listDcds *)
	    val n = maxLength + maxDepth
	    val dcdsNames = map (#1) (!typeList)
	    val unrolled = map (Type.listDcds n) (map (#2) (!typeList))
	    val dcds = zip dcdsNames unrolled
	    fun getMatchList cvlist [] = []
	      | getMatchList cvlist ((name,d)::rest) =
		if (getSingleMatch cvlist d) 
		    then name::(getMatchList cvlist rest)
		else getMatchList cvlist rest
	in
	    getMatchList cvlist dcds
	end

  fun remove (e, []) = []
    | remove (e, e'::l) = 
      if e=e' then remove(e,l) else e'::(remove(e,l))

  fun removeList [] l2 = l2
    | removeList (e::l1) l2 = 
      removeList l1 (remove(e,l2))

      (* Given two names of dcds, find out if t1 <=e t2 *)
  fun isSubByExt (t1, t2) =
      (let val (_, t2Children) = Type.typeLookup(t2,!hierarchy)
	   val k = find(t1,t2Children)
       in if k=ext then true else false
       end) handle Find => false

     (* Given a list of dcds names, remove all subtypes by extension *)
     (* Do it recursively, until input no longer changes. *)
  fun removeExtSub [] = []
    | removeExtSub (t1::rest) =
      let fun toRemove (t1,[]) = false
	    | toRemove (t1,t2::l) = if isSubByExt(t1,t2) then true else toRemove(t1,l)
	  fun subByExtOf t1 [] = []
	    | subByExtOf t1 (t2::l) =
	      if isSubByExt(t1,t2) then t2::(subByExtOf t1 l) else subByExtOf t1 l
      in if (toRemove(t1,rest)) then (removeExtSub rest)
	 else let val dead = subByExtOf t1 rest
		  val remainder = removeList dead rest
	      in t1::(removeExtSub remainder)
	      end
      end

     (* Given a list of dcds names, discards all supertypes of elts in list. *)
  fun removeSupertypes [] = []
    | removeSupertypes (t1::rest) =
      let fun toRemove (t1,[]) = false
	    | toRemove (t1,t2::l) =
	      if (Subtype.subtype(t2,t1)) then true
	      else toRemove(t1,l)
	  fun supertypeOf t1 [] = []
	    | supertypeOf t1 (t2::l) =
	      if (Subtype.subtype(t1,t2)) then t2::(supertypeOf t1 l)
	      else supertypeOf t1 l
      in if (toRemove(t1,rest)) then (removeSupertypes rest)
	 else let val dead = supertypeOf t1 rest
		  val remainder = removeList dead rest
	      in t1::(removeSupertypes remainder)
	      end
      end

      (* Construct curry type between at least 2 types. *)
  fun genCurryType [T1, T2] = Arrow(T1, T2)
    | genCurryType (T::rest) = Arrow(T, genCurryType rest)

  fun quotient [] tag2 other = (tag2,other)
    | quotient ((c,vlist)::rest) tag2 other = 
      case c of
	  (Cell_graft(c',Tag_str "2")) => 
	      quotient rest ((c',vlist)::tag2) other
	| (Cell_graft(c',Tag_arexpr(Arexpr_int 2))) => 
	      quotient rest ((c',vlist)::tag2) other
	| (Cell_graft(c',Tag_str "1")) => 
	      quotient rest tag2 ((c',vlist)::other)
	| (Cell_graft(c',Tag_arexpr(Arexpr_int 1))) => 
	      quotient rest tag2 ((c',vlist)::other)
	| _ => raise InferError "separate: quotient: not right tags"

      (* Break down a list of cell, values list pairs by cell tag: *)
      (* first is .1.1....1 n times,, then .2.1....1, last is .2.  *)
  fun separate 1 cvlist result = 
        let val (tag2,other) = quotient cvlist [] []
	in other::(tag2::result)
	end
    | separate n cvlist result =
        let val (tag2,other) = quotient cvlist [] []
	in separate (n-1) other (tag2::result)
	end

      (* Given a list of types, constructs a typeExp. *)
  fun finalPolish t = 
        if (length t) > 1 then Meet (duplicates t)
	else if t = nil then raise InferError "term does not have a type"
	     else hd t

  fun meetRemove (Meet tlist) = tlist
    | meetRemove t = [t]

  fun makeArrowMeet [] _ = []
    | makeArrowMeet (t1::t1list) t2list =
      let fun makeOneList t1 [] = []
	    | makeOneList t1 (t2::t2list) = 
	      (Arrow(t1,t2))::(makeOneList t1 t2list)
      in (makeOneList t1 t2list)@(makeArrowMeet t1list t2list)
      end

  fun makeArrowCurryMeet [t1list, t2list] = makeArrowMeet t1list t2list
    | makeArrowCurryMeet (tlist::rest) = makeArrowMeet tlist (makeArrowCurryMeet rest)
      

      (* Get rid of dcds names that are on the refinement only list *)
  fun elimRefTypes [] = []
    | elimRefTypes (n::rest) = if member(n,!refineOnlyList) then elimRefTypes rest
			       else n::(elimRefTypes rest)


      (* Given a (cell,value list) list it will come up with a type.*)
      (* This can be higher-order. *)
  fun getType refInf cvList =
      let val _ = if !trace then output(std_out, "  getType: ["
					  ^(implode(printCVlist cvList))^"]\n")
		  else ()
      in
        (* if empty it can match anything *)
      if cvList = nil then Alpha(newVar())
	(* else find out if it's higher-order *)
      else let val firstCell = #1(hd cvList)
	       val (found, state) = findState firstCell
	   in if found then 
	       let val order = cellOrder firstCell
		   (* Uncurried higher-order cv list *)
	       in if order <= 2 then
		   let val inputCV = getInCVfromCVList cvList []
		       val outputCV = getOutCVfromCVList cvList []
		       val Tin = getType refInf inputCV
		       val Tout = getType refInf outputCV
		   in finalPolish (makeArrowMeet (meetRemove Tin) (meetRemove Tout))
		   end
		  (* Curried higher-order cv list *)
		  else let val uncurriedCVList = 
		      Internal.uncurryCVList(order-2,cvList)
			   val inputCV = getInCVfromCVList uncurriedCVList []
			   val outputCV = getOutCVfromCVList uncurriedCVList []
			   val inputCVlist = separate (order-2) inputCV []
			   val TinList = map (getType refInf) inputCVlist
			   val Tout = getType refInf outputCV
		       in finalPolish (makeArrowCurryMeet (map meetRemove (TinList @ [Tout])))
		       end
	       end
	      (* First-order cv list.  Check for matches with dcds's. *)
	      else let val names = getMatches cvList
		           (* Get rid of refinement types if not refInf *)
		       val names' = if not refInf then elimRefTypes names else names
		           (* Get rid of subtypes by extension *)
		       val names'' = removeExtSub names'
			   (* Now get rid of all supertypes *)
		       val types = map Dcds (removeSupertypes names'')
		         (* Also check if it's a product type *)
		   in (let val dots = separate 1 cvList []
			   val dot1 = if null dots orelse (length dots) <> 2
					  then raise InferError 
					   "separate: quotient: not right tags"
				      else hd dots
			   val dot2 = hd(tl dots)
			   val dot1Types = getType refInf dot1
			   val dot2Types = getType refInf dot2
			   val prodType = [Prod [dot1Types, dot2Types]]
		       in finalPolish(types @ prodType)
		       end) handle InferError _ => finalPolish types
		   end
	   end
      end



      (* *** POLYMORPHISM *** *)

      (* generic_ functions verify if _ is just a variable, or *)
      (* combination of variables.  This does not include something *)
      (* like $C.foo, or ($V,$W), which do not match everything. *)
  fun genericCell (Cell_name _) = false
    | genericCell (Cell_var _) = true
    | genericCell (Cell_fun(e,c)) = 
        (case e of
	     (Expr_state x) => if (genericState x) andalso (genericCell c) 
				  then true else false
	   | _ => false)
    | genericCell (Cell_graft(c,t)) =
        if genericCell c
	    then case t of
		(Tag_str "1") => true
	      | (Tag_str "2") => true
	      | (Tag_arexpr(Arexpr_int 1)) => true
	      | (Tag_arexpr(Arexpr_int 2)) => true
	      | _ => false
	else false
    | genericCell (Cell_with(c,bexp)) = false

  and genericVal (Val_arexpr(Arexpr_var _)) = true
    | genericVal (Val_output v) = genericVal v
    | genericVal (Val_valof c) = genericCell c
    | genericVal _ = false

  and genericState [] = true
    | genericState ((c,v)::x) = 
        if (genericCell c andalso genericVal v) then genericState x
	else false


      (* overloaded_ functions will pick up the $C.foo and ($V,$W) *)
      (* rejected by generic_.  The idea is that something is overloaded *)
      (* if it can match different unrelated cds's. *)
  fun overloadedCell (Cell_var _) = true
    | overloadedCell (Cell_fun(e,c)) = 
        (case e of
	     (Expr_state x) => if (overloadedState x) orelse 
		 (overloadedCell c) then true else false
	   | _ => false)
    | overloadedCell (Cell_graft(c,t)) = 
        (overloadedCell c) orelse (overloadedTag t)
    | overloadedCell (Cell_with(c,bexp)) = overloadedCell c
    | overloadedCell _ = false

  and overloadedVal (Val_arexpr a) = overloadedArexpr a
    | overloadedVal (Val_output v) = overloadedVal v
    | overloadedVal (Val_valof c) = overloadedCell c
    | overloadedVal (Val_pair(v1,v2)) = 
        (overloadedVal v1) orelse (overloadedVal v2)
    | overloadedVal (Val_with(_,_)) = true
    | overloadedVal _ = false

  and overloadedTag (Tag_arexpr a) = overloadedArexpr a
    | overloadedTag _ = false

  and overloadedArexpr (Arexpr_var _) = true
    | overloadedArexpr (Arexpr_minus a) = overloadedArexpr a
    | overloadedArexpr (Arexpr_plus(a1,a2)) = 
        (overloadedArexpr a1) orelse (overloadedArexpr a2)
    | overloadedArexpr (Arexpr_sub(a1,a2)) = 
        (overloadedArexpr a1) orelse (overloadedArexpr a2)
    | overloadedArexpr (Arexpr_mult(a1,a2)) = 
        (overloadedArexpr a1) orelse (overloadedArexpr a2)
    | overloadedArexpr (Arexpr_div(a1,a2)) = 
        (overloadedArexpr a1) orelse (overloadedArexpr a2)
    | overloadedArexpr _ = false

  and overloadedState [] = true
    | overloadedState ((c,v)::x) = 
        if (overloadedCell c orelse overloadedVal v) then overloadedState x
	else false


      (* Determine if a CV contains only generic references. *)
  fun generic [] = true
    | generic ((c,vlist)::cvlist) =
        if ((genericCell c) andalso (forAll genericVal vlist))
	    then generic cvlist
	else false

      (* Used for curried inputCVs, represented as a list of CVs. *)
  fun checkGeneric _ [] = []
    | checkGeneric refInf (inCV::rest) = 
        if generic inCV then Alpha(newVar())::(checkGeneric refInf rest)
	else (getType refInf inCV)::(checkGeneric refInf rest)

      (* Determine if a certain state (given as 2 cv lists) is partially *)
      (* polymorphic. *)
  fun isPoly inCV outCV = 
      (generic inCV) andalso (generic outCV)

      (* Same as above but for curried inputCVs. *)
  fun isPolyC inCVlist outCV =
      (forAll generic inCVlist) andalso (generic outCV)


      (* Determine if a CV contains some overloaded references. *)
  fun overloadedCV [] = false
    | overloadedCV ((c,vlist)::cvlist) =
      if (overloadedCell c) orelse (forAll overloadedVal vlist)
	  then true else overloadedCV cvlist

      (* Determine if a certain state (given as 2 cv lists) is overloaded *)
  fun isOverloaded inCV outCV = 
      (overloadedCV inCV) orelse (overloadedCV outCV)

      (* Same as above but for curried inputCVs. *)
  fun isOverloadedC inCVlist outCV =
      (exists overloadedCV inCVlist) orelse (overloadedCV outCV)


      (* Given a list of type pairs, constructs a substitution to unify them.*)
  fun generateSub [] = emptySub
    | generateSub ((t1,t2)::ts) = 
      let val s = unify t1 t2
      in compose s (generateSub ts)
      end

      (* Given a list of CV lists paired with a type, it constructs *)
      (* a list of individual CVs paired with their respective types.*)
  fun extractCVT [] = []
    | extractCVT ((cvlist,t)::cvtlist) =
      let fun pair [] t = []
	    | pair ((c,vs)::cvs) t = (c,vs,t)::(pair cvs t)
      in (pair cvlist t)@(extractCVT cvtlist)
      end

      (* Does v1 occur somewhere in v2? *)
  fun usedValue v1 v2 = 
      let val _ = if !trace then
	            let val v1str = Printer.unparseVal v1
			val v2str = Printer.unparseVal v2
		    in output(std_out, "  usedValue: v1 = "^v1str^
			      ",  v2 = "^v2str^"\n")
		    end
		  else ()
	  fun findInArexpr v a =
	  case a of
	      (Arexpr_var s) => (v = Val_arexpr(Arexpr_var s))
	    | (Arexpr_minus a') => findInArexpr v a'
	    | (Arexpr_plus(a1,a2)) => findInArexpr v a1 orelse 
		                      findInArexpr v a2
	    | (Arexpr_sub(a1,a2)) => findInArexpr v a1 orelse 
		                     findInArexpr v a2
	    | (Arexpr_mult(a1,a2)) => findInArexpr v a1 orelse 
		                      findInArexpr v a2
	    | (Arexpr_div(a1,a2)) => findInArexpr v a1 orelse 
		                     findInArexpr v a2
	    | _ => false
      in if v1 = v2 then true
	 else case v1 of
	  (Val_with(s,b)) => usedValue (Val_arexpr(Arexpr_var s)) v2
	| _ => case v2 of
	      (Val_output v) => usedValue v1 v
	    | (Val_arexpr a) => findInArexpr v1 a
	    | (Val_pair(v2',v2'')) => (usedValue v1 v2') orelse 
		                      (usedValue v1 v2'')
	    | (Val_with(s,b)) => usedValue v1 (Val_arexpr(Arexpr_var s))
	    | _ => false
      end

      (* Check that all values in v1list match those in v2list. *)
  fun used (x,[]) = false
    | used (x,y::l) = (usedValue x y) orelse used(x,l)
  fun usedInList ([],_) = true
    | usedInList (v1::v1list,v2list) =
      if used(v1,v2list) then usedInList(v1list,v2list)
      else false

      (* Given 2 CVs*type lists, finds matches between the input cells *)
      (* and output cells and values and gives a substitution.  For a  *)
      (* match, the cells must match (i.e., equal up to extra product  *)
      (* tags), and every input value must be used somewhere in the    *)
      (* output.  Matches between input CVs are not calculated expli-  *)
      (* citly--this should be taken care of by both matching the output.*)    
  fun polyMatches (inCVType, outCVType) =
      let val _ = if !trace 
		      then output(std_out, "  polyMatches:\n    in = "^
				  (printCVTlist inCVType))
		  else ()
	  val _ = if !trace
		    then output(std_out, "    out = "^(printCVTlist outCVType))
		  else ()
	  val outCVTList = extractCVT outCVType
	  val inCVTList = extractCVT inCVType
	  (* Generate a list of type pairs which should be unified. *)
	  fun compare CVT [] = []
	    | compare (c,vlist,t) ((c1,v1list,t1)::CVTlist) =
	      if stripProductTags c = stripProductTags c1 andalso 
		  usedInList(vlist,v1list)
		  then (t,t1)::(compare (c,vlist,t) CVTlist)
	      else compare (c,vlist,t) CVTlist
	  fun compareList [] _ = []
	    | compareList ((c,vlist,t)::CVT1) CVT2 = 
	      (compare (c,vlist,t) CVT2)@(compareList CVT1 CVT2)
	  val typePairs = compareList inCVTList outCVTList
      in generateSub typePairs
      end

      (* Generate matches between the input CVTs of a curried algo. *)
  fun polyCMatches [] = raise InferError "polyCMatches: empty list"
    | polyCMatches [CVT] = emptySub
    | polyCMatches (CVT1::CVT2::CVTs) = 
      let fun matchThrough CVT [] = emptySub
	    | matchThrough CVT (CVT1::CVTs) = 
	      compose (polyMatches(CVT,CVT1)) (matchThrough CVT CVTs)
      in compose (polyCMatches (CVT2::CVTs)) (matchThrough CVT1 (CVT2::CVTs))
      end

      (* Clumsy function that breaks down a poly type to *)
      (* Alphas and packages them with respective cvLists *)
      (* for matching purposes. *)
  fun reconstructCVT (Alpha i, cvList) = [(cvList, Alpha i)]
    | reconstructCVT (Dcds s, cvList) = [(cvList, Dcds s)]
    | reconstructCVT (Prod [t1, t2], cvList) =
        let val dots = separate 1 cvList []
	    val dot1 = if null dots orelse (length dots) <> 2
			   then raise InferError 
			       "reconstructCVT: not a product"
		       else hd dots
	    val dot2 = hd(tl dots)
	in (reconstructCVT(t1,dot1)) @ (reconstructCVT(t2,dot2))
	end
    | reconstructCVT (Arrow(t1, t2), cvList) =
        raise InferError "reconstructCVT: arrow type"
    | reconstructCVT (And tlist, cvList) = [(cvList, And tlist)]
    | reconstructCVT (Meet tlist, cvList) = [(cvList, Meet tlist)]
    | reconstructCVT (t, _) = 
        raise InferError ("reconstructCVT: unexpected type = "^
	    (Printer.printType t))

      (* Takes a curried type and cv lists and pairs lists with *)
      (* corresponding type. *)
  fun breakType (t, [], outCV) = reconstructCVT(t,outCV)
    | breakType (Arrow(t1,t2), inCV::rest, outCV) =
      (reconstructCVT(t1,inCV)) @ (breakType(t2,rest,outCV))
    | breakType (_, _, _) = raise InferError "breakType: not arrow type"


      (* Generate a polymorphic type. The output cell vars must   *)
      (* "match up" with the input cell vars, and the input value *)
      (* vars with the output ones. More detail: the output cells *)
      (* will be bound first to something during a query; that's  *)
      (* why we'll check for dangling cell variables in the input.*)
      (* For each cell in inCV check that it appers in outCV +/-  *)
      (* a 1 or 2 tag.  For values it's the other way around. We  *)
      (* assume both inCV and outCV are generic. *)
  fun genPolyType inCV outCV = 
      let val (Tin,inCVType) = genPolySingleType inCV
	  val (Tout,outCVType) = genPolySingleType outCV
	  val S = polyMatches(inCVType,outCVType)
      in apply S (Arrow(Tin,Tout))
      end

  and genPolyCType inCVlist outCV = 
      let val (Tout,outCVType) = genPolySingleType outCV
	  val TinCVTypeList = map genPolySingleType inCVlist
	  val TinList = map (#1) TinCVTypeList
	  val inCVTypes = map (#2) TinCVTypeList
	  val inCVTypeList = flatten inCVTypes
	  val Sin = polyCMatches inCVTypes
	  val S = compose (polyMatches(inCVTypeList,outCVType)) Sin
	  fun foldArrow [] t2 = raise InferError "genPolyCType: empty input"
	    | foldArrow [t1] t2 = Arrow(t1,t2)
	    | foldArrow (t1::ts) t2 = Arrow(t1,foldArrow ts t2)
	  val t = foldArrow TinList Tout
      in apply S t
      end

      (* Generate a poly type for a (cell * value list) list. *)
      (* Check if higher order.  Return a pair consisting of  *)
      (* a type and cell * values * type list to keep track of matches. *)
  and genPolySingleType [] = (Alpha(newVar()),[])
    | genPolySingleType cvList =
      let 
	  val _ = if !trace then output(std_out, "  genPolySingleType: ["
					  ^(implode(printCVlist cvList))^"]\n")
		  else ()
	  val firstCell = #1(hd cvList)
	  val (found, state) = findState firstCell
      in if found then  (* higher-order poly state *)
	  let val order = cellOrder firstCell
	  in if order <= 2 then (* uncurried higher-order *)
	      let val inputCV = getInCVfromCVList cvList []
		  val outputCV = getOutCVfromCVList cvList []
		  val t = genPolyType inputCV outputCV
	      in case t of
		  (Arrow(t1,t2)) => (t,(reconstructCVT(t1,inputCV)) @
				     (reconstructCVT(t2,outputCV)))
		| _ => raise InferError "genPolySingleType: non arrow type"
	      end
	     else (* curried higher-order *)
		 let val uncurriedCVList = 
		     Internal.uncurryCVList(order-2,cvList)
		     val inputCV = getInCVfromCVList uncurriedCVList []
		     val outputCV = getOutCVfromCVList uncurriedCVList []
		     val inputCVList = separate (order-2) inputCV []
		     val t = genPolyCType inputCVList outputCV
		 in (t,breakType(t,inputCVList,outputCV))
		 end
	  end
	 else (* first-order poly state *)
	     (let val dots = separate 1 cvList []
		  val dot1 = if null dots orelse (length dots) <> 2
				 then raise InferError 
				     "separate: quotient: not right tags"
			     else hd dots
		  val dot2 = hd(tl dots)
		  val (dot1Type,dot1CVT) = genPolySingleType dot1
		  val (dot2Type,dot2CVT) = genPolySingleType dot2
	        (* Only save the component cvLists * types for matching *)
	      in (Prod [dot1Type,dot2Type], dot1CVT @ dot2CVT)
	      end) handle InferError "separate: quotient: not right tags" =>
		  let val t = Alpha(newVar())
		  in (t,[(cvList,t)])
		  end
      end


      (* Given a type, generates a similar one involving just *)
      (* variables, and keeps track of changes.  E.g., *)
      (* 'a -> int ===> ('a -> 'b, [(int,b')]) *)
  fun genDummy (Alpha i) l = (Alpha i, l)
    | genDummy (Arrow(t1,t2)) l =
      let val (t1',l') = genDummy t1 l
	  val (t2',l'') = genDummy t2 l'
      in (Arrow(t1',t2'), l'')
      end
    | genDummy (Prod tlist) l =
      let val (newtlist,l') = genDummyList tlist ([],l)
      in (Prod (rev newtlist), l')
      end
    | genDummy t l = (let val t' = find(t,l)
			     in (t',l)
			     end handle Find => 
				 let val newt = Alpha(newVar())
				 in (newt, (t,newt)::l)
				 end)

   and genDummyList [] (ts,l) = (ts,l)
     | genDummyList (t::tlist) (ts,l) =
       let val (t',l') = genDummy t l
       in genDummyList tlist (t'::ts,l')
       end

      (* Each cell in inCV must occur in outCV in a compatible form. *)
      (* This only applies to variable cell names.  For constants,   *)
      (* only check if values match. *)
  fun sameCV [] _ = false
    | sameCV inCV outCV =
      let fun sameCVrec [] _ = true
	    | sameCVrec ((c,vlist)::inCV) outCV =
	      if (genericCell c) orelse (overloadedCell c) 
		  then if search(c, outCV) 
			   then (sameCVrec inCV outCV) else false
	      else 
	  (* We have a regular cell name, but values could be matching vars *)
		  if (not (null vlist)) 
		      then (let val outVlist = find(c,outCV)
			    in usedInList(vlist,outVlist)
			    end) handle Find => false
		  else false
      in sameCVrec inCV outCV
      end

      (* Return a substitution (in terms of the dummy types) *)
      (* to account for matches. *)
  fun match (matchIn, inCVType, matchOut, outCVType) =
              (* if neither t1 nor t2 is an alpha then ignore *)
              (* else make a binding from the alpha to the non-alpha *)
              (* else if both alphas, from in to out *)
      let fun makeBinding (Alpha i1) t2 = [(Alpha i1, t2)]
	    | makeBinding t1 (Alpha i2) = [(Alpha i2, t1)]
	    | makeBinding _ _ = []
	  fun doCVlt _ [] = []
	    | doCVlt (inCVlist,t1) ((outCVlist,t2)::rest) =
	      if (sameCV inCVlist outCVlist)
		  then (makeBinding t1 t2)@(doCVlt (inCVlist,t1) rest)
	      else doCVlt (inCVlist,t1) rest
	  fun doCV [] _ = []
	    | doCV ((cvList,t)::inCVlist) outCVlist =
	      (doCVlt (cvList,t) outCVlist)@(doCV inCVlist outCVlist)
	  fun changeCVT matches [] = []
	    | changeCVT matches ((cvlist,t)::rest) =
	    (cvlist,find(t,matches) handle Find => t)::(changeCVT matches rest)
	  val typePairs = doCV (changeCVT matchIn inCVType)
	                       (changeCVT matchOut outCVType)
		      (* this will not work with products *)
      in generateSub typePairs
      end
      
      (* Instantiate "skeleton" type with items from matchList and /\ them *)
  fun genMeet t matchList =
      let fun genMeet' [] tlist = 
	              if (length tlist = 1) then hd tlist 
		      else Meet tlist
	    | genMeet' ((t1,t2)::rest) tlist = 
	      let fun makeSubs t1 [] = []
		    | makeSubs t1 (t2::ts) = (unify t1 t2)::(makeSubs t1 ts)
		  fun applyAll [] tlist = []
		    | applyAll (s::rest) tlist = 
		      (map (apply s) tlist) @ (applyAll rest tlist)
	      in case t2 of
		    (* Check if t2 occurs in tlist. Need only check for *)
		    (* one type in there; if it's in one, it's in all *)
		  (Alpha i) => 
		      let val firstType = hd tlist
			  handle Hd => raise InferError "genMeet: empty tlist"
		      in if (occursIn i firstType)
			     then case t1 of
				 (Meet t1list) => 
				     let val subList = makeSubs t2 t1list
					 val new = applyAll subList tlist
				     in genMeet' rest new
				     end
			       | _ => let val new = map (apply 
							 (unify t1 t2)) tlist
				      in genMeet' rest new
				      end
			 else genMeet' rest tlist
		      end
		| _ => raise InferError "genMeet: t2 not an alpha"
	      end
      in genMeet' matchList [t]
      end

	(* Same as above, but don't distribute the /\ *)
    fun genEmbeddedMeet t matchList =
	let fun genMeet' [] tlist = if (length tlist = 1) then hd tlist else Meet tlist
	      | genMeet' ((t1,t2)::rest) tlist = 
		case t2 of
		    (* Check if t2 occurs in tlist. Need only check for *)
		    (* one type in there; if it's in one, it's in all *)
		    (Alpha i) => 
			let val firstType = hd tlist
			    handle Hd => raise InferError "genMeet: empty tlist"
			in if (occursIn i firstType)
			       then let val new = map (apply (unify t1 t2)) tlist
				    in genMeet' rest new
				    end
			   else genMeet' rest tlist
			end
		  | _ => raise InferError "genMeet: t2 not an alpha"
	in genMeet' matchList [t]
	end

  fun printMatches [] = ""
    | printMatches ((t,t')::ts) = (Printer.printType t)^" -> "^
      (Printer.printType t')^", "^(printMatches ts)

      (* First get a 'dummy' poly type, so we can then figure out *)
      (* matches.  Apply substitutions to get a 'skeleton' type. *)
      (* Figure out matches with dcds's.  For each alpha, for each *)
      (* matching type, instantiate the skeleton.  /\ together all *)
      (* filled out skeletons. *)
  fun genOverloadedType refInf inCV outCV = 
      let val (Tin,inCVType) = genOverloadedSingleType refInf inCV
	  val (Tout,outCVType) = genOverloadedSingleType refInf outCV
	  val _ = if !trace 
		     then output(std_out, "  genOverloadedType:\n    inCVT = "^
				  (printCVTlist inCVType))
		  else ()
	  val _ = if !trace
		    then output(std_out, "    outCVT = "
				^(printCVTlist outCVType))
		  else ()
	  val (dummyIn, matchIn) = genDummy Tin []
	  val (dummyOut, matchOut) = genDummy Tout []
	  val S = match(matchIn, inCVType, matchOut, outCVType)
	  val skeleton = apply S (Arrow(dummyIn, dummyOut))
	  val _ = if !trace then output(std_out, "  genOverloadedType:"^
		     " skeleton = "^(Printer.printType skeleton)^"\n") else ()
	  val matchList = matchIn @ matchOut
	  val _ = if !trace then output(std_out, "  genOverloadedType:"^
		     " matches = "^(printMatches matchList)^"\n") else ()
      in genMeet skeleton matchList
      end

  and genOverloadedCType refInf inCVlist outCV = 
      let val inCVTlist = map (genOverloadedSingleType refInf) inCVlist
	  val TinList = map (#1) inCVTlist
	  val inCVTypes = map (#2) inCVTlist
	  val (Tout,outCVType) = genOverloadedSingleType refInf outCV
	  val inDummyMatchList = map (fn x => genDummy x []) TinList
	  val (dummyOut, matchOut) = genDummy Tout []
	  fun genSubst [] _ _ = emptySub
	    | genSubst ((matchIn, inCVType)::rest) matchOut outCVType =
	      let val S = match(matchIn, inCVType, matchOut, outCVType)
	      in compose S (genSubst rest matchOut outCVType)
	      end
	  val matchInList = map (#2) inDummyMatchList
	  val dummyInList = map (#1) inDummyMatchList
	  val S = genSubst (zip matchInList inCVTypes) matchOut outCVType
	  val skeleton = fold (fn (x,y) => Arrow(x,y)) dummyInList dummyOut
	  val matchList = fold op@ matchInList matchOut
      in genMeet (apply S skeleton) matchList
      end

  and genOverloadedSingleType _ [] = (Alpha(newVar()),[])
    | genOverloadedSingleType refInf cvList =
      let 
	  val _ = if !trace then output(std_out, "  genOverloadedSingleType: ["
					  ^(implode(printCVlist cvList))^"]\n")
		  else ()
	  val firstCell = #1(hd cvList)
	  val (found, state) = findState firstCell
      in if found then  (* higher-order poly state *)
	  let val order = cellOrder firstCell
	  in if order <= 2 then (* uncurried higher-order *)
	      let val inputCV = getInCVfromCVList cvList []
		  val outputCV = getOutCVfromCVList cvList []
		  val t = genOverloadedType refInf inputCV outputCV
	      in case t of
		  (Arrow(t1,t2)) => (t,(reconstructCVT(t1,inputCV)) @
				     (reconstructCVT(t2,outputCV)))
		| _ => raise InferError 
		      "genOverloadedSingleType: non arrow type"
	      end
	     else (* curried higher-order *)
		 let val uncurriedCVList = 
		     Internal.uncurryCVList(order-2,cvList)
		     val inputCV = getInCVfromCVList uncurriedCVList []
		     val outputCV = getOutCVfromCVList uncurriedCVList []
		     val inputCVList = separate (order-2) inputCV []
		     val t = genOverloadedCType refInf inputCVList outputCV
		 in (t,breakType(t,inputCVList,outputCV))
		 end
	  end
	 else (* first-order poly state *)
	     (let val dots = separate 1 cvList []
		  val dot1 = if null dots orelse (length dots) <> 2
				 then raise InferError 
				     "separate: quotient: not right tags"
			     else hd dots
		  val dot2 = hd(tl dots)
		  val (dot1Type,dot1CVT) = genOverloadedSingleType refInf dot1
		  val (dot2Type,dot2CVT) = genOverloadedSingleType refInf dot2
	        (* Only save the component cvLists * types for matching *)
	      in (Prod [dot1Type,dot2Type], dot1CVT @ dot2CVT)
	      end) handle InferError "separate: quotient: not right tags" =>
		  (if generic cvList 
		       then let val t = Alpha(newVar())
			    in (t,[(cvList,t)])
			    end
		   else if overloadedCV cvList
		       then let val t = getType refInf cvList
			    in (t, [(cvList, t)])
			    end
			else let val t = getType refInf cvList
			     in (t, [(cvList, t)])  
			     end)
      end


      (* Is a cvList an instance of a product of some arrow type, e.g.  *)
      (* (int -> int) * int.  We have to treat this separately, because *)
      (* even though it looks like a state of ground type it isn't, so  *)
      (* it might be overloaded. *)
  fun arrowProd cvList = 
      let val cellNames = map (fn (x:cell,y:value list) => x) cvList
	  val cellDegrees = map degree cellNames
	  val maximum = fold max cellDegrees 0
      in maximum > 1
      end


      (* Find global type for a state.  The refInf argument tells us if we *)
      (* are doing refinement type inference or regular type inference. *)
      (* This matters in 2 ways:  overloaded ground states are OK for *)
      (* refinement inference, and we need to eliminate refinement only *)
      (* types from regular type inference. *)
  fun typeState refInf x =
      if x = [] then Alpha (newVar())
      else let val order = cellOrder(#1(hd x))
	   (* Simple state, first-order algo, or higher-order? *)
	   in case order of
	       1 => let val outputCV = getOutputCV order x []
		    in if (null outputCV) then Alpha(newVar())
		       else if (overloadedCV outputCV) 
				then if refInf then #1(genOverloadedSingleType true outputCV)
				     else if (arrowProd outputCV) then 
					 #1(genOverloadedSingleType false outputCV)
					  else raise InferError "ground state cannot be overloaded"
			    else getType refInf outputCV
		    end
	     | 2 => let val inputCV = getInputCV x []
			val outputCV = getOutputCV order x []
		    (* Do we have a fully polymorphic algo? *)
		    in if (isPoly inputCV outputCV) then
			genPolyType inputCV outputCV
		       (* Do we have an overloaded algo? *)
		       else if (isOverloaded inputCV outputCV) then
			     genOverloadedType refInf inputCV outputCV
			    else let val Tin = if (generic inputCV)
						   then Alpha(newVar())
					       else getType refInf inputCV
				     val Tout = if (generic outputCV)
						    then Alpha(newVar())
						else getType refInf outputCV
				 in Arrow(Tin, Tout)
				 end
		    end
	     | _ => let val uncurriedState = Internal.uncurry(order-2, x)
			val inputCV = getInputCV uncurriedState []
			val outputCV = getOutputCV 2 uncurriedState []
			val inputCVlist = separate (order-2) inputCV []
		    (* Do we have a fully polymorphic curried algo? *)
		    in if (isPolyC inputCVlist outputCV) then
			genPolyCType inputCVlist outputCV
		       (* Do we have a curried overloaded algo? *)
		       else if (isOverloadedC inputCVlist outputCV) then
			   genOverloadedCType refInf inputCVlist outputCV
			    else let val TinList = checkGeneric refInf inputCVlist
				     val Tout = if (generic outputCV)
						    then Alpha(newVar())
						else getType refInf outputCV
				     val t = genCurryType (TinList @ [Tout])
				 in t
				 end
		    end
	   end


      (*  Used by the refine module to type a (cell, value list) list after *)
      (* instantiating variables. *)
  fun typeCVlist refInf cvList =
      let val (t, _) = genOverloadedSingleType refInf cvList
      in finalPolish [t]
      end


    (* Given a list of types put /\ around it if it contains >1 element *)
    fun meetWrap t = 
	let val t' = duplicates t
	in if (length t') > 1 then Meet t'
	   else if t' = nil then raise InferError "term does not have a type"
		else hd t'
	end

    fun doGlobals gT1 gT2 meetFun unifyFun =
	(case gT1 of 
	     (TYPE gt1) => (case gT2 of
				(TYPE gt2) => TYPE (meetFun gt1 gt2 unifyFun)
			      | _ => UNTYPED)
	   | _ => UNTYPED)

    fun doOneGlobal UNTYPED _ _ = UNTYPED
      | doOneGlobal (TYPE t) meetFun unifyFun = TYPE (meetFun t unifyFun)


        (* Applies unifyOp to all elements of t if t is a /\ *)
        (* Also wraps a /\ around result if it has >1 element *)
    fun doOneMeet t unifyOp =
	let val result = case t of
	    (Meet tlist) => unifyOp tlist []
	  | _ => unifyOp [t] []
	in meetWrap result
	end

        (* Combines type expressions involving /\, by essentially *)
        (* trying out all possibilities. *)
    fun combMeet t1 t2 unifyOp =
	let val t = case t2 of
	    (Meet t2list) => 
		(case t1 of
		     (Meet t1list) => unifyOp t1list t2list []
		   | _ => unifyOp [t1] t2list [])
	  | _ => (case t1 of
		      (Meet t1list) => unifyOp t1list [t2] []
		    | _ => unifyOp [t1] [t2] [])
	in meetWrap t
	end

    fun isMeetHead (Arrow(t1, t2)) = isMeet t1
      | isMeetHead (Meet tlist) = fold (fn (x,y) => x orelse y) (map isMeetHead tlist) false

        (* If left side is a meet, try out all possibilities *)
    fun combLeftMeet t1 t2 unifyOp =
	let val t = 
	    if isMeetHead t1
		then (case t1 of
			  (Meet t1list) => unifyOp t1list [t2] []
			| _ => unifyOp [t1] [t2] [])
	    else (case t1 of
		      (Meet t1list) => (case t2 of 
					    (Meet t2list) => unifyOp t1list t2list []
					  | _ => unifyOp t1list [t2] [])
		    | _ => (case t2 of 
				(Meet t2list) => unifyOp [t1] t2list []
			      | _ => unifyOp [t1] [t2] []))
	in meetWrap t
	end


        (* Find refinement and global type for an expression *)
        (* Global type is stored as TYPE T or UNTYPED, because *)
        (* typing is optional. *)
    and inferType (Expr_id s) =
	let val (time,e,T) = lookupExpType(s,!nameExpTypeList)
	    handle Lookup str => 
		raise InferError ("nonexistent identifier: "^s)
	in case T of
	    UNTYPED => let val gT = inferType e
			   val _ = storeExpType nameExpTypeList
			       (s, !currentTimeStamp, e, gT)
		       in gT
		       end
	  | TYPE t' => if time = !currentTimeStamp
			   then TYPE (#1(freshInst t' emptyInst))
		       else let val newgT = inferType e
				val _ = storeExpType nameExpTypeList
				  (s, !currentTimeStamp, e, newgT)
			    in newgT
			    end
	end
      | inferType (Expr_state x) = TYPE (typeState false x)
      | inferType (Expr_algo a) = inferType (Expr_state (Internal.algoToState(a,[])))
      | inferType (Expr_apply(e1,e2)) =
	let val gT1 = inferType e1
	    val gT2 = inferType e2
	in doGlobals gT1 gT2 combLeftMeet unifyApp
	end
      | inferType (Expr_compose(e1,e2)) =
	let val gT1 = inferType e1
	    val gT2 = inferType e2
	in doGlobals gT1 gT2 combMeet unifyComp
	end
      | inferType (Expr_fix(e)) =
	let val gT = inferType e
	in doOneGlobal gT doOneMeet unifyFix
	end
      | inferType (Expr_curry(e)) =
	let val gT = inferType e
	in doOneGlobal gT doOneMeet unifyCurry
	end
      | inferType (Expr_uncurry(e)) = 
	let val gT = inferType e
	in doOneGlobal gT doOneMeet unifyUncurry
	end
      | inferType (Expr_pair(e1,e2)) =
	let val gT1 = inferType e1
	    val gT2 = inferType e2
	in doGlobals gT1 gT2 combMeet unifyPair
	end 
      | inferType (Expr_prod(e1,e2)) =
	let val gT1 = inferType e1
	    val gT2 = inferType e2
	    val gT = case (gT1, gT2) of 
		(TYPE t1, TYPE t2) => TYPE (Prod [t1, t2])
	      | _ => UNTYPED
	in gT
	end

  end
  end;
