In this chapter we will discuss memoization, a programming technique for cacheing the results of previous computations so that they can be quickly retrieved without repeated effort. Memoization is fundamental to the implementation of lazy data structures, either "by hand" or using the implementation provided by the SML/NJ compiler.
We begin with a discussion of memoization to increase the efficiency of computing a recursively-defined function whose pattern of recursion involves a substantial amount of redundant computation. The problem is to compute the number of ways to parenthesize an expression consisting of a sequence of n multiplications as a function of n. For example, the expression
2*3*4*5
can be parenthesized in 5 ways:
((2*3)*4)*5, (2*(3*4))*5, (2*3)*(4*5), 2*(3*(4*5)), 2*((3*4)*5).
A simple recurrence expresses the number of ways of parenthesizing a sequence of n multiplications:
fun sum f 0 = 0
| sum f n = (f n) + sum (f (n-1))
fun p 1 = 1
| p n = sum (fn k => (p k) * (p (n-k)) (n-1)
where sum
f n computes the sum of values of a function f
(k) with k running from 1 to n. This program is extremely
inefficient because of the redundancy in the pattern of the recursive calls.
What can we do about this problem? One solution is to be clever and solve the recurrence. As it happens this recurrence has a closed-form solution (the Catalan numbers). But in many cases there is no known closed form, and something else must be done to cut down the overhead. In this case a simple cacheing technique proves effective. The idea is to maintain a table of values of the function that is filled in whenever the function is applied. If the function is called on an argument n, the table is consulted to see whether the value has already been computed; if so, it is simply returned. If not, we compute the value and store it in the table for future use. This ensures that no redundant computations are performed. We will maintain the table as an array so that its entries can be accessed in constant time. The penalty is that the array has a fixed size, so we can only record the values of the function at some pre-determined set of arguments. Once we exceed the bounds of the table, we must compute the value the "hard way". An alternative is to use a dictionary (e.g., a balanced binary search tree) which has no a priori size limitation, but which takes logarithmic time to perform a lookup. For simplicity we'll use a solution based on arrays.
Here's the code to implement a memoized version of the parenthesization function:
local
val limit = 100
val memopad = Array.array (100, NONE)
in
fun p' 1 = 1
| p' n = sum (fn k => (p k) * (p (n-k))) (n-1)
and p n =
if n < limit then
case Array.sub of
SOME r => r
| NONE =>
let
val r = p' n
in
Array.update (memopad, n, SOME r);
r
end
else
p' n
end
The main idea is to modify the original definition so that the recursive calls consult
and update the memopad. The "exported" version of the function is the one
that refers to the memo pad. Notice that the definitions of p
and p'
are mutually recursive!
Lazy evaluation is a combination of delayed evaluation and memoization. Delayed
evaluation is implemented using thunks, functions of type unit -> 'a
.
To delay the evaluation of an expression exp of type 'a
,
simply write fn () =>
exp. This is a value of type unit
-> 'a
; the expression exp is effectively "frozen" until the
function is applied. To "thaw" the expression, simply apply the thunk to
the null tuple, ()
. Here's a simple example:
val thunk = fn () => print "hello\n" (* nothing printed *)
val _ = thunk () (* prints hello *)
While this example is especially simple-minded, remarkable effects can be achieved by combining delayed evaluation with memoization. To do so, we will consider the following signature of suspensions:
signature SUSP = sig
type 'a susp
val force : 'a susp -> 'a
val delay : (unit -> 'a) -> 'a susp
end
The function delay
takes a suspended computation (in the form of a thunk)
and yields a suspension. It's job is to "memoize" the suspension so that
the suspended computation is evaluated at most once --- once the result is computed, the
value is stored in a reference cell so that subsequent forces are fast. The
implementation is slick. Here's the code to do it:
structure Susp :> SUSP = struct
type 'a susp = unit -> 'a
fun force t = t ()
fun delay (t : 'a susp) =
let
exception Impossible
val memo : 'a susp ref = ref (fn () => raise Impossible)
fun t' () =
let val r = t () in memo := (fn () => r); r end
in
memo := t';
fn () => (!memo)()
end
end
It's worth discussing this code in detail because it is rather tricky.
Suspensions are just thunks; force
simply applies the suspension to the null
tuple to force its evaluation. What about delay
? When applied, delay
allocates a reference cell containing a thunk that, if forced, raises an internal
exception. This can never happen for reasons that will become apparent in a moment;
it is merely a placeholder with which we initialize the reference cell. We then
define another thunk t'
that, when forced, does three things:
t
to obtain its value r
.r
.r
as result.We then assign t'
to the memo pad (hence obliterating the placeholder),
and return a thunk dt
that, when forced, simply forces the contents of the
memo pad. Whenever dt
is forced, it immediately forces the contents of
the memo pad. However, the contents of the memo pad changes as a result of forcing
it so that subsequent forces exhibit different behavior. Specifically, the first
time dt
is forced, it forces the thunk t'
, which then forces t
its value r
, "zaps" the memo pad, and returns r
.
The second time dt
is forced, it forces the contents of the memo
pad, as before, but this time the it contains the constant function that immediately
returns r
. Altogether we have ensured that t
is forced at
most once by using a form of "self-modifying" code.
Here's an example to illustrate the effect of delaying a thunk:
val t = Susp.delay (fn () => print "hello\n")
val _ = Susp.force t (* prints hello *)
val _ = Susp.force t (* silent *)
Notice that "hello" is printed once, not twice! The reason is that the suspended computation is evaluated at most once, so the message is printed at most once on the screen.
The constructs for manipulating lazy data structures provided by the SML/NJ compiler may be explained in terms of suspensions. For the sake of specificity we'll consider the implementation of streams, but the same ideas apply to any lazy datatype.
The type declaration
datatype lazy 'a stream = Cons of 'a * 'a stream
expands into the following pair of type declarations
datatype 'a stream_ = Cons_ of 'a * 'a stream
withtype 'a stream = 'a stream_ Susp.susp
The first defines the type of stream values, the result of forcing a stream
computation, the second defines the type of stream computations, which are
suspensions yielding stream values. Thus streams are represented by suspended
(unevaluated, memoized) computations of stream values, which are formed by applying the
constructor Cons_
to a value and another stream.
The value constructor Cons
, when used to build a stream, automatically
suspends computation. This is achieved by regarding Cons e
as
shorthand for Cons_ (Susp.susp (fn () => e)
. When used in a
pattern, the value constructor Cons
induces a use of force
.
For example, the binding
val Cons (h, t) = e
becomes
val Cons_ (h, t) = Susp.force e
which forces the right-hand side before performing pattern matching.
A similar transformation applies to non-lazy function definitions --- the argument is forced before pattern matching commences. Thus the "eager" tail function
fun stl (Cons (_, t)) = t
expands into
fun stl_ (Cons_ (_, t)) = t
and stl s = stl_ (Susp.force s)
which forces the argument as soon as it is applied.
On the other hand, lazy function definitions defer pattern matching until the result is forced. Thus the lazy tail function
fun lstl (Cons (_, t)) = t
expands into
fun lstl_ (Cons_ (_, t)) = t
and lstl s = Susp.delay (fn () => lstl_ (Susp.force s))
which a suspension that, when forced, performs the pattern match.
Finally, the recursive stream definition
val rec lazy ones = Cons (1, ones)
expands into the following recursive function definition:
val rec ones = Susp.delay (fn () => Cons (1, ones))
Unfortunately this is not quite legal in SML since the right-hand side involves an
application of a a function to another function. This can either be provided by
extending SML to admit such definitions, or by extending the Susp
package to
include an operation for building recursive suspensions such as this one. Since it
is an interesting exercise in itself, we'll explore the latter alternative.
We seek to add a function to the Susp
package with signature
val loopback : ('a susp -> 'a susp) -> 'a susp
that, when applied to a function f mapping suspensions to suspensions, yields a suspension s whose behavior is the same as f(s), the application of f to the resulting suspension. In the above example the function in question is
fun ones_loop s = Susp.delay (fn () => Cons (1, s))
We use loopback
to define ones
as follows:
val ones = Susp.loopback ones_loop
The idea is that ones
should be equivalent to Susp.delay (fn ()
=> Cons (1, ones))
, as in the original definition and which is the result of
evaluating Susp.loopback ones_loop
, assuming Susp.loopback
is
implemented properly.
How is loopback
implemented? We use a technique known as backpatching.
Here's the code
fun loopback f =
let
exception Circular
val r = ref (fn () => raise Circular)
val t = fn () => (!r)()
in
r := f t ; t
end
First we allocate a reference cell which is initialized to a placeholder that, if
forced, raises the exception Circular
. Then we define a thunk that,
when forced, forces the contents of this reference cell. This will be the return
value of loopback
. But before returning, we assign to the reference
cell the result of applying the given function to the result thunk. This "ties
the knot" to ensure that the output is "looped back" to the input.
Observe that if the loop function touches its input suspension before yielding an
output suspension, the exception Circular
will be raised.