module NodeSet = CFG.NodeSet
module NodeMap = CFG.NodeMap

module NodePairSet = CFG.NodePairSet
module NodePairMap = CFG.NodePairMap

type dag = NodeSet.t NodeMap.t

type backval = {
  source : NodeSet.elt;
  dest : NodeSet.elt;
  from_entry : int;
  to_exit : int
}

exception Backedge of backval



(* DAG building *)

let get dag node =
  try
    NodeMap.find node dag
  with
    Not_found -> NodeSet.empty

let (+$) dag (start, finish) =
  NodeMap.add start (NodeSet.add finish (get dag start)) dag

let (-$) dag (start, finish) =
  let set = NodeSet.remove finish (get dag start) in
  if NodeSet.is_empty set then
    NodeMap.remove start dag
  else
    NodeMap.add start set dag

let all_edges cfg =
  let nodes = CFG.nodes cfg in
  NodeSet.fold
    (fun start dag ->
       (List.fold_right
	  (fun finish dag -> dag +$ (start, finish))
	  (CFG.succ cfg start)
	  dag))
    nodes
    NodeMap.empty
	     
let remove_backedges cfg dag =
  let dominators = Dominators.dominators cfg in
  let backedges = Dominators.backedges dominators cfg in
  let dag_without_backedges = 
      let n2s = CFG.string_of_node in
      (*
      print_string "DEBUG: backedges: ";
      List.iter (fun (m, n) -> print_string ("(" ^ n2s m ^ ", " ^ n2s n ^ ") "))
                backedges;
      print_newline ();
      *) (* DEBUG *)
      List.fold_left (-$) dag backedges
  in
  let entry = CFG.entry cfg in
  let exit = CFG.exit cfg in
  let dag_with_fake_edges =
    List.fold_right
      (fun (b, a) dag -> (dag +$ (entry, a)) +$ (b, exit))
      backedges
      dag_without_backedges
  in
  dag_with_fake_edges
    

let makedag cfg =
  remove_backedges cfg (all_edges cfg)

(*
 (* new old code *)
let reachable_from dag start target =
  let rec loop node visited =
    if NodeSet.mem node visited then
      visited
    else
      let next = get dag node in
      if NodeSet.mem target next then
	failwith "Found it!"
      else
	NodeSet.fold loop next (NodeSet.add node visited)
  in
  try
    let _ = loop start NodeSet.empty in false
  with
    Failure "Found it!" -> true

let makedag cfg =
  List.fold_right
    (fun n dag ->
       (List.fold_right
	  (fun finish dag -> dag +$ (start, finish))
	  (CFG.succ cfg start)
	  dag))
    nodes
    NodeMap.empty
	     
let remove_backedges cfg dag =
  let dominators = Dominators.dominators cfg in
  let backedges = Dominators.backedges dominators cfg in
      let n2s = CFG.string_of_node in
      print_string "DEBUG: backedges: ";
      List.iter (fun (m, n) -> print_string ("(" ^ n2s m ^ ", " ^ n2s n ^ ") "))
                backedges;
      print_newline ();
  List.fold_left (-$) dag backedges

let makedag cfg =
  remove_backedges cfg (all_edges cfg)
*)

(* Compute update values *)

let edges dag =
  NodeMap.fold
    (fun start set acc ->
       NodeSet.fold
         (fun finish acc -> (start, finish) :: acc)
         set
         acc)
    dag
    []

let any_incoming dag m =
  List.exists (fun (_, m') -> m = m') (edges dag)

let rev_toposort dag start_nodes =
  begin
    let dag = ref dag in
    let lst = ref [] in
    let q = Queue.create() in
    NodeSet.iter (fun n -> Queue.add n q) start_nodes;
    while not (NodeMap.is_empty !dag && Queue.is_empty q) do
      if Queue.is_empty q then
	failwith "toposort: all remaining edges part of a cycle";
      let n = Queue.take q in
      (* print_string ("taken: " ^ CFG.string_of_node n ^ "\n"); *) (* DEBUG *)
      lst := n :: !lst;
      let ms = get !dag n in
      NodeSet.iter
	(fun m ->
	   begin
	     dag := !dag -$ (n, m);
	     if not (any_incoming !dag m) then
	       Queue.add m q;
	   end)
	ms
    done;
    !lst
  end

let start_nodes cfg =
  NodeSet.filter (fun n -> [] = CFG.pred cfg n) (CFG.nodes cfg)

let (@@) map key = NodeMap.find key map
let (%%) map (key, value) = NodeMap.add key value map
let (%%%) map (key, value) = NodePairMap.add key value map
let (<=) key value = (key, value)

let edge_values dag nodes =
  let numpaths = NodeMap.empty in
  let vals = NodePairMap.empty in
  let leaf n = NodeSet.is_empty (get dag n) in
  let rec loop lst (numpaths, vals) =
    match lst with
    | [] -> (numpaths, vals)
    | v :: vs ->
        let (numpaths, vals) = 
          if leaf v then
            (numpaths %% (v <= 1), vals)
          else
            let numpaths = numpaths %% (v <= 0) in
            NodeSet.fold
              (fun w (numpaths, vals) ->
                 let e = (v,w) in
                 let vals = vals %%% (e <= (numpaths @@ v)) in
                 let numpaths = numpaths %% (v <= (numpaths@@v) + (numpaths@@w))
                 in
                 (numpaths, vals))
              (get dag v)
              (numpaths, vals)
        in
        loop vs (numpaths, vals)
  in
  snd (loop nodes (NodeMap.empty, NodePairMap.empty))

(* Code fixups *)

let increment counter inc dest =
  let label = Temp.newLabel() in
  let code  = [ Ir.LABEL label;
		Ir.MOVE (Ir.TEMP counter,
			 Ir.BINOP(Ir.PLUS,
				  Ir.TEMP counter,
				  Ir.CONST (Int32.of_int inc, Pcc.Bogus)));
		Ir.JUMP dest ] in
  (label, code)

let change_cjump instr label1 label2 =
  match instr with
  | Ir.CJUMP(op, e, e', _, _) -> Ir.CJUMP(op, e, e', label1, label2)
  | _ -> assert false

let rec last list =
  match list with
  | [] -> failwith "last: empty list"
  | [x] -> x
  | x :: xs -> last xs

let rec replace_last list y =
  match list with
  | [] -> failwith "replace_last: empty list"
  | [x] -> [y]
  | x :: xs -> x :: (replace_last xs y)

let jumpdest block =
  match last block with
  | Ir.JUMP dest -> [dest]
  | Ir.CJUMP(_, _, _, d1, d2) -> if d1 = d2 then [d1] else [d1;d2]
  | _ -> assert false

(* TODO: add a "string" for the function name, ... *)
let counter_init_on_entry cfg name instrs var =
  if name = CFG.entry cfg then
    let zero = Ir.CONST (Int32.zero, Pcc.Bogus) in
    match instrs with
    | header :: rest -> header :: Ir.MOVE (Ir.TEMP var, zero) :: rest
    | _ -> assert false
  else
    instrs

let rec append_but_last l1 l2 =
  match l1 with
  | [] -> assert false
  | [x] -> l2 @ [x]
  | x :: xs -> x :: (append_but_last xs l2)

(* ... and print it *)
let print_path_on_exit cfg name var instrs =
  if List.mem name (CFG.pred cfg (CFG.exit cfg)) then
    match instrs with
    | _ :: _ ->
	append_but_last
	  instrs
	  [Ir.EXP (Frame.X86_FRAME.externalCall "print_int" [Ir.TEMP var]);
	   Ir.EXP (Frame.X86_FRAME.externalCall "print_newline" [])]
    | _ -> assert false
  else
    instrs

let increment_backedge var data =
  let to_const n = Ir.CONST (Int32.of_int n, Pcc.Bogus) in
  let (+@) e1 e2 = Ir.BINOP(Ir.PLUS, e1, e2) in
  let call = Frame.X86_FRAME.externalCall
  in
  let label = Temp.newLabel() in
  let code  = [ Ir.LABEL label;
		Ir.MOVE (Ir.TEMP var, Ir.TEMP var +@ to_const data.to_exit);
		Ir.EXP (call "print_int" [Ir.TEMP var]);
		Ir.EXP (call "print_newline" []);
		Ir.MOVE (Ir.TEMP var, to_const data.from_entry);
		Ir.JUMP (CFG.label data.dest) ] in
  (label, code)

let fix_backjump old_dest new_dest instrs =
  let new_jump = 
    match last instrs with
    | Ir.JUMP _ -> Ir.JUMP new_dest
    | Ir.CJUMP(op, e1, e2, d1, d2) ->
	let d1 = if d1 = old_dest then new_dest else d1 in
	let d2 = if d2 = old_dest then new_dest else d2 in
	Ir.CJUMP(op, e1, e2, d1, d2)
    | _ -> assert false
  in
  replace_last instrs new_jump

let add_output var cfg  =
  let rev_blocks =
    CFG.bblock_fold
      (fun ~name ~targets instrs blocks ->
	 let instrs = print_path_on_exit cfg name var instrs in
	 instrs :: blocks)
      []
      cfg
  in
  CFG.make (List.rev rev_blocks)

let add_profiles cfg vals =
  let entry = CFG.entry cfg in
  let exit = CFG.exit cfg in
  let valof (u,v) =
    try
      (* NodePairMap.find will raise Not_found on back-edges *)
      NodePairMap.find (u,v) vals
    with
    | Not_found -> raise(Backedge
			   {source = u;
			    dest = v;
			    from_entry = NodePairMap.find (entry,v) vals;
			    to_exit = NodePairMap.find (u, exit) vals})
  in
  let var = Temp.simpTemp Pcc.Bogus in
  let (rev_code, fixups) = 
    CFG.bblock_fold
      (fun ~name ~targets instrs (rev_code, fixups) ->
	 let instrs = counter_init_on_entry cfg name instrs var in
	 try
	   match targets with
	   | [d_node] ->
	       let [d] = jumpdest instrs in
	       let (label, fix) = increment var (valof (name, d_node)) d in
	       (replace_last instrs (Ir.JUMP label) :: rev_code,
		fix :: fixups)
	   | [d_node; d_node'] ->
	       let [d; d'] = jumpdest instrs in
	       let (l, fix) = increment var (valof (name, d_node)) d in
	       let (l',fix') = increment var (valof (name, d_node')) d' in
	       let b = replace_last instrs (change_cjump (last instrs) l l') in
	       (b :: rev_code, fix :: fix' :: fixups)
	   | [] when name = CFG.exit cfg -> (rev_code, fixups)
	   | _ -> assert false
	 with
	   Backedge r ->
	     let (label, fix) = increment_backedge var r in
	     let b = fix_backjump (CFG.label r.dest) label instrs in
	     (b :: rev_code, fix :: fixups))
      ([], [])
      cfg
  in
  add_output var (CFG.make ((List.rev rev_code) @ fixups))



(*****************************************************************)


let print filename cfg dag =
  let file = open_out filename in
  let out = Format.formatter_of_out_channel file in
  let print fmt = Format.fprintf out fmt in
  let ns = CFG.nodes cfg in
  let n2s n =
    let s = CFG.string_of_node n in
    String.sub s 1 (String.length s - 1)
  in
  begin
    print "digraph DAG {\n";
    NodeSet.iter (fun n -> print "  %s;\n" (n2s n)) ns;
    print "\n";
    NodeSet.iter
      (fun start ->
	 (NodeSet.iter
	    (fun finish ->
	       let start = n2s start in
	       let finish = n2s finish in
	       print "  %s -> %s;\n" start finish)
	    (get dag start)))
      ns;
    print "}\n";
    close_out file;
  end
	 
  


let instrument fname cfg =
  let dag = makedag cfg in
   (* let () = print (fname ^ "-dag.dot") cfg dag in *) (* DEBUG *)
  let vertices = rev_toposort dag (start_nodes cfg) in
  let vals = edge_values dag vertices in
  add_profiles cfg vals

