(* Parser *)
(* Frank Pfenning <fp@cs.cmu.edu> *)

(*
 * Handwritten shift/reduce parser to support
 * best possible error messages
 *)

signature PARSE =
sig

    val parse : string -> ExtSyn.env (* may raise ErrorMsg.Error *)

    exception ParseError of (int * int) * Lex.lexresult Stream.front (* error region, continuation *)

end  (* signature PARSE *)


structure Parse :> PARSE =
struct

structure E = ExtSyn
structure PS = ParseState
structure M = Stream
structure T = Terminal
structure Lex = Lex

(******************)
(* Error Messages *)
(******************)

fun pp_tok t = "'" ^ T.toString t ^ "'"

fun pp_toks (nil) = ""
  | pp_toks (t::nil) = " or " ^ pp_tok t
  | pp_toks (t::ts) = pp_tok t ^ ", " ^ pp_toks ts

fun ^^(s1,s2) = s1 ^ "\n[Hint: " ^ s2 ^ "]"
infix ^^

exception ParseError of (int * int) * Lex.lexresult M.front

fun parse_error (region, str) tf =
    ( ErrorMsg.error (PS.ext region) str
    ; raise ParseError (region, tf) )

fun msg_expected t' t =
    ("expected " ^ pp_tok t' ^ ", found: " ^ pp_tok t)

fun error_expected (region, t', t) tf =
    ( ErrorMsg.error (PS.ext region) (msg_expected t' t)
    ; raise ParseError (region, tf) )

fun error_expected_h (region, t', t, error_hint) tf =
    ( ErrorMsg.error (PS.ext region) (msg_expected t' t ^^ error_hint)
    ; raise ParseError (region, tf) )

fun msg_expected_list ts t =
    "expected one of " ^ pp_toks ts ^ ", found: " ^ pp_tok t

fun error_expected_list (region, ts, t) tf =
    ( ErrorMsg.error (PS.ext region) (msg_expected_list ts t)
    ; raise ParseError (region, tf) )

fun error_expected_list_h (region, ts, t, error_hint) tf =
    ( ErrorMsg.error (PS.ext region) (msg_expected_list ts t ^^ error_hint)
    ; raise ParseError (region, tf) )
 
fun location (NONE) = "_"
  | location (SOME(mark)) = Mark.show(mark)

(*******************)
(* Data structures *)
(*******************)

type region = int * int
type prec = int                 (* precedence *)

(* stack items for shift/reduce parsing *)
datatype stack_item =
   Tok of T.terminal * region                          (* lexer token *)
 | Tp of E.tp * region                                 (* type *)
 | TpInfix of prec * (E.tp * E.tp -> E.tp) * region    (* type infix type operator, all right assoc *)
 | Dec of E.def * region                               (* top-level declaration *)
 | Proc of E.proc * region                             (* process *)
 | Parms of E.parm list * region                       (* typed parameters *)
 | Vars of E.chan list * region                        (* variables list *)
 | Alts of (E.label * E.tp) list * region              (* alternatives 'k : A *)
 | Branches of (E.value * E.proc) list * region        (* branches *)
 | Val of E.value * region                             (* value *)

 | Error of region                                     (* lexing or parsing error *)

datatype stack
  = Bot
  | $ of stack * stack_item

infix 2 $

fun pp_item (Tok(t,_)) = "Tok(" ^ pp_tok t ^ ")"
  | pp_item (Tp _) = "Tp"
  | pp_item (TpInfix _) = "TpInfix"
  | pp_item (Dec _) = "Dec"
  | pp_item (Proc _) = "Proc"
  | pp_item (Parms _) = "Parms"
  | pp_item (Vars _) = "Vars"
  | pp_item (Alts _) = "Alts"
  | pp_item (Branches _) = "Branches"
  | pp_item (Val _) = "Val"

  | pp_item (Error _) = "Error"

fun pp_stack Bot = "Bot"
  | pp_stack (S $ item) = pp_stack S ^ " $ " ^ pp_item item

(* This is a hand-written shift/reduce parser
 * I have tried to resist the temptation to optimize or transform,
 * since it is very easy to make mistakes.
 *
 * Parsing functions are named p_<nonterminal>, possibly with a suffix
 * for intermediate parsing states.  Reducing functions are named
 * r_<nonterminal>, possibly with a suffix for intermediate states.
 * With few exceptions, a parsing functions have type
 *
 * p_<nonterminal> : stack * Lex.lexresult M.Front -> stack * Lex.lexresult M.Front
 * r_<nonterminal> : stack -> stack
 *
 * Note that in input and output of the parsing function, the first
 * token of the lexer stream is exposed (type Lex.lexresult M.Front) which
 * make for easy matching and composition of these functions.
 *
 * Generally p_<nt> will consume terminals for an <nt> from the lexresult
 * stream and return with the resulting abstract syntax for <nt> on the stack.
 * Generally r_<nt> will consume a mix of terminal and nonterminals on
 * the stack and push the abstract syntax for an <nt> onto the stack.
 *
 * While parsing expression with infix, prefix, and postfix operators
 * we continue parsing, extending the expression until we encounter
 * a terminal that completes that is not part of an expression.
 *
 * p_<nt> is a function that parses nonterminal <nt>
 * r_<nt> is a function that reduces nonterminal <nt>
 * c_<cond> is a function that checks condition <cond>
 * e_<error> is a function that reports error <error>
 * m_<nt> is a function that marks nonterminal <nt> with region information
 *)

(***********************)
(* Parsing Combinators *)
(***********************)

(* Always call 'first ST' to extract the first token for examination *)
fun first (S, M.Cons((t, r), ts')) = t
fun second (S, ft) = ft

fun shift (S, M.Cons((t, r), ts')) = (S $ Tok(t, r), M.force ts')
fun reduce reduce_fun (S, ft) = (reduce_fun S, ft)

fun drop (S, M.Cons((t, r), ts')) = (S, M.force ts') (* use sparingly *)
fun push item (S, ft) = (S $ item, ft)

fun >>(f,g) = fn x => g(f(x))
fun |>(x,f) = f x

infixr 2 >>
infix 1 |>

(* recover (S, tf) = (S, tf') *)
(* skips over all tokens in tf until it reaches a top level
 * declaration or end-of-file.  tf' will start with the keyword starting a declaration
 * or EOF. This supports recovery after a lex or parse error.
 *)
fun recover (S, tf as M.Nil) = (S, tf)
  | recover (S, tf as M.Cons((t,r), tf')) =
    ( case t
       of T.TYPE => (S, tf)
        | T.PROC => (S, tf)
        | T.EXEC => (S, tf)
        | T.FAIL => (S, tf)
        | T.EOF => (S, tf)
        | _ => (* skip all other tokens *)
          recover (S, M.force tf') )

(* region manipulation *)
fun join (left1, right1) (left2, right2) = (left1, right2)
fun here (S, M.Cons((t, r), ts')) = r
val nowhere = (0,0)

(****************************)
(* Building abstract syntax *)
(****************************)

fun m_val (V, (left, right)) = E.MarkedValue (Mark.mark' (V, PS.ext (left, right)))

fun m_proc (P, (left, right)) = E.Marked (Mark.mark' (P, PS.ext (left, right)))

(* not marking types for now *)
fun m_tp (tau, (left, right)) = tau

(***********)
(* Parsing *)
(***********)

(*
 * Refer to the grammar in readme.txt
 * Comments refer to the nonterminals shown there in <angle brackets>
 *)

(* <dec> *)
fun p_dec ST = case first ST of
    T.TYPE => ST |> shift >> p_id >> p_terminal T.EQ >> p_tp >> reduce r_dec
  | T.PROC => ST |> shift >> p_id >> push (Parms(nil, here ST)) >> p_parms >> p_terminal T.EQ >> p_proc >> reduce r_dec
  | T.EXEC => ST |> shift >> p_id >> reduce r_dec
  | T.FAIL  => ST |> shift >> p_dec_robust
  | T.EOF => ST
  | t => parse_error (here ST, "unexpected token " ^ pp_tok t ^ " at top level") (second ST)

(* <dec>, in the scope of a 'fail' *)
and p_dec_robust (ST as (S, ft)) =
    ErrorMsg.suppress (fn () => ST |> p_dec >> reduce r_dec)
    handle Lex.LexError (r,ts) => recover (S, M.force ts) |> push (Error(r)) >> reduce r_dec
         | ParseError (r,ft') => recover (S, ft') |> push (Error(r)) >> reduce r_dec

(* reducing <dec> *)
and r_dec (S $ Tok(T.TYPE,r1) $ Tok(T.IDENT(t),_) $ Tok(T.EQ,_) $ Tp(tau,r2)) =
    S $ Dec(E.TpDef(t,tau,PS.ext(join r1 r2)), join r1 r2)
  | r_dec (S $ Tok(T.PROC,r1) $ Tok(T.IDENT(p),_) $ Parms(xA::yBs, r2) $ Tok(T.EQ,_) $ Proc(P,r3)) =
    S $ Dec(E.ProcDef(p,xA,yBs,P,PS.ext(join r1 r3)), join r1 r3)
  | r_dec (S $ Tok(T.EXEC,r1) $ Tok(T.IDENT(p),r2)) =
    S $ Dec(E.Exec(p,PS.ext(join r1 r2)), join r1 r2)
  | r_dec (S $ Tok(T.FAIL,r1) $ Error(r2)) = S $ Dec(E.Fail(E.Error(PS.ext(r2)),PS.ext(join r1 r2)), join r1 r2)
  | r_dec (S $ Tok(T.FAIL,r1) $ Dec(d,r2)) = S $ Dec(E.Fail(d,PS.ext(join r1 r2)), join r1 r2)
  (* | r_dec S = ( TextIO.print (pp_stack S ^ "\n") ; raise Match ) *)

and p_parms ST = case first ST of
    T.LPAREN => ST |> shift >> p_id >> p_terminal T.COLON >> p_tp >> p_terminal T.RPAREN >> reduce r_parm
                   >> p_parms_next
  | t => error_expected (here ST, T.LPAREN, t) (second ST)

and p_parms_next ST = case first ST of
    T.LPAREN => ST |> p_parms
  | _ => ST

and r_parm (S $ Parms(xs, r1) $ Tok(T.LPAREN, _) $ Tok(T.IDENT(x), _) $ Tok(T.COLON, _) $ Tp(tau, _)
              $ Tok(T.RPAREN, r2)) =
    S $ Parms(xs @ [E.Tp(E.Var(x),tau)], join r1 r2)

(* <tp> *)
and p_tp ST = case first ST of
    T.IDENT(t) => ST |> shift >> reduce r_atom_tp >> p_tp
  | T.ONE => ST |> shift >> reduce r_atom_tp >> p_tp
  | T.STAR => ST |> drop >> p_tp_prec (TpInfix(5, E.Tensor, here ST))
  | T.LOLLI => ST |> drop >> p_tp_prec (TpInfix(4, E.Lolli, here ST))
  | T.LPAREN => ST |> shift >> p_tp >> p_terminal T.RPAREN >> reduce r_atom_tp >> p_tp
  | T.PLUS => ST |> shift >> p_terminal T.LBRACE >> push (Alts(nil, here ST)) >> p_alts
              >> p_terminal T.RBRACE >> reduce r_atom_tp >> p_tp
  | T.AMPERSAND => ST |> shift >> p_terminal T.LBRACE >> push (Alts(nil, here ST)) >> p_alts
                      >> p_terminal T.RBRACE >> reduce r_atom_tp >> p_tp
  | _ => ST |> reduce (r_tp ST) (* type complete: reduce *)

(* <alts> *)
and p_alts ST = case first ST of
    T.TAG(k) => ST |> shift >> p_terminal T.COLON >> p_tp >> reduce r_alts >> p_alts_next
  | _ => parse_error (here ST, "missing tag") (second ST)

(* [ ',' <alts> ] *)
and p_alts_next ST = case first ST of
    T.COMMA => ST |> drop >> p_alts
  | _ => ST

(* reduce <alts> *)
and r_alts (S $ Alts(tps,r1) $ Tok(T.TAG(k),_) $ Tok(T.COLON,_) $ Tp(tau,r2)) =
    S $ Alts(tps @ [(k,tau)], join r1 r2)

(* infix operators are all right associative *)
(* <tp>, operator precedence resolution *)
and p_tp_prec (opr as TpInfix(p',f',r')) (ST as (S,ft)) = case S of
    S $ Tp(tau1,r1) $ TpInfix(p,f,_) $ Tp(tau2,r2) =>
    if p > p' then p_tp_prec opr (S $ Tp(m_tp(f(tau1,tau2), join r1 r2), join r1 r2), ft) (* reduce *)
    else ST |> push opr >> p_tp (* shift, for p = p' since right associative *)
  | S $ TpInfix(p,f,r) => parse_error (join r r', "consecutive infix operators") ft
  | S $ Tp _ => ST |> push opr >> p_tp (* shift *)
  | _ => parse_error (r', "leading infix operator") ft

and p_tp_opt ST = case first ST of
    T.COLON => ST |> shift >> p_tp >> p_terminal T.LEFTARROW
  | T.LEFTARROW => ST |> shift
  | t => error_expected_list_h (here ST, [T.COLON, T.LEFTARROW], t, "add 'call' keyword?") (second ST)

(* reduce atomic <tp> *)
and r_atom_tp (S $ Tok(T.IDENT(t),r)) = S $ Tp(m_tp(E.TpName(t), r),r)
  | r_atom_tp (S $ Tok(T.ONE,r)) = S $ Tp(m_tp(E.One, r),r)
  | r_atom_tp (S $ Tok(T.PLUS,r1) $ Tok(T.LBRACE,_) $ Alts(alts,_) $ Tok(T.RBRACE,r2)) =
    S $ Tp(m_tp(E.Plus(alts), join r1 r2), join r1 r2)
  | r_atom_tp (S $ Tok(T.AMPERSAND,r1) $ Tok(T.LBRACE,_) $ Alts(alts,_) $ Tok(T.RBRACE,r2)) =
    S $ Tp(m_tp(E.With(alts), join r1 r2), join r1 r2)
  | r_atom_tp (S $ Tok(T.LPAREN,r1) $ Tp(tau,_) $ Tok(T.RPAREN,r2)) = S $ Tp(tau,join r1 r2) (* mark? *)

(* reduce <tp> *)
and r_tp ST (S $ Tp(tau1,r1) $ TpInfix(p,f,_) $ Tp(tau2,r2)) =
    r_tp ST (S $ Tp(m_tp(f(tau1,tau2), join r1 r2), join r1 r2))
  | r_tp ST (S $ Tp(tau1,r1) $ Tp(tau2,r2)) = parse_error (join r1 r2, "consecutive types") (second ST)
  | r_tp ST (S $ Tp(tau,r)) = S $ Tp(tau,r)  (* must come after (prefix and) infix cases *)
  | r_tp ST (S $ TpInfix(p,f,r)) = parse_error (join r (here ST), "incomplete type") (second ST)
  | r_tp ST S = parse_error (here ST, "empty type") (second ST)
  (* | r_tp ST S = ( TextIO.print (pp_stack S) ; raise Match ) *)

(* <proc> *)
and p_proc ST = case first ST of
    T.IDENT(x) => ST |> shift >> p_tp_opt >> p_proc >> p_terminal T.SEMICOLON >> p_proc >> reduce r_proc
  | T.SEND => ST |> shift >> p_id >> p_val >> reduce r_proc
  | T.RECV => ST |> shift >> p_id >> p_terminal T.LPAREN >> push (Branches(nil, here ST)) >> p_cont >> p_cont_end
  | T.FWD => ST |> shift >> p_id >> p_id >> reduce r_proc
  | T.CALL => ST |> shift >> p_id >> push (Vars(nil, here(ST))) >> p_ids >> reduce r_proc
  | T.LPAREN => ST |> shift >> p_proc >> p_terminal T.RPAREN >> reduce r_proc
  | t => error_expected_list (here ST, [T.IDENT("<id>"), T.SEND, T.RECV, T.FWD, T.CALL, T.LPAREN], t) (second ST)

(* for better error message, if possible *)
and p_cont_end ST = case first ST of
    T.RPAREN => ST |> shift >> reduce r_proc
  | T.BAR => error_expected_h (here ST, T.RPAREN, T.BAR, "need to quote a label in an earlier branch?") (second ST) (* fix?? *)
  | t => error_expected (here ST, T.RPAREN, t) (second ST)

and p_val ST = case first ST of
    T.IDENT(x) => ST |> shift >> reduce r_val_atom >> p_val_next
  | T.TAG(k) => ST |> shift >> p_val (* >> reduce (r_val ST) *)
  | T.LPAREN => ST |> shift >> p_val_opt >> p_terminal T.RPAREN >> reduce r_val_atom >> p_val_next
  | t => error_expected_list (here ST, [T.IDENT("<id>"), T.TAG("<tag>"), T.LPAREN], t)
         (second ST)

and p_val_next ST = case first ST of
    T.COMMA => ST |> shift >> p_val (* shift since right associative *)
  | _ => ST |> reduce (r_val ST) (* value complete: reduce; pass ST for error message *)

and p_val_opt ST = case first ST of
    T.RPAREN => ST (* do not shift or reduce, just recognize *)
  | _ => ST |> p_val >> reduce (r_val ST)

and r_val ST (S $ Tok(T.TAG(k),r1) $ Val(V,r2)) =
    r_val ST (S $ Val(m_val(E.Label(k,V), join r1 r2), join r1 r2))
  | r_val ST (S $ Val(V1,r1) $ Tok(T.COMMA, _) $ Val(V2,r2)) =
    r_val ST (S $ Val(m_val(E.Pair(V1,V2), join r1 r2), join r1 r2))
  | r_val ST (S $ Val(V,r)) = S $ Val(V,r)
  | r_val ST S = parse_error (here ST, "empty value") (second ST) (* fix?? *)

and r_val_atom (S $ Tok(T.LPAREN, r1) $ Tok(T.RPAREN, r2)) =
    S $ Val(m_val(E.Unit, join r1 r2), join r1 r2)
  | r_val_atom (S $ Tok(T.IDENT(x), r)) = S $ Val(m_val(E.Channel(E.Var(x)), r), r)
  | r_val_atom (S $ Tok(T.LPAREN, r1) $ Val(V,_) $ Tok(T.RPAREN, r2)) =
    S $ Val(m_val(V, join r1 r2), join r1 r2)
  | r_val_atom S = ( print (pp_stack S) ; raise Match )

(* <id>* *)
and p_ids ST = case first ST of
    T.IDENT(x) => ST |> shift >> reduce r_id >> p_ids_next
  | t => error_expected (here ST, T.IDENT("<id>"), t) (second ST)

and p_ids_next ST = case first ST of
    T.IDENT(x) => ST |> p_ids
  | _ => ST

and r_id (S $ Vars(ys, r1) $ Tok(T.IDENT(y), r2)) =
    S $ Vars(ys @ [E.Var(y)], join r1 r2)

and r_proc (S $ Tok(T.RECV, r1) $ Tok(T.IDENT(x), _) $ Tok(T.LPAREN,_) $ Branches(bs, r2) $ Tok(T.RPAREN,_)) =
    S $ Proc(m_proc(E.Recv(E.Var(x), E.Cont(bs)), join r1 r2), join r1 r2)
  | r_proc (S $ Tok(T.SEND, r1) $ Tok(T.IDENT(x), _) $ Val(V, r2)) =
    S $ Proc(m_proc(E.Send(E.Var(x), V), join r1 r2), join r1 r2)
  | r_proc (S $ Tok(T.IDENT(x), r1) $ Tok(T.COLON, _) $ Tp(tau, _)
                 $ Tok(T.LEFTARROW, _) $ Proc(P, _) $ Tok(T.SEMICOLON, _) $ Proc(Q, r2)) =
    S $ Proc(m_proc(E.Cut(E.Var(x), SOME(tau), P, Q), join r1 r2), join r1 r2)
  | r_proc (S $ Tok(T.IDENT(x), r1) $ Tok(T.LEFTARROW, _) $ Proc(P, _) $ Tok(T.SEMICOLON, _) $ Proc(Q,r2)) =
    S $ Proc(m_proc(E.Cut(E.Var(x), NONE, P, Q), join r1 r2), join r1 r2)
  | r_proc (S $ Tok(T.FWD, r1) $ Tok(T.IDENT(x), _) $ Tok(T.IDENT(y), r2)) =
    S $ Proc(m_proc(E.Fwd(E.Var(x), E.Var(y)), join r1 r2), join r1 r2)
  | r_proc (S $ Tok(T.CALL, r1) $ Tok(T.IDENT(p), _) $ Vars(x::ys, r2)) =
    S $ Proc(m_proc(E.Call(p, x, ys), join r1 r2), join r1 r2)
  | r_proc (S $ Tok(T.LPAREN, r1) $ Proc(P, _) $ Tok(T.RPAREN,r2)) =
    S $ Proc(m_proc(P, join r1 r2), join r1 r2)

and p_cont ST = ST |> p_branch >> p_cont_next
and p_cont_next ST = case first ST of
    T.BAR => ST |> drop >> p_cont
  | _ => ST (* |> reduce r_branch *) (* already reduced? *)

and p_branch ST = ST |> p_val >> p_terminal T.RIGHTARROW >> p_proc >> reduce r_branch

and r_branch (S $ Branches(bs, r1) $ Val(V,_) $ Tok(T.RIGHTARROW,_) $ Proc(P, r2)) =
    S $ Branches(bs @ [(V,P)], join r1 r2)

(* <id> *)
and p_id ST = case first ST of
    T.IDENT(id) => ST |> drop >> push (Tok(T.IDENT(id), here ST))
  | t => parse_error (here ST, "expected identifier, found " ^ pp_tok t) (second ST)

(* <tag> *)
and p_tag ST = case first ST of
    T.TAG(l) => ST |> shift
  | t => parse_error (here ST, "expected tag, found " ^ pp_tok t) (second ST)

(* parse any token (terminal symbol) 't_needed' *)
and p_terminal t_needed ST = case first ST of
    t => if t_needed = t
	 then ST |> shift
	 else error_expected (here ST, t_needed, t) (second ST)

fun parse_decs token_front =
    let
        val ST = p_dec (Bot, token_front)
    in
        case ST
         of (Bot, M.Cons((T.EOF, r), token_front)) => [] (* whole file processed *)
          | (Bot $ Dec(d,r), token_front) => d::parse_decs token_front
    end

(* parse filename = decs
 * first apply lexer, the parser to the resulting token stream
 * raise ErrorMsg.Error if not lexically or syntactically correct
 * or if file does not exist or cannot be opened
 *)
fun parse filename =
    SafeIO.withOpenIn filename (fn instream =>
      let val () = PS.pushfile filename (* start at pos 0 in filename *)
          val token_stream = Lex.makeLexer (fn _ => TextIO.input instream)
          val decs = parse_decs (M.force token_stream)
          val () = PS.popfile ()
      in decs end)
    handle e as IO.Io _ => ( ErrorMsg.error NONE (exnMessage e)
                           ; raise ErrorMsg.Error )
         | Lex.LexError (r,ts) => raise ErrorMsg.Error
         | ParseError (r,ft') => raise ErrorMsg.Error

end (* structure Parse *)
