(* 15-150, Spring 2020 *)
(* Michael Erdmann & Frank Pfenning *)
(* Code for Lecture 11: Higher-Order Functions (cont) *)
(************************************************************************)
(* Pre-defined SML Composition operator o *)
fun incr x = x + 1
fun double x = 2*x
val 21 = (incr o double) 10
val 22 = (double o incr) 10
(************************************************************************)
(* Combinators *)
(************************************************************************)
(* Some int -> int functions: *)
fun identity (x:int):int = x
fun addone (x:int):int = x + 1
fun subone (x:int):int = x - 1
fun two (x:int):int = 2
fun five (x:int):int = 5
fun square (x:int):int = x*x
fun twice (x:int):int = 2*x
(* Some combinators for functions of type 'a -> int : *)
infixr ++ -- ** //
(* We define the combinators using the point-wise principle: *)
fun (f ++ g) (x : 'a) : int = f(x) + g(x)
fun (f -- g) (x : 'a) : int = f(x) - g(x)
fun (f ** g) (x : 'a) : int = f(x) * g(x)
fun (f // g) (x : 'a) : int = f(x) div g(x)
fun MIN (f, g) (x : 'a) : int = Int.min(f(x), g(x))
fun MAX (f, g) (x : 'a) : int = Int.max(f(x), g(x))
(*
Once we have these combinators we can work at the 'a -> int function
level of thinking, rather than merely at the int level.
Here we look at combinators over int -> int functions:
*)
val seven = two ++ five
val nochange = addone ++ subone -- identity
val nochange' = (addone ++ subone) // two
(* Tests: *)
val 7 = seven(373)
val 10 = nochange(10)
val 10 = nochange'(10)
val 15 = (square ++ twice)(3)
val lowest = MIN(square, twice)
val ~10 = lowest(~5)
val 1 = lowest(1)
val 6 = lowest(3)
(*
As a side comment, you should know how to write functions
and operators like ++ both in curried form as above and in
lambda form as follows:
*)
infixr ++
fun f ++ g = fn x => f(x) + g(x)
(* Test: *)
val 22 = (square ++ twice ++ seven)(3)
(*******************************************************************)
(* A staged computation example, demonstrating efficiency gain: *)
(*******************************************************************)
(* horriblecomputation : int -> int *)
fun horriblecomputation(x:int):int =
let fun ackermann(0:int, n:int):int = n+1
| ackermann(m, 0) = ackermann(m-1, 1)
| ackermann(m, n) = ackermann(m-1, ackermann(m, n-1))
val y = Int.abs(x) mod 3 + 2
fun count(0) = ackermann(y, 4)
| count(n) = count(n-1)+0*ackermann(y,4)
val large = 10000
in
ackermann(y, 1)*ackermann(y, 2)*ackermann(y, 3)*count(large)
end
(* Unstaged uncurried version: *)
(* f1 : int * int -> int *)
fun f1 (x:int, y:int) : int =
let
val z = horriblecomputation(x)
in
z + y
end
(* The horrible computation is performed each time: *)
val result1 = f1(10, 5)
val result2 = f1(10, 2)
val result3 = f1(10, 18)
(* Unstaged curried version: *)
(* f2 : int -> int -> int *)
fun f2 (x:int) (y:int) : int =
let
val z = horriblecomputation(x)
in
z + y
end
(* f2' : int -> int *)
val f2' = f2 10
(* The horrible computation is again performed each time: *)
val result1 = f2' 5
val result2 = f2' 2
val result3 = f2' 18
(* Staged curried version:
Note how we now create the function (fn y => ...)
*after* the horrible computation.
*)
(* f3 : int -> int -> int *)
fun f3 (x:int) : int -> int =
let
val z = horriblecomputation(x)
in
(fn y => z + y)
end
(* The horrible computation is performed once,
during the declaration of f3':
*)
(* f3' : int -> int *)
val f3' = f3 10
val res1 = f3' 5
val res2 = f3' 2
val res3 = f3' 18
(* NOTE: f2 and f3 have exactly the same type and describe the same
functional relationship between domain and range values,
but their efficiencies are vastly different.
*)
(************************************************************************)
(* Combining higher-order functions with recursive datatypes: *)
(************************************************************************)
(* filter : ('a -> bool) -> 'a list -> 'a list
REQUIRES: (may want to assume that p is total)
ENSURES: filter p l ==> sublist containing the elements of l for which p holds.
*)
fun filter (p : 'a -> bool) (nil : 'a list) : 'a list = nil
| filter (p : 'a -> bool) (x::l : 'a list) : 'a list =
if p(x) then x::filter p l
else filter p l
(* odds : int list -> int list
REQUIRES: true
ENSURES: odds L ==> sublist of l consisting of odd integers.
*)
val odds = filter (fn (n:int) => n mod 2 = 1)
val [3,7,~11] = odds [3,4,7,12,10,~11]
(* Generalize to a tree datatype: *)
datatype 'a tree = Empty | Node of 'a tree * 'a * 'a tree
val tr : int tree = Node(Node(Node(Empty,1,Empty), 4, Node(Empty,12,Empty)),
3,
Node(Empty, 7, Node(Empty,2,Empty)))
(* treemap : ('a -> 'b) -> 'a tree -> 'b tree
REQUIRES: (maybe assume f is total)
ENSURES: (treemap f t) evaluates to a tree isomorphic to t
in which each x of t has been replaced by f(x).
*)
fun treemap f Empty = Empty
| treemap f (Node(left,x,right)) = Node(treemap f left, f x, treemap f right)
(* turn each integer of an int tree into a string: *)
(* stringify : int tree -> string tree *)
val stringify = treemap Int.toString
val Node(Node(Node(Empty,"1",Empty), "4", Node(Empty,"12",Empty)),
"3", Node(Empty, "7", Node(Empty,"2",Empty)))
= stringify tr
(* combine : 'a tree * 'a tree -> 'a tree
REQUIRES: true
ENSURES: combine(t1,t2) evaluates to a tree containing all the elements of
t1 and all the elements of t2 (with duplicates when such exist).
NOTE: The function finds the leftmost Empty of t1 and inserts t2 there.
*)
fun combine (Empty, t2) = t2
| combine (t1, Empty) = t1 (* don't really need this; just a speedup *)
| combine (Node(left, x, right), t2) = Node(combine(left,t2), x, right)
(*
Here is a version that finds the rightmost Empty of t1 and inserts t2 there.
*)
fun combine (Empty, t2) = t2
| combine (t1, Empty) = t1
| combine (Node(left, x, right), t2) = Node(left, x, combine(right,t2))
(* treefilter : ('a -> bool) -> 'a tree -> 'a tree
REQUIRES: (maybe assume p is total)
ENSURES: (treefilter p t) evaluates to a tree consisting of all
elements of t that satisfy p.
*)
fun treefilter p Empty = Empty
| treefilter p (Node(left,x,right)) =
if p(x) then Node(treefilter p left, x, treefilter p right)
else combine(treefilter p left, treefilter p right)
(* todds : int tree -> int tree
REQUIRES: true
ENSURES: todds T ==> a tree containing the odd integers of T.
*)
val todds = treefilter (fn (n:int) => n mod 2 = 1)
val Node(Node(Empty, 1, Empty), 3, Node(Empty,7 , Empty)) = todds tr
(* Observe that combine might produce an unbalanced tree, even if we *)
(* start with balanced trees. This is similar to the situation we *)
(* encountered when merging trees in the implementation of Msort. *)
(* As there, one possibility is to rebalance after each combine. *)
(* Doing so would produce an implementation of treefilter with work *)
(* and span like that of Msort. *)
(* Here is a function that folds a three-argument function over a tree: *)
(* treeFold : ('a * 'b * 'a -> 'a) -> 'a -> 'b tree -> 'a *)
fun treeFold g z Empty = z
| treeFold g z (Node (l, x, r)) = g (treeFold g z l, x, treeFold g z r)
(* sumTree : int tree -> int *)
val sumTree = treeFold (fn (a,x,b) => a + x + b) 0
val 29 : int = sumTree tr
(************************************************************************)
(* Here is a topic related to Algebraic Topology (a form of geometry).
One is interested in detecting n-dimensional holes in a space,
for instance in some high-dimensional data set.
(An example of a two-dimensional hole is the empty space within a balloon.)
One step in that process is detecting n-dimensional cycles.
(For intuition, think of one-dimensional cycles as edge-cycles in graphs,
two-dimensional cycles as the balloons themselves, etc.)
Functional approaches turn out to be useful in higher dimensions,
with the functions assigning values in some group to n-dimensional
facets, then defining geometric derivative operators.
The code below decides whether a function that assigns integers to
a list of oriented edges (represented as pairs) has a spatial
derivative that is identically zero (indicating a cycle).
(Of course, one can detect cycles in edge lists directly without
a functional intermediate, but this approach generalizes to
more interesting groups and dimensions.)
*)
(* dedge : ('a * 'a -> bool) -> ('a * 'a -> int) -> ('a * 'a) -> 'a -> int
REQUIRES: eq is a total equality function,
f(a,b) returns a value,
eq(a,b) is false.
ENSURES: dedge eq f (a,b)
returns the spatial derivative of the oriented edge (a,b),
meaning a function that has value ~f(a,b) at vertex a,
value +f(a,b) at vertex b,
and value 0 elsewhere.
*)
fun dedge eq f (a,b) x = case (eq(a,x), eq(x,b)) of
(false, true) => f(a,b)
| (true, false) => ~ (f(a,b))
| _ => 0
(* deriv : ('a * 'a -> bool) -> ('a * 'a -> int) -> ('a * 'a) list -> 'a -> int
REQUIRES: eq is a total equality function,
f(e) returns a value for every edge e in the ('a * 'a) list,
none of the edges e in that list are degenerate, meaning
eq(e) is false for all the edges in the list.
ENSURES: deriv eq f L
returns the sum of all the functions (dedge eq f e), with e in L.
(Duplicates in L are ok, they will count with multiplicity.)
*)
fun deriv eq f = foldr (fn (e,g) => (dedge eq f e) ++ g) (fn _ => 0)
(* cyclep : ('a * 'a -> bool) -> ('a * 'a -> int) -> ('a * 'a) list -> bool
REQURES: as for deriv
ENSURES: (cyclep eq f L) returns true if (deriv eq f L) is identically 0,
and false otherwise.
*)
fun cyclep eq f L =
let
val df = deriv eq f L
in
List.all (fn x => (df x) = 0) (foldr (fn ((a,b), xs) => a::b::xs) [] L)
end
val true = cyclep op= (fn _ => 1)
[("A", "B"), ("B", "C"), ("C", "D"), ("D", "A")]
val false = cyclep op= (fn _ => 1)
[("A", "B"), ("B", "C"), ("C", "D")]
val true = cyclep op= (fn _ => 1) [("A", "B"), ("B", "C"), ("C", "A")]
val false = cyclep op= (fn _ => 1) [("A", "B"), ("B", "C"), ("A", "C")]
val true = cyclep op= (fn ("A", "C") => ~1 | _ => 1)
[("A", "B"), ("B", "C"), ("A", "C")]
val true = cyclep op= (fn ("A", "C") => ~1 | _ => 2)
[("A", "B"), ("B", "C"), ("A", "C"), ("A", "C")]
(************************************************************************)