module A = Absyn

module StringSet = Set.Make (struct
			       type t = string
			       let compare = compare
			     end)

module CtypeSet = Set.Make (struct
			      type t = A.ctype
			      let compare = compare
			    end)

module CtypeMap = Set.Make (struct
			      type t = A.ctype
			      let compare = compare
			    end)

let reachable_types ctype =
  let rec loop ctype (visited, set) =
    match ctype with
    | A.TyInt
    | A.TyBool
    | A.Void -> (visited, set)
    | A.TyPointer typ ->
	let (visited, set) = loop typ (visited, set) in
	(visited, CtypeSet.add ctype set)
    | A.TyRecord name ->
	if StringSet.mem name visited then
	  (visited, set)
	else
	  let types = List.map snd (Checker.getStructType name) in
	  let visited = StringSet.add name visited in
	  List.fold_right
	    (fun typ (visited, acc) ->
	       let (visited, set) = loop typ (visited, acc) in
	       (visited, CtypeSet.union set acc))
	    types
	    (visited, CtypeSet.add ctype set)
    | A.NS -> failwith "Nonsense type -- can this happen?"
  in
  snd (loop ctype (StringSet.empty, CtypeSet.empty))

let reachable_union types =
  List.fold_right
    (fun typ acc -> CtypeSet.union acc (reachable_types typ))
    types
    CtypeSet.empty

let rec optimize_decl decl =
  match decl with
  | A.VarDecl(_, _, _)
  | A.StructDecl(_, _) -> decl
  | A.FunDecl(f, ret, e, args, body) ->
      let rettype = Checker.typ2ctype ret in
      let argtypes = List.map (fun (_,t,_) -> Checker.typ2ctype t) args in
      let types = reachable_union (rettype :: argtypes) in
      let body = optimize_body types body in
      A.FunDecl(f, ret, e, args, body)

and optimize_body types body =
  match body with
  | A.Foreign e -> A.Foreign e
  | A.Body (decls, stmts) -> A.Body(decls,
				    List.map (optimize_stmt types) stmts)
and optimize_stmt types stmt = 
  match stmt with 
  | A.Assign(lval, pos, exp, pos') ->
      A.Assign(lval, pos, optimize_exp types exp, pos')
  | A.IfElse (e, pos, trues, falses, pos') ->
      let e = optimize_exp types e in
      let trues = List.map (optimize_stmt types) trues in
      let falses = List.map (optimize_stmt types) falses in
      A.IfElse(e, pos, trues, falses, pos')
  | A.If(e, pos, trues, pos') ->
      let e = optimize_exp types e in
      let trues = List.map (optimize_stmt types) trues in
      A.If(e, pos, trues, pos')
  | A.For(stmt1, pos1, exp, pos2, stmt2, pos3, stmts) ->
      let stmt1 = optimize_stmt types stmt1 in
      let exp = optimize_exp types exp in
      let stmt2 = optimize_stmt types stmt2 in
      let stmts = List.map (optimize_stmt types) stmts in
      A.For(stmt1, pos1, exp, pos2, stmt2, pos3, stmts)
  | A.While(e, pos, stmts) ->
      let e = optimize_exp types e in
      let stmts = List.map (optimize_stmt types) stmts in
      A.While(e, pos, stmts)
  | A.StmtExp(e, pos) ->
      A.StmtExp(optimize_exp types e, pos)
  | A.Return(None, pos) -> A.Return(None, pos)
  | A.Return(Some e, pos) -> A.Return(Some (optimize_exp types e), pos)
  | A.Continue pos -> A.Continue pos
  | A.Break pos -> A.Break pos
  | A.Blank -> A.Blank

and optimize_exp types exp =
  match exp with
  | A.ConstExp(c,t) ->
      A.ConstExp(c,t)
  | A.OpExp (exp, op, None, pos, typ) ->
      A.OpExp(optimize_exp types exp, op, None, pos, typ)
  | A.OpExp (exp, op, Some exp', pos, typ) ->
      A.OpExp(optimize_exp types exp,
	      op,
	      Some (optimize_exp types exp'),
	      pos,
	      typ)
  | A.Offset(e, pos, typ) ->
      A.Offset(optimize_exp types e, pos, typ)
  | A.Size(e, pos, typ) ->
      A.Size(optimize_exp types e, pos, typ)
  | A.Reference (lval, pos, typ) ->
      A.Reference (lval, pos, typ)
  | A.Call(f, arglist, pos, typ) ->
      let (args, positions) = List.split arglist in
      let args = List.map (optimize_exp types) args in
      let arglist = List.combine args positions in
      A.Call(f, arglist, pos, typ)
  | A.LVal lval ->
      A.LVal lval
  | A.StackAlloc (_, _, _, _, _) ->
      assert false (* Can't happen yet! *)
  | A.Alloc (exp, pos, tp, pos', ctype) ->
      let exp = optimize_exp types exp in
      if CtypeSet.mem ctype types then
	A.Alloc(exp, pos, tp, pos', ctype)
      else
	A.StackAlloc(exp, pos, tp, pos', ctype)

let optimize_program (A.Program lst) = A.Program (List.map optimize_decl lst)
