(* * This is an example showing that array bounds checking * is not needed for doing quicksort on an array. * The code is copied from SML/NJ lib with some modification. *) (* 16 type annotations, which occupy about 40 lines *) structure Array_QSort = struct datatype order = LESS | EQUAL | GREATER assert sub <| {n:nat, i:nat | i < n } 'a array(n) * int(i) -> 'a and update <| {n:nat, i:nat| i < n } 'a array(n) * int(i) * 'a -> unit and length <| {n:nat} 'a array(n) -> int(n) fun('a){size:nat} sortRange(arr, start, n, cmp) = let fun item i = sub(arr,i) where item <| {i:nat | i < size } int(i) -> 'a fun swap (i,j) = let val tmp = item i in update(arr, i, item j); update(arr, j, tmp) end where swap <| {i:nat, j:nat | i < size /\ j < size } int(i) * int(j) -> unit fun vecswap (i,j,n) = if (n = 0) then () else (swap(i,j);vecswap(i+1,j+1,n-1)) where vecswap <| {i:nat, j:nat, n:nat | i+n <= size /\ j+n <= size } int(i) * int(j) * int(n) -> unit (* insertSort is called if there are less than 8 elements to be sorted *) fun insertSort (start, n) = let val limit = start+n fun outer i = if i >= limit then () else let fun inner j = if j <= start then outer(i+1) else let val j' = j - 1 in case cmp(item j',item j) of GREATER => (swap(j,j'); inner j') | _ => outer(i+1) end where inner <| {j:nat | j < size } int(j) -> unit in inner i end where outer <| {i:nat} int(i) -> unit in outer(start+1) end where insertSort <| {start:nat, n:nat | start+n <= size } int(start) * int(n) -> unit (* calculate the median of three *) fun med3(a,b,c) = let val a' = item a val b' = item b val c' = item c in case (cmp(a', b'),cmp(b', c')) of (LESS, LESS) => b | (LESS, _) => (case cmp(a', c') of LESS => c | _ => a) | (_, GREATER) => b | _ => (case cmp(a', c') of LESS => a | _ => c) (* end case *) end where med3 <| {a:nat,b:nat,c:nat | a < size /\ b < size /\ c < size } int(a) * int(b) * int(c) -> [n:nat | n < size ] int(n) (* generate the pivot for splitting the elements *) fun getPivot (a,n) = if n <= 7 then a + n div 2 else let val p1 = a val pm = a + n div 2 val pn = a + n - 1 in if n <= 40 then med3(p1,pm,pn) else let val d = n div 8 val p1 = med3(p1,p1+d,p1+2*d) val pm = med3(pm-d,pm,pm+d) val pn = med3(pn-2*d,pn-d,pn) in med3(p1,pm,pn) end end where getPivot <| {a:nat,n:nat | 1 < n /\ a + n <= size } int(a) * int(n) -> [p:nat | p < size] int(p) fun quickSort (arg as (a, n)) = let (* this was defined as a higher order function in SML/NJ library *) fun bottom(limit, arg as (pa, pb)) = if pb > limit then arg else case cmp(item pb,item a) of GREATER => arg | LESS => bottom(limit, (pa, pb+1)) | _ => (swap arg; bottom(limit, (pa+1,pb+1))) where bottom <| {l:nat, ppa:nat, ppb:nat | l < size /\ ppa <= ppb <= l+1 } int(l) * (int(ppa) * int(ppb)) -> [pa:nat, pb:nat | ppa <= pa <= pb <= l+1] (int(pa) * int(pb)) (* this was defined as a higher order function in SML/NJ library *) fun top(limit, arg as (pc, pd)) = if limit > pc then arg else case cmp(item pc,item a) of LESS => arg | GREATER => top(limit, (pc-1,pd)) | _ => (swap arg; top(limit, (pc-1,pd-1))) where top <| {l:nat, ppc:nat, ppd:nat | 0 < l <= ppc+1 /\ ppc <= ppd < size } int(l) * (int(ppc) * int(ppd)) -> [pc:nat, pd:nat | l <= pc+1 /\ pc <= pd <= ppd] (int(pc) * int(pd)) fun split (pa,pb,pc,pd) = let val (pa,pb) = bottom(pc, (pa,pb)) val (pc,pd) = top(pb, (pc,pd)) in if pb >= pc then (pa,pb,pc,pd) else (swap(pb,pc); split(pa,pb+1,pc-1,pd)) end where split <| {ppa:nat, ppb:nat, ppc:nat, ppd:nat | 0 < ppa <= ppb <= ppc+1 /\ ppc <= ppd < size } int(ppa) * int(ppb) * int(ppc) * int(ppd) -> [pa:nat, pb:nat, pc:nat, pd:nat | ppa <= pa <= pb <= pc+1 /\ pc <= pd <= ppd ] (int(pa) * int(pb) * int(pc) * int(pd)) val pm = getPivot arg and _ = swap(a,pm) and pa = a + 1 and pc = a + (n-1) and (pa,pb,pc,pd) = split(pa,pa,pc,pc) and pn = a + n val r = min(pa - a, pb - pa) val _ = vecswap(a, pb-r, r) val r = min(pd - pc, pn - pd - 1) val _ = vecswap(pb, pn-r, r) val n' = pb - pa val _ = (if n' > 1 then sort(a,n') else ()) <| unit val n' = pd - pc val _ = (if n' > 1 then sort(pn-n',n') else ()) <| unit in () end where quickSort <| {a:nat, n:nat | 7 <= n /\ a+n <= size } int(a) * int(n) -> unit and sort (arg as (_, n)) = if n < 7 then insertSort arg else quickSort arg where sort <| {a:nat, n:nat | a+n <= size } int(a) * int(n) -> unit in sort (start,n) end where sortRange <| {start:nat, n:nat | start+n <= size } 'a array(size) * int(start) * int(n) * ('a * 'a -> order) -> unit (* sorted checks if a list is well-sorted *) fun('a){size:nat} sorted cmp arr = let val len = length arr fun s(v,i) = let val v' = sub(arr,i) in case cmp(v,v') of GREATER => false | _ => if i+1 = len then true else s(v',i+1) end where s <| {i:nat | i < size } 'a * int(i) -> bool in if len <= 1 then true else s(sub(arr,0),1) end where sorted <| ('a * 'a -> order) -> 'a array(size) -> bool end