type node = Temp.label

let string_of_node = Temp.label2string

module NodeMap = Map.Make (struct
                             type t = node
                             let compare = compare (* labels are ints *)
                           end)

module NodeSet = Set.Make (struct
                             type t = node
                             let compare = compare (* labels are ints *)
                           end)

module NodePairMap = Map.Make (struct
				 type t = node * node
				 let compare = compare
			       end)

module NodePairSet = Set.Make (struct
				 type t = node * node
				 let compare = compare
			       end)

type t = { bblocks : Ir.stmt list list;
           preds : node list NodeMap.t;
           succs : node list NodeMap.t;
	   nodes : NodeSet.t;
           code : Ir.stmt list NodeMap.t;
           entry : node; }
           
let (%%) map (key, value) = NodeMap.add key value map 
let (<=) key value = (key, value)
let (@@) map key = NodeMap.find key map
(* let get map key default = try map @@ key with Not_found -> default *)

let block_name bb =
    match bb with
      | Ir.LABEL l :: _ -> l
      | _ -> assert false (* basic block that doesn't begin with a label *)

let rec block_targets bb =
    match bb with
      | [] -> assert false (* no blocks *)
      | [Ir.JUMP lab] -> [lab]
      | [Ir.CJUMP (_, _, _, lab1, lab2)] ->
            if lab1 = lab2 then [lab1] else [lab1; lab2]
      | [stmt] -> assert false (* basic block doesn't end with a jump *)
      | Ir.LABEL t :: _ when t = Temp.getReturnLabel () ->
            [] (* NB: return "block" has no targets -wjl *)
      | st::sts -> block_targets sts

(* NB: we use initial_nodemap' below in make, before the cfg is constructed -wjl *)
let initial_nodemap' bblocks x =
    List.fold_left (fun init n -> init %% (n <= x))
                   NodeMap.empty
                   (List.map block_name bblocks)

let initial_nodemap cfg x = initial_nodemap' cfg.bblocks x

let make bblocks =
    let rec loop bblocks succs preds code nodes =
        match bblocks with
          | [] -> (succs, preds, code, nodes)
          | bb::bbs ->
                let name = block_name bb in
                let targets = block_targets bb in
                let succs = succs %% (name <= targets) in
                let code = code %% (name <= bb) in
                let preds = List.fold_left
                                (fun acc l -> acc %% (l <= name :: (acc @@ l)))
                                preds
                                targets in
		let nodes = NodeSet.add name nodes 
                (*
                let preds = List.fold_left
                                (fun acc l ->
                                    if l = Temp.getReturnLabel () then
                                        acc
                                    else
                                        acc %% (l <= name :: get acc l []))
                                preds
                                targets
                *)
                in
                loop bbs succs preds code nodes
    in
    (* NB!!: add fake basic block at the end for the return label -wjl *)
    let returnLabel = Temp.getReturnLabel () in
    let bblocks = bblocks @ [ [Ir.LABEL returnLabel; Ir.JUMP returnLabel] ] in
    (* NB: succs and code will each get one entry per block because the code
       obviously ensures this.  preds only gets updated when we see something
       jump into a block though, so it needs to be initialized to empties.
       -wjl *)
    let succs = NodeMap.empty in
    let preds = initial_nodemap' bblocks [] in
    let code = NodeMap.empty in
    let nodes = NodeSet.empty in
    let (succs, preds, code, nodes) = loop bblocks succs preds code nodes in
    let entry =
        match bblocks with
          | bb::_ -> block_name bb
          | _ -> assert false (* no basic blocks! *)
    in
    { bblocks = bblocks;
      preds = preds;
      succs = succs;
      nodes = nodes;
      code = code;
      entry = entry; }


let pred cfg node = cfg.preds @@ node
let succ cfg node = cfg.succs @@ node
let code cfg node = cfg.code @@ node
let nodes cfg = cfg.nodes

let entry cfg = cfg.entry
let exit cfg = Temp.getReturnLabel()

let all_blocks cfg = List.map block_name cfg.bblocks

let bblock_fold process_bblock initial cfg =
    List.fold_left
        (fun acc bb ->
            let name = block_name bb in
            let targets = block_targets bb in
            process_bblock ~name ~targets bb acc)
        initial
        cfg.bblocks

let dfs cfg =
  let rec loop node ((visited, list) as result) =
    if NodeSet.mem node visited then
      result
    else
      let result = (NodeSet.add node visited, node :: list) in
      List.fold_right
	loop
	(succ cfg node)
	result
  in
  List.rev (snd (loop (entry cfg) (NodeSet.empty, [])))

let bfs cfg =
  let rec loop node ((visited, list) as result) =
    if NodeSet.mem node visited then
      result
    else
      let result = (NodeSet.add node visited, list @ [node]) in
      List.fold_right
	loop
	(succ cfg node)
	result
  in
  List.rev (snd (loop (entry cfg) (NodeSet.empty, [])))

let label node = node

let print filename cfg =
  let file = open_out filename in
  let out = Format.formatter_of_out_channel file in
  let print fmt = Format.fprintf out fmt in
  let ns = nodes cfg in
  let n2s n =
    let s = string_of_node n in
    String.sub s 1 (String.length s - 1)
  in
  begin
    print "digraph D {\n";
    NodeSet.iter (fun node -> print "  %s;\n" (n2s node)) ns;
    print "\n";
    NodeSet.iter
      (fun start ->
	 List.iter
	   (fun finish ->
	      let start = n2s start in
	      let finish = n2s finish in
	      print "  %s -> %s;\n" start finish)
	   (succ cfg start))
      ns;
    print "}\n";
    flush file; (* ??? *)
    close_out file;
  end 

