(* Lexer *)
(* Authors: Frank Pfenning <fp@cs.cmu.edu> 
 *          Ankush Das <ankushd@cs.cmu.edu>
 *)

signature LEX =
sig
    type lexresult = Terminal.terminal * (int * int)
    exception LexError of lexresult Stream.stream (* continuation stream *)

    val makeLexer : (int -> string) -> lexresult Stream.stream
end

structure Lex :> LEX =
struct

structure PS = ParseState
structure M = Stream
structure T = Terminal

(* lexing error *)
(* (lpos,rpos)  is source location information *)
(* msg is error message *)
(*
fun error (lpos,rpos) msg = ( ErrorMsg.error (PS.ext(lpos,rpos)) msg
                            ; raise ErrorMsg.Error )
 *)

(* lexresult = (token, (lpos, rpos)) *)
type lexresult = T.terminal * (int * int)

exception LexError of lexresult Stream.stream (* continuation stream *)

fun error (lpos, rpos) msg token_stream =
    ( ErrorMsg.error (PS.ext(lpos,rpos)) msg
    ; raise LexError token_stream )

fun id_start_char c =
    Char.isAlpha c
    orelse c = #"_"

fun id_char c =
    id_start_char c
    orelse Char.isDigit c
    orelse c = #"'"


(* run_cond cond (n, accum, cs) = (string, n', cs')
 * accumulate characters in character stream 'cs' satisfying 'cond' in
 * 'accum' and return string, number of characters, and remaining
 * character stream
 *)
fun run_cond cond (n, accum, cs) =
    case M.force cs
     of M.Nil => (String.implode (List.rev accum), n, cs)
      | M.Cons (c, cs') =>
        if cond c
        then run_cond cond (n+1, c::accum, cs')
        else (String.implode (List.rev accum), n, cs)

(* lex_code (pos, charstream) = (token, lpos, rpos, cs') *)
(* token is the lexed token, [lpos,rpos) is the source region,
 * and cs' is the remaining character stream
 *)
fun lex_code (pos, charstream) =
    case M.force charstream
     of M.Nil => (T.EOF, pos, pos, charstream)
      (* Pragma *)
      | M.Cons (#"#", cs) => lex_pragma (pos+1, cs)
      (* Whitespace *)
      | M.Cons (#" ", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\t", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\011", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\012", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\r", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\n", cs) =>
        ( PS.newline pos        (* track newlines for error messages *)
        ; lex_code (pos+1, cs) )
      (* Separators *)
      | M.Cons (#"\\", cs) => (T.BACKSLASH, pos, pos+1, cs)
      | M.Cons (#".", cs) => (T.DOT, pos, pos+1, cs)
      | M.Cons (#",", cs) => (T.COMMA, pos, pos+1, cs)
      | M.Cons (#"!", cs) => (T.EXCL, pos, pos+1, cs)
      | M.Cons (#"?", cs) => (T.QUEST, pos, pos+1, cs)
      | M.Cons (#"$", cs) => (T.DOLLAR, pos, pos+1, cs)
      | M.Cons (#":", cs) => (T.COLON, pos, pos+1, cs)
      | M.Cons (#"%", cs) => lex_comment_line (pos+1, cs)
      | M.Cons (#"[", cs) => (T.LBRACKET, pos, pos+1, cs)
      | M.Cons (#"]", cs) => (T.RBRACKET, pos, pos+1, cs)
      | M.Cons (#"(", cs) =>
        (case M.force cs
          of M.Cons(#"*", cs) => lex_comment (pos+2, cs, 1)
           | M.Cons(#"|", cs) => (T.LBANANA, pos, pos+2, cs)
           | _ => (T.LPAREN, pos, pos+1, cs))
      | M.Cons (#")", cs) => (T.RPAREN, pos, pos+1, cs)
      | M.Cons (#"-", cs) =>
        (case M.force cs
          of M.Cons(#">", cs) => (T.ARROW, pos, pos+2, cs)
           | _ => (T.MINUS, pos, pos+1, cs))
      | M.Cons (#"/", cs) =>
        (case M.force cs
          of M.Cons(#"\\", cs) => (T.SLASHBACK, pos, pos+2, cs)
           | _ => (T.SLASH, pos, pos+1, cs))
      | M.Cons (#"=", cs) =>
        (case M.force cs
          of M.Cons(#">", cs) => (T.RIGHTARROW, pos, pos+2, cs)
           | _ => (T.EQ, pos, pos+1, cs))
      | M.Cons (#"*", cs) => (T.STAR, pos, pos+1, cs)
      | M.Cons (#"+", cs) => (T.PLUS, pos, pos+1, cs)
      | M.Cons (#"&", cs) => (T.AMPERSAND, pos, pos+1, cs)
      | M.Cons (#"|", cs) =>
        (case M.force cs
          of M.Cons(#")", cs) => (T.RBANANA, pos, pos+2, cs)
           | _ => (T.BAR, pos, pos+1, cs))
      | M.Cons (#"'", cs) =>
        let val (label_string, n, cs) = run_cond id_char (0, [], cs)
            val () = if n = 0 then error (pos, pos+1) "tick not followed by label name"
                                         (fn () => lexer (pos+1, cs))
                     else ()
        in (T.LABEL(label_string), pos, pos+n+1, cs) end
      | M.Cons (c, cs') =>
        if Char.isDigit c
        then let val (num_string, n, cs) = run_cond Char.isDigit (0, [], charstream)
                 val num = Option.valOf(Int.fromString(num_string))
                     handle Overflow => error (pos, pos+n) ("number '" ^ num_string ^ "' too large")
                                        (fn () => lexer (pos+n, cs))
             in (T.NAT(num), pos, pos+n, cs) end
        else if id_start_char c
        then (case run_cond id_char (0, [], charstream)
               of ("case", n, cs) => (T.CASE, pos, pos+n, cs)
                | ("of", n, cs) => (T.OF, pos, pos+n, cs)
                | ("fold", n, cs) => (T.FOLD, pos, pos+n, cs)
                | ("unfold", n, cs) => (T.UNFOLD, pos, pos+n, cs)
                | ("type", n, cs) => (T.TYPE, pos, pos+n, cs)
                | ("decl", n, cs) => (T.DECL, pos, pos+n, cs)
                | ("defn", n, cs) => (T.DEFN, pos, pos+n, cs)
                | ("norm", n, cs) => (T.NORM, pos, pos+n, cs)
                | ("conv", n, cs) => (T.CONV, pos, pos+n, cs)
                | ("eval", n, cs) => (T.EVAL, pos, pos+n, cs)
                | ("fail", n, cs) => (T.FAIL, pos, pos+n, cs)
                | (ident, n, cs) => (T.IDENT(ident), pos, pos+n, cs))
        else error (pos, pos+1) ("illegal character: '" ^ Char.toString c ^ "'")
                   (fn () => lexer (pos+1, cs'))

(* single-line comment % ... \n *)
and lex_comment_line (pos, charstream) =
    case M.force charstream
     of M.Nil => (T.EOF, pos, pos, charstream)
      | M.Cons (#"\n", cs) =>
        ( PS.newline pos
        ; lex_code (pos+1, cs) )
      | M.Cons (_, cs) => lex_comment_line (pos+1, cs)

(* single-line pragma #<pragma> ... *)
and lex_pragma (pos, charstream) =
    case run_cond id_char (1, [#"#"], charstream)
      of ("#options", n, cs) =>
         (case run_cond (fn c => c <> #"\n") (0, [], cs)
           (* do not process newline *)
           of (line, m, cs) => (T.PRAGMA("#options", line), pos-1, pos-1+n+m, cs))
       | ("#test", n, cs) =>
         (case run_cond (fn c => c <> #"\n") (0, [], cs)
           of (line, m, cs) => (T.PRAGMA("#test", line), pos-1, pos-1+n+m, cs))
       | (s, n, cs) => error (pos-1, pos-1+n) ("unrecognized pragma: " ^ s)
                             (fn () => lexer (pos-1+n, cs))

(* delimited comment (* ... *) *)
and lex_comment (pos, charstream, depth) = (* depth >= 1 *)
    case M.force charstream
     of M.Nil => error (pos, pos) ("unclosed delimited comment: reached end of file")
                       (fn () => lexer (pos, charstream))
      | M.Cons(#"\n", cs) =>
        ( PS.newline pos
        ; lex_comment (pos+1, cs, depth) )
      | M.Cons(#"*", cs) =>
        (case M.force cs
          of M.Cons(#")", cs) =>
             (if depth = 1 then lex_code (pos+2, cs)
              else lex_comment (pos+2, cs, depth-1))
           | _ => lex_comment (pos+1, cs, depth))
      | M.Cons(#"(", cs) =>
        (case M.force cs
          of M.Cons(#"*", cs) => lex_comment (pos+2, cs, depth+1)
           | _ => lex_comment (pos+1, cs, depth))
      | M.Cons(_, cs) => lex_comment (pos+1, cs, depth)

and lexer (pos, charstream) =
    let val (token, left_pos, right_pos, charstream) = lex_code (pos, charstream)
    in M.Cons ((token, (left_pos, right_pos)), fn () => lexer (right_pos, charstream)) end

(* some infrastructure to allow strings, files, and
 * interactive streams to be lexed
 *)
fun buffered_stream source = 
    let
        fun use_buf (str, len, i) = 
            if i = len 
            then refill_buf ()
            else fn () => M.Cons (String.sub (str, i), use_buf (str, len, i+1))

        and refill_buf () =
            let
                val memo = ref (fn () => raise Match)
                fun read () =
                    let val ans = 
                            case source 1024 of 
                                "" => M.Nil
                              | s => M.Cons (String.sub (s, 0), use_buf (s, size s, 1))
                    in memo := (fn () => ans); ans end
            in memo := read; (fn () => (!memo) ()) end
    in refill_buf () end

fun str_stream str = 
    let val r = ref false
    in buffered_stream (fn _ => if !r then "" else (r := true; str)) end

(* start counting at pos = 1 for error messages *)
fun makeLexer source = fn () => lexer (1, buffered_stream source)

end (* struct Lex *)
