(* canon.ml *)
(* 15-411 *)
(* by Roland Flury *)
(* @version $Id: canon.ml,v 1.3 2003/09/17 18:54:07 rflury Exp $ *)

module T=Ir
module TP = Temp
module E = Errormsg

exception CanonError of string


(* **************************************************************** *)
(* Tripleize *)
(* **************************************************************** *)

type 'a join =
  | Zero
  | One of 'a
  | Join of 'a join * 'a join

let joinlist list = List.fold_right (fun x acc -> Join(x, acc)) list Zero

let list_of_join joinlist =
  let rec loop j acc =
    match j with
    | Zero -> acc
    | One x -> x :: acc
    | Join(l, r) -> loop l (loop r acc)
  in
    loop joinlist []

let new_temp () = TP.simpTemp Pcc.Bogus

let _ =
  let j1 = Zero in
  let j2 = One 3 in
  let j3 = One 5 in
  let j4 = Join(j2, j3) in
  let j5 = Join(j3, j4) in
  let j6 = Join(j5, j5) in
    begin
      assert ([] = list_of_join j1);
      assert ([3] = list_of_join j2);
      assert ([5] = list_of_join j3);
      assert ([3; 5] = list_of_join j4);
      assert ([5; 3; 5] = list_of_join j5);
      assert ([5; 3; 5; 5; 3; 5] = list_of_join j6);
      assert ([3;5;3;5;3;5] = list_of_join (joinlist [j4; j4; j4]));
      (* Printf.printf "Joinlist tests succeeded!\n"; *)
    end

let store_exp e =
  let t = new_temp() in
    (t, T.MOVE(T.TEMP t, e))

let rec triple_exp' e =
  match e with
  | T.CONST(n, t) -> (e, Zero)
  | T.NAME l -> (e, Zero)
  | T.TEMP t -> (e, Zero)
  | T.BINOP(op, e1, e2) ->
      let (e1', stmts1) = triple_exp' e1 in
      let (e2', stmts2) = triple_exp' e2 in
      let (t, i) = store_exp (T.BINOP(op, e1', e2')) in
	(T.TEMP t, Join(Join(stmts1, stmts2), One i))
  | T.MEM e ->
      let (e', stmts) = triple_exp' e in
      let (t, i) = store_exp (T.MEM e') in
	(T.TEMP t, Join(stmts, One i))
  | T.CALL(f, args) ->
      let (temps, arg_stmt_list) = List.split (List.map triple_exp' args) in
      let arg_stmts = joinlist arg_stmt_list in
      let (t, i) = store_exp (T.CALL(f, temps)) in
	(T.TEMP t, Join(arg_stmts, One i))
  | T.ALLOCA e -> 
      let (e', stmts) = triple_exp' e in
      let (t, i) = store_exp (T.ALLOCA e') in
	(T.TEMP t, Join(stmts, One i))
(*
  | T.ESEQ(stmt, e) ->
      let stmts1 = triple_stmt' stmt in
      let (e', stmts2) = triple_exp' e in
      let (t, i) = store_exp e in
	(T.TEMP t, Join(stmts1, Join(stmts2, One i)))
*)
  | T.ESEQ(_, _) 
  | T.PHI _ -> assert false
and triple_stmt' s =
  match s with
  | T.MOVE(T.MEM(e1), e2) ->
      let (e1', stmts1) = triple_exp' e1 in
      let (e2', stmts2) = triple_exp' e2 in
      let i = T.MOVE(T.MEM e1', e2') in
      (* Evaluate expression before destination *) 
	Join(stmts2, Join(stmts1, One i))
  | T.MOVE(e1, e2) ->
      let (e1', stmts1) = triple_exp' e1 in
      let (e2', stmts2) = triple_exp' e2 in
      let i = T.MOVE (e1', e2') in
      (* Evaluate expression before destination *) 
	Join(stmts2, Join(stmts1, One i))
  | T.EXP e ->
      let (e', stmts) = triple_exp' e in
	Join(stmts, One (T.EXP e'))
  | T.LABEL l -> One s
  | T.JUMP l -> One s
  | T.CJUMP(relop, e1, e2, l1, l2) ->
      (* order of evaluation bug? *)
      let (e1', stmts1) = triple_exp' e1 in
      let (e2', stmts2) = triple_exp' e2 in
      let i = T.CJUMP(relop, e1', e2', l1, l2) in
	Join(stmts1, Join(stmts2, One i))
  | T.SEQ(s1, s2) -> assert false
(*
  | T.SEQ(s1, s2) ->
      let stmts1 = triple_stmt' s1 in
      let stmts2 = triple_stmt' s2 in
	Join(stmts1, stmts2)
*)
  | T.COMMENT(_, _) -> One s
  | T.INVARIANT _ -> One s
		 
let triple_exp e =
  let (e', stmts) = triple_exp' e in (e', list_of_join stmts)

let triple_stmt s = list_of_join (triple_stmt' s)


(**************************************************************************)
(* Linearize *)
(**************************************************************************)

(* Concatenate two expressions to T.SEQ, 
   but only if none of the two is T.nop *)
let (%) x y = (* infix operator % *)
  match (x, y) with
  | (T.EXP(T.CONST(_)), y) -> y
  | (x, T.EXP(T.CONST(_))) -> x
  | (x, y) -> T.SEQ(x,y)

(* test whether a stmt and an exp commute *)
let commute stmt exp = 
  match (stmt,exp) with
  | (T.EXP(T.CONST(_)), _) -> true
  | (T.EXP(T.NAME(_)), _) -> true
  | (_, T.NAME(_)) -> true
  | (_, T.CONST(_)) -> true
  | _ -> false


let hd = 
  (try
    List.hd
  with hd -> 
    raise (CanonError "head of an empty list")
  )
let tl = 
  (try
    List.tl
  with hd -> 
    raise (CanonError "tail of an empty list")
  )
    
let rec do_stmt stmt = 
  match stmt with
  | T.MOVE(T.TEMP(t), T.CALL(f, al)) -> 
      reorder_stmt al (fun al -> T.MOVE(T.TEMP(t), T.CALL(f, al)))
  | T.MOVE(T.TEMP(t), T.ALLOCA e) ->
      reorder_stmt [e] (fun al -> T.MOVE(T.TEMP(t), T.ALLOCA (hd al)))
  | T.MOVE(T.TEMP(t), b) -> 
      reorder_stmt [b] (fun li -> T.MOVE(T.TEMP(t), hd li))
  | T.MOVE(T.MEM(a), b) -> 
      (* watch out! eval first rhs, then lhs *)
      reorder_stmt [b;a] (fun li -> T.MOVE(T.MEM(hd (tl li)), hd li))
  | T.MOVE(a,b) -> 
      (* watch out! eval first rhs, then lhs *)
      reorder_stmt [b;a] (fun li -> T.MOVE(hd (tl li), hd li))
  | T.EXP(T.CALL(f,al)) -> 
      reorder_stmt al (fun al -> T.EXP(T.CALL(f,al)))
  | T.EXP(e) -> 
      reorder_stmt [e] (fun li -> T.EXP(hd li))
  | T.LABEL(l) -> 
      T.LABEL(l)
  | T.JUMP(l) -> 
      T.JUMP(l)
  | T.CJUMP(op, a, b, t, f) -> 
      reorder_stmt [a;b] (fun li -> T.CJUMP(op, hd li, hd (tl li), t, f))
  | T.SEQ(a,b) -> 
      T.SEQ(do_stmt a, do_stmt b)
  | T.COMMENT(s,l) -> 
      T.COMMENT(s,l)
  | T.INVARIANT(li) -> 
      T.INVARIANT(li)

and do_exp exp = 
  match exp with
  | T.CONST(_) | T.NAME(_) | T.TEMP(_) | T.PHI(_) -> 
      (T.nop, exp)
  | T.BINOP(op, a, b) -> 
      reorder_exp [a;b] (fun li -> T.BINOP(op, hd li, hd (tl li)))
  | T.MEM(e) -> 
      reorder_exp [e] (fun li -> T.MEM(hd li))
  | T.ESEQ(s, e) -> 
      let (pre, post) = reorder_exp [e] (fun li -> hd li) in 
      (T.SEQ(do_stmt s, pre), post)
  | T.CALL(l, args) -> 
      let tmp = TP.simpTemp Pcc.Bogus in (* PCC FIXME *)
      (do_stmt (T.MOVE(T.TEMP(tmp), 
		       T.CALL(l,args))), 
       T.TEMP(tmp))
  | T.ALLOCA e ->
      let tmp = TP.simpTemp Pcc.Bogus in
      (do_stmt (T.MOVE(T.TEMP(tmp), T.ALLOCA e)),
       T.TEMP(tmp))

(* reorders a list of expressions *)
and reorder eli = 
  (* check whether a stmt commutes with all exp in a list *)
  let comm li stmt = 
    List.fold_left (fun r (s,e) -> r && (commute stmt e)) true li
  in

  (* Iterates from the first to the last expression of li 
   * (which is the reverted eli) and stops by the first exp that does
   * not commute with all preceding ones, who then get get assigned
   * to a new temp *)
  let rec iter pre res li = 
    match li with
    | [] -> (pre, res)
    | (s,e) :: tail -> 
	if(comm tail s) then
	  iter (s % pre) (e :: res) tail
	else
	  (* does not commute, -> insert moves for all preceding exps *)
	  List.fold_left (fun (pre,res) (s,e) -> 
	    match (s,e) with
	    | ((T.EXP(T.CONST(_))|T.EXP(T.NAME(_))), (T.NAME(_)|T.CONST(_))) -> 
		(s % pre, e :: res)
	    | _ -> 
		let tmp = TP.simpTemp Pcc.Int in (* PCC FIXME *)
		(T.SEQ(s, T.MOVE(T.TEMP(tmp), e)) % pre, 
		 T.TEMP(tmp) :: res)
			 ) ((s % pre), (e :: res)) tail
  in
  
  (* linearize all sub-expressions *)
  let li = List.map (fun e -> do_exp e) eli in 
  let revli = List.rev li in (* revert list *)
  
  iter T.nop [] revli

and reorder_stmt eli f = 
  let (pre, li) = reorder eli in
  T.SEQ(pre, f li)

and reorder_exp eli f =
  let (pre, li) = reorder eli in
  (pre, f li)


(* linearizes a program produced by translate *)
let linearize irList = 
  let rec linear stmt = 
    match stmt with
    | T.SEQ(a, b) -> (linear a) @ (linear b)
    | T.EXP(T.CONST(_)) -> []
    | x -> [x]
  in
    List.map
      (fun (fName, stmt) -> 
	 TP.openFunLookUp fName;
	 let lin = linear (do_stmt stmt) in
	   TP.saveFunLookUp fName;
	   (fName, lin))
      irList

let with_fname fname thunk =
  begin
    TP.openFunLookUp fname;
    try
      let v = thunk() in
	(TP.saveFunLookUp fname; v)
    with
    | error -> (TP.saveFunLookUp fname; raise error)
  end

let triple_function (fname, stmts) =
  with_fname
    fname
    (fun () -> 
       let stmts' = List.map triple_stmt stmts in
	 (fname, List.concat stmts'))

let tripleize list =
  List.map triple_function list
  

(**************************************************************************)
(* Basic Blocks *)
(**************************************************************************)

(* Takes a list of stmts and transforms it to a 
 * list of basic blocks, each appended with a marker used in tracing *)
let bBlocks stmts = 
  let rec blocks (work , res) =
    match work with
    | (T.LABEL(l)) :: tail -> 
	let rec next (work, thisblock) = 
	  match work with
	  | (T.JUMP(j)) :: tail -> 
	      endblock (tail, (T.JUMP(j)) :: thisblock)
	  | (T.CJUMP(a1, a2, a3, a4, a5)) :: tail -> 
	      endblock (tail, (T.CJUMP(a1, a2, a3, a4, a5)) :: thisblock)
	  | (T.LABEL(l)) :: _ -> 
	      next((T.JUMP(l) :: work), thisblock)
	  | s :: tail -> next(tail, s :: thisblock)
	  | [] -> raise (CanonError "Got empty list in Basic Block gen.")
	and endblock (stmts, thisblock) = 
	  blocks(stmts, ((List.rev thisblock), ref false) :: res)
	in
	next(tail, [T.LABEL(l)])
    | [] -> List.rev res
    | s -> blocks((T.LABEL(TP.newLabel ())) :: s, res)
  in
  blocks(stmts, [])
 
(* Transforms the statement lists of functions to basic blocks *)
let basicBlocks linIR = 
  List.map (fun (fName, s) -> (fName, bBlocks s)) linIR

(**************************************************************************)
(* Tracing *)
(**************************************************************************)

let trace bblist = 
  (* Returns an unmarked block and marks it, 
     raises Not_found if all blocks have been marked *)
  let newTrace () =
    fst(List.find (fun (_,m) -> 
      if(not !m) then (m := true; true) else false) bblist)
  in

  (* Find and mark a block with given label that is not yet marked. 
   * If the search fails, Not_found will be raised *)
  let findBlock label = 
    fst (List.find (function (s,m) -> 
      if(not !m) then (
	match s with
	| T.LABEL(l) :: tail when (l = label) -> 
	    m := true; true
	| _ -> false
       ) else 
	false
		   ) bblist)
  in

  (* Returns a tuple consisting of all statements but the last and the last *)
  let rec splitBlock = function
  | [] -> raise (CanonError ("Empty list in splitBlock."))
  | [a] -> ([], a)
  | head :: tail -> 
      let (a,b) = splitBlock tail in
      (head :: a, b)
  in
  
  (* Fall through on false *)
  (* tries to append another block *)
  let rec next curr res = 
    match (splitBlock curr) with
    | (_, T.JUMP(l)) -> 
	(try
	  next (findBlock l) (curr :: res)
	with Not_found -> 
	  (curr :: res)
	)
    | (pre, T.CJUMP(op, e1, e2, t, f)) -> 
	(try
	  (* Rewrite, such that fall through on true if possible *)
	  next (findBlock t) 
	    ((pre @ [T.CJUMP(Ir.notRelop op, e1, e2, f, t)]) :: res)
	with Not_found -> 
	  (try
	    next (findBlock f) (curr :: res)
	  with Not_found -> 
	    let l = TP.newLabel () in
	    [T.LABEL(l);T.JUMP(f)] :: (pre @ [T.CJUMP(op, e1, e2, t, l)]) :: res
	  )
	)
    | _ -> raise (CanonError ("Empty block in trace."))
  in

  let rec loop res = 
    (try
      loop (next (newTrace ()) res)
    with 
      Not_found -> List.rev res
    )
  in
  
  loop []

(* Traces the basic blocks *)
let traceSchedule bblist = 
  List.map (fun (fName, bb) -> (fName, trace bb)) bblist
