(* Monadic reflection via callcc Based on Andrzej Filinski's thesis and his paper Representing monads, Proceedings POPL'99. Kevin Watkins October 2004 *) (* First, the reflection of SML/NJ's intrinsic continuation-passing monad. *) signature REFLECT = sig type behavior val reflect : (('a -> behavior) -> behavior) -> 'a val reify : (unit -> behavior) -> behavior exception Embed val embed : unit -> ('a -> behavior) * (behavior -> 'a) end structure Reflect :> REFLECT = struct datatype void = Void of void val coerce = fn (x:void) => raise Match (* Impossible by static typing *) type behavior = unit -> void structure C = SMLofNJ.Cont fun reflect f = C.callcc (fn k => coerce (f (fn x => fn () => C.throw k x) ())) fun reify f = fn () => f () () exception Embed val r = ref (fn x => raise Embed) : (exn -> void) ref fun 'a embed () = let exception E of 'a in (fn x => fn () => (!r)(E x), fn b => C.callcc(fn k => let val s=(!r) in r:=(fn E x => (r:=s; C.throw k x)); coerce(b()) end)) end end (* Next, the reflection of an arbitrary monad in terms of the reflection of SML/NJ's continuation-passing monad. *) signature MONAD = sig type 'a t val unit : 'a -> 'a t val bind : ('a -> 'b t) -> ('a t -> 'b t) val glue : (unit -> 'a t) -> 'a t val show : string t -> string end signature LINKAGE = sig val embed : unit -> ('a -> Reflect.behavior) * (Reflect.behavior -> 'a) val run : (unit -> string) -> string end signature REFLECT_MONAD = sig include MONAD include LINKAGE val reflect : 'a t -> 'a val reify : (unit -> 'a) -> 'a t end structure Base :> LINKAGE = struct val embed = Reflect.embed fun run t = t () end functor ReflectMonad(structure M : MONAD structure Link : LINKAGE) :> REFLECT_MONAD where type 'a t = 'a M.t = struct structure R = Reflect open M val (f : exn t -> R.behavior, g : R.behavior -> exn t) = Link.embed() fun reflect m = R.reflect (fn k => f (bind (fn a => glue (fn () => g (k a))) m)) fun map f = bind (unit o f) fun 'a reify t = let exception E of 'a in map (fn E x => x) (glue (fn () => g (R.reify (f o unit o E o t)))) end (* AF's version -- not sure which is more efficient fun 'a reify t = let exception E of 'a in glue (fn () => map (fn E x => x) (g (R.reify (f o unit o E o t)))) end *) fun 'a embed () = let exception E of 'a in (f o unit o E, (fn E x => x) o reflect o g) end fun run t = Link.run (fn () => show (reify t)) end structure ListR = ReflectMonad( structure Link = Base structure M = struct type 'a t = 'a list fun unit x = [x] fun bind f = foldr op@ nil o map f fun glue t = t () fun show nil = "nil" | show (h::t) = foldl (fn (s,t) => t^","^s) h t end) structure NdR = ReflectMonad( structure Link = Base structure M = struct type 'a t = unit -> 'a list fun unit x = fn () => [x] fun bind f t = fn () => foldr op@ nil (map (fn x => f x ()) (t ())) fun glue t = fn () => t () () fun show t = case t () of nil => "" | h::t => foldl (fn (s,t) => t^" "^s) h t end) type state = int structure StateR = ReflectMonad( structure Link = NdR structure M = struct type 'a t = state -> 'a * state fun unit x = fn s => (x,s) fun bind f t = fn s => let val (x,s)=t s in f x s end fun glue t = fn s => t () s fun show t = let val (x,s)=t 0 in ""^x end end) datatype 'a r = Ok of 'a | Exn of string structure ExnR = ReflectMonad( structure Link = StateR structure M = struct type 'a t = unit -> 'a r fun unit x = fn () => Ok x fun bind f t = fn () => case t() of Ok x=>f x () | Exn e=>Exn e fun glue t = fn () => t () () fun show t = case t() of Ok x=>x | Exn x=>"" end) datatype 'a front = Done of 'a | Susp of 'a res withtype 'a res = unit -> 'a front structure ResR = ReflectMonad( structure Link = StateR structure M = struct type 'a t = 'a res fun unit x = fn () => Done x fun bind f t = fn () => bind_ f (t ()) and bind_ f (Done x) = f x () | bind_ f (Susp t) = Susp (bind f t) fun glue t = fn () => t () () fun show t = show_ (t ()) and show_ (Done x) = x | show_ (Susp t) = show t end) structure Test = struct (* Simple examples *) fun get() = StateR.reflect (fn s => (s,s)) fun set n = StateR.reflect (fn s => ((),n)) fun fraise s = ExnR.reflect (fn () => Exn s) fun fhandle t h = case ExnR.reify t () of Ok x => x | Exn e => h e (* Concurrency examples *) fun yield() = ResR.reflect (fn () => Susp (ResR.unit ())) fun por(t1,t2) = ResR.reflect (rpor (ResR.reify t1,ResR.reify t2)) and rpor(t1,t2) = fn () => rpor_(t1(),t2) and rpor_(Done true, _) = Done true | rpor_(Done false, t2) = t2 () (* I would have written Susp(t2) *) | rpor_(Susp t1, t2) = Susp (rpor(t2,t1)) fun atomically t = ResR.reflect (atom (ResR.reify t)) and atom t = fn () => atom_ (t (), false) and atom_ (Done x, true) = Susp (ResR.unit x) | atom_ (r as Done x, false) = r | atom_ (Susp t, y) = atom_ (t (), true) fun par(t1,t2) = ResR.reflect (rpar (ResR.reify t1,ResR.reify t2)) and rpar(t1,t2) = fn () => if NdR.reflect (fn ()=>[true,false]) then rpar_(t1(), Susp t2) else rpar_(Susp t1, t2()) and rpar_(Done a,Done b) = Done (a,b) | rpar_(Done a,Susp tb) = Susp (ResR.bind (fn b => ResR.unit (a,b)) tb) | rpar_(Susp ta,Done b) = Susp (ResR.bind (fn a => ResR.unit (a,b)) ta) | rpar_(Susp ta,Susp tb) = Susp (rpar(ta,tb)) fun store n = (yield (); set n) fun fetch () = (yield (); get ()) end val t = ExnR.run (fn () => (Test.set 3; "ok")); val t = ExnR.run (fn () => (Test.set 4; Test.fraise "err"; "ok")); val t = ExnR.run (fn () => (Test.set 5; Test.fhandle (fn () => (Test.set 8; Test.fraise "err"; "ok")) (fn x => x ^ ", " ^ Int.toString (Test.get ())))); val t = ResR.run (fn () => (Test.par (fn () => Test.store 3, fn () => Test.store (Test.fetch() + 1)); Int.toString (Test.fetch()))); val t = ResR.run (fn () => (Test.par (fn () => Test.store 3, fn () => Test.atomically (fn () => Test.store (Test.fetch() + 1))); Int.toString (Test.fetch()))); (* Example of "treeness" of the layered effects: we have [st] < [exn] and [st] < [res] as this little example shows *) fun loop() = (Test.yield(); loop()) val t = StateR.run (fn () => ("exn part: " ^ ExnR.show (ExnR.reify (fn () => (Test.set (NdR.reflect (fn () => [3,5,7])); if Test.get()>6 then Test.fraise "err" else (); "ok"))) ^ " res part: " ^ ResR.show (ResR.reify (fn () => (Test.por (loop, fn () => true); Int.toString (Test.get()))))));