(* Copyright 1989 by AT&T Bell Laboratories *)
(* collect.sml *)

(****************************************************************************
 *                                                                          *
 *  ASSUMPTIONS: (1) Four possible combinations of bindings in the same     *
 *                   FIX : known,escape,cont,known+escape;                  *
 *                                                                          *
 *               (2) Continuation function is never recursive; there is     *
 *                   at most ONE continuation function definition per FIX;  *
 *                                                                          *
 *               (3) The outermost function is always a non-recursive       *
 *                   escaping funciton.                                     *
 *                                                                          *
 ****************************************************************************)

signature COLLECT =
  sig
    val collectCPS : (CPS.function * (CPS.lvar -> 
                               (CPS.lvar list * CPS.lvar list option)))
                     -> CPS.function 
  end

functor Collect(MachSpec: MACH_SPEC) : COLLECT = struct

local 
  structure ClosureUtil = ClosUtil(MachSpec)
  open CPS Access SortedList ClosureUtil
in 

(*** miscellaneous and utility functions ***)

structure CGoptions = Control.CG
val error = ErrorMsg.impossible
val OFFp0 = OFFp 0

fun gplength l = 
 let fun h(FLTt::z,k) = h(z,k)
       | h(_::z,k) = h(z,k+1)
       | h([],k) = k
  in h(l,0)
 end

fun fplength l = 
 let fun h(FLTt::z,k) = h(z,k+1)
       | h(_::z,k) = h(z,k)
       | h([],k) = k
  in h(l,0)
 end

fun partition f l = 
  fold (fn (e,(a,b)) => if f e then (e::a,b) else (a,e::b)) l (nil,nil)

fun clean vl = 
  let fun vars((VAR v)::r,l) = vars(r,enter(v,l))
        | vars(_::r,l) = vars(r,l)
        | vars([],l) = l
   in vars(vl,[])
  end

fun numgp l = 
  let fun h (FLTt::z) = h(z)
        | h (CNTt::z) = 1 + !numCSgp + h(z)
        | h (_::z) = 1 + h(z)
        | h [] = 0
   in h l
  end

(*** count the number of fp registers needed for a list of lvars ***)
fun numfp l = 
  let fun h (FLTt::z) = 1 + h(z)
        | h (CNTt::z) = !numCSfp + h(z)
        | h (_::z) = h(z)
        | h [] = 0
   in h l
  end

val maxfpregs = MachSpec.numFloatRegs - 2  (* need 1 or 2 temps *)

fun exceedLimit(cl) =
  if MachSpec.unboxedFloats then 
      (((numgp cl) >= MachSpec.numRegs) orelse ((numfp cl) >= maxfpregs))
  else (length cl >= (MachSpec.numRegs - MachSpec.numCalleeSaves))


fun some (SOME _) = true | some _ = false

fun partBindings fl = 
  let fun h((fe as (ESCAPE,_,_,_,_))::r,el,kl,rl,cl) = h(r,fe::el,kl,rl,cl)
        | h((fe as (KNOWN,_,_,_,_))::r,el,kl,rl,cl) = h(r,el,fe::kl,rl,cl)
        | h((fe as (KNOWN_REC,_,_,_,_))::r,el,kl,rl,cl) = h(r,el,kl,fe::rl,cl)
        | h((fe as (CONT,_,_,_,_))::r,el,kl,rl,cl) = h(r,el,kl,rl,fe::cl)
        | h(_::r,el,kl,rl,cl) = error "partBindings in closure phase 231"
        | h([],el,kl,rl,cl) = (el,kl,rl,cl)
   in h(fl,[],[],[],[])
  end

fun partVariables fl = 
  let fun h((ESCAPE,f,_,_,_)::r,el,kl,rl,cl) = h(r,enter(f,el),kl,rl,cl)
        | h((KNOWN,f,_,_,_)::r,el,kl,rl,cl) = h(r,el,enter(f,kl),rl,cl)
        | h((KNOWN_REC,f,_,_,_)::r,el,kl,rl,cl) = h(r,el,kl,enter(f,rl),cl)
        | h((CONT,f,_,_,_)::r,el,kl,rl,cl) = h(r,el,kl,rl,enter(f,cl))
        | h(_::r,el,kl,rl,cl) = error "partVariables in closure phase 231"
        | h([],el,kl,rl,cl) = (el,kl,rl,cl)
   in h(fl,[],[],[],[])
  end

(****************************************************************************
 *                                                                          * 
 *                   collectCPS: MAIN FUNCTION                              *
 *                                                                          *
 ****************************************************************************)

fun collectCPS((fk,v,args,cl,ce),freevars) = let 

(****************************************************************************
 *  collectfix : collecting information from functions in the FIX.          *
 ****************************************************************************)

fun collectfix(parentKind, depth, initEnv, bindings) = 
let

(*** temporarily recB and recV are always made to be empty ***)
val (escapeB,knownB,recB,calleeB) = partBindings(bindings)
val (escapeV,knownV,recV,calleeV) = partVariables(bindings)

(*** check whether the assumption No. 2 is valid ***)
val fixKind = 
  case (escapeB,knownB,calleeB,recB) 
   of ([],_,[],_) => KNOWN
    | ([],[],[v],[]) => CONT
    | (_,_,[],_) => ESCAPE
    | _ => error "Assumption No.2 is violated in closure phase"

val recFlag = 
  case (recB,knownB) 
   of ([],_) => false
    | (_::_,[]) => true
    | _ => error "Assumption No.4 is violated in closure phase"

(****************************************************************************
 *      Processing known, escaping and continuation function bindings.      *
 ****************************************************************************)

(*** make a new label for each non-recursive known function of the FIX ***)
val knownB =
  map (fn (_,v,args,cl,body) => 
        let val free = #1 (freevars v)
            val callc = length(free) <> length(difference(free,escapeV))
            val _ = inc CGoptions.knownGen 
         in {v=v,label=dupLvar(v),args=args,cl=cl,body=body,callc=callc}
        end) knownB

(*** make a new label for each escaping function of the FIX ***)
val escapeB =
  map (fn (_,v,args,cl,body) => 
         (inc CGoptions.escapeGen; 
          {v=v,label=dupLvar(v),args=args,cl=cl,body=body})) escapeB

(*** make a new label for each continuation function of the FIX ***)
val calleeB =
  map (fn (_,v,args,cl,body) => 
         (inc CGoptions.calleeGen;
          {v=v,label=dupLvar(v),args=args,cl=cl,body=body})) calleeB

(****************************************************************************
 *      Processing recursive known function bindings.                       *
 ****************************************************************************)

(*** Get the call graph of all known functions in this FIX. ***)
val recB = 
  map (fn (fe as (_,v,_,_,_)) =>
        let val (free,lpv) = freevars v
            val (free,spill) =  
                  (case lpv of NONE => (free,NONE)
                             | SOME l => (l,SOME(difference(free,l))))
            val (fns,other) = partition (member recV) free
         in ({v=v,fe=fe,other=other,spill=spill},length fns,fns)
        end) recB

(*** Compute the closure of the call graph of the known functions. ***)
val recB = 
  let fun closeCallGraph g =
        let fun getNeighbors l =
              fold (fn (({v,fe,other,spill},_,nbrs),n) =>
	              if member l v then merge(nbrs,n) else n) g l
            fun traverse ((x,len,nbrs),(l,change)) =
	       let val nbrs' = getNeighbors nbrs
	           val len' = length nbrs'
    	        in ((x,len',nbrs')::l,change orelse len<>len')
	       end
            val (g',change) = fold traverse g (nil,false)
         in if change then closeCallGraph g' else g'
        end
   in closeCallGraph recB
  end

(*** Compute the closure of the set of free variables ***)
val recB = 
  let fun gatherNbrs l init = 
        fold (fn (({v,other,...},_,nbrs),free) =>
                if member l v then merge(other,free) else free) recB init

   in map (fn ({v,fe=(_,_,args,cl,body),other,spill},_,fns) =>
                   {v=v,args=args,body=body,cl=cl,spill=spill,
                    other=gatherNbrs fns other,fns=fns}) recB
  end

(*** Collects all spilled variables from the recursive known functions ***)
val collected = foldmerge (map ((fn NONE => [] | SOME l => l) o #spill) recB)

(*** See which recursive known function requires a closure, pass 1. ***)
val recB = fold
 (fn ((x as {v,args,cl,other,fns,spill,...}),k) =>
   let val free = difference(other,escapeV)

       (*** callc means that all its free variables need to be spilled ***)
       val callc = ((length other) <> (length free))  (* calls an escape fun *)
                   orelse (member collected v)        (* forced-spill *)

       (*** needc means that it needs to use the closure in its body ***)
       val needc = some spill

       val free = freeAnalysis(free,initEnv)
       val len = length free

       (*** If the function has too many extra arguments to fit into
            registers, then we must put them in the closure. ***)
       val allty = cl @ (map (get_cty initEnv) free)
       val callc = callc orelse (exceedLimit(allty) andalso len > 1)

    in ((x,free,callc,needc)::k)
   end) recB nil

(*** Make the collected free variables more concrete. ***)
val collected = difference(difference(collected,recV),escapeV)
val collected = freeAnalysis(collected,initEnv)

(*** See which known function requires a closure, pass 2. ***)
val (recB,collected) = 
  let fun checkNbrs1 l init =
       fold (fn (({v,...},_,callc,_),c) => 
                  c orelse (callc andalso (member l v))) recB init

      fun checkNbrs2 l init =
       fold (fn (({v,...},_,_,needc),c) => 
                  c orelse (needc andalso (member l v))) recB init

    
      fun g(({v,args,body,cl,fns,other,spill},free,callc,needc),(l,c)) =
        let val callc = checkNbrs1 fns callc
            val needc = checkNbrs2 fns needc
            val (free,collects) = 
              case (callc,spill,collected) 
               of (true,_,_) => (nil,merge(free,c))
                | (_,SOME (_::_),_::_) => (free,merge(free,c))
                | _ => (free,c)

            val _ = if callc then (inc CGoptions.knownClGen)
                    else (inc CGoptions.knownGen)

         in ({v=v,label=dupLvar(v),args=args,cl=cl,body=body,
              free=free,callc=callc orelse needc}::l,collects)
        end
   in fold g recB (nil,collected)
  end

(*** Modify the set of free variables for known functions accordingly ***)
val (recB,closureInfo,closVL) = 
  case (recB,escapeB)
   of ([],[]) => ([],NONE,[])
    | _ => 
       (let val cname = closureLvar()
         in (map (fn {v,label,args,cl,body,free,callc} =>
                    let val newfree = enter(cname,free)
                          (* if callc then enter(cname,free) else free *)
                     in {v=v,label=label,args=args,cl=cl,body=body,
                         callc=callc,free=newfree}
                    end) recB, 
            SOME cname, [cname])
        end)

(*** add the closure into the environment if there is one ***)
(*** it will be updated, i.e. add more freevars later ***)
val baseClu = mkClu(depth+1,collected,initEnv)   (*** ? ***)
val baseEnv = 
  case closureInfo 
   of NONE => initEnv
    | SOME cname => augSpill(cname,collected,depth+1,initEnv)
      (*** considered as a depth+1 object to avoid cycles in closures ***)

(****************************************************************************
 *      Final construction of the environment for each function             *
 ****************************************************************************)

(*** add escapeB and recB to the env because they may be recursive ***)   
val baseEnv = 
  case closureInfo 
   of NONE => baseEnv
    | SOME cname =>
        (fold (fn ({v,label,...},env) => 
                      augEscapeF(v,label,cname,env)) escapeB baseEnv)

val baseEnv = 
  fold (fn ({v,label,free,...},env) =>
          augKnownF(v,label,free,env)) recB baseEnv

(*** Final construction of the environment for each recB function. ***)
val recFrags = 
  map (fn {v,label,args,cl,body,free,...} =>
        let val (args',cl',env) = augArgs(args,cl,depth+1,baseEnv) 
            val (nce,info) = collect(KNOWN,depth+1,env,body)
            val newinfo = delvInfo(closVL,delvInfo(uniq args',info))
            val (clu,_) = formatInfo(depth+1,env,newinfo)
            val clu = delClu(clu,free)           
            val nargs = args'@free 
            val ncl = cl'@(map (get_cty env) free)
         in ((KNOWN,label,nargs,ncl,nce),clu,v,label,free)
        end) recB


(*** Final construction of the environment for each known function. ***)
val (knownB,needCL) =
  fold (fn ({v,label,args,cl,body,callc},(l,base)) =>
         let val (args',cl',env) = augArgs(args,cl,depth+1,baseEnv) 
             val (nce,info) = collect(KNOWN,depth+1,env,body)
             val newinfo = delvInfo(closVL,delvInfo(uniq args',info))
             val (clu,_) = formatInfo(depth+1,env,newinfo)
             val free = 
               let val gpmax = max(maxgpfree - gplength(cl') - 2,0)
                   val fpmax = max(maxfpfree - fplength(cl') - 1,0)
                in knownregs(clu,env,gpmax,fpmax)
               end
             val clu = delClu(clu,free)   
             val callc = callc orelse (not (isEmpClu clu))
          in ({v=v,label=label,args=args',cl=cl',body=nce,free=free,
               callc=callc,clu=clu} :: l, base orelse callc)
         end) knownB ([],false)

val (baseEnv,closureInfo,closVL) = 
  case closureInfo 
   of SOME _ => (baseEnv,closureInfo,closVL)
    | _ => if needCL then (let val cname = closureLvar()
                            in (augSpill(cname,collected,depth+1,baseEnv),
                                SOME cname, [cname])
                           end)
           else (baseEnv,closureInfo,closVL)

val knownFrags = 
  map (fn {v,label,args,cl,body,callc,free,clu} =>
        let val free = if callc then merge(closVL,free) else free
            val nargs = args @ free 
            val ncl = cl @ (map (get_cty baseEnv) free)
         in ((KNOWN,label,nargs,ncl,body),clu,v,label,free)
        end) knownB

(*** Update the known function information ***)
val baseEnv = 
  fold (fn ((frag,clu,v,label,free),env) =>
          augKnownF(v,label,free,env)) knownFrags baseEnv

(*** Final construction of the environment for each escape function. ***)
val escapeFrags =
  map (fn {v,label,args,cl,body} =>
        let val (args',cl',env) = augArgs(args,cl,depth+1,baseEnv)
            val (nce,info) = collect(ESCAPE,depth+1,env,body)
            val nargs = mkLvar()::v::args'
            val ncl = PTRt::PTRt::cl'
            val newinfo = delvInfo(closVL,delvInfo(uniq args',info))
            val (clu,_) = formatInfo(depth+1,env,newinfo)
         in ((ESCAPE,label,nargs,ncl,nce),clu)
	end) escapeB

(*** Final construction of the environment for each continuation. ***)
val calleeFrags =
  map (fn {v,label,args,cl,body} =>
        let val (args',cl',env) = augArgs(args,cl,depth+1,baseEnv)
            val (nce,info) = collect(CONT,depth+1,env,body)
            val newinfo = delvInfo(closVL,delvInfo(uniq args',info))
            val (clu,tgt) = formatInfo(depth+1,env,newinfo)
            val csregs = tgtCSregs tgt

            val nargs = mkLvar()::args'
            val ncl = PTRt::cl'
         in ((nargs,ncl,nce),clu,v,label,csregs)
        end) calleeB

(*** From now on knownB and recB are treated in the same way ***)
val knownFrags = knownFrags@recFrags
val knownFree = difference(foldmerge (map #5 knownFrags),closVL)

(*** Gather all the free variables in the calleesave registers ***)
val calleeFree = foldmerge (map (freeCSvars o #5) calleeFrags)

(*** Collecting all the free variable information together ***)
val allCluList = (map #2 escapeFrags)@(map #2 calleeFrags)@(map #2 knownFrags)
val totalClu = fold mergeClu allCluList baseClu

(*** Decide the final closure representation ***)
val (contents,hdr,kill,newfree,newEnv) 
       = mkClosure(depth,baseEnv,totalClu,calleeFree,knownFree,fixKind)

(*** Build the big closure for the current FIX ***)
val (closureInfo,header,kill,newfree,newEnv) = 
  case closureInfo 
   of NONE => 
         (case contents
           of [] => (NONE,hdr,kill,newfree,newEnv)
            | [v] => (SOME v,hdr,kill,newfree,newEnv)
            | _ => (let val cname = closureLvar()
                        val ul = map (fn x => (VAR x,OFFp0)) contents
                        val nfree = merge(contents,newfree)
                        val hdr' = (fn ce => hdr(RECORD(RK_CONT,ul,cname,ce)))
                        val env = augValue(cname,PTRt,depth,newEnv)
                     in (SOME cname,hdr',enter(cname,kill),nfree,env)
                    end))
    | (u as SOME cname) => 
         (let val nfree = merge(contents,newfree)
              val l = (map (LABEL o #label) escapeB)@(map VAR contents) 
              val ul = map (fn x => (x,OFFp0)) l
              val rk = case fixKind of ESCAPE => RK_CLOSURE(map #v escapeB)
                                     | CONT => RK_CONT
                                     | _ => RK_KNOWN

              val hdr' = (fn ce => hdr(RECORD(rk,ul,cname,ce)))
              val env = augValue(cname,PTRt,depth,newEnv)
           in (u,hdr',enter(cname,kill),nfree,env)
          end)

(*** fix the calleesaveregs info in the calleeFrags ***)
val calleeFrags = 
  map (fn ((args,cl,ce),_,v,label,csregs) =>
        let val ncsregs = fixhdCSregs(closureInfo,csregs)
            val (args',cl') = mkCSformal(ncsregs,newEnv)
            val nargs = (hd args)::(args'@(tl args))
            val ncl = (hd cl)::(cl'@(tl cl))
         in ((CONT,label,nargs,ncl,ce),v,label,ncsregs)
        end) calleeFrags

(*** The environment that will be propagated to the body of this FIX ***)
val newEnv = 
  fold (fn ((_,v,label,calleeregs),env) =>
          augCalleeF(v,label,calleeregs,env)) calleeFrags newEnv

(*** the result function bindings after all above transformations ***)
val newbinds = (map #1 knownFrags)@(map #1 escapeFrags)@(map #1 calleeFrags)

in  (newbinds,newEnv,newfree,kill,header)

end (* function collectfix *)


(****************************************************************************
 *  collect : collecting information from CPS expressions.                  *
 ****************************************************************************)

and collect(k,d,env,ce) =
 (case ce
   of FIX(fl,body) =>
       let val (nfl,nenv,new,kill,header) = collectfix(k,d,env,fl)
           val (nb,info) = collect(k,d,nenv,body)
	in (header(FIX(nfl,nb)), addvInfo(new,delvInfo(kill,info)))
       end
    | APP(f,args) => 
       let val (nce,free,cregs) = procAPP(f,args,env)    
           val info = initInfo(k,d,cregs,free)
        in (nce,info)
       end
    | SWITCH(v,c,el) => 
       let val l = map (fn e => collect(k,d,env,e)) el
           val nel = map #1 l 
           val info = foldInfo(map #2 l)
        in (SWITCH(v,c,nel), addvInfo(clean [v],info))
       end
    | RECORD(rk,l,v,ce) =>
       let val env = augValue(v,PTRt,d,env)
           val (nce,info) = collect(k,d,env,ce)
        in (RECORD(rk,l,v,nce),addvInfo(clean(map #1 l),delvInfo([v],info)))
       end
    | OFFSET(i,v,w,ce) => error "OFFSET in cps/closure.sml!"
    | SELECT(i,v,w,t,ce) =>
       let val env = augValue(w,t,d,env)
           val (nce,info) = collect(k,d,env,ce)
	in (SELECT(i,v,w,t,nce),addvInfo(clean [v],delvInfo([w],info)))
       end
    | BRANCH(i,args,c,ce1,ce2) => 
       let val (nce1,info1) = collect(k,d,env,ce1)
           val (nce2,info2) = collect(k,d,env,ce2)
           val info = addvInfo(clean args,mergeInfo(info1,info2))
        in (BRANCH(i,args,c,nce1,nce2),info)
       end
    | SETTER(i,args,ce) =>
       let val (nce,info) = collect(k,d,env,ce)
        in (SETTER(i,args,nce),addvInfo(clean args,info))
       end
    | LOOKER(i,args,w,t,ce) =>
       let val env = augValue(w,t,d,env)
           val (nce,info) = collect(k,d,env,ce)
	in (LOOKER(i,args,w,t,nce),addvInfo(clean args,delvInfo([w],info)))
       end
    | ARITH(i,args,w,t,ce) =>
       let val env = augValue(w,t,d,env)
           val (nce,info) = collect(k,d,env,ce)
	in (ARITH(i,args,w,t,nce),addvInfo(clean args,delvInfo([w],info)))
       end
    | PURE(i,args,w,t,ce) =>
       let val env = augValue(w,t,d,env)
           val (nce,info) = collect(k,d,env,ce)
	in (PURE(i,args,w,t,nce),addvInfo(clean args,delvInfo([w],info)))
       end)

val label = dupLvar(v)
val (args',cl',env) = augArgs(args,cl,0,emptyEnv())
val nargs = label::v::args'
val ncl = PTRt::PTRt::cl'
val (nce,_) = collect(fk,0,env,ce)

in  
   (fk,label,nargs,ncl,nce)

end (* function collectCPS *)

end (* local *)

end (* functor Collect *)

