(** Abstract Syntax Tree for Sax *)

type label = [%import: Ast.label]
type tpname = [%import: Ast.tpname]
type modename = [%import: Ast.modename]
type mode = [%import: Ast.mode]
type tp = [%import: Ast.tp]
type varname = [%import: Ast.varname]
type expname = [%import: Ast.expname]
type pat = [%import: Ast.pat]
type exp = [%import: Ast.exp]
and atom = [%import: Ast.atom]
type parm = [%import: Ast.parm]
type defn = [%import: Ast.defn]
type env =  [%import: Ast.env]

let is_lowercase c =
  let code = Char.code c in
  Char.code 'a' <= code && code <= Char.code 'z'

let string2mode m = if is_lowercase (String.get m 0) then ModeVar(m) else ModeConst(m)

module Print = struct

let rec indent n s = if n = 0 then s else " " ^ indent (n-1) s
let parens s = "(" ^ s ^ ")"

let pp_label (l : label) = l

let pp_mode m = match m with
  | ModeVar(mvar) -> mvar
  | ModeConst(mconst) -> mconst

let rec pp_tp (tau : tp) : string = match tau with
  | Times(tau1, tau2) -> parens (pp_tp tau1 ^ " * " ^ pp_tp tau2)
  | One -> "1"
  | Plus(alts) ->
     "+" ^ "{" ^ String.concat ", " (List.map (fun (l, tau_l) -> l ^ " : " ^ pp_tp tau_l) alts) ^ "}"
  | Arrow(tau1, tau2) -> parens (pp_tp tau1 ^ " -> " ^ pp_tp tau2)  (* Lab 3 *)
  | With(alts) ->
     "&" ^ "{" ^ String.concat ", " (List.map (fun (l, tau_l) -> l ^ " : " ^ pp_tp tau_l) alts) ^ "}"
  | Down(tau) -> "<" ^ pp_tp tau ^ ">" (* Lab 4 *)
  | Up(tau) -> "^" ^ parens (pp_tp tau) (* Lab 4 *)
  | Flat(m, tau) -> "[" ^ pp_mode m ^ "]" ^ parens (pp_tp tau) (* Lab 4 *)
  | TpInst(a, []) -> a                                 (* Lab 4 *)
  | TpInst(a, ms) -> a ^ "[" ^ String.concat " " (List.map pp_mode ms) ^ "]" (* Lab 4 *)

let rec pp_pat (pat : pat) : string = match pat with
  | PairPat(p,q) -> "(" ^ pp_pat p ^ ", " ^ pp_pat q ^ ")"
  | UnitPat -> "()"
  | InjPat(l,p) -> l ^ "(" ^ pp_pat p ^ ")"
  | ShiftPat(p) -> "<" ^ pp_pat p ^ ">"
  | VarPat(x) -> x

let rec pp_exp (col : int) (e : exp) : string = match e with
  | Var(x) -> x
  | Pair(e1, e2) -> "(" ^ pp_exp col e1 ^ ", " ^ pp_exp col e2 ^ ")"
  | Unit -> "()"
  | Inj(k, e) -> pp_label k ^ "(" ^ pp_exp col e ^ ")"
  | MatchWith(e, branches) ->
     "match " ^ pp_exp col e ^ " with\n"
     ^ pp_branches col branches
     ^ indent col "end"
  | Fun(x, e) -> parens ("fun " ^ x ^ " => " ^ pp_exp col e) (* Lab 3 *)
  | Record(fields) ->                                        (* Lab 3 *)
     "record\n"
     ^ pp_fields col fields
     ^ indent col "end"
  | Call(f, []) -> "(" ^ f ^ ")"
  | Call(f, args) -> "(" ^ f ^ " " ^ String.concat " " (List.map (pp_atom col) args) ^ ")"
  | Shift(e) -> "<" ^ pp_exp col e ^ ">"
  | Susp(e) -> "susp " ^ parens (pp_exp col e)

and pp_branches col branches = match branches with
  | ((pat, e)::branches) -> pp_branch col (pat, e) ^ "\n"
                            ^ pp_branches col branches
  | [] -> ""

and pp_branch col branch = match branch with
    (pat, e) -> indent col ("| " ^ pp_pat pat ^ " => " ^ pp_exp (col+4) e)

(* Lab 3 *)
and pp_fields col fields = match fields with
  | ((l, e)::fields) -> indent col ("| " ^ pp_label l ^ " => " ^ pp_exp (col+4) e ^ "\n")
                        ^ pp_fields col fields
  | [] -> ""

(* Lab 3 *)
and pp_atom col atom = match atom with
  | Exp(Unit) -> "()"
  | Exp(Var(x)) -> x
  | Exp(e) -> parens (pp_exp col e)
  | Dot(k) -> "." ^ k
  | Force -> "." ^ "force"      (* Lab 4 *)

let pp_parm x_tau = match x_tau with
  | (x, tau) -> parens (x ^ " : " ^ pp_tp tau)

let pp_parms y_sigmas = String.concat " " (List.map pp_parm y_sigmas)

let rec pp_defn defn = match defn with
  | TypeDefn(a, ms, tau) -> "type " ^ a ^ "[" ^ String.concat " " (List.map pp_mode ms) ^ "]"
                            ^ " = " ^ pp_tp tau ^ "\n"
  | ExpDefn(f, y_sigmas, tau, e) ->
     "defn " ^ f ^ " " ^ pp_parms y_sigmas ^ " : " ^ pp_tp tau ^ " =\n"
     ^ indent 4 (pp_exp 4 e) ^ "\n"
  | InstDefn(f, y_sigmas, tau) ->
     "inst " ^ f ^ " " ^ pp_parms y_sigmas ^ " : " ^ pp_tp tau ^ "\n"

let pp_env defns = String.concat "\n" (List.map pp_defn defns)

end (* module Print *)
