(* Copyright 1989 by AT&T Bell Laboratories *)
(* lambdatype.sml *)

signature LAMBDATYPE = sig 

type lty

datatype lty'
  = INT
  | INT32
  | BOOL
  | REAL
  | SRCONT
  | BOXED
  | RECORD of lty list
  | SRECORD of lty list
  | ARROW of lty * lty
  | CONT of lty
  | LIST of int                   (* list type with the length parity *)
  | GREC of (int * lty) list      (* records with partial type information *)
  | SUM of (int * lty) list       (* sum types such as list *)
  | RBOXED                        (* recursively boxed type variables *) 
  | BOT of lty                    (* refinement bottom type *)
  | REF of lty * (lty vector)     (* refinement arrow type *)

val inj: lty' -> lty
val out: lty -> lty'
val rehashcons: lty -> lty

val crc: CRC.crc -> lty -> unit	

val injINT : lty
val injINT32 : lty
val injBOXED: lty
val injREAL: lty
val injARROW: lty*lty -> lty
val injRECORD: lty list -> lty

val eq : lty * lty -> bool


val BOGUS : lty

val printLty : lty -> unit
val equivLty : lty * lty -> bool
val compatLty : lty * lty -> bool
val arrowLty : lty -> lty * lty
val contLty : lty -> lty
val recordLty : lty -> lty list
val selectLty : lty * int -> lty
val mergeLty : lty * lty -> lty
val sizeLty : lty -> int
val lengthLty : lty -> int 

type 'a ltyMap
val empty: 'a ltyMap
val enter : 'a ltyMap * lty * 'a -> 'a ltyMap
val look  : 'a ltyMap * lty -> 'a option

end

abstraction LambdaType : LAMBDATYPE = struct 


datatype lty' 
  = INT 
  | INT32
  | BOOL
  | REAL
  | SRCONT
  | BOXED
  | RECORD of lty list
  | SRECORD of lty list
  | ARROW of lty * lty
  | CONT of lty
  | LIST of int
  | GREC of (int * lty) list      
  | SUM of (int * lty) list       
  | RBOXED                      
  | BOT of lty                    
  | REF of lty * (lty vector)     

withtype lty = (int * lty') ref

exception LambdatyMap

val itow = Word.fromInt
val wtoi = Word.toInt

val N = 1024  (* a power of 2 *)
val NN = itow(N*N)
val P = 0w1019  (* a prime < 1024, so that N*N*P < maxint *)


structure Weak = System.Weak

val table: lty Weak.weak list Array.array = Array.array(N,nil)

fun getnum(ref(i,_)) = i
fun tagnums nil = nil
  | tagnums((i,t)::rest) = i::getnum t::tagnums rest

fun vector2list v = Vector.foldr (op ::) [] v



fun crc c = let
    val { app = c, int, list, vector, ... } = CRCUtil.mkUtils (c, true)
    fun lty (ref (_, t')) = lty' t'
    and intlty (i, t) = (int i; lty t)
    and lty' INT = c 0
      | lty' INT32 = c 1
      | lty' BOOL = c 2
      | lty' REAL = c 3
      | lty' SRCONT = c 4
      | lty' BOXED = c 5
      | lty' (RECORD l) = (c 6; list lty l)
      | lty' (SRECORD l) = (c 7; list lty l)
      | lty' (ARROW (t1, t2)) = (c 8; lty t1; lty t2)
      | lty' (CONT t) = (c 9; lty t)
      | lty' (LIST i) = (c 10; int i)
      | lty' (GREC l) = (c 11; list intlty l)
      | lty' (SUM l) = (c 12;list intlty l)
      | lty' RBOXED = c 13
      | lty' (BOT t) = (c 14; lty t)
      | lty' (REF (t, v)) = (c 15; lty t; vector lty v)
in
    lty
end

fun revcat(a::rest,b) = revcat(rest,a::b)
  | revcat(nil,b) = b

fun inj t = let 
  fun combine [x] = itow x
    | combine(a::rest) = Word.andb(itow a +(combine rest)*P, NN - 0w1)
  fun hash(INT) = 0w1
    | hash(INT32) = 0w2
    | hash(BOOL) = 0w3
    | hash(REAL) = 0w4
    | hash(SRCONT) = 0w5
    | hash(BOXED) = 0w6
    | hash(RECORD l) = combine(7::map getnum l)
    | hash(SRECORD l) = combine(8::map getnum l)
    | hash(ARROW(a,b)) = combine[9,getnum a,getnum b]
    | hash(CONT t) = combine[10,getnum t]
    | hash(LIST i) = combine[11,i]
    | hash(GREC l) = combine(12::tagnums l)
    | hash(SUM l) = combine(13::tagnums l)
    | hash(RBOXED) = 0w14
    | hash(BOT t) = combine[15, getnum t]
    | hash(REF(t,tv)) = combine(16::map getnum(t::vector2list tv))

  val h = wtoi(hash t)
  val index = wtoi(Word.andb(itow h, itow(N-1)))
  fun look(l,w::rest) = 
      (case Weak.strong w
	of SOME(r as ref(h',t')) =>
	  if h=h' andalso t=t' 
	    then (Array.update(table,index,
			       w::revcat(l,rest));
		  r)
	  else look(w::l,rest)
	| NONE => look(l,rest))
    | look(l,nil) = 
	 let val r = ref(h,t)
	 in Array.update(table,index,
			 Weak.weak r :: rev l);
	   r
	 end
in look(nil,Array.sub(table,index))
end

fun out(r as ref(h,t)) = t

val injINT = inj INT
val injINT32 = inj INT32
val injBOXED= inj BOXED
val injRBOXED= inj RBOXED
val injREAL = inj REAL
val injARROW = inj o ARROW
val injRECORD = inj o RECORD

fun eq(r,s) = r=s

val BOGUS = inj BOXED    

val say = Control.Print.say
fun warn (msg : string) = app say ["Warning: ", msg, "\n"]
fun err (msg : string) = ErrorMsg.impossible msg

fun printLty t =
  case out t
   of INT => say "I"
    | INT32 => say "I32"
    | BOOL => say "B"
    | REAL => say "F"
    | BOXED => say "P"
    | RBOXED => say "R"
    | SRCONT => say "CONT"
    | (LIST 0) => say "UL"
    | (LIST 1) => say "OL"
    | (LIST 2) => say "EL"
    | (LIST _) => say "L"
    | (CONT t) => (printLty t;say " cont")
    | (ARROW(t1,t2)) => (say"("; printLty t1; say "->"; printLty t2; say")")
    | (GREC []) => say"[<>]" 
    | (GREC ((i,t)::r)) => 
	      (say "["; say ("<"^(makestring(i))^","); printLty t; say ">";
	       app (fn (j,x) => (say (",<"^makestring(j)^","); 
        	                 printLty x; say ">")) r; say "]")
    | (SUM []) => err "SUMty with empty cells in lambdatype" 
    | (SUM ((i,t)::r)) => 
	      (say "$"; say ("<"^(makestring(i))^","); printLty t; say ">";
	       app (fn (j,x) => (say (",<"^makestring(j)^"+"); 
        	                 printLty x; say ">")) r; say "$")
    | (RECORD []) => say "I"
    | (RECORD (a::r)) =>
	      (say "{"; printLty a;
	       app (fn t => (say "," ; printLty t)) r; say "}")
    | (SRECORD []) => say "I"
    | (SRECORD (a::r)) =>
	      (say "#S{"; printLty a;
	       app (fn t => (say "," ; printLty t)) r; say "}S#")
    | (BOT t) => (printLty t; say " BOTTOM")
    | (REF(t,z)) => (printLty t; say "==>"; printLty (Vector.sub(z,0)))


fun equivLty(a,b) = eq(a,b) 
  orelse case (out a, out b) of (a',b') => eqv(a',b') orelse eqv(b',a')

and eqv (INT,BOOL) = true
  | eqv (INT,RECORD nil) = true
  | eqv (INT,SRECORD nil) = true
  | eqv (BOXED,RBOXED) = true
  | eqv (BOT t1,BOT t2) = equivLty (t1,t2)
  | eqv (ARROW(lt,b),SRCONT) = 
           (case out b of BOXED => equivLty(lt,injBOXED) | _ => false)
  | eqv (SRECORD r,GREC l) = eqRG(r,l)
  | eqv (GREC l,GREC r) = eqGG(l,r)
  | eqv (SRECORD l,SRECORD r) = eqRR(l,r)
  | eqv (RECORD l,RECORD r) = eqRR(l,r)
  | eqv (ARROW(t1,b),CONT t2) =
           (case out b of BOXED => equivLty(t1,t2) | _ => false)
  | eqv (ARROW(t1,t2),ARROW(s1,s2)) = 
	        equivLty(t1,s1) andalso equivLty(t2,s2)
  | eqv (CONT t1,CONT t2) = equivLty(t1,t2)
  | eqv (SUM l,SUM r) = eqGG(l,r)        (* a temporary definition *)
  | eqv (SUM _,t) = equivLty(injBOXED, inj t)
  | eqv (REF(t11,z12),REF(t21,z22)) =  (* is this right? *)
         equivLty(t11,t21) andalso
            (let val l1 = Vector.length(z12)
                 val l2 = Vector.length(z22)
		 fun fromto(i,j) = if i < j then (i::fromto(i+1,j)) else []
                 val ll = fromto(0,l2)
              in (l1 = l2) andalso
		(foldr (fn (a,b) => a andalso b) true
		  (map (fn i => equivLty(Vector.sub(z12,i),Vector.sub(z22,i))) ll))
             end)
  | eqv _ = false

and compatLty(a,b) = equivLty(a,b) 
  orelse case (out a, out b) of (a',b') => compat(a',b') orelse compat(b',a')

and compat (INT,BOOL) = true
  | compat (INT,RECORD nil) = true
  | compat (INT,SRECORD nil) = true
  | compat (INT,BOXED) = true
  | compat (INT,RBOXED) = true
  | compat (BOOL,BOXED) = true
  | compat (BOOL,RBOXED) = true
  | compat (SRCONT,BOXED) = true
  | compat (SRCONT,RBOXED) = true
  | compat (BOXED,RBOXED) = true
(*
  | compat (ARROW(t1,t2),BOXED) = true
  | compat (ARROW(t1,t2),RBOXED) = 
                compat(out t1,RBOXED) andalso compat(out t2,RBOXED)
  | compat (RECORD l,BOXED) = true
  | compat (RECORD l,RBOXED) = compatRR(l,map (fn _ => injRBOXED) l)
*)
  | compat (ARROW(lt,b),SRCONT) = 
           (case out b of BOXED => compatLty(lt,injBOXED) | _ => false)
  | compat (RECORD l,RECORD r) = compatRR(l,r)
  | compat (ARROW(t1,b),CONT t2) =
           (case out b of BOXED => compatLty(t1,t2) | _ => false)
  | compat (ARROW(t1,t2),ARROW(s1,s2)) = 
	        compatLty(t1,s1) andalso compatLty(t2,s2)
  | compat (CONT t1,CONT t2) = compatLty(t1,t2)
  | compat _ = false

and compatRR(nil,nil) = true
  | compatRR(nil,l) = false
  | compatRR(l,nil) = false
  | compatRR(a::r,b::s) = if compatLty(a,b) then compatRR(r,s) else false

and eqRR(nil,nil) = true
  | eqRR(nil,l) = false
  | eqRR(l,nil) = false
  | eqRR(a::r,b::s) = if equivLty(a,b) then eqRR(r,s) else false

and eqRG(r,nil) = true
  | eqRG(r,(i,t)::s) = 
      (if (equivLty(List.nth(r,i),t)) then eqRG(r,s) else false) handle _ => false

and eqGG(nil,l) = true
  | eqGG(l,nil) = true
  | eqGG(a as ((i,t)::l),b as ((j,s)::r)) = 
      if i > j then eqGG(a,r) 
      else if i < j then eqGG(l,b) 
           else if equivLty(t,s) then eqGG(l,r)
                else false 

val boxedLty = inj BOXED

fun arrowLty t = case out t 
       of ARROW p => p
	| CONT t => (warn "arrowLty on CONT"; (t,boxedLty))
        | SRCONT => (boxedLty,boxedLty) (* For THROW *)
        | _ => (say "*** "; printLty t; say "\n";  
                    err "lambdatype.sml arrowLty 324")

fun contLty t = case out t
   of CONT lt => lt
    | SRCONT => boxedLty
    | _ => (say "*** "; printLty t; say "\n";  
            err "lambdatype.sml contLty 301")

fun recordLty t = case out t
  of RECORD [] => err "lambdatype.sml nil recordLty"
   | RECORD l => l
   | SRECORD [] => err "lambdatype.sml nil (S) recordLty"
   | SRECORD l => l
   | _ => (say "*** "; printLty t; say "\n";    
           err "lambdatype.sml recordLty 212")

fun selectLty (t,i) = 
        let fun f(RECORD l) = List.nth(l,i)
              | f(SRECORD l) = List.nth(l,i)
              | f(GREC l) = g(l,i) 
              | f _ = (say "*** "; printLty t; say "\n";    
                       err "selectLty in lambdatype 131")
            and g(nil,_) = (say "*** "; printLty t; say "\n";    
                          err "selectLty in lambdatype 132")
              | g((j,s)::r,i) = if i = j then s else g(r,i)
         in f(out t) handle _ => (say "*** "; printLty t; 
				  say "."; say(makestring i);
				  say "\n";    
                                  err "selectLty in lambdatype 133")
        end

fun mergelist([],[]) = []
  | mergelist(l,[]) = l
  | mergelist([],r) = r
  | mergelist(x as ((i,t)::l),y as ((j,s)::r)) = 
      if i < j then ((i,t)::(mergelist(l,y)))
      else if i > j then ((j,s)::(mergelist(x,r)))
           else ((i,(mergeLty(t,s)))::(mergelist(l,r)))

and mergeLty(t1,t2) = case (out t1, out t2)
  of (GREC r, GREC l) => inj(GREC (mergelist(r,l)))
   | (GREC _, _) => if equivLty(t1,t2) then t2
     else (say "*** "; printLty t1; say "\n";
           say "*** "; printLty t2; say "\n";
           err "mergeLty in lambdatype 136")
   | (_, GREC _) => if equivLty(t1,t2) then t1
     else (say "*** "; printLty t1; say "\n";
           say "*** "; printLty t2; say "\n";
           err "mergeLty in lambdatype 137")
   | _ => if equivLty(t1,t2) then t1 
     else (say "*** "; printLty t1; say "\n";
           say "*** "; printLty t2; say "\n";
           err "mergeLty in lambdatype 138")

fun sizeLty t = case out t
   of RECORD [] => 1
    | RECORD args => foldr (op +) 0 (map sizeLty args)
    | SRECORD [] => 1
    | SRECORD args => foldr (op +) 0 (map sizeLty args)
    | _ => 1  (* in the future, real and arrowty should be 2 *)

fun lengthLty t = case out t
  of RECORD [] => err "RECORD nil in lengthLty"
   | RECORD l => length l 
   | SRECORD [] => err "SRECORD nil in lengthLty"
   | SRECORD l => length l 
   | _ => 1

type 'a ltyMap = (lty * 'a) list IntmapF.intmap

val empty = IntmapF.empty
fun look(m,t as ref(h,_)) = 
	let fun loop((t',x)::rest) = if t=t' then SOME x else loop rest
	      | loop nil = NONE
         in loop(IntmapF.lookup m h)
        end handle IntmapF.IntmapF => NONE

fun enter(m, t as ref(h,_), x) =
        IntmapF.add(m,h, (t,x)::(IntmapF.lookup m h handle IntmapF => nil))


    fun rehashcons t = case out t of
	CONT t1 => inj (CONT (rehashcons t1))
      | ARROW (t1, t2) => injARROW (rehashcons t1, rehashcons t2)
      | GREC l => inj (GREC (map withint l))
      | SUM  l => inj (SUM (map withint l))
      | RECORD l => injRECORD (map rehashcons l)
      | SRECORD l => inj (SRECORD (map rehashcons l))
      | BOT t => inj (BOT (rehashcons t))
      | REF (t, z) => let
	    fun rh i = rehashcons (Vector.sub (z, i))
	in
	    inj (REF (rehashcons t, Vector.tabulate (Vector.length z, rh)))
	end
      | u => inj u      
    and withint(i,t) = (i, rehashcons t)


end
