open Sigs

type definition = Util.definition = {
  label: CFG.node;
  instr: int; (* Instruction # in its block, 0-indexed including label  *)
  assign_var: Temp.temp;
  assign_exp: Ir.exp;
}

let definition name i stmt =
    match stmt with
    | Ir.MOVE (Ir.TEMP t, e) ->
        Some { label = name;
               instr = i;
               assign_var = t;
               assign_exp = e }
    | _ -> None

let string_of_definition def =
    "(" ^ CFG.string_of_node def.label ^ "," ^ string_of_int def.instr ^ ")"
    ^ "{{" ^ Ir.ir_pstmt2str (Ir.MOVE (Ir.TEMP def.assign_var, def.assign_exp))
    ^ "}}"

module DefinitionSet = Util.DefinitionSet

module TempMap = Util.TempMap

module InstrMap = Util.InstrMap

module Analysis : ANALYSIS with type info = DefinitionSet.t
                            and type output = DefinitionSet.t InstrMap.t
                                            * DefinitionSet.t InstrMap.t =
  struct
    type info = DefinitionSet.t
    type output = info InstrMap.t * info InstrMap.t

    type cfg = CFG.t
    type node = CFG.node

    module MakeFlow (X : sig val cfg : CFG.t end) =
      struct
	let count_definitions cfg =
	  CFG.bblock_fold
	    (fun ~name ~targets bblock n ->
	       (List.fold_right
		  (fun instr n ->
		     match instr with
		     | Ir.MOVE(Ir.TEMP _, _) -> n+1
		     | _ -> n)
		  bblock
		  n))
	    0
	    X.cfg

	let n = count_definitions X.cfg

	module Lattice = Bitset.BV (struct let n = n end)

	let combine = Lattice.join
	let combine_list = Lattice.join_list

	let prev = CFG.pred X.cfg
	let next = CFG.succ X.cfg

	exception Node of node

	let entry =
	  try
	    List.hd (CFG.bblock_fold
		       (fun ~name ~targets _  -> raise (Node name))
		       []
		       X.cfg)
	  with
	  | Node name -> name


	let defs =
	  List.rev (* bblock_fold visits nodes in forward order, so
		      defs_list is built up backwards *)
	    (CFG.bblock_fold
	       (fun ~name ~targets block defs_list ->
		  let (_, block_defs) =
		    List.fold_right
		      (fun instr (i, defs) ->
			 match instr with
			 | Ir.MOVE(Ir.TEMP t, e) ->
			     (i + 1,
			        { label = name;
				  assign_var = t;
				  assign_exp = e;
				  instr = i; }
			        :: defs)
			 | _ -> (i+1, defs))
		      block
		      (0, [])
		  in
		    (name, block_defs) :: defs_list)
	       []
	       X.cfg)

	let flat_defs = List.flatten (List.map snd defs)

        (* maps temps to the set of definitions that define it *)
	let id_map (* temp => def-set *) =
	  List.fold_right
	    (fun def map ->
	       let v = def.assign_var in
		 try
		   let set = TempMap.find v map in
		     TempMap.add v (DefinitionSet.add def set) map
		 with
		   Not_found ->
		     TempMap.add v (DefinitionSet.singleton def) map)
	    flat_defs
	    TempMap.empty

	let gen_set def_list =
	  let map =
	    (* The fold left is important; it ensures later definitions
	       screen out earlier ones *)
	    List.fold_left 
	      (fun map def ->
		 TempMap.add def.assign_var def map)
	      TempMap.empty
	      def_list
	  in
	    TempMap.fold
	      (fun k def set -> DefinitionSet.add def set)
	      map
	      DefinitionSet.empty

	let kill_set_of_def def =
	  let v = def.assign_var in
	    DefinitionSet.remove def (TempMap.find v id_map)

	let kill_set def_list =
	  let all_kills = 
	    List.fold_right
	      DefinitionSet.union
	      (List.map kill_set_of_def def_list)
	      DefinitionSet.empty
	  in
	    DefinitionSet.diff all_kills (gen_set def_list)

	let context = Array.of_list (flat_defs)

	let defset_to_bv context defset =
	  Bitset.make n (fun i -> DefinitionSet.mem context.(i) defset)

	let gen_map =
	  List.fold_right
	    (fun (name, def_list) map ->
	       (CFG.NodeMap.add
		  name
		  (defset_to_bv context (gen_set def_list))
		  map))
	    defs
	    CFG.NodeMap.empty

	let kill_map = 
	  List.fold_right
	    (fun (name, def_list) map ->
	       (CFG.NodeMap.add
		  name
		  (defset_to_bv context (kill_set def_list))
		  map))
	    defs
	    CFG.NodeMap.empty

	let gen node = CFG.NodeMap.find node gen_map

	let kill node = CFG.NodeMap.find node kill_map

	let initial = Lattice.bot

	let interpret lattice =
	  Bitset.fold
	    (fun i present set ->
	       if present then
		 DefinitionSet.add context.(i) set
	       else
		 set)
	    DefinitionSet.empty
	    lattice

        (*** newish code -- should probably rewrite above in terms of these
         * eventually.  for now, it's just to interpret' into a sensible
         * output type -wjl ***)

        let igen (node, i) stmt =
            match definition node i stmt with
            | Some def -> DefinitionSet.singleton def
            | None -> DefinitionSet.empty

        let (@@) tmap t = TempMap.find t tmap

        let (++) dset1 dset2 = DefinitionSet.union dset1 dset2
        (* NB: -- is remove, not diff like it usually is -wjl *)
        let (--) dset d = DefinitionSet.remove d dset
        (* this should be a \\, but that's illegal -wjl *)
        let (//) dset1 dset2 = DefinitionSet.diff dset1 dset2

        let ikill (node, i) stmt =
            match definition node i stmt with
            | Some def -> (id_map @@ def.assign_var) -- def
            | None -> DefinitionSet.empty

        let (%%) imap (i, v) = InstrMap.add i v imap
        let (<=) x y = (x, y)

        (* dear lord,
            forgive me for not writing this abstractly once.
            see live_vars.ml.
                -wjl *)
        let interpret_node name block inflow outflow instr_inflow instr_outflow=
            let named_instrs =
                Util.mapi (fun n instr -> ((name, n), instr)) block in
            let rec loop instrs inflow outflow instr_inflow instr_outflow =
                match instrs with
                | [] -> (instr_inflow, instr_outflow)
                | (id, instr) :: instrs ->
                    let inflow' = outflow in
                    let outflow' = igen id instr ++ (inflow' // ikill id instr) in
                    let instr_inflow' = instr_inflow %% (id <= inflow') in
                    let instr_outflow' = instr_outflow %% (id <= outflow') in
                    loop instrs inflow' outflow' instr_inflow' instr_outflow'
            in
            loop named_instrs inflow outflow instr_inflow instr_outflow

        (* look out!  verbatim from live_vars.ml!  -wjl *)
        let interpret' (in_node_map, out_node_map) =
            CFG.bblock_fold
                (fun ~name ~targets block (in_instr_map, out_instr_map) ->
                    let inflow = CFG.NodeMap.find name in_node_map in
                    let outflow = CFG.NodeMap.find name out_node_map in
                    interpret_node name block inflow outflow
                                   in_instr_map out_instr_map)
                (InstrMap.empty, InstrMap.empty)
                X.cfg
      end
  end

module DFA = Dataflow.IDA (Analysis)
