(* This is an optimized version of byte copy function used in the Fox * project. All the array bound checks can be eliminated. There are * 13 type annotations, which consists of roughly 20% of the code *) structure Bcopy = struct assert sub1 <| {n:nat, i:nat| i < n } array(n) * int(i) -> byte1 and update1 <| {n:nat, i:nat| i < n } array(n) * int(i) * byte1 -> unit assert sub2 <| {n:nat, i:nat| i + 1 < n } array(n) * int(i) -> byte2 and update2 <| {n:nat, i:nat| i + 1 < n } array(n) * int(i) * byte2 -> unit assert sub4 <| {n:nat, i:nat| i + 3 < n } array(n) * int(i) -> byte4 and update4 <| {n:nat, i:nat| i + 3 < n } array(n) * int(i) * byte4 -> unit assert << <| byte4 * int -> byte4 and || <| byte4 * byte4 -> byte4 and >> <| byte4 * int -> byte4 fun{m:nat, n:nat, endsrc:nat} unaligned(src, srcpos, endsrc, dest, destpos) = let fun loop(i,j) = if (i >= endsrc) then () else (update1(dest, j, sub1(src, i)); loop(i+1, j+1)) where loop <| {i:nat, j:nat | j + endsrc - i <= n } int(i) * int(j) -> unit in loop(srcpos, destpos) end where unaligned <| {srcpos:nat, destpos:nat | endsrc <= m /\ destpos + endsrc - srcpos <= n } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) -> unit fun{m:nat, n:nat, endsrc:nat} common(src, srcpos, endsrc, dest, destpos) = case endsrc - srcpos of 1 => (update1(dest, destpos, sub1(src, srcpos))) | 2 => (update1(dest, destpos, sub1(src, srcpos)); update1(dest, destpos+1, sub1(src, srcpos+1))) | 4 => (update1(dest, destpos, sub1(src, srcpos)); update1(dest, destpos+1, sub1(src, srcpos+1)); update1(dest, destpos+2, sub1(src, srcpos+2)); update1(dest, destpos+3, sub1(src, srcpos+3))) | 8 => (update1(dest, destpos, sub1(src, srcpos)); update1(dest, destpos+1, sub1(src, srcpos+1)); update1(dest, destpos+2, sub1(src, srcpos+2)); update1(dest, destpos+3, sub1(src, srcpos+3)); update1(dest, destpos+4, sub1(src, srcpos+4)); update1(dest, destpos+5, sub1(src, srcpos+5)); update1(dest, destpos+6, sub1(src, srcpos+6)); update1(dest, destpos+7, sub1(src, srcpos+7))) | 16 => (update1(dest, destpos, sub1(src, srcpos)); update1(dest, destpos+1, sub1(src, srcpos+1)); update1(dest, destpos+2, sub1(src, srcpos+2)); update1(dest, destpos+3, sub1(src, srcpos+3)); update1(dest, destpos+4, sub1(src, srcpos+4)); update1(dest, destpos+5, sub1(src, srcpos+5)); update1(dest, destpos+6, sub1(src, srcpos+6)); update1(dest, destpos+7, sub1(src, srcpos+7)); update1(dest, destpos+8, sub1(src, srcpos+8)); update1(dest, destpos+9, sub1(src, srcpos+9)); update1(dest, destpos+10, sub1(src, srcpos+10)); update1(dest, destpos+11, sub1(src, srcpos+11)); update1(dest, destpos+12, sub1(src, srcpos+12)); update1(dest, destpos+13, sub1(src, srcpos+13)); update1(dest, destpos+14, sub1(src, srcpos+14)); update1(dest, destpos+15, sub1(src, srcpos+15))) | 20 => (update1(dest, destpos, sub1(src, srcpos)); update1(dest, destpos+1, sub1(src, srcpos+1)); update1(dest, destpos+2, sub1(src, srcpos+2)); update1(dest, destpos+3, sub1(src, srcpos+3)); update1(dest, destpos+4, sub1(src, srcpos+4)); update1(dest, destpos+5, sub1(src, srcpos+5)); update1(dest, destpos+6, sub1(src, srcpos+6)); update1(dest, destpos+7, sub1(src, srcpos+7)); update1(dest, destpos+8, sub1(src, srcpos+8)); update1(dest, destpos+9, sub1(src, srcpos+9)); update1(dest, destpos+10, sub1(src, srcpos+10)); update1(dest, destpos+11, sub1(src, srcpos+11)); update1(dest, destpos+12, sub1(src, srcpos+12)); update1(dest, destpos+13, sub1(src, srcpos+13)); update1(dest, destpos+14, sub1(src, srcpos+14)); update1(dest, destpos+15, sub1(src, srcpos+15)); update1(dest, destpos+16, sub1(src, srcpos+16)); update1(dest, destpos+17, sub1(src, srcpos+17)); update1(dest, destpos+18, sub1(src, srcpos+18)); update1(dest, destpos+19, sub1(src, srcpos+19))) | _ => unaligned(src, srcpos, endsrc, dest, destpos) where common <| {srcpos:nat, destpos:nat | endsrc <= m /\ destpos + endsrc - srcpos <= n } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) -> unit fun{m:nat, n:nat, endsrc:nat} sixteen(src, srcpos, endsrc, dest, destpos) = let fun loop(i, j) = if i >= endsrc then () else (update4(dest, j, sub4(src, i)); update4(dest, j+4, sub4(src, i+4)); update4(dest, j+8, sub4(src, i+8)); update4(dest, j+12, sub4(src, i+12)); loop(i+16, j+16)) where loop <| {i:nat, j:nat | (endsrc - i) mod 16 = 0 /\ j + endsrc - i <= n } int(i) * int(j) -> unit in loop(srcpos, destpos) end where sixteen <| {srcpos:nat, destpos:nat | endsrc <= m /\ (endsrc - srcpos) mod 16 = 0 /\ destpos + endsrc - srcpos <= n } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) -> unit fun{srcalign:nat} aligned(src, srcpos, endsrc, dest, destpos, srcalign, bytes) = let val front = (case srcalign of 0 => 0 | 1 => 3 | 2 => 2 | 3 => 1) <| [i:nat | (srcalign = 0 /\ i = 0) \/ (srcalign = 1 /\ i = 3) \/ (srcalign = 2 /\ i = 2) \/ (srcalign = 3 /\ i = 1) ] int(i) val rest = bytes - front val tail = rest mod 16 val middle = rest - tail val midsrc = srcpos + front val middest = destpos + front val backsrc = midsrc + middle val backdest = middest + middle in unaligned(src, srcpos, midsrc, dest, destpos); sixteen(src, midsrc, backsrc, dest, middest); unaligned(src, backsrc, endsrc, dest, backdest) end where aligned <| {m:nat, n:nat, srcpos:nat, endsrc:nat, destpos:nat, bytes:nat | endsrc <= m /\ srcpos + bytes = endsrc /\ destpos + bytes <= n /\ 16 <= bytes } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) * int(srcalign) * int(bytes) -> unit fun{m:nat, n:nat, endsrc:nat} eightlittle(src, srcpos, endsrc, dest, destpos) = let assert makebyte2 <| byte4 -> byte2 and makebyte4 <| byte2 -> byte4 fun loop(i, j, carry) = if i >= endsrc then update2(dest, j, makebyte2(carry)) else let val srcv = sub4(src, i) in update4(dest, j, ||(carry, <<(srcv, 16))); let val i = i + 4 val j = j + 4 val carry = >>(srcv, 16) val srcv = sub4(src, i) in update4(dest, j, ||(carry, <<(srcv, 16))); loop(i+4, j+4, >>(srcv, 16)) end end where loop <| {i:nat, j:nat | i <= endsrc /\ (endsrc - i) mod 8 = 0 /\ j + endsrc - i + 2 <= n } int(i) * int(j) * byte4 -> unit in loop(srcpos+2, destpos, makebyte4(sub2(src, srcpos))) end where eightlittle <| {srcpos:nat, destpos:nat | endsrc <= m /\ srcpos <= endsrc /\ (endsrc - srcpos) mod 8 = 2 /\ destpos + endsrc - srcpos <= n } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) -> unit fun{m:nat, n:nat, endsrc:nat} eightbig(src, srcpos, endsrc, dest, destpos) = let assert makebyte2 <| byte4 -> byte2 and makebyte4 <| byte2 -> byte4 fun loop(i, j, carry) = if i >= endsrc then update2(dest, j, makebyte2(>>(carry, 16))) else let val srcv = sub4(src, i) in update4(dest, j, ||(carry, >>(srcv, 16))); let val i = i + 4 val j = j + 4 val carry = <<(srcv, 16) val srcv = sub4(src, i) in update4(dest, j, ||(carry, >>(srcv, 16))); loop(i + 4, j + 4, <<(srcv, 16)) end end where loop <| {i:nat, j:nat | i <= endsrc /\ (endsrc - i) mod 8 = 0 /\ j + endsrc - i + 2 <= n } int(i) * int(j) * byte4 -> unit in loop(srcpos + 2, destpos, <<(makebyte4(sub2(src, srcpos)), 16)) end where eightbig <| {srcpos:nat, destpos:nat | endsrc <= m /\ srcpos <= endsrc /\ (endsrc - srcpos) mod 8 = 2 /\ destpos + endsrc - srcpos <= n } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) -> unit assert endian <| int and Little <| int fun eight(src, srcpos, endsrc, dest, destpos) = if endian = Little then eightbig(src, srcpos, endsrc, dest, destpos) else eightlittle(src, srcpos, endsrc, dest, destpos) where eight <| {m:nat, n:nat, endsrc:nat, srcpos:nat, destpos:nat | endsrc <= m /\ srcpos <= endsrc /\ (endsrc - srcpos) mod 8 = 2 /\ destpos + endsrc - srcpos <= n } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) -> unit fun{srcalign:nat} semialigned(src, srcpos, endsrc, dest, destpos, srcalign, bytes) = let val front = (case srcalign of 0 => 2 | 2 => 0 | 1 => 1 | 3 => 3) <| [i:nat | (srcalign = 0 /\ i = 2) \/ (srcalign = 2 /\ i = 0) \/ (srcalign = 1 /\ i = 1) \/ (srcalign = 3 /\ i = 3) ] int(i) val rest = bytes -front val tail = (rest - 2) mod 8 val middle = rest - tail val midsrc = srcpos + front val middest = destpos + front val backsrc = midsrc + middle val backdest = middest + middle in unaligned(src, srcpos, midsrc, dest, destpos); eight(src, midsrc, backsrc, dest, middest); unaligned(src, backsrc, endsrc, dest, backdest) end where semialigned <| {m:nat, n:nat, srcpos:nat, endsrc:nat, destpos:nat, bytes:nat | endsrc <= m /\ srcpos + bytes = endsrc /\ destpos + bytes <= n /\ 16 <= bytes } array(m) * int(srcpos) * int(endsrc) * array(n) * int(destpos) * int(srcalign) * int(bytes) -> unit fun copy(src, srcpos, bytes, dest, destpos) = if (bytes < 25) then common(src, srcpos, srcpos + bytes, dest, destpos) else let val srcalign = srcpos mod 4 val destalign = destpos mod 4 val endsrc = srcpos + bytes in if srcalign = destalign then aligned(src, srcpos, endsrc, dest, destpos, srcalign, bytes) else if (srcalign + destalign) mod 2 = 0 then semialigned(src, srcpos, endsrc, dest, destpos, srcalign, bytes) else unaligned(src, srcpos, endsrc, dest, destpos) end where copy <| {m:nat, n:nat, srcpos:nat, bytes:int, destpos:nat | srcpos + bytes <= m /\ destpos + bytes <= n } array(m) * int(srcpos) * int(bytes) * array(n) * int(destpos) -> unit end