(* 15-150, Spring 2023 *)
(* Michael Erdmann & Karl Crary *)
(* Code for Lecture 11: Higher-Order Functions (cont) *)
(************************************************************************)
(* Pre-defined SML Composition operator o *)
fun incr x = x + 1
fun double x = 2*x
val 21 = (incr o double) 10
val 22 = (double o incr) 10
(************************************************************************)
(* Combinators *)
(************************************************************************)
(* Some int -> int functions: *)
fun identity (x:int):int = x
fun addone (x:int):int = x + 1
fun subone (x:int):int = x - 1
fun two (x:int):int = 2
fun five (x:int):int = 5
fun square (x:int):int = x*x
fun twice (x:int):int = 2*x
(* Some combinators for functions of type 'a -> int : *)
infixr ++ -- ** //
(* We define the combinators using the pointwise principle: *)
fun (f ++ g) (x : 'a) : int = f(x) + g(x)
fun (f -- g) (x : 'a) : int = f(x) - g(x)
fun (f ** g) (x : 'a) : int = f(x) * g(x)
fun (f // g) (x : 'a) : int = f(x) div g(x)
fun MIN (f, g) (x : 'a) : int = Int.min(f(x), g(x))
fun MAX (f, g) (x : 'a) : int = Int.max(f(x), g(x))
(*
Once we have these combinators we can work at the 'a -> int function
level of thinking, rather than merely at the int level.
This approach is sometimes called "Point-Free Programming".
Here we look at combinators over int -> int functions:
*)
val seven = two ++ five
val nochange = addone ++ subone -- identity
val nochange' = (addone ++ subone) // two
(* Tests: *)
val 7 = seven(373)
val 10 = nochange(10)
val 10 = nochange'(10)
val 15 = (square ++ twice)(3)
val lowest = MIN(square, twice)
val ~10 = lowest(~5)
val 1 = lowest(1)
val 6 = lowest(3)
(*
As a side comment, you should know how to write functions
and operators like ++ both in curried form as above and in
lambda form as follows:
*)
infixr ++
fun f ++ g = fn x => f(x) + g(x)
(* Test: *)
val 22 = (square ++ twice ++ seven)(3)
(*******************************************************************)
(* A staged computation example, demonstrating efficiency gain: *)
(*******************************************************************)
(* horriblecomputation : int -> int *)
fun horriblecomputation(x:int):int =
let fun ackermann(0:int, n:int):int = n+1
| ackermann(m, 0) = ackermann(m-1, 1)
| ackermann(m, n) = ackermann(m-1, ackermann(m, n-1))
val y = Int.abs(x) mod 3 + 2
fun count(0) = ackermann(y, 4)
| count(n) = count(n-1)+0*ackermann(y,4)
val large = 10000
in
ackermann(y, 1)*ackermann(y, 2)*ackermann(y, 3)*count(large)
end
(* Unstaged uncurried version: *)
(* f1 : int * int -> int *)
fun f1 (x:int, y:int) : int =
let
val z = horriblecomputation(x)
in
z + y
end
(* The horrible computation is performed each time: *)
val result1 = f1(10, 5)
val result2 = f1(10, 2)
val result3 = f1(10, 18)
(* Unstaged curried version: *)
(* f2 : int -> int -> int *)
fun f2 (x:int) (y:int) : int =
let
val z = horriblecomputation(x)
in
z + y
end
(* f2' : int -> int *)
val f2' = f2 10
(* The horrible computation is again performed each time: *)
val result1 = f2' 5
val result2 = f2' 2
val result3 = f2' 18
(* Staged curried version:
Note how we now create the function (fn y => ...)
*after* the horrible computation.
*)
(* f3 : int -> int -> int *)
fun f3 (x:int) : int -> int =
let
val z = horriblecomputation(x)
in
(fn y => z + y)
end
(* The horrible computation is performed once,
during the declaration of f3':
*)
(* f3' : int -> int *)
val f3' = f3 10
val res1 = f3' 5
val res2 = f3' 2
val res3 = f3' 18
(* NOTE: f2 and f3 have exactly the same type and describe the same
functional relationship between domain and range values,
but their efficiencies are vastly different.
*)
(************************************************************************)
(* Combining higher-order functions with recursive datatypes: *)
(************************************************************************)
(* Recall filter for lists from lecture 10: *)
(* filter : ('a -> bool) -> 'a list -> 'a list
REQUIRES: (may want to assume that p is total)
ENSURES: filter p l ==> sublist containing the elements of l for which p holds.
*)
fun filter (p : 'a -> bool) (nil : 'a list) : 'a list = nil
| filter (p : 'a -> bool) (x::l : 'a list) : 'a list =
if p(x) then x::filter p l
else filter p l
(* odds : int list -> int list
REQUIRES: true
ENSURES: odds L ==> sublist of l consisting of odd integers.
*)
val odds = filter (fn (n:int) => n mod 2 = 1)
val [3,7,~11] = odds [3,4,7,12,10,~11]
(************************************************************************)
(* Generalizations to a tree datatype: *)
datatype 'a tree = Empty | Node of 'a tree * 'a * 'a tree
val tr : int tree = Node(Node(Node(Empty,1,Empty), 4, Node(Empty,12,Empty)),
3,
Node(Empty, 7, Node(Empty,2,Empty)))
(* sketch of tr: 3
/ \
4 7
/ \ \
1 12 2
*)
(* treemap : ('a -> 'b) -> 'a tree -> 'b tree
REQUIRES: (maybe assume f is total)
ENSURES: (treemap f t) evaluates to a tree isomorphic to t
in which each x of t has been replaced by f(x).
*)
fun treemap f Empty = Empty
| treemap f (Node(left,x,right)) = Node(treemap f left, f x, treemap f right)
(* turn each integer of an int tree into a string: *)
(* stringify : int tree -> string tree *)
val stringify = treemap Int.toString
val Node(Node(Node(Empty,"1",Empty), "4", Node(Empty,"12",Empty)),
"3", Node(Empty, "7", Node(Empty,"2",Empty)))
= stringify tr
(* combine : 'a tree * 'a tree -> 'a tree
REQUIRES: true
ENSURES: combine(t1,t2) evaluates to a tree containing all the elements of
t1 and all the elements of t2 (with duplicates when such exist).
NOTE: The function finds the leftmost Empty of t1 and inserts t2 there.
*)
fun combine (Empty, t2) = t2
| combine (t1, Empty) = t1 (* don't really need this; just a speedup *)
| combine (Node(left, x, right), t2) = Node(combine(left,t2), x, right)
(*
Here is a version that finds the rightmost Empty of t1 and inserts t2 there.
*)
fun combine (Empty, t2) = t2
| combine (t1, Empty) = t1
| combine (Node(left, x, right), t2) = Node(left, x, combine(right,t2))
(* treefilter : ('a -> bool) -> 'a tree -> 'a tree
REQUIRES: (maybe assume p is total)
ENSURES: (treefilter p t) evaluates to a tree consisting of all
elements of t that satisfy p.
*)
fun treefilter p Empty = Empty
| treefilter p (Node(left,x,right)) =
if p(x) then Node(treefilter p left, x, treefilter p right)
else combine(treefilter p left, treefilter p right)
(* todds : int tree -> int tree
REQUIRES: true
ENSURES: todds T ==> a tree containing the odd integers of T.
*)
val todds = treefilter (fn (n:int) => n mod 2 = 1)
val Node(Node(Empty, 1, Empty), 3, Node(Empty,7 , Empty)) = todds tr
(* Observe that combine might produce an unbalanced tree, even if we *)
(* start with balanced trees. This is similar to the situation we *)
(* encountered when merging trees in the implementation of Msort. *)
(* As there, one possibility is to rebalance after each combine. *)
(* Doing so would produce an implementation of treefilter with work *)
(* and span like that of Msort. *)
(************************************************************************)
(* There are various ways to define folding functions for trees. *)
(* We discussed a particular natural paradigm in lecture for folding *)
(* over arbitrary datatypes: *)
(* Replace constant constructors by constants, and *)
(* replace other constructors by functions. *)
(* Here is a function that folds a three-argument function over a tree: *)
(* By analogy to foldr for lists, we want treeFold to have type
('b * 'a * 'b -> 'b) -> 'b -> 'a tree -> 'b
The code below has this type, but SML prints it out differently
because it always starts with type variable 'a.
*)
(* treeFold : ('a * 'b * 'a -> 'a) -> 'a -> 'b tree -> 'a *)
fun treeFold f z Empty = z
| treeFold f z (Node (l, x, r)) = f (treeFold f z l, x, treeFold f z r)
(* sum up all the integers in a tree. *)
(* sumTree : int tree -> int *)
val sumTree = treeFold (fn (a,x,b) => a + x + b) 0
(* Notice the point-free programming in defining sumTree! *)
val 29 : int = sumTree tr
(* If we had the alternate definition of trees shown below, we would *)
(* have the corresponding alternate folding function shown below. *)
datatype 'a leafy = Leaf of 'a | Node of 'a leafy * 'a leafy
val lfy : int leafy = Node(Node(Leaf 1, Leaf 12),
Node(Leaf 5, Node(Leaf 7, Leaf 8)))
(* sketch of lfy: .
/ \
/ \
. .
/ \ / \
1 12 5 .
/ \
7 8
*)
(* By analogy to foldr for lists, we want leafyFold to have type
('b * 'b -> 'b) -> ('a -> 'b) -> 'a leafy -> 'b
The code below has this type, but SML prints it out differently
because it always starts with type variable 'a.
*)
(* leafyFold : ('a * 'a -> 'a) -> ('b -> 'a) -> 'b leafy -> 'a *)
fun leafyFold f g (Leaf x) = g x
| leafyFold f g (Node (l, r)) = f (leafyFold f g l, leafyFold f g r)
(* Again, sum up all the integers, now in a leafy tree. *)
(* sumLeafy : int tree -> int *)
val sumLeafy = leafyFold (op +) (fn x => x)
(* Notice the point-free programming in defining sumLeafy! *)
val 33 : int = sumLeafy lfy
(************************************************************************)
datatype 'a tree = Empty | Node of 'a tree * 'a * 'a tree
val tr : int tree = Node(Node(Node(Empty,1,Empty), 4, Node(Empty,12,Empty)),
3,
Node(Empty, 7, Node(Empty,2,Empty)))
(* It is less clear how to fold with a folding function of type
'a * 'b -> 'b. One possibility is to mimic our flatten2 motif:
tfold : ('a * 'b -> 'b) -> 'b -> 'a tree -> 'b
*)
fun tfold f z Empty = z
| tfold f z (Node(l,x,r)) = tfold f (f(x, tfold f z r)) l
fun inorder T = tfold (op ::) [] T
val [1,4,12,3,7,2] = inorder tr
(* Another possibility is to fold with functions of type 'a * 'a -> 'a.
See the posted notes.
*)
(************************************************************************)