(* Copyright 1989 by AT&T Bell Laboratories *)
(* transtypes.sml *)

signature TRANSTYPES = sig

  val transTyLty : Types.ty -> LambdaType.lty
  val unwrapOp : LambdaType.lty * LambdaType.lty -> Lambda.lexp -> Lambda.lexp
  val wrapOp : LambdaType.lty * LambdaType.lty -> Lambda.lexp -> Lambda.lexp
  val specialWrapperGen : 
            bool ->
           ((LambdaType.lty * LambdaType.lty -> Lambda.lexp -> Lambda.lexp)
            * (LambdaType.lty * LambdaType.lty -> Lambda.lexp -> Lambda.lexp)
            * (Lambda.lexp -> Lambda.lexp))
end

structure Transtypes : TRANSTYPES  = struct

  open Access Types BasicTypes  Lambda List2
  structure LT = LambdaType

(****************************************************************************
 *                  UTILITY FUNCTIONS AND CONSTANTS                         * 
 ****************************************************************************) 

fun err (s : string) = ErrorMsg.impossible s
fun say (s : string) = Control.Print.say s

val eqTycon = TypesUtil.eqTycon
val applyTyfun = TypesUtil.applyTyfun
val equalTycon = TypesUtil.equalTycon
fun ident le = le
fun fromto(i,j) = if i < j then (i::fromto(i+1,j)) else []
val LSTty = LT.inj(LT.LIST 0)

fun merge2(h1::t1,h2::t2,i) = (h1,h2,i)::(merge2(t1,t2,i+1))
  | merge2([],[],i) = []
  | merge2 _ = err "type conflicts in transtypes-merge2"

fun mkmod t = case LT.out t 
  of LT.RECORD l => LT.injRECORD (map (fn x => LT.inj LT.RBOXED) l)
   | LT.SRECORD l => LT.inj(LT.SRECORD (map (fn x => LT.inj LT.RBOXED) l))
   | LT.ARROW _ => LT.injARROW (LT.inj LT.RBOXED, LT.inj LT.RBOXED)
   | _ => LT.injBOXED

fun arity0 t = case LT.out t
 of LT.RECORD _ => false
  | LT.SRECORD _ => false
  | LT.ARROW _ => false
  | _ => true

fun option(NONE) = false
  | option(SOME _) = true

fun force (NONE, le) = le
  | force (SOME f, le) = f le

(****************************************************************************
 * Translate a ML type (Types.ty) to the lambda type (LambdaType.lty)       *
 *               transTyLty : Types.ty -> LambdaType.lty                    *
 ****************************************************************************)

fun transTyLty ty = 
  let fun transty(POLYty{tyfun=TYFUN{body,arity},...},_) = 
             let val f = TypesUtil.getRecTyvarMap(arity,body)
              in transty(body,f)
             end

        | transty(VARty(ref (INSTANTIATED ty)),f) = transty(ty,f)

        | transty(CONty(RECORDtyc label,args),f) = 
             let val args' = map (fn y => transty(y,f)) args
              in case args' of [] => LT.injINT
                             | _ => LT.injRECORD args'
             end  

        | transty(CONty(GENtyc{kind=ref(FORMDEFtyc tyc),...},args),f) = 
           	  transty(CONty(tyc,args),f)

        | transty(CONty(DEFtyc{tyfun,...},args),f) =
             let val t = applyTyfun(tyfun,args)
              in transty(t,f)
             end

        | transty(CONty(FORMtyc{spec=tc,...},args),f) =
             transty(CONty(tc,args),f)

        | transty(CONty(OPENFORMtyc{spec=tc,...},args),f) =
             transty(CONty(tc,args),f)

        | transty(CONty(RELtyc{spec=ERRORtyc,...},args),f) = 
	    err "Unexpected RELtyc in transtypes\n"

        | transty(CONty(RELtyc{spec=tc,...},args),f) = 
             transty(CONty(tc,args),f)

        | transty(CONty(ABSFBtyc(_,tc),args),f) = 
             transty(CONty(tc,args),f)

(*        | transty(CONty(GENtyc{kind=ref(ABStyc tc),...},args),f) =
             transty(CONty(tc,args),f)
*)

        | transty(CONty(GENtyc{kind=ref(DATAtyc[DATACON{typ,rep=TRANSPARENT,
                            ...}]),...},args),f) =
             let fun argType(POLYty{tyfun=TYFUN{arity,
                                    body=CONty(_,[domain,_])},...}, args) =
                       applyTyfun(TYFUN{arity=arity,body=domain},args)
                   | argType(CONty(_,[argty,_]), []) = argty
                   | argType _ = err "Transtypes.argType"
              in transty(argType(typ,args),f)
             end

        | transty(CONty(tycon,args),f) =
           if	   equalTycon(tycon,intTycon) then LT.injINT
           else if equalTycon(tycon,wordTycon) then LT.inj LT.INT
           else if equalTycon(tycon,word8Tycon) then LT.inj LT.INT
           else if equalTycon(tycon,word32Tycon) then LT.inj LT.INT32
           else if equalTycon(tycon,boolTycon) then LT.inj LT.BOOL
           else if equalTycon(tycon,realTycon) then LT.injREAL
           else if equalTycon(tycon,arrowTycon) then
                    (let val t1 = transty(List.nth(args,0),f)
                         val t2 = transty(List.nth(args,1),f)
                      in LT.injARROW(t1,t2)
                     end)
           else if equalTycon(tycon,contTycon) then LT.inj LT.SRCONT
           else if equalTycon(tycon,unitTycon) then LT.injINT
           else if equalTycon(tycon,ulistTycon) then LSTty
           else LT.inj LT.RBOXED

        | transty(IBOUND i, f) = if (f i) then LT.inj LT.RBOXED else LT.injBOXED

        | transty _ = LT.injBOXED

   in (transty (ty,fn i => false) 
         handle Subscript => (print "Subscript in transTyLty ... \n"; LT.injBOXED))
  end



  val rep_flag = ref true (* Control.CG.representations *)


  val transTyLty = fn x => 
	let val t = transTyLty x
         in if !rep_flag then t
	    else case LT.out t
		  of LT.ARROW _ => LT.injARROW(LT.BOGUS,LT.BOGUS)
		   | _ => LT.BOGUS
	end

fun specialWrapperGen withSharing = let

fun isBOXED t = case LT.out t of LT.BOXED => true | _ => false

val wrappers_list : (lvar * lexp) list ref = ref []
fun addwrappers p = (wrappers_list := p::(!wrappers_list))

val wrapTable : (lexp->lexp) option LT.ltyMap ref = ref LT.empty

(* application to () `freezes' wrapper_list; this is necessary to be able to
 * delay the actual construction of the header *)
fun buildheader base =
    (foldl (fn ((v,le),b) => APP(FN(v,LT.BOGUS,b),le)) base (!wrappers_list))

val getW =
  if withSharing
   then (fn(mark,t1,t2,makeWrapper) =>
	 if LT.eq(t1,t2) then NONE
	 else let val key = LT.injRECORD[t1,t2,mark]
	      in case LT.look(!wrapTable, key)
		   of SOME x => x
		    | NONE =>  let val res = makeWrapper(LT.out t1, LT.out t2)
			       in wrapTable := LT.enter(!wrapTable,key,res);
				   res
			       end
	      end)
   else (fn(_,t1,t2,makeWrapper) =>
          if LT.eq(t1,t2) then NONE else makeWrapper(LT.out t1, LT.out t2))

val markWrap = LT.injINT    
val markUnwrap = LT.injREAL  (* any two distinct types will do *)

val doWrap = if withSharing 
		 then fn exp =>
		        let val w = mkLvar()
			 in addwrappers(w,exp); VAR w 
			end
	         else fn exp => exp

fun unwrapOp(t1,t2) = getW(markUnwrap,t1,t2,
     fn (LT.BOXED,LT.RBOXED) => NONE
      | (LT.RBOXED,LT.BOXED) => NONE
      | (LT.RECORD [],LT.INT) => NONE
      | (LT.SRECORD [],LT.INT) => NONE
      | (LT.INT,LT.RECORD []) => NONE
      | (LT.INT,LT.SRECORD []) => NONE
      | (LT.RECORD args, LT.RECORD args') =>
	      let val bigSl = map2 unwrapOp (args,args')
	       in if List.exists option bigSl
                  then let val v = mkLvar()
                            val nl = fromto(0,length(args))
                            val base = map (fn i => SELECT(i,VAR v)) nl 
                            val re = map2 force (bigSl,base)
			    val e = doWrap(FN(v,t1,RECORD re))
                        in SOME(fn le => APP(e,le))
                       end
                  else NONE
              end
      | (LT.SRECORD args, LT.SRECORD args') =>
              let val bigSl = map2 unwrapOp (args,args')
               in if (List.exists option bigSl) 
                  then let val v = mkLvar()
                            val nl = fromto(0,length(args))
                            val base = map (fn i => SELECT(i,VAR v)) nl 
                            val re = map2 force (bigSl,base)
			    val e = doWrap(FN(v,t1,SRECORD re))
                         in SOME(fn le => APP(e,le))
                       end
                  else NONE
              end
      | (LT.ARROW(lt1,lt2),LT.ARROW(rt1,rt2)) =>
       let val bigS = unwrapOp(lt2,rt2)
           val bigG = wrapOp(lt1,rt1)
        in case (bigS,bigG)
            of (NONE,NONE) => NONE
             | _ => let val r = mkLvar()
                         val v = mkLvar()
                         val re = force(bigS,VAR r)
                         val ve = force(bigG,VAR v)
			 val w = mkLvar()
			 val e = doWrap(FN(w,t1,FN(v,rt1,APP(FN(r,LT.BOGUS,re),
                                                     APP(VAR w,ve)))))
                      in SOME (fn le => APP(e,le))
                    end
       end
      | (LT.BOXED,_) => SOME (fn le => UNWRAP(t2,le))
      | (LT.RBOXED,_) => 
         if arity0 t2 then unwrapOp(LT.injBOXED,t2)
	 else  let val mt = mkmod t2
		   val bigS = unwrapOp(mt,t2)         
	       in case bigS 
		    of NONE => SOME (fn le => UNWRAP(mt,le))
		     | SOME h => SOME (fn le => h(UNWRAP(mt,le)))
	       end
      | (LT.CONT args,LT.CONT args') => err "src-level cont in lookUNWRAP"
      | _ => err "type conflicts happened in transtypes-lookUNWRAP2"
	    )
 
 and wrapOp(t,t') = getW(markWrap,t,t',
       fn (LT.BOXED,LT.RBOXED) => NONE
	| (LT.RBOXED,LT.BOXED) => NONE
	| (LT.RECORD [],LT.INT) => NONE
	| (LT.SRECORD [],LT.INT) => NONE
	| (LT.INT,LT.RECORD []) => NONE
	| (LT.INT,LT.SRECORD []) => NONE
	| (LT.RECORD args, LT.RECORD args') =>
              let val bigGl = map2 wrapOp (args,args')
               in if List.exists option bigGl
                  then  let val v = mkLvar()
                            val nl = fromto(0,length(args))
                            val base = map (fn i => SELECT(i,VAR v)) nl 
                            val re = map2 force (bigGl,base)
			    val e = doWrap(FN(v,t',RECORD re))
                         in SOME(fn le => APP(e,le))
                        end
                  else NONE
              end
  | (LT.SRECORD args, LT.SRECORD args') =>
              let val bigGl = map2 wrapOp (args,args')
               in if List.exists option bigGl
                  then  let val v = mkLvar()
                            val nl = fromto(0,length(args))
                            val base = map (fn i => SELECT(i,VAR v)) nl 
                            val re = map2 force (bigGl,base)
			    val e = doWrap(FN(v,t',SRECORD re))
                         in SOME(fn le => APP(e,le))
                        end
                  else NONE
              end
  | (LT.ARROW(lt1,lt2),LT.ARROW(rt1,rt2)) =>
      let val bigG = wrapOp(lt2,rt2)
          val bigS = unwrapOp(lt1,rt1)
       in case (bigS,bigG) 
           of (NONE,NONE) => NONE
            | _ =>  let val r = mkLvar() 
                        val v = mkLvar()
                        val re = force(bigG,VAR r)
                        val ve = force(bigS,VAR v)
			val w = mkLvar()
			val e = doWrap(FN(w,t',FN(v,lt1,APP(FN(r,LT.BOGUS,re),
                                                    APP(VAR w,ve)))))
                     in SOME (fn le => APP(e,le))
                    end
      end
  | (LT.BOXED,_) => SOME (fn le => WRAP(t',le))
  | (LT.RBOXED,_) =>
      if arity0 t' then wrapOp(LT.injBOXED,t')
      else (let val mt = mkmod t'
                val bigG = wrapOp(mt,t')         
             in case bigG
                 of NONE => SOME (fn le => WRAP(mt,le))
                  | SOME h => SOME (fn le => WRAP(mt,h(le)))
            end)
  | (LT.CONT args, LT.CONT args') => err "source-level cont in wrapOp"
  | _ => err "type conflicts in transtypes-wrapOp2"
 )

  val unwrapOp = fn (x,y) => 
        if (!rep_flag) then (fn le => force(unwrapOp(x,y),le))
        else ident

  val wrapOp = fn (x,y) => 
        if (!rep_flag) then (fn le => force(wrapOp(x,y),le))
        else ident

 in (unwrapOp,wrapOp, buildheader)

end

val (unwrapOp,wrapOp,_) = specialWrapperGen false


end
