open Sigs

module TempSet = Util.TempSet

module TempMap = Map.Make(struct
			   type t = Temp.temp
			   let compare = compare
			 end)

module ExpMap = Map.Make(struct
			   type t = Ir.exp
			   let compare = compare
			 end)

type aexp = {
  label: CFG.node;
  instr: int;
  lhs: Temp.temp;
  expr: Ir.exp;
  freevars: TempSet.t;
}

module AexpSet = Set.Make(struct
			    type t = aexp
			    let compare = compare
			  end)

module AexpMap = Map.Make(struct
			    type t = aexp
			    let compare = compare
			  end)


module Analysis : ANALYSIS with type info = Temp.temp ExpMap.t
                            and type output = Temp.temp ExpMap.t CFG.NodeMap.t
                                            * Temp.temp ExpMap.t CFG.NodeMap.t =
struct
  type info = Temp.temp ExpMap.t
  type output = info CFG.NodeMap.t * info CFG.NodeMap.t

  type cfg = CFG.t
  type node = CFG.node

  module MakeFlow (X : sig val cfg : CFG.t end) =
  struct
    let freetemps = Util.freetemps
	  
    let expression label i stmt =
      match stmt with
      | Ir.MOVE(Ir.TEMP t, (Ir.BINOP(_, _, _) as e)) ->
	  Some {
	    label = label;
	    instr = i;
	    lhs = t;
	    expr = e;
	    freevars = TempSet.add t (freetemps e)
	  } 
      | _ -> None

    let update_temp_map aexp map =
      let temps = TempSet.add aexp.lhs (freetemps aexp.expr)
      in
	TempSet.fold
	  (fun temp map ->
	     let set =
	       try TempMap.find temp map
	       with Not_found -> AexpSet.empty
	     in
	       TempMap.add temp (AexpSet.add aexp set) map)
	  temps
	  map

    let (aexps, temp_map) =
      CFG.bblock_fold
	(fun ~name ~targets block (aexps, temp_map) ->
	   let rec loop i block (aexps, temp_map) =
	     match block with
	     | [] -> (aexps, temp_map)
	     | stmt :: rest ->
		 (match expression name i stmt with
		  | None ->
		      loop (i+1) rest (aexps, temp_map)
		  | Some aexp ->
		      let aexps = AexpSet.add aexp aexps in
		      let temp_map = update_temp_map aexp temp_map in
			  loop (i+1) rest (aexps, temp_map))
	   in
	     loop 0 block (aexps, temp_map))
	(AexpSet.empty, TempMap.empty)
	X.cfg

    let n = AexpSet.cardinal aexps

    let context =
      try
        let aexp = AexpSet.choose aexps in
        let a = Array.make n aexp in
        let _ = AexpSet.fold
            (fun aexp i -> (a.(i) <- aexp; i+1))
            aexps
            0
        in
          a
      with
        Not_found -> [| |]

    module Lattice = Bitset.BV(struct let n = n end)

    let combine = Lattice.meet
    let combine_list = Lattice.meet_list

    let prev = CFG.pred X.cfg
    let next = CFG.succ X.cfg

    let genset label block =
      let (--) = AexpSet.diff in
      let (++) = AexpSet.union in
      let one = AexpSet.singleton in
      let rec loop i block set =
	match block with
	| [] -> set
	| stmt :: rest ->
	    (match expression label i stmt with
	     | None -> loop (i+1) rest set
	     | Some aexp ->
		 let kill = TempMap.find aexp.lhs temp_map in
		 let gen =
                    (* filter out, e.g., x = x + 5 *)
                    if TempSet.mem aexp.lhs (freetemps aexp.expr)
                    then AexpSet.empty
                    else one aexp
                 in
		 let set = gen ++ (set -- kill) in
		   loop (i+1) rest set)
      in
	loop 0 block AexpSet.empty

    let killset label block =
      let (--) = AexpSet.diff in
      let (++) = AexpSet.union in
      let one = AexpSet.singleton in
      let rec loop i block allkills =
	match block with
	| [] -> allkills
	| stmt :: rest ->
	    (match expression label i stmt with
	     | None -> loop (i+1) rest allkills
	     | Some aexp ->
		 let kill = TempMap.find aexp.lhs temp_map in
		 let gen =
                    (* filter out, e.g., x = x + 5 *)
                    if TempSet.mem aexp.lhs (freetemps aexp.expr)
                    then AexpSet.empty
                    else one aexp
                 in
		 let set = (allkills ++ kill) -- gen in
		   loop (i+1) rest set)
      in
	loop 0 block AexpSet.empty

    let bv set =
      Bitset.make n (fun i -> AexpSet.mem context.(i) set)

    let (genmap, killmap) =
      CFG.bblock_fold
	(fun ~name ~targets block (genmap, killmap) ->
	   (CFG.NodeMap.add name (bv (genset name block)) genmap,
	    CFG.NodeMap.add name (bv (killset name block)) killmap))
	(CFG.NodeMap.empty,
	 CFG.NodeMap.empty)
	X.cfg

    let entry = List.hd (CFG.all_blocks X.cfg)

    let initial = Lattice.bot
    let gen node = CFG.NodeMap.find node genmap
    let kill node = CFG.NodeMap.find node killmap

    let interpret lat =
      Bitset.fold
	(fun i present map ->
	   if present then
	     let aexp = context.(i) in
	       ExpMap.add aexp.expr aexp.lhs map
	   else
	     map)
	ExpMap.empty
	lat

    let interpret' maps = maps
  end
end
		 
module DFA = Dataflow.IDA(Analysis)
