(* Copyright 1989 by AT&T Bell Laboratories *)
(* closecps.sml *)

signature CLOSECPS =
  sig
    val closeCPS : CPS.function -> CPS.function
  end

functor CloseCPS(MachSpec : MACH_SPEC) : CLOSECPS = struct

local
  structure Collect0= Collect(MachSpec)
  structure CGoptions = Control.CG
  open CPS Access AllocProf SortedList 
in

(*****************************************************************************
 *      Misc and utility functions                                           * 
 *****************************************************************************)

val error = ErrorMsg.impossible
val say = Control.Print.say

datatype acc = SELpath of lvar * int * cty
             | OFFpath of lvar * int 
             | Direct

type accmap = acc IntmapF.intmap
type dirmap = int IntmapF.intmap

datatype closure = CL of {subcl : closure list, 
                          contents : lvar list,
                          amap : accmap}

datatype object = Value of cty
                | Lambda of fun_kind
                | Closure of closure

(*** the access environment for each function ***)
abstype env = Env of accmap list * dirmap
with exception NotInAmap
     fun initEnv (al) = Env (al,IntmapF.empty)

     fun dirM(v,Env(al,d)) = Env(al,IntmapF.add(d,v,0))
     fun dirP(Env(_,d),v) = 
            (IntmapF.lookup d v; true) handle _ => false

(*   fun addE(Env(a,d),n) = Env(n::a,d) *)
     fun addE(Env(a::r,d),n) = Env((IntmapF.overlay(n,a))::r,d)
       | addE(Env([],d),n) = Env([n],d)

     fun sayt(INTt) = say "[I]"
       | sayt(FLTt) = say "[R]"
       | sayt(PTRt) = say "[P]"
       | sayt(FUNt) = say "[F]"
       | sayt(CNTt) = say "[C]"

     fun pramap(a) = 
       let fun ppac(SELpath(v,i,ct)) = 
                 (say "  SELpath  "; say (lvarName v); say " at ";
                  say (makestring(i)); say " with type "; sayt ct)             
             | ppac(OFFpath(v,i)) = 
                 (say "  OFFpath  "; say (lvarName v); say " at ";
                  say (makestring(i)))
             | ppac _ = error "Direct path in ppac amap"
           fun pp(v,ac) = 
            (say ">>> "; say (lvarName v); ppac(ac); say " <<< \n")
        in app pp (IntmapF.members a)
       end

     fun prdir(d) = 
       let fun pp(v,s) = 
             (say ">>> "; say (lvarName v); say " in directmap <<< \n")
        in app pp (IntmapF.members d)
       end

     fun whereIs(env as Env(al,d),v) =
       if dirP(env,v) then (env,Direct)
       else (let val d' = IntmapF.add(d,v,0)
                 fun g(am,[]) = [am]
                   | g(am,a::r) = (IntmapF.overlay(am,a))::r

                 fun lookal (am::r,v,h) = 
                       ((IntmapF.lookup am v,g(am,(rev h)@r))
                                        handle _ => lookal(r,v,am::h))
                   | lookal ([],v,_) = 
                       (say "**** lvar not found: ";
                        say (lvarName(v)); 
                        say " **** in the following amaps **** \n";
                        app pramap (al); prdir d; 
                        say "****raise exception NotInAmap ********* \n";
                        raise NotInAmap)
                 val (p,al')  = lookal(al,v,[])
              in (Env(al',d'),p)
             end)

     fun access(rootvar,env,header) = 
        let val (env',p) = whereIs(env,rootvar)
         in case p
             of Direct => (env',header)
              | SELpath(w,i,t) => 
                  access(w,env',fn ce => SELECT(i,VAR w,rootvar,t,header(ce)))
              | OFFpath(w,i) => 
                  access(w,env',fn ce => OFFSET(i,VAR w,rootvar,header(ce)))
        end
end (*** end of abstype env ***)

(*** bfs through an list of closure access maps ***)
fun bfsamap l = 
 let fun h([],[],r) = rev r
       | h([],ql,r) = h(ql,[],r)
       | h((CL{subcl=cl,amap=a,...})::z,ql,r) = h(z,ql@cl,a::r)
  in h(l,[],[])
 end      

(*** access a list of values from the current environment ***)
fun fixAccess (vl,env) =
 let fun h(VAR rootvar,(env,header)) = 
           let val (env',hdr) = access(rootvar,env,fn x => x)
            in (env',header o hdr)
           end
       | h(_,(env,header)) = (env,header)

  in fold h vl (env, fn x => x)
 end

(*** build a record in the current environment ***)
fun recordEl(l,env) =
 let fun h(u as (VAR rootvar,p),(l,env,header)) = 
          let val (_,np) = whereIs(env,rootvar)
           in case np of Direct => (u::l, env, header)
                       | SELpath(w,i,_) => 
                           let val (env',hdr) = access(w,env,fn x => x)
                            in ((VAR w,SELp(i,p))::l,env',header o hdr)
                           end
                       | OFFpath(w,0) => h((VAR w,p),(l,env,header))
                       | OFFpath(w,i) =>
                           let val (env',hdr) = access(w,env,fn x => x)
                            in ((VAR w,OFFp i)::l,env',header o hdr)
                           end 
          end
       | h (u,(l,env,header)) = (u::l,env,header)
  in fold h l ([],env,fn x => x)
 end

abstype dict = Dict of object IntmapF.intmap
with exception NotInDict
     fun emptyDict () = Dict (IntmapF.empty)

     fun look(v,Dict d) = 
       (IntmapF.lookup d v) handle _ => 
         (if (!CGoptions.misc4=17) then
            (say "****Dict**** lvar not found: "; 
             say (lvarName v); say " in the following dict *** \n")
          else (); 
          raise NotInDict)

     fun augV(v,t,Dict d) = Dict(IntmapF.add(d,v,Value t))

     fun augL(v,fk,Dict d) = Dict(IntmapF.add(d,v,Lambda fk))

     fun getk(v,dict) =
       case (look(v,dict)) 
        of (Value CNTt) => CONT
         | (Lambda fk) => fk
         | _ => ESCAPE

     fun getkind(VAR v,dict) = getk(v,dict)
       | getkind(LABEL v,dict) = getk(v,dict)
       | getkind _ = error "getkind in closecps.sml 2312"

     fun augC(RK_KNOWN,[],[],cname,dict,env) =
          let val nenv = dirM(cname,env) 
              val ndict = augV(cname,PTRt,dict)
              val header = 
                (fn ce => RECORD(RK_KNOWN,[(INT 0,OFFp 0)],cname,ce))
           in (ndict,nenv,header)
          end
       | augC(rk,[],[],cname,dict,env) =
            error "augC tries to build empty closures in closecps" 
       | augC(rk,fns,vl,cname,dict,env) =
          let fun sc((VAR v)::r,l,al) = 
                   (case look(v,dict) 
                     of (Closure c) => sc(r,c::l,(v,PTRt)::al)
                      | (Value t) => sc(r,l,(v,t)::al)
                      | _ => error "augC in closecps 1231")
                | sc(_::r,l,al) = sc(r,l,al)
                | sc([],l,al) = (rev l,rev al)
           
              val (subcl,vals) = sc(vl,[],[])
              val contents = 
                let fun clean((VAR v)::r,l) = clean(r,enter(v,l))
                      | clean((LABEL v)::r,l) = clean(r,enter(v,l))
                      | clean(_::r,l) = clean(r,l)
                      | clean([],l) = l
                 in clean(vl,[])
                end 

              fun mkamap(v,offd) = 
                let fun add1(x,(i,m)) = 
                      if x = v then (i+1,m)
                      else (i+1,IntmapF.add(m,x,(OFFpath(v,i-offd))))

                    fun add2((x,t),(i,m)) = 
                      (i+1,IntmapF.add(m,x,(SELpath(v,i-offd,t))))
            
                    val (d,amap) = revfold add1 fns (0,IntmapF.empty)
                    val (_,amap) = revfold add2 vals (d,amap)
                 in if v = cname then amap
                    else IntmapF.add(amap,cname,OFFpath(v,~offd))
                end

              val (ndict,nenv) = 
                let fun augAll((v,offd),Dict d) = 
                       Dict(IntmapF.add(d,v,
                             Closure(CL {subcl=subcl,contents=contents,
                                         amap=mkamap(v,offd)})))


                    fun formap (a::r,i,l) = formap(r,i+1,(a,i)::l)
                      | formap ([],i,l) = l

                    val index = formap(fns,0,[])
                    val bdict = augAll((cname,0),dict)
                    val am = fold (fn ((x,offd),b) 
                                     => IntmapF.add(b,x,OFFpath(cname,offd)))
                               index IntmapF.empty

                 in (fold augAll index bdict, addE(env,am))
                end

              val (header,nenv) = 
                let val ul = map (fn v => (v,OFFp 0)) vl
                    val (ul',nenv',hdr) = recordEl(ul,nenv)
                    val nrk = case rk of RK_FBLOCK => RK_FBLOCK
                                       | _ => RK_BLOCK
                 in (fn ce => hdr(RECORD(rk,ul',cname,ce)),dirM(cname,nenv'))
                end
           in (ndict,nenv,header)
          end

end (*** end of abstype dict ***)

(*****************************************************************************
 *             closeCPS : MAIN FUNCTIONS                                     *
 *****************************************************************************)
fun closefix(fe as (fk,f,vl,cl,ce),dict) = 
 let fun scan(v::vl,t::cl,dict,r) = 
          ((case look(v,dict)
             of Closure c => scan(vl,cl,dict,c::r)
              | _ => scan(vl,cl,dict,r))
           handle _ => (scan(vl,cl,augV(v,t,dict),r)))
       | scan([],[],dict,r) = (dict,r)
       | scan _ = error "makenv in closecps 2331"

     val (ndict,closlist) = scan(vl,cl,dict,[])
     val nenv = fold dirM vl (initEnv(bfsamap closlist))

  in (fk,f,vl,cl,close(ce,ndict,nenv))
 end

and close(cexp,dict,env) =
 case cexp
  of FIX(fl,body) =>
       let val dict' = fold (fn ((fk,f,_,_,_),dict) 
                               => augL(f,fk,dict)) fl dict
           val nfl = map (fn fe => closefix(fe,dict')) fl
        in FIX(nfl,close(body,dict',env))
       end
   | APP(f,args) =>
       let val (env',header) = fixAccess(f::args,env)
        in if not(!CGoptions.allocprof) then header(cexp)
           else (case getkind(f,dict) 
                  of ESCAPE => header(profStdCall cexp)
                   | CONT => header(profCSCntCall cexp)
                   | KNOWN => header(profKnownCall cexp)
                   | _ => error "APP node in close in closecps 1213")
       end
   | SWITCH(v,c,el) =>
       let val (env',header) = fixAccess([v],env)
           val el' = map (fn e => close(e,dict,env')) el
        in header(SWITCH(v,c,el'))
       end
   | RECORD(rk as RK_CLOSURE fns,ul,cname,ce) =>
       let val (dict',env',header) = augC(rk,fns,map #1 ul,cname,dict,env)
        in header(close(ce,dict',env'))
       end
   | RECORD(rk as RK_KNOWN,ul,cname,ce) =>
       let val (dict',env',header) = augC(rk,[],map #1 ul,cname,dict,env)
        in header(close(ce,dict',env'))
       end
   | RECORD(rk as RK_CONT,ul,cname,ce) =>
       let val (dict',env',header) = augC(rk,[],map #1 ul,cname,dict,env)
        in header(close(ce,dict',env'))
       end
   | RECORD(rk as RK_BLOCK,ul,cname,ce) =>
       let val (dict',env',header) = augC(rk,[],map #1 ul,cname,dict,env)
        in header(close(ce,dict',env'))
       end
   | RECORD(rk as RK_FBLOCK,ul,cname,ce) =>
       let val (dict',env',header) = augC(rk,[],map #1 ul,cname,dict,env)
        in header(close(ce,dict',env'))
       end
   | RECORD(rk as RK_I32BLOCK,ul,cname,ce) =>
       let val (dict',env',header) = augC(rk,[],map #1 ul,cname,dict,env)
        in header(close(ce,dict',env'))
       end
   | RECORD(k,vl,w,ce) =>
       let val (vl',env',header) = recordEl(vl,env)
           val ce' = close(ce,augV(w,PTRt,dict),dirM(w,env'))
           val record = RECORD(k,vl',w,ce')
        in if not(!CGoptions.allocprof) then header(record)
           else header(profRecord (length vl) record)
       end
   | SELECT(i,v,w,t,ce) =>
       let val (env',header) = fixAccess([v],env)
        in header(SELECT(i,v,w,t,close(ce,augV(w,t,dict),dirM(w,env'))))
       end
   | OFFSET(i,v,w,ce) => error "OFFSET in cps/closure.sml!"
   | BRANCH(i,args,c,e1,e2) =>
       let val (env',header) = fixAccess(args,env)
        in header(BRANCH(i,args,c,close(e1,dict,env'),close(e2,dict,env')))
       end
   | SETTER(i,args,ce) =>
       let val (env',header) = fixAccess(args,env)
        in header(SETTER(i,args,close(ce,dict,env')))
       end
   | LOOKER(i,args,w,t,ce) =>
       let val (env',header) = fixAccess(args,env)
        in header(LOOKER(i,args,w,t,close(ce,augV(w,t,dict),dirM(w,env'))))
       end
   | ARITH(i,args,w,t,ce) =>
       let val (env',header) = fixAccess(args,env)
        in header(ARITH(i,args,w,t,close(ce,augV(w,t,dict),dirM(w,env'))))
       end
   | PURE(i,args,w,t,ce) =>
       let val (env',header) = fixAccess(args,env)
        in header(PURE(i,args,w,t,close(ce,augV(w,t,dict),dirM(w,env'))))
       end

   fun fprint (function, s : string) =
     (if (!CGoptions.comment)
      then (say "\n"; say s; say "\n \n"; PPCps.printcps0 function)
      else ())

   fun fprint1 (function, s : string) = 
     (say "\n"; say s; say "\n \n"; PPCps.printcps0 function)

(*** The closeCPS consists of four phases ***)
fun closeCPS(fe as (fk,f,vl,cl,ce)) = 
  let (*** free variable analysis and branch prediction ***)
      val (fe1,freevars) = FreeClose.freemapClose(fe)

      val _ = fprint(fe1,"After FreeClose:")

      (*** bottom-up collecting layout information ***)
      val fe2 = Collect0.collectCPS(fe1,freevars)

      val _ = fprint(fe2,"After Collect:")

      (*** back-patching the access information ***)
      val fe3 = closefix(fe2,emptyDict())

      val _ = fprint(fe3,"After CloseCPS:")

      (*** unrebinding (or alpha conversion) ***)
      val fe4 = UnRebind.unrebind(fe3)

      val _ = fprint(fe4,"After UnRebind:")
(*
      val carg1 = GlobalFix.globalfix(fe1)
      val carg2 = GlobalFix.globalfix(fe2)
      val carg3 = GlobalFix.globalfix(fe3)
      val carg4 = GlobalFix.globalfix(fe4)

      exception BadFPRecord
      exception BadGPRecord

      fun valid (fk,v,args,cl,ce) = 
       let val fltset = Intset.new()
           val isflt = Intset.mem fltset
           fun addf(v,FLTt) = Intset.add fltset v
             | addf(v,_) = ()
           val _ = List2.app2 addf (args,cl)

           fun fprecord(VAR v,OFFp 0) = isflt v
             | fprecord(VAR v,OFFp i) = false
             | fprecord(VAR v,SELp(_,OFFp 0)) = true
             | fprecord(VAR v,_) = false
             | fprecord(REAL _, _) = true
             | fprecord _ = false

           fun gprecord(VAR v,OFFp 0) = not(isflt v)
             | gprecord(REAL _, _) = false
             | gprecord _ = true

           fun fpc (a::r) = if fprecord a then fpc r
                            else raise BadFPRecord
             | fpc [] = ()

           fun gpc (a::r) = if gprecord a then gpc r
                            else raise BadGPRecord
             | gpc [] = ()

           fun h cexp = case cexp 
             of APP(v,args) => []
              | SWITCH(v,c,l) => fold (op @) (map h l) []
              | RECORD(RK_FBLOCK,l,w,e) => (fpc l; w::(h e))
              | RECORD(_,l,w,e) => (gpc l; w::(h e))
              | SELECT(_,v,w,t,e) => (addf(w,t); w::(h e))
              | OFFSET(_,v,w,e) => w::(h e)
              | SETTER(_,vl,e) => (h e)
              | LOOKER(_,vl,w,t,e) => (addf(w,t); w::(h e))
              | ARITH(_,vl,w,t,e) => (addf(w,t); w::(h e))
              | PURE(_,vl,w,t,e) => (addf(w,t); w::(h e))
              | BRANCH(_,vl,c,e1,e2) => ((h e1)@(h e2))
              | FIX(fl,e) => error "bugs in valid in closecps"
           val l = (h ce)@args
        in (length l) = (length (uniq l))
       end
                      
      fun check(fe1,fe2,fe3,fe4 as (fk,v,args,cl,ce)) = 
        let val free = FreeMap.freevars(ce)
            val unbound = difference(free,uniq(args))

            fun ilist(v::r) = (say (lvarName v); say " "; ilist r)
              | ilist [] = say "\n"

            val (normal,sss) = 
              (if valid fe4 then 
                (unbound=[],"*** Unbound free variables after closecps : \n")
               else (false,"*** duplicated variable bindings in closecps :\n"))
              handle _ => (false, "*** BadFPorGPrecord in closecps : \n")

         in if normal then ()
            else (say sss;
                  ilist unbound;
                  fprint1(fe1,"*** in the FreeClose CPS expressions: \n");
                  fprint1(fe2,"*** in the Collect CPS expressions: \n");
                  fprint1(fe3,"*** in the CloseCPS CPS expressions: \n");
                  fprint1(fe4,"*** in the UnRebind CPS expressions: \n");
                  say "******************************************* \n")
        end


      fun merge4(a1::r1,a2::r2,a3::r3,a4::r4) = 
            (a1,a2,a3,a4)::(merge4(r1,r2,r3,r4))
        | merge4([],[],[],[]) = []
        | merge4 _ = error "different number of FIXes in closecps.sml"

      val _ = app check (merge4(carg1,carg2,carg3,carg4))
*)
   in fe4
  end (* function closeCPS *)

end (* local *)
end (* functor Closure *)
