(* Constant Propagation *)

module Dset = Util.DefinitionSet

let newvar ~label ~var ~exp ~instr =
  { Reaching.label = label;
    Reaching.assign_var = var;
    Reaching.assign_exp = exp;
    Reaching.instr = instr;
  }

let label def = def.Reaching.label
let var def = def.Reaching.assign_var
let expression def = def.Reaching.assign_exp
let instr def = def.Reaching.instr

module TempMap = Util.TempMap

let get map v =
  try
    TempMap.find v map
  with
  | Not_found -> Dset.empty

(*
let ($$) map (k, v) =
  let set = Dset.add v (get map k) in
    TempMap.add k set map
*)

let (%%) map (k, v) =
  TempMap.add k (Dset.singleton v) map
      
let (<=) k v = (k, v)

(* only propagate if exactly one definition reaches.  could potentially
 * do a bit better, by e.g. checking to see if all the definitions that
 * reach are equivalent... -wjl *)
let (==>) temp map =
  let set = get map temp in 
    if 1 = Dset.cardinal set then
      Some (Dset.choose set)
    else
      None

(*
let gen label instr map stmt =
  match stmt with
  | Ir.MOVE(Ir.TEMP var, exp) ->
      map %% (var <= (newvar ~label ~instr ~var ~exp))
  | _ -> map 
*)

let rec cprop_exp info exp =
  match exp with
  | Ir.TEMP t ->
      (match t ==> info with
       | None -> exp
       | Some def ->
	   (match (expression def) with
	    | Ir.CONST(_, _) -> (expression def)
	    | _ -> exp))
  | Ir.BINOP(op, e1, e2) ->
      Ir.BINOP(op, cprop_exp info e1, cprop_exp info e2)
  | Ir.MEM e -> Ir.MEM (cprop_exp info e)
  | Ir.CALL(f, args) -> Ir.CALL(f, List.map (cprop_exp info) args)
  | Ir.ESEQ(_, _)
  | Ir.PHI _ -> assert false
  | _ -> exp
  
let prop_stmt info stmt =
  match stmt with
  | Ir.MOVE(e1, e2) ->
      let e2 = cprop_exp info e2 in
        Ir.MOVE(e1, e2)
  | Ir.EXP e ->
      Ir.EXP (cprop_exp info e)
  | Ir.CJUMP(relop, e1, e2, lt, lf) ->
      Ir.CJUMP(relop, cprop_exp info e1, cprop_exp info e2, lt, lf)
  | _ -> stmt

(*
    (* old -wjl *)
let propagate_block label block info =
  let rec loop i info block =
    match block with
    | [] -> []
    | stmt :: rest ->
	let stmt = prop_stmt label i info stmt in
	let info = gen label i info stmt in
	  stmt :: (loop (i+1) info rest)
  in
    loop 0 info block

let propagate node_blocks ~inflow =
  let info label = map_of_set (CFG.NodeMap.find label inflow) in
    List.map
      (fun (label, block) -> propagate_block label block (info label))
      node_blocks
		
*)

let (@@) imap i = Util.InstrMap.find i imap

let propagate_block label (block : Ir.stmt list) inflow =
  Util.mapi
    (fun i stmt ->
        let info = Util.temp_map_of_def_set (inflow @@ (label, i)) in
        prop_stmt info stmt)
    block

let propagate node_blocks ~inflow =
  (* let info label = map_of_set (CFG.NodeMap.find label inflow) in *)
    List.map
      (fun (label, block) -> propagate_block label block inflow)
      node_blocks
