module Dset = Util.DefinitionSet
(* Loop hoisting *)

let block def = def.Reaching.label
let instr_num def = def.Reaching.instr
let assign_var def = def.Reaching.assign_var
let assign_exp def = def.Reaching.assign_exp

let id def = (block def, instr_num def)

let (@@) map temp =
  try
    Util.TempMap.find temp map
  with
    Not_found -> Dset.empty

let (++) set elt = Dset.add elt set

let rec is_invariant' cfg dom rdefs loop vdefs def =
  if Dset.mem def vdefs || Temp.reserved (assign_var def) then
    false
  else
    let rdmap =
      Util.temp_map_of_def_set (Util.InstrMap.find (id def) rdefs) in
    let rdefs_outside_loop t = 
      not (Dset.exists
	     (fun def -> CFG.NodeSet.mem (block def) loop)
	     (rdmap @@ t))
    in
      match assign_exp def with
      | Ir.BINOP(Ir.DIV, _, _) 
      | Ir.BINOP(Ir.MOD, _, _) -> false
      | Ir.CONST(_, _)
      | Ir.BINOP(_, Ir.CONST(_, _), Ir.CONST(_, _)) -> true
      | Ir.TEMP t
      | Ir.BINOP(_, Ir.TEMP t, Ir.CONST(_, _))
      | Ir.BINOP(_, Ir.CONST(_, _), Ir.TEMP t) ->
	  if rdefs_outside_loop t then
	    true
	  else
	    let rdefs' = rdmap @@ t in
	      (Dset.cardinal rdefs' = 1) &&
	      (is_invariant'
		 cfg dom rdefs loop (vdefs ++ def) (Dset.choose rdefs'))
      | Ir.BINOP(_, Ir.TEMP t, Ir.TEMP t') -> 
	  if (rdefs_outside_loop t) && (rdefs_outside_loop t') then
	    true
	  else
	    let rdefs1 = rdmap @@ t in
	      if (Dset.cardinal rdefs1 = 1) &&
		(is_invariant'
		   cfg dom rdefs loop (vdefs ++ def) (Dset.choose rdefs1))
	      then
		let rdefs2 = rdmap @@ t' in
		  (Dset.cardinal rdefs2 = 1) &&
		  (is_invariant'
		     cfg dom rdefs loop (vdefs ++ def) (Dset.choose rdefs2))
	      else
		false
      | _ -> false

let is_invariant cfg dom rdefs loop def =
  is_invariant' cfg dom rdefs loop Dset.empty def 

let node_set list = List.fold_right CFG.NodeSet.add list CFG.NodeSet.empty

let loop_exits cfg loop =
  let (++) = CFG.NodeSet.union in
  let (--) = CFG.NodeSet.diff in
  let succ n = node_set (CFG.succ cfg n) in
    CFG.NodeSet.fold
      (fun n set -> ((succ n) -- loop) ++ set)
      loop
      CFG.NodeSet.empty 

let last_instr cfg node =
  let block = CFG.code cfg node in
  let n = (List.length block) - 1 in
    (node, n)

let liveout_node liveout cfg node = 
  Util.InstrMap.find (last_instr cfg node) liveout

let dominates_all_loop_exits cfg dom liveout loop def =
  (* does the def's block dominate all loop exits where it is live out? *) 
  let t = assign_var def in
  let exits = loop_exits cfg loop in
  let live_exits =
    CFG.NodeSet.filter
      (fun node -> Util.TempSet.mem t (liveout_node liveout cfg node))
      exits
  in
    CFG.NodeSet.for_all
      (fun node -> Dominators.dom dom (block def) node)
      live_exits

let unique_def cfg loop t =
  (* is there is a unique definition of t in the loop? *)
  let num_defs = 
    CFG.NodeSet.fold
      (fun node count ->
	 count +
	 (List.length
	    (List.filter
	       (fun stmt ->
		  match stmt with
		  | Ir.MOVE(Ir.TEMP t', _) when t = t' -> true
		  | _ -> false)
	       (CFG.code cfg node))))
      loop
      0
  in
    num_defs = 1 

let not_live_from_preheader cfg liveout start def = 
  let preheader = node_set (CFG.pred cfg start) in
    CFG.NodeSet.for_all
      (fun node ->
	 let livevars = liveout_node liveout cfg node in
	   not (Util.TempSet.mem (assign_var def) livevars))
      preheader

let may_hoist cfg dom liveout start loop def =
  dominates_all_loop_exits cfg dom liveout loop def
  &&
  unique_def cfg loop (assign_var def)
  &&
  not_live_from_preheader cfg liveout start def

let defs cfg loop =
  CFG.NodeSet.fold
    (fun node blocks -> (CFG.code cfg node) :: blocks)

let loop_fold cfg loop f init =
  CFG.NodeSet.fold
    (fun node acc ->
       let block = CFG.code cfg node in
       let (acc, i) = 
	 List.fold_left
	   (fun (acc, i) stmt -> (f node i stmt acc, i+1))
	   (acc, 0)
	   block
       in
	 acc)
    loop
    init

let def_to_stmt def = Ir.MOVE(Ir.TEMP (assign_var def), assign_exp def)

let deflist_of_dset defs =
  Dset.fold (fun def list -> (def_to_stmt def) :: list) defs []
			

let match_def block inum stmt =
  match stmt with
  | Ir.MOVE(Ir.TEMP t, exp) ->
      Some {Reaching.label = block;
	    Reaching.instr = inum;
	    Reaching.assign_var = t;
	    Reaching.assign_exp = exp}
  | _ -> None

let collect_defs cfg nodes = 
  loop_fold cfg nodes
    (fun block inum stmt defs ->
       match match_def block inum stmt with
       | Some def -> Dset.add def defs
       | None -> defs)
    Dset.empty

let hoist_targets cfg rdefs liveout =
  let (--) = CFG.NodeSet.diff in
  let dom = Dominators.dominators cfg in
  let loops = Dominators.natural_loops dom cfg in
    List.map
      (fun ((back, start), loop) ->
	 let defs = collect_defs cfg loop in
	 let defs = Dset.filter (is_invariant cfg dom rdefs loop) defs in
	 let defs = Dset.filter (may_hoist cfg dom liveout start loop) defs in
	 let preheader = (node_set (CFG.pred cfg start)) -- loop in
	   ((back, start), loop, preheader, defs))
      loops

let print_hoist_targets out hoist_targets =
  let print fmt = Format.fprintf out fmt in
  let string = CFG.string_of_node in
    List.iter
      (fun ((back, start), loop, preheader, defs) ->
	 begin
	   print "loop:\n";
	   Dominators.print_loop out ((back, start), loop);
	   print "hoist locations:\n";
	   Dominators.print_nodeset out preheader;
	   print "\n";
	   print "hoisted definitions:\n";
	   Dset.iter (Util.print_def out) defs
	 end)
      hoist_targets
	   
let rec append_defs defs block =
  match block with
  | [] -> failwith "broken block invariant"
  | [stmt] -> defs @ [stmt]
  | b :: bs -> b :: (append_defs defs bs)

let ($) f g x = f (g x)

let cfg_update_preheaders cfg hoist_targets =
  (CFG.make $ List.rev)
    (CFG.bblock_fold
       (fun ~name ~targets block blocklist ->
	  let block = 
	    List.fold_right
	      (fun (_, _, preheader, defs) block ->
		 if CFG.NodeSet.mem name preheader then
		   append_defs (deflist_of_dset defs) block
		 else
		   block)
	      hoist_targets
	      block
	  in
	    if targets = [] then
	      blocklist
	    else
	      block :: blocklist)
       []
       cfg) 

let cfg_delete cfg hoist_targets =
  let deletable_instructions =
    List.fold_left
      (fun instr_set (_, _, _, defs) ->
	 Dset.fold
	 (fun def instr_set ->
	    Util.InstrSet.add (block def, instr_num def) instr_set)
	 defs
	 instr_set)
      Util.InstrSet.empty
      hoist_targets
  in
  (CFG.make $ List.rev ) 
    (CFG.bblock_fold
       (fun ~name ~targets block blocklist ->
	  let rec loop inum block =
	    match block with
	    | [] -> []
	    | b :: bs -> 
		if Util.InstrSet.mem (name, inum) deletable_instructions then
		  loop (inum + 1) bs
		else
		  b :: (loop (inum + 1) bs)
	  in
	    if targets = [] then
	      blocklist
	    else
	      (loop 0 block) :: blocklist)
       []
       cfg)
	  
let hoist cfg rdefs liveout =
  let targets = hoist_targets cfg rdefs liveout in
  let cfg = cfg_update_preheaders cfg targets in
  let cfg = cfg_delete cfg targets in
    cfg
