(* Lexer *)
(* Author: Frank Pfenning <fp@cs.cmu.edu> *)

(*
 * A simple stream-processing lexer that supports single-line and
 * nestable comments.  A lexical error exception carries a continuation
 * stream so that lexing can continue past the error.
 *)

signature LEX =
sig
    type lexresult = Terminal.terminal * (int * int)
    exception LexError of lexresult Stream.stream (* continuation stream *)

    val makeLexer : (int -> string) -> lexresult Stream.stream
end

structure Lex :> LEX =
struct

structure PS = ParseState
structure M = Stream
structure T = Terminal

(* lexresult = (token, (lpos, rpos)) *)
type lexresult = T.terminal * (int * int)

exception LexError of lexresult Stream.stream (* continuation stream *)

fun error (lpos, rpos) msg token_stream =
    ( ErrorMsg.error (PS.ext(lpos,rpos)) msg
    ; raise LexError token_stream )

fun id_start_char c =
    Char.isAlpha c
    orelse c = #"_"

fun id_char c =
    id_start_char c
    orelse Char.isDigit c
    orelse c = #"$"

(* run_cond cond (n, accum, cs) = (string, n', cs')
 * accumulate characters in character stream 'cs' satisfying 'cond' in
 * 'accum' and return string, number of characters, and remaining
 * character stream
 *)
fun run_cond cond (n, accum, cs) =
    case M.force cs
     of M.Nil => (String.implode (List.rev accum), n, cs)
      | M.Cons (c, cs') =>
        if cond c
        then run_cond cond (n+1, c::accum, cs')
        else (String.implode (List.rev accum), n, cs)

(* lex_code (pos, charstream) = (token, lpos, rpos, cs') *)
(* token is the lexed token, [lpos,rpos) is the source region,
 * and cs' is the remaining character stream
 *)
fun lex_code (pos, charstream) =
    case M.force charstream
     of M.Nil => (T.EOF, pos, pos, charstream)
      (* Whitespace *)
      | M.Cons (#" ", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\t", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\r", cs) => lex_code (pos+1, cs)
      | M.Cons (#"\n", cs) =>
        ( PS.newline pos        (* track newlines for error messages *)
        ; lex_code (pos+1, cs) )
      (* Separators *)
      | M.Cons (#",", cs) => (T.COMMA, pos, pos+1, cs)
      | M.Cons (#":", cs) => (T.COLON, pos, pos+1, cs)
      | M.Cons (#"/", cs) =>
        (case M.force cs
          of M.Cons (#"/", cs) => lex_comment_line (pos+2, cs)
           | M.Cons (#"*", cs) => lex_comment (pos+2, cs, 1)
           | M.Cons (c, cs) => error (pos, pos+2) ("illegal sequence: '/" ^ Char.toString c ^ "'")
                                     (fn () => lexer (pos+2, cs))
           | M.Nil => error (pos, pos+1) ("illegal final character: '/'")
                      (fn () => lexer (pos+1, cs)))
      | M.Cons (#"(", cs) => (T.LPAREN, pos, pos+1, cs)
      | M.Cons (#")", cs) => (T.RPAREN, pos, pos+1, cs)
      | M.Cons (#"{", cs) => (T.LBRACE, pos, pos+1, cs)
      | M.Cons (#"}", cs) => (T.RBRACE, pos, pos+1, cs)
      | M.Cons (#"=", cs) =>
        (case M.force cs
          of M.Cons(#">", cs) => (T.RIGHTARROW, pos, pos+2, cs)
           | _ => (T.EQ, pos, pos+1, cs))
      | M.Cons (#"*", cs) => (T.STAR, pos, pos+1, cs)
      | M.Cons (#"+", cs) => (T.PLUS, pos, pos+1, cs)
      | M.Cons (#"|", cs) => (T.BAR, pos, pos+1, cs)
      | M.Cons (#"'", cs) =>
        let val (label_string, n, cs) = run_cond id_char (0, [], cs)
            val () = if n = 0 then error (pos, pos+1) "tick not followed by label name"
                                         (fn () => lexer (pos+1, cs))
                     else ()
        in (T.LABEL(label_string), pos, pos+n+1, cs) end
      | M.Cons (c, cs') =>
        if Char.isDigit c
        then let val (num_string, n, cs) = run_cond Char.isDigit (0, [], charstream)
                 val num = Option.valOf(Int.fromString(num_string))
                     handle Overflow => error (pos, pos+n) ("number '" ^ num_string ^ "' too large")
                                        (fn () => lexer (pos+n, cs))
             in (T.NAT(num), pos, pos+n, cs) end
        else if id_start_char c
        then (case run_cond id_char (0, [], charstream)
               of ("read", n, cs) => (T.READ, pos, pos+n, cs)
                | ("write", n, cs) => (T.WRITE, pos, pos+n, cs)
                | ("cut", n, cs) => (T.CUT, pos, pos+n, cs)
                | ("id", n, cs) => (T.ID, pos, pos+n, cs)
                | ("call", n, cs) => (T.CALL, pos, pos+n, cs)
                | ("type", n, cs) => (T.TYPE, pos, pos+n, cs)
                | ("proc", n, cs) => (T.PROC, pos, pos+n, cs)
                | ("fail", n, cs) => (T.FAIL, pos, pos+n, cs)
                | ("value", n, cs) => (T.VALUE, pos, pos+n, cs)
                | (ident, n, cs) => (T.IDENT(ident), pos, pos+n, cs))
        else error (pos, pos+1) ("illegal character: '" ^ Char.toString c ^ "'")
                   (fn () => lexer (pos+1, cs'))

(* single-line comment // ... \n *)
and lex_comment_line (pos, charstream) =
    case M.force charstream
     of M.Nil => (T.EOF, pos, pos, charstream)
      | M.Cons (#"\n", cs) =>
        ( PS.newline pos
        ; lex_code (pos+1, cs) )
      | M.Cons (_, cs) => lex_comment_line (pos+1, cs)

(* delimited comment /* ... */ *)
(* may be nested *)
and lex_comment (pos, charstream, depth) = (* depth >= 1 *)
    case M.force charstream
     of M.Nil => error (pos, pos) ("unclosed delimited comment: reached end of file")
                       (fn () => lexer (pos, charstream))
      | M.Cons(#"\n", cs) =>
        ( PS.newline pos
        ; lex_comment (pos+1, cs, depth) )
      | M.Cons(#"*", cs) =>
        (case M.force cs
          of M.Cons(#"/", cs) =>
             (if depth = 1 then lex_code (pos+2, cs)
              else lex_comment (pos+2, cs, depth-1))
           | _ => lex_comment (pos+1, cs, depth))
      | M.Cons(#"/", cs) =>
        (case M.force cs
          of M.Cons(#"*", cs) => lex_comment (pos+2, cs, depth+1)
           | _ => lex_comment (pos+1, cs, depth))
      | M.Cons(_, cs) => lex_comment (pos+1, cs, depth)

and lexer (pos, charstream) =
    let val (token, left_pos, right_pos, charstream) = lex_code (pos, charstream)
    in M.Cons ((token, (left_pos, right_pos)), fn () => lexer (right_pos, charstream)) end

(* 
 * some infrastructure to allow strings, files, and interactive streams to
 * be lexed (currently unused)
 *)
fun buffered_stream source = 
    let
        fun use_buf (str, len, i) = 
            if i = len 
            then refill_buf ()
            else fn () => M.Cons (String.sub (str, i), use_buf (str, len, i+1))

        and refill_buf () =
            let
                val memo = ref (fn () => raise Match)
                fun read () =
                    let val ans = 
                            case source 1024 of 
                                "" => M.Nil
                              | s => M.Cons (String.sub (s, 0), use_buf (s, size s, 1))
                    in memo := (fn () => ans); ans end
            in memo := read; (fn () => (!memo) ()) end
    in refill_buf () end

fun str_stream str = 
    let val r = ref false
    in buffered_stream (fn _ => if !r then "" else (r := true; str)) end

(* start counting at pos = 1 for error messages *)
fun makeLexer source = fn () => lexer (1, buffered_stream source)

end (* struct Lex *)
