open Sigs

module TempSet = Util.TempSet
let (++) = TempSet.union
let (--) = TempSet.diff

module InstrMap = Util.InstrMap

module Analysis : ANALYSIS with type info = TempSet.t
                            and type output = TempSet.t InstrMap.t
                                            * TempSet.t InstrMap.t =
struct
  type info = TempSet.t
  type output = info InstrMap.t * info InstrMap.t

  type cfg = CFG.t
  type node = CFG.node

  module MakeFlow(X : sig val cfg : CFG.t end) =
  struct
    let context =
      let tempset =
        CFG.bblock_fold
            (fun ~name ~targets block set -> Util.block_alltemps block ++ set)
            TempSet.empty
            X.cfg
      in
        Array.of_list (TempSet.elements tempset)

    let size = Array.length context

    let freetemps = Util.freetemps

    let igen stmt = Util.stmt_freetemps stmt

    let ikill stmt =
      match stmt with
      | Ir.SEQ(_, _) -> assert false
      | Ir.MOVE (Ir.TEMP t, _) -> TempSet.singleton t
      (* | Ir.MOVE(tgt, _) -> freetemps tgt *)
      | _ -> TempSet.empty

    (* Exported interface *)

    module Lattice = Bitset.BV(struct let n = size end)

    let combine = Lattice.join
    let combine_list = Lattice.join_list

    let entry = CFG.exit X.cfg
    let prev = CFG.succ X.cfg
    let next = CFG.pred X.cfg

    let initial = Lattice.bot

    let bitvec tempset =
        Bitset.make size (fun i -> TempSet.mem context.(i) tempset)

    let gen node =
      let backward_instrs = List.rev (CFG.code X.cfg node) in
      let rec loop inset outset instrs =
	match instrs with
	| [] -> inset
	| i :: instrs ->
	    let outset = inset in
	    let inset = (outset -- (ikill i)) ++ (igen i)  in
	      loop inset outset instrs
      in
	bitvec (loop TempSet.empty TempSet.empty backward_instrs)
			  
    let kill node =
      bitvec (List.fold_left
                (fun kills stmt -> kills ++ (ikill stmt))
                TempSet.empty
                (CFG.code X.cfg node))

    let interpret lattice =
      Bitset.fold
	(fun i present set ->
	   if present then
	     TempSet.add context.(i) set
	   else
	     set)
	TempSet.empty
	lattice

    let interpret_node name block inflow outflow instr_inflow instr_outflow =
      let named_instrs = Util.mapi (fun n instr -> ((name, n), instr)) block in
      let backward_instrs = List.rev named_instrs
      in
      let rec loop backward_instrs inflow outflow instr_inflow instr_outflow =
	match backward_instrs with
	| [] -> (instr_inflow, instr_outflow)
	| (id, instr) :: backward_instrs ->
	    let outflow' = inflow in
	    let inflow' = (outflow' -- (ikill instr)) ++ (igen instr) in
	    let instr_inflow' = InstrMap.add id inflow' instr_inflow in
	    let instr_outflow' = InstrMap.add id outflow' instr_outflow in
	      loop backward_instrs inflow' outflow' instr_inflow' instr_outflow'
      in
	loop backward_instrs inflow outflow instr_inflow instr_outflow

    let interpret' (in_node_map, out_node_map) =
        CFG.bblock_fold
            (fun ~name ~targets block (in_instr_map, out_instr_map) ->
                let inflow = CFG.NodeMap.find name in_node_map in
                let outflow = CFG.NodeMap.find name out_node_map in
                interpret_node name block inflow outflow
                               in_instr_map out_instr_map)
            (InstrMap.empty, InstrMap.empty)
            X.cfg
  end
end

module DFA = Dataflow.IDA (Analysis)
