(* 15-150, Spring 2020 *)
(* Michael Erdmann & Frank Pfenning *)
(* Code for Lecture 12: Continuations *)
(************************************************************************)
(* Continuations *)
(* Basic CPS (Continuation-Passing Style) *)
fun add (x, y, k) = k (x+y)
fun mult (x, y, k) = k (x*y)
val direct = (3+4)*(5+6)
(* The following is reminiscent of how one might write assembly code
using registers r1,r2,r3:
add 3 4 r1
add 5 6 r2
mult r1 r2 r3
return r3
Here we have the power of continuations to act much like registers,
i.e., they "are passed a result" and then do something with it.
*)
val 77 = add (3, 4, fn r1 =>
add (5, 6, fn r2 =>
mult (r1, r2, fn r3 =>
r3)))
(************************************************************************)
(* Summing a list of integers, via different implementations: *)
(* Using straightforward recursion: *)
(* sum : int list -> int
REQUIRES: true
ENSURES: sum(L) evaluates to the sum of all the elements in L.
*)
fun sum ([] : int list) : int = 0
| sum (x::xs) = x + sum xs
(* Using tail-recursion: *)
(* tsum : int list * int -> int
REQUIRES: true
ENSURES: tsum (L, acc) == (sum L) + acc.
*)
fun tsum ([] : int list, acc : int) : int = acc
| tsum (x::xs, acc) = tsum(xs, x + acc)
(* Using continuation-passing style: *)
(* csum : int list -> (int -> 'a) -> 'a
REQUIRES: true
ENSURES: csum L k == k (sum L).
*)
fun csum ([] : int list) (k: int -> 'a) : 'a = k(0)
| csum (x::xs) k = csum xs (fn s => k (x + s))
(* Using higher-order list functions: *)
(* lsum : int list -> int
REQUIRES: true
ENSURES: lsum(L) evaluates to the sum of all the elements in L
*)
fun lsum L = foldl (op +) 0 L
(* Alternate (but equivalent) definition using foldr: *)
val Lsum = foldr (op +) 0
(* testing: *)
val 10 = sum [1,2,3,4]
val 10 = tsum ([1,2,3,4], 0)
val 10 = csum [1,2,3,4] (fn x => x)
val 10 = lsum [1,2,3,4]
val 10 = Lsum [1,2,3,4]
(************************************************************************)
(* Here is an example that we did not cover in lecture. *)
(* Evaluation of operator/operand trees: *)
datatype arith =
Plus of arith * arith
| Mult of arith * arith
| Divide of arith * arith
| Value of int
(* Sample *)
val a18 = Plus(Mult(Value 3, Value 4), Value 6)
val a77 = Mult (Plus (Value (3), Value (4)), Plus (Value (5), Value (6)))
val abad = Mult (Plus (Value (3), Value (4)), Divide (Value (5), Value (0)))
(* A direct evaluator:
eval : arith -> int
REQUIRES: true
ENSURES: eval(A) computes the integer value of A when viewed
as a mathematical expression, if there is one,
and raises exception Div otherwise (as when dividing by 0).
*)
fun eval (Plus (a1, a2)) = eval(a1) + eval(a2)
| eval (Mult (a1, a2)) = eval(a1) * eval(a2)
| eval (Divide (a1, a2)) = eval(a1) div eval(a2)
| eval (Value (n)) = n
(* An evaluator based on continuations:
eval' : arith * (int -> 'a option) -> 'a option
REQUIRES: true
ENSURES: eval'(A,k) == k (eval A), if A contains no division by 0;
NONE, otherwise.
*)
fun eval' (Plus (a1, a2), k) =
eval' (a1, fn r1 => eval' (a2, fn r2 => k (r1+r2)))
| eval' (Mult (a1, a2), k) =
eval' (a1, fn r1 => eval' (a2, fn r2 => k (r1*r2)))
| eval' (Divide (a1, a2), k) =
eval' (a1, fn r1 => eval' (a2, fn 0 => NONE | r2 => k (r1 div r2)))
| eval' (Value (n), k) = k(n)
(* We can therefore redefine our top-level eval function as:
evalc : arith -> int option
REQUIRES: true
ENSURES: eval(A) evaluates to
{SOME(n), if A has value n when viewed
{ as a mathematical expression;
{NONE, otherwise.
*)
fun evalc a = eval'(a, fn r => SOME r)
val 18 = eval a18
val SOME(18) = evalc a18
val 77 = eval a77
val SOME(77) = evalc a77
(* Uncomment the next expression to see it raise the exception Div *)
(*
val _ = eval abad
*)
val NONE = evalc abad
(************************************************************************)
(* Comparing tree contents to list contents: *)
(* A tree datatype, with integers stored in internal nodes. *)
datatype tree = Empty | Node of tree * int * tree
(* Sample: *)
val t1 = Node(Node(Node(Empty, 1, Empty), 2, Node(Empty, 3, Empty)),
4,
Node(Empty, 5, Empty))
(* 4
/ \
2 5
/ \
1 3
*)
(* The following function generates a list corresponding to the
inorder traversal of a tree:
val inorder : tree * int list --> int list
REQUIRES: true
ENSURES: inorder(T, acc) == L @ acc, where L consists of the
elements of T as encountered in an in-order traversal of T.
*)
fun inorder(Empty, acc) = acc
| inorder(Node(left, n, right), acc) =
inorder(left, n::inorder(right, acc))
(* The following function decides whether the inorder traversal of
the given tree is the given list.
The function relies on "inorder".
Specifically, the function first creates a list representing
the inorder traversal of the tree, then it compares that list
against the given list.
treematch : tree -> int list -> bool
REQUIRES: k is total.
ENSURES: (treematch T L) evaluates to true if L consists of the
elements of T as encountered in an in-order traversal of T,
and to false otherwise.
*)
fun treematch T L = (inorder(T, nil) = L)
val false = treematch t1 [1, 2, 3, 4]
val true = treematch t1 [1, 2, 3, 4, 5]
val false = treematch t1 [1, 2, 3, 4, 5, 6]
(* We now re-implement the previous treematch by interleaving the tree
traversal and list comparison, using continuations. This allows the
function to stop searching as soon as it finds a mismatch.
*)
(* The following function decides whether the inorder traversal of a
tree corresponds to a prefix of a list.
prefix : tree -> int list -> (int list -> bool) -> bool
prefix T L k ==> { true, if L == L1 @ L2, such that
{ the inorder traversal of t is equal to L1,
{ and k(L2) == true.
{
{ false, otherwise.
*)
fun prefix (Empty) L k = k(L)
| prefix (Node(left, n, right)) L k =
prefix left L (fn nil => false
| y::ys => (n=y) andalso (prefix right ys k))
(* NOTE: prefix may look slightly complicated, but this basic format
will be very useful to us when we look at regular expressions.
Also, observe that the code for prefix does not have a general return
type 'a, but rather the specific type bool (as does the
continuation). This is useful here since we want to do logical
operations on the values returned by recursive calls to prefix.
In general, with an arbitrary return type, such computations are not
possible. So, if the specs in a problem specify a polymorphic return
type for a continuation in a function f as well as for f, then
generally your recursive function calls to f cannot wait for or
perform any computations on the results of those recursive calls.
Instead, the recursive calls to f should be tail calls, with an
augmented continuation.
*)
(* The following function decides whether the inorder traversal of
the given tree is the given list. The function relies on "prefix".
treematch' : tree -> int list -> bool
REQUIRES: true
ENSURES: (treematch' T L) evaluates to true if L consists of the
elements of T as encountered in an in-order traversal of T,
and to false otherwise.
*)
fun treematch' T L = prefix T L List.null
val false = treematch' t1 [1, 2, 3, 4]
val true = treematch' t1 [1, 2, 3, 4, 5]
val false = treematch' t1 [1, 2, 3, 4, 5, 6]
(************************************************************************)
(*
Search problems can frequently be implemented with continuations,
namely a "success continuation" and a "failure continuation".
Intuitively, the success continuation says what to do if the search
is successful while the failure continuation says what to do otherwise.
The skill in writing search code with such continuations is
figuring when to call the continuations, when to pass them as
arguments, and when to augment a given continuation. Details
depend on the particular problem, but a basic template is given by
the following search problem:
We have a tree and are looking for an element that satisfies a
predicate. If found, we pass the element to the success
continuation. If we run out of nodes in the current subtree
being searched, we call the failure continuation. The failure
continuation takes () : unit as an argument. This means we aren't
passing any information to the failure continuation. Instead, all
relevant information is already stored in the failure
continuation's closure environment. The failure continuation thus
implements the backtracking aspect of the search.
*)
datatype 'a tree = Empty | Node of 'a tree * 'a * 'a tree
(* search : ('a -> bool) -> 'a tree -> ('a -> 'b) -> (unit -> 'b) -> 'b
REQUIRES: p is total.
ENSURES:
{sc(x), if p(x)==true for some x in T;
search p T sc fc == {
{fc(), otherwise
(if more than one x satisfies p(x)==true, then use
the first encountered in a pre-order traversal of T).
*)
fun search _ Empty _ fc = fc()
| search p (Node(left, x, right)) sc fc =
if p(x) then sc(x)
else search p left sc (fn () => search p right sc fc)
(* findeven : int tree -> string
REQUIRES: true
ENSURES: findeven(T) evaluates to the string representation of the
first even integer found in a pre-order traversal of T,
if there is such an integer. Otherwise, findeven(T)
evaluates to "none found".
*)
fun findeven T =
search (fn n => n mod 2 = 0) T Int.toString (fn () => "none found")
val "2" = findeven (Node(Empty,1,Node(Empty,2,Empty)))
val "none found" = findeven (Node(Empty,1,Node(Empty,3,Empty)))
val "4" = findeven (Node(Node(Empty,4,Empty),1,Node(Empty,3,Empty)))
val "4" = findeven (Node(Node(Node(Empty,6,Empty),4,Empty),1,Node(Empty,3,Empty)))
val "4" = findeven (Node(Node(Empty,4,Node(Empty,6,Empty)),1,Node(Empty,3,Empty)))
(************************************************************************)