(* 15-150, Lec 2 *)
(* Notation
T ~ T' means T and T' have the same elements
L <= x means every element in tree L is <= x
x <= R means x is <= every element in tree R
*)
(*
Control.Print.printDepth := 20;
Control.printWarnings := false;
*)
datatype tree = Empty | Node of tree * int * tree
fun trav Empty = []
| trav (Node(T1,x,T2)) = trav T1 @ [x] @ trav T2
(* splitAt : int * tree -> tree * tree
* REQUIRES T sorted
* ENSURES splitAt(x, T) ==> (L, R) where L <= x <= R
* and T ~ L,R
*)
fun splitAt (x, Empty) = (Empty, Empty)
| splitAt (x, Node(L, y, R)) = (* L <= y <= R *)
(case Int.compare(x,y)
of LESS => (case splitAt(x, L) (* x < y *)
of (L1, L2) => (* L1 <= x <= L2 *)
(L1, Node(L2, y, R))
)
| _ => (case splitAt(x, R) (* y <= x *)
of (R1, R2) => (* R1 <= x <= R2 *)
(Node(L, y, R1), R2)
)
)
fun leaf x = Node(Empty, x, Empty)
val t12345 = Node(Node(leaf 1, 2, leaf 3), 4, leaf 5)
val (t12_3,t3_45) = splitAt(3, t12345)
val [1,2,3,4,5] = trav t12_3 @ trav t3_45
(* merge : tree * tree -> tree
* REQUIRES T1, T2 sorted
* ENSURES merge (T1, T2) ==> T' where T' sorted and T' ~ T1,T2
*)
fun merge (Empty, T2) = T2
| merge (T1, Empty) = T1
| merge (Node(L1,x,R1), T2) = (* L1 <= x <= R1, by T1 sorted *)
(case splitAt (x,T2)
of (L2, R2) => (* L2 <= x <= R2, by spec of splitAt *)
Node(merge(L1, L2), x, merge(R1, R2))
)
val [1,2,3,4,5] = trav (merge(t12_3, t3_45))
(* msort : tree -> tree
* REQUIRES true
* ENSURES msort T ==> T' where T' sorted and T ~ T'
*)
fun msort Empty = Empty
| msort (Node(T1, x, T2)) =
merge (msort T1, merge (leaf x, msort T2))
val t521334 = Node(Node(leaf 5, 2, leaf 1), 3, Node(leaf 3, 4, Empty))
val t123345 = msort t521334
val [1,2,3,3,4,5] = trav t123345