(*
 * compile/cunit.sml: CM's representation of `compiled units'
 *
 *   Copyright (c) 1995 by AT&T Bell Laboratories
 *
 * author: Matthias Blume (blume@cs.princeton.edu)
 *)
functor CUnitFun (structure Iid: IID
		  structure Control: CONTROL): CUNIT = struct

    structure Iid = Iid
    structure Compiler = Iid.Compiler
    structure Pid = Compiler.PersStamps
    structure Env = Compiler.Environment
    structure Pickle = Compiler.PickleEnv
    structure Bare = Compiler.BareEnvironment
    structure Err = Compiler.ErrorMsg
    structure Comp = Compiler.Compile
    structure Print = Compiler.Control.Print
    structure Source = Compiler.Source

    exception FormatError and Outdated and Compile of string and NoCodeBug

    type pid = Pid.persstamp
    type senv = Env.staticEnv
    type denv = Env.dynenv
    type symenv = Env.symenv
    type env = Env.environment
    type iid = Iid.t
    type iidset = Iid.set
    type obj = System.Unsafe.object
    type objvec = obj Vector.vector

    datatype code =
	DISCARDED
      | CSEGS of Comp.csegments
      | CLOSURE of objvec -> obj

    datatype t =
	CU of {
	       imports: pid list,
	       exportPid: pid option,
	       references: iidset,
	       iid: iid,
	       senv: senv,
	       penv: Compiler.PickleEnv.pickledEnv,
	       env: env option ref,
	       code: code ref,
	       lambda_i: Comp.clambda option
	      }

    fun codeSegments (CU { code = ref (CSEGS s), ... }) = s
      | codeSegments _ = raise NoCodeBug

    fun codeClosure (CU { code, ... }) =
	case !code of
	    DISCARDED => raise NoCodeBug
	  | CLOSURE obj => obj
	  | CSEGS s => let
		val obj = Comp.applyCode s
	    in
		code := CLOSURE obj;
		obj
	    end

    fun discardCode (CU { code, ... }) = code := DISCARDED

    fun iid (CU { iid = i, ... }) = i
    fun senv (CU { senv = se, ... }) = se
    fun symenv (CU { exportPid, lambda_i, ... }) =
	Comp.symDelta (exportPid, lambda_i)
    fun env (CU { env = ref e, ... }) = e

    val << = Word32.<<
    val >> = Word32.>>
    val || = Word32.orb
    val && = Word32.andb
    val fromInt = Word32.fromInt
    val toInt = Word32.toInt
    fun fromByte b = Word32.fromLargeWord(Word8.toLargeWord b)
    fun toByte w = Word8.fromLargeWord(Word32.toLargeWord w)
    infix << >> || &&

    (*
     * layout of binfiles:
     *  - 0..x-1:		magic string (length = x)
     *  - x..x+3:		# of imports (= y)
     *  - x+4..x+7:		# of exports (= z)
     *  - x+8..x+11:		# of envPids
     *  - x+12..x+15:		size lambda_i
     *  - x+16..x+19:		size reserved area1
     *  - x+20..x+23:		size reserved area2
     *  - x+24..x+27:		size code
     *  - x+28..x+31:		size env
     *  - x+32..x+y+31:		import pids
     *  - x+y+32..x+y+z+31:	export pids
     *  - ...			env pids
     *  - 			lambda_i
     *  - 			reserved area1
     *  - 			reserved area2
     *  -			code
     *  -			pickled_env
     *  EOF
     *
     * All counts and sizes are represented in big-endian layout.
     * This should be tracked by the run-time system header file
     * "runtime/include/bin-file.h"
     *)

    val MAGIC = let
	fun fit (i, s) = let
	    val sz = size s
	    fun pad (i, s) = if (size s) >= i then s else pad (i, s ^ " ")
	in
	    if (sz = i) then s
	    else if (sz < i) then pad (i, s)
		 else substring (s, 0, i)
	end
	fun version [] = []
	  | version [x : int] = [makestring x]
	  | version (x :: r) = (makestring x) :: "." :: (version r)
	val v = fit (6, concat (version (#version_id Compiler.version)))
	val a = Compiler.architecture
	val a =
	    if String.sub (a, 0) = #"." then
		fit (6, substring (a, 1, (String.size a) - 1))
	    else
		fit (6, a)
	val r = fit (3, !System.runtimeStamp)
    in
	concat [v, a, r, "\n"]
    end

    val magicBytes = String.size MAGIC
    val bytesPerPid = 16

    fun sin (s, n) = let
	val r = Byte.bytesToString (BinIO.inputN (s, n))
	val l = String.size r
    in
	if n = l then r else raise FormatError
    end

(** NOTE: this should use the Pack32Big structure, once it is available **)
    fun readInt32 s = let
	val vec = BinIO.inputN (s, 4)
	val _  = if Word8Vector.length vec = 4 then () else raise FormatError
        fun r i = fromByte (Word8Vector.sub (vec, i))
	val b1 = (r 0) << 0w24
	val b2 = (r 1) << 0w16
	val b3 = (r 2) << 0w8
	val b4 = r 3
    in
	toInt(b1 || b2 || b3 || b4)
    end

(** NOTE: we should really be using Word8Vector.vector for PIDS. **)
    fun sout (s, str) = BinIO.output (s, Byte.stringToBytes str)

    fun writeInt32 s i = let
	val w = fromInt i
	fun out w = BinIO.output1 (s, toByte w)
        in
	  out (w >> 0w24); out (w >> 0w16);  out (w >> 0w8); out w
	end

    val mkPid = Pid.stringToStamp
    val mkString = Pid.stampToString

    fun readPid s = mkPid (sin (s, bytesPerPid))
    fun readIid s = Iid.intern (sin (s, Iid.size))

    fun readIdList rid (s, n) = let
	fun loop (n, a) =
	    if n > 0 then
		loop (n - 1, (rid s) :: a)
	    else
		rev a
    in
	loop (n, [])
    end

    val readPidList = readIdList readPid
    val readIidList = readIdList readIid

    fun writePidList (s, l) = app (fn p => sout (s, mkString p)) l
    fun writeIidList (s, l) = app (fn i => sout (s, Iid.extern i)) l

    fun checkMagic s =
	if (sin (s, magicBytes)) = MAGIC then () else raise FormatError

    fun readHeader s = let
	val _ = checkMagic s
	val ni = readInt32 s
	val ne = readInt32 s
	val cmInfoSzB = readInt32 s
	val nei = cmInfoSzB div Iid.size
	val sLam = readInt32 s
	val sa1 = readInt32 s
	val sa2 = readInt32 s
	val cs = readInt32 s
	val es = readInt32 s
	val imports = readPidList (s, ni)
	val exportPid = (case ne
	       of 0 => NONE
		| 1 => SOME (readPid s)
		| _ => raise FormatError
	      (* end case *))
	val envIids = readIidList (s, nei)
        in
	  case envIids
	   of (iid :: references) => {
		  nImports=ni, nExports=ne, nEnvPids=nei, lambdaSz=sLam,
		  res1Sz=sa1, res2Sz=sa2, codeSz=cs, envSz=es,
		  imports=imports, exportPid=exportPid,
		  references=references, iid=iid
		}
	    | _ => raise FormatError
	end

    val blastRead = System.Unsafe.blastRead o BinIO.inputN
    fun blastReadPlambda s: Comp.plambda = blastRead s

    (* must be called with second arg >= 0 *)
    fun readCodeList (_, 0) = []
      | readCodeList (s, n) = let
	    val sz = readInt32 s
	    val n' = n - sz - 4
	    val c = if n' < 0 then raise FormatError else sin (s, sz)
	in
	    c :: readCodeList (s, n')
	end

    fun readUnit (s, check, unpickle, keep_code) = let
	val {nImports=ni, nExports=ne, nEnvPids=nei, lambdaSz=sa2,
	     res1Sz, res2Sz, codeSz=cs, envSz=es,
	     imports, exportPid, references, iid} = readHeader s
	val references = Iid.makeset references
	val _ = check references
	val _ = sin (s, sa2)
	val lambda_i = if sa2 = 0 then NONE
		       else SOME (Comp.unpickle (blastReadPlambda (s, sa2)))
	val _ = if res1Sz = 0 andalso res2Sz = 0 then () else raise FormatError
	val code = case readCodeList (s, cs) of
	    [] => raise FormatError
	  | c0 :: cn => { c0 = c0, cn = cn }
	val remain =
	    Position.toInt (BinIO.endPosIn s) - 
	    Position.toInt (BinIO.getPosIn s)
	val _ = if remain = es then () else raise FormatError
	val pickled_env = blastRead (s, es)
    in
	CU {
	    imports = imports,
	    exportPid = exportPid,
	    references = references,
	    iid = iid,
	    senv = unpickle pickled_env,
	    penv = pickled_env,
	    env = ref NONE,
	    code = ref (if keep_code then CSEGS code else DISCARDED),
	    lambda_i = lambda_i
	   }
    end

    local
	fun show1 pid = let
	    val s = String.explode (Pid.stampToString pid)
	    fun showc (c, r) = let
		val i = Char.ord c
		val h = i div 16
		val l = i mod 16
		fun dig x =
		    if x < 10 then Char.chr (x + Char.ord #"0")
		    else Char.chr (x - 10 + Char.ord #"a")
	    in
		dig h :: dig l :: r
	    end
	in
	    String.implode (foldr showc [] s)
	end
	fun showl (m, l) =
	    Control.say (concat (m :: ":" ::
				 (foldr (fn (x, r) =>
					 " " :: (show1 x) :: r)
				  ["\n"] l)))
	fun pids_debug (m1, m2) (p1, p2) =
	    if Control.debug NONE then
		(showl ("\n!-- " ^ m1, p1); showl ("!-- " ^ m2, p2))
	    else
		()
	fun makelist2 (x, y) = (Set.makelist x, Set.makelist y)
	fun sel_imp_exp (CU { imports, exportPid, ... }) =
	    (imports, case exportPid of NONE => [] | SOME pid => [pid])
    in
	fun prov_req_iids _ = ()
	val imp_exp_pids = (pids_debug ("imports", "exports")) o sel_imp_exp
    end

    fun iids_ok provided requested =
	(prov_req_iids (provided, requested);
	 Iid.isSubset (requested, provided))

    fun check_iids p r = if iids_ok p r then () else raise Outdated

    fun dont_bother _ = ()

    fun b2e be = Env.staticPart
	(Compiler.CoerceEnv.b2e
	 (Bare.mkenv {
		      static = be,
		      dynamic = Bare.dynamicPart Bare.emptyEnv,
		      symbolic = Bare.symbolicPart Bare.emptyEnv
		     }))

    fun e2b e = Bare.staticPart
	(Compiler.CoerceEnv.e2b
	 (Env.mkenv { 
		     static = e,
		     dynamic = Env.dynamicPart Env.emptyEnv,
		     symbolic = Env.symbolicPart Env.emptyEnv
		    }))

    fun unpickle ctxt e = b2e (Pickle.unPickleEnv { env = e, context = ctxt })

    (*
     * recover: binfile * senv * sourcetime option * pids provided -> t option
     *  - recover a compiled unit from the binfile specified (if
     *    possible) unless the source was newer or the unit
     *    is incompatible with the pids provided.
     *  - if everything is ok then return SOME (unit, bintime), otherwise NONE.
     *)
    fun recover { opener, binfile, se, sourcetime, provided, keep_code } = let
	val bintime = AbsPath.modTime binfile
	val _ =
	    case sourcetime of
		NONE => ()
	      | SOME st =>
		    if Time.< (st, bintime) then
			()
		    else
			raise Outdated
	val _ = Control.vsay (concat ["[recovering ",
				      AbsPath.elab binfile, "..."])
	val s = opener ()
	val u = readUnit (s, check_iids provided, unpickle se, keep_code)
	    handle exn => (BinIO.closeIn s;
			   Control.vsay " failed]\n"; raise exn)
    in
	Control.vsay " done]\n";
	imp_exp_pids u;
	BinIO.closeIn s;
	SOME { u = u, bintime = bintime }
    end handle _ => NONE

    fun blastWrite (s, x) = let
	val str =  System.Unsafe.blastWrite x
    in
	BinIO.output (s, str);
	size str
    end
    fun blastWritePlambda (s, l: Comp.plambda) = blastWrite (s, l)

    fun writeUnit (s, u, keep_code) = let
	val CU { imports, exportPid, references, iid, penv, lambda_i, ... } = u
	val envIids = iid :: (Set.makelist references)
	val ni = length imports
	val (ne, epl) =
	    case exportPid of NONE => (0, []) | SOME p => (1, [p])
	val nei = length envIids
	val cmInfoSzB = nei * Iid.size
	val dummy_sa2 = 0
	val res1Sz = 0
	val res2Sz = 0
	val { c0, cn } = codeSegments u
	fun csize c = (String.size c) + 4 (* including size field *)
	val cs = foldl (fn (c, a) => (csize c) + a) (csize c0) cn
	val dummy_es = 0
	val _ = sout (s, MAGIC);
	val _ = app (writeInt32 s) [ni, ne, cmInfoSzB]
	val sa2_pos = BinIO.getPosOut s
	val _ = app (writeInt32 s) [dummy_sa2, res1Sz, res2Sz, cs];
	val es_pos = BinIO.getPosOut s
	val _ = writeInt32 s dummy_es
	val _ = writePidList (s, imports);
	val _ = writePidList (s, epl);
	val _ = writeIidList (s, envIids); (* arena1 *)
	val sa2 = case lambda_i of
	    NONE => 0
	  | SOME l => blastWritePlambda (s, Comp.pickle l) (* arena 2 *)
	(* arena3 is empty *)
	(* arena4 is empty *)
	fun codeOut c = (writeInt32 s (size c); sout (s, c))
	val _ = codeOut c0
	val _ = app codeOut cn
	val es = blastWrite (s, penv)
    in
	if sa2 > 0 then
	    (BinIO.setPosOut (s, sa2_pos);
	     writeInt32 s sa2)
	else ();
	BinIO.setPosOut (s, es_pos);
	writeInt32 s es;
	(* BinIO.setPosOut (s, BinIO.endPosOut s); *)
	if keep_code then () else discardCode u
    end

    fun deleteFile name = OS.FileSys.remove name handle _ => ()

    fun saveUnit (opener, binfile, u, keep_code) = let
	val s = opener ()
	val _ = Interrupt.guarded (fn () => writeUnit (s, u, keep_code))
	    handle exn => (BinIO.closeOut s; raise exn)
    in
	BinIO.closeOut s;
	Control.vsay (concat ["[wrote ", AbsPath.elab binfile, "]\n"]);
	()
    end handle exn => let
	val binstring = AbsPath.elab binfile
    in
	deleteFile binstring;
	Control.vsay (concat ["[writing ", binstring, " failed]\n"]);
	raise exn
	(* case exn of Interrupt.Interrupt => raise exn | _ => () *)
    end

    local
	(*
	 * make our own `coreEnvRef' so `Batch' can set it...
	 *)
	val r = ref (#get Env.coreEnvRef ())
    in
	val coreEnvRef = { get = fn () => !r, set = fn x => (r := x) }
    end

    fun check errs phase =
	if Err.anyErrors errs then
	    raise Compile (phase ^ " failed")
	else ()

    (*
     * create: ast * source * binfile * senv * pids provided -> t
     *  - create a compiled unit from an abtract syntax tree and
     *    a static environment by compiling the source file
     *  - cache the unit as a binary file in the file system
     *)
    fun create_pid (pidopt, splitting)
	{ ast, source, name, opener, binfile, senv, symenv,
	  provided, keep_code } = let
	val errors = Err.errors source
	val senv' = e2b senv
	val corenv = e2b (#get coreEnvRef ())
	val { absyn, newenv, exportLexp, staticPid, exportPid } =
	    Comp.elaborate { errors = errors, compenv = senv', corenv = corenv,
			     transform = (fn x => x), ast = ast }
	    before check errors "elaboration"
	val statenv = Bare.layerStatic (newenv, senv')
	val { genLambda, imports } =
	    Comp.translate { errors = errors, absyn = absyn, corenv = corenv,
			     exportLexp = exportLexp, statenv = statenv,
			     exportPid = exportPid }
	    before check errors "translation"

	val imports =
	    case pidopt of
		NONE => imports
	      | SOME s => let
		    val _ = Control.vsay ("fake import pid = \"" ^ s ^ "\"\n")
		    val p = mkPid s
		in
		    case imports of
			[_] => [p]
		      | _ => raise Compile "core compilation failed"
		end

	val lambda = Comp.inline { genLambda = genLambda, imports = imports,
				   symenv = symenv }

	val { lambda_e, lambda_i, pid = lambdaPid } =
	    Comp.split { lambda = lambda, enable = splitting }

	val code = Comp.codegen { errors = errors, lambda = lambda_e }
	    before check errors "codegen"
	val cref = ref (CSEGS code)
	val u = CU {
		    imports = imports,
		    exportPid = exportPid,
		    references = provided,
		    iid = Iid.new { senv = staticPid, lambda = lambdaPid },
		    senv = b2e newenv,
		    penv = Pickle.pickleEnv { env = newenv, context = senv },
		    env = ref NONE,
		    code = cref,
		    lambda_i = lambda_i
		   }
    in
	saveUnit (opener, binfile, u, keep_code);
	imp_exp_pids u;
	u
    end

    val create = create_pid (NONE, true)

    fun isValid (CU { code = ref DISCARDED, ... }, _, true) = false
      | isValid (CU { references, ... }, provided, _) =
	iids_ok provided references

    fun parse { file, desc } = let
	val file = AbsPath.elab file
	val makeSource = Source.newSource
	val cparse = Comp.parse
	val s = TextIO.openIn file
	val source = makeSource (desc, 1, s, false,
				 {
				  linewidth = !Print.linewidth,
				  flush = Print.flush,
				  consumer = Print.say
				 },
				 Compiler.Index.openIndexFile file)
	val ast = cparse source
	    handle Comp.Compile msg => (TextIO.closeIn s; raise Compile msg)
		 | exn => (TextIO.closeIn s; raise exn)
    in
	TextIO.closeIn s; { ast = ast, source = source }
    end

    fun execute (u as CU { imports, exportPid,
			   senv, env, ... }, denv) = let
	val clos = codeClosure u
	val de = Comp.execute {
			       executable = clos,
			       imports = imports,
			       exportPid = exportPid,
			       dynenv = denv
			      }
	val e = Env.mkenv { static = senv, dynamic = de,
			    symbolic = symenv u }
    in
	env := SOME e; e
    end

    val senv2pid = Comp.makePid o e2b

    fun compileBootFile (pidopt, splitting) (sf, bf, senv, sye) = let
	val _ = Control.vsay (concat ["[compiling (boot) ",
				      AbsPath.elab sf,
				      " -> ", AbsPath.elab bf, "]\n"])
	val sfs = AbsPath.elab sf
	val { ast, source } = parse { file = sf, desc = sfs }
	val u as CU { senv, lambda_i, exportPid, ... } =
	    create_pid (pidopt, splitting)
	      { ast = ast, source = source, name = sf,
	        opener = fn () => BinIO.openOut (AbsPath.elab bf),
		binfile = bf, senv = senv, symenv = sye,
		provided = Set.empty, keep_code = false }
	val _ = Source.closeSource source
	val _ = imp_exp_pids u
    in
	(senv, symenv u)
    end

    fun fetchUnit (bf, senv) = let
	val s = BinIO.openIn (AbsPath.elab bf)
	val cu = readUnit (s, dont_bother, unpickle senv, true)
	    handle exn => (BinIO.closeIn s; raise exn)
	val _ = imp_exp_pids cu
    in
	BinIO.closeIn s; cu
    end

    fun fetchObjectEnv (bf, se) = let
	val u = fetchUnit (bf, se)
    in
	discardCode u;
	(senv u, symenv u)
    end

end
