(* pcfcode.sml 
   Implementation of PCF-related functions: unparsing, maintenance
   of PCF environment, compilation to sequential combinators. *)

signature PCF_CODE =
    sig
    exception ExitPCFLoop
    val pcfCommand : 
	      ((CDSBasic.expr -> unit)*(CDSBasic.command_tree -> unit)) ->
	      PCFBasic.PCF_ParseTree -> unit
    val pcf : ((CDSBasic.expr -> unit)*(CDSBasic.command_tree -> unit)) -> unit
    end
    

functor PcfCodeFUN (structure PCFParser : PARSER
		    structure PCFLex : LEXER
		    structure Interface : INTERFACE
		    structure Parser : PARSER_INTERFACE
		    structure Printer : PRINTER) : PCF_CODE =
    struct
    local open CDSBasic 
	  open PCFBasic
    in
    exception ExitPCFLoop

    datatype deBruijn = Free of string
                      | Bound of int
		      | Abs of string * deBruijn
		      | Apply of deBruijn * deBruijn
		      | Intconst of int
		      | Boolconst of bool
		      | Fix of deBruijn
		      | Couple of deBruijn * deBruijn
		      | Fst of deBruijn
		      | Snd of deBruijn
		      | Cons of deBruijn * deBruijn
		      | Head of deBruijn
                      | Tail of deBruijn
                      | Nil
		      | Null of deBruijn
		      | Cond of deBruijn * deBruijn * deBruijn
		      | Plus of deBruijn * deBruijn
		      | Minus of deBruijn * deBruijn
		      | Times of deBruijn * deBruijn
		      | Div of deBruijn * deBruijn
		      | Equal of deBruijn * deBruijn
		      | Less of deBruijn * deBruijn
		      | Grtr of deBruijn * deBruijn
		      | Leq of deBruijn * deBruijn
		      | Geq of deBruijn * deBruijn
		      | And of deBruijn * deBruijn
		      | Or of deBruijn * deBruijn

    exception FindIndex
    fun findIndex s [] _ = raise FindIndex
      | findIndex s (x::l) i = if (s=x) then i else findIndex s l (i+1)

    fun todeBruijn (pcf_Bool b) _ = Boolconst b
      | todeBruijn (pcf_Int i) _ = Intconst i
      | todeBruijn (pcf_Ident s) l = 
	  (Bound(findIndex s l 0) handle FindIndex => Free s)
      | todeBruijn (pcf_App(e1,e2)) l = Apply(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Lam(s,e)) l = Abs(s, todeBruijn e (s::l))
      | todeBruijn (pcf_Let(s,e1,e2)) l = 
	  todeBruijn(pcf_App(pcf_Lam(s,e2),e1)) l
      | todeBruijn (pcf_Letrec(s,e1,e2)) l =
	  let val a1 = todeBruijn (pcf_Lam(s,e2)) l
	      val a2 = todeBruijn (pcf_Lam(s,e1)) l
	  in Apply(a1, Fix a2)
	  end
      | todeBruijn (pcf_Cond(e1,e2,e3)) l = 
	  Cond(todeBruijn e1 l,todeBruijn e2 l,todeBruijn e3 l)
      | todeBruijn (pcf_Couple(e1,e2)) l = 
	  Couple(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Fst e) l = Fst (todeBruijn e l)
      | todeBruijn (pcf_Snd e) l = Snd (todeBruijn e l)
      | todeBruijn (pcf_Cons(e1,e2)) l = 
	  Cons (todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Head e) l = Head (todeBruijn e l)
      | todeBruijn (pcf_Tail e) l = Tail (todeBruijn e l)
      | todeBruijn pcf_Nil l = Nil
      | todeBruijn (pcf_Null e) l = Null (todeBruijn e l)
      | todeBruijn (pcf_Bop(pcf_Plus(e1,e2))) l = 
	  Plus(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Minus(e1,e2))) l = 
	  Minus(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Times(e1,e2))) l = 
	  Times(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Div(e1,e2))) l = 
	  Div(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Equal(e1,e2))) l = 
	  Equal(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Less(e1,e2))) l = 
	  Less(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Grtr(e1,e2))) l = 
	  Grtr(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Leq(e1,e2))) l =
	  Leq(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Geq(e1,e2))) l = 
	  Geq(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_And(e1,e2))) l = 
	  And(todeBruijn e1 l, todeBruijn e2 l)
      | todeBruijn (pcf_Bop(pcf_Or(e1,e2))) l = 
	  Or(todeBruijn e1 l, todeBruijn e2 l)

	(* Transform a bool or int const into corresponding state *)
    fun getBoolConst true = Expr_state [(Cell_name "B", Val_string "tt")]
      | getBoolConst false = Expr_state [(Cell_name "B", Val_string "ff")]

    fun getIntConst n = Expr_state [(Cell_name "N", Val_arexpr(Arexpr_int n))]

	(* Constructs curry(fst).x *)
    fun getConst x = Expr_apply(Expr_curry(Expr_id "fst"), x)

	(* Bound identifier n = snd o (fst)^n *)
    fun getBound 0 = Expr_id "snd"
      | getBound n = Expr_compose(getBound (n-1), Expr_id "fst")

    fun toCC (Boolconst b) = getConst(getBoolConst b)
      | toCC (Intconst n) = getConst(getIntConst n)
      | toCC (Bound n) = getBound n
      | toCC (Free s) = getConst (Expr_id s)
      | toCC (Apply(e1,e2)) = Expr_compose(Expr_uncurry(Expr_id "id"), Expr_pair(toCC e1, toCC e2))
      | toCC (Abs(s,e)) = Expr_curry(toCC e)
      | toCC (Fix e) = Expr_apply(Expr_id "Yenv", toCC e)
      | toCC (Cond(e1,e2,e3)) = Expr_compose(Expr_id "cond", 
	      Expr_pair(Expr_pair(toCC e1, toCC e2), toCC e3))
      | toCC (Couple(e1,e2)) = Expr_pair(toCC e1, toCC e2)
      | toCC (Fst e) = Expr_compose(Expr_id "fst", toCC e)
      | toCC (Snd e) = Expr_compose(Expr_id "snd", toCC e)
      | toCC (Cons(e1,e2)) = 
	  Expr_compose(Expr_id "cons", Expr_pair(toCC e1,toCC e2))
      | toCC (Head e) = Expr_compose(Expr_id "hd", toCC e)
      | toCC (Tail e) = Expr_compose(Expr_id "tl", toCC e)
      | toCC Nil = getConst (Expr_id "nil")
      | toCC (Null e) = Expr_compose(Expr_id "null", toCC e)
      | toCC (Plus(e1,e2)) = getBop ("plus",e1,e2)
      | toCC (Minus(e1,e2)) = getBop ("minus",e1,e2)
      | toCC (Times(e1,e2)) = getBop ("times",e1,e2)
      | toCC (Div(e1,e2)) = getBop ("div",e1,e2)
      | toCC (Equal(e1,e2)) = getBop ("equal",e1,e2)
      | toCC (Less(e1,e2)) = getBop ("less",e1,e2)
      | toCC (Grtr(e1,e2)) = getBop ("grtr",e1,e2)
      | toCC (Leq(e1,e2)) = getBop ("leq",e1,e2)
      | toCC (Geq(e1,e2)) = getBop ("geq",e1,e2)
      | toCC (And(e1,e2)) = getBop ("land",e1,e2)
      | toCC (Or(e1,e2)) = getBop ("lor",e1,e2)

    and getBop (s,e1,e2) = Expr_compose(Expr_id s,Expr_pair(toCC e1,toCC e2))

    fun pcfCommand (ProcessEXP,_) (PCF_Exp e) = 
	  let val code = toCC (todeBruijn e [])
	  in ProcessEXP (Expr_apply(code, Expr_id "emptyenv"))
	  end
      | pcfCommand (_,ProcessCOM) (PCF_Val(s,e)) = 
	  (let val code = toCC (todeBruijn e [])
	   in 
	       ProcessCOM(Com_abbreviate(s,Expr_apply(code, Expr_id "emptyenv")))
	   end handle CDSEnv.Lookup s => output(std_out, "Error: Identifier "^s^" not defined\n") )
      | pcfCommand (ProcessEXP,ProcessCOM) (PCF_Load file) = 
	(Parser.pcf_load (file, pcfCommand (ProcessEXP,ProcessCOM))
	    handle Io s => output(std_out, "Error: "^s^".\n"))
      | pcfCommand _ (PCF_Print s) = 
	let val (_,e,_) = CDSEnv.lookupExpType(s,!CDSEnv.nameExpRTypeList)
	in output(std_out, s^" = "^(Printer.unparseExpr e)^"\n")
	end
      | pcfCommand _ (PCF_Quit) = raise ExitPCFLoop
      | pcfCommand _ (PCF_Empty) = ()

    fun pcf (ProcessEXP, ProcessCOM) =
	while true do
	    (output(std_out, "$ ");
	     (let val new_input = Parser.pcf_kbd()
	      in pcfCommand (ProcessEXP,ProcessCOM) new_input
	      end) handle PCFParser.ParseError => 
		          output(std_out, "Error: Parsing error.\n")
		    | PCFLex.LexError => 
			output(std_out, "Error: Lexer: illegal symbol used.\n")
                    | Interface.CommentError s =>
			  output(std_out, "Error: "^s^"\n")
	    )

    end
    end;
    
