(*

	FoxNet: The Fox Project's Communication Protocol Implementation Effort
	Edoardo Biagioni (esb@cs.cmu.edu)
	Brian Milnes (Brian.Milnes@cs.cmu.edu)
	Nick Haines  (Nick.Haines@cs.cmu.edu)
	Fox Project
	School of Computer Science
	Carnegie Mellon University
	Pittsburgh, Pa 15139-3891

		i.	Abstract

	A one's complement, 16-bit checksum module. 

		ii.	Table of Contents

	i.	Abstract
	ii.	Table of Contents
	1.	functor Checksum
	2.	constants
	3.	one_s functions
	4.	function check_partial
	5.	function complete_partial
	6.	function checksum


		1.	functor Checksum
*)


functor Checksum (structure V: VENDOR 
                  structure Debug: DEBUG): CHECKSUM =
 struct

  type partial_state = {previous: Word16.word, odd: Word8.word option} 
  val initial_state = {previous = Word16.fromInt 0, odd = NONE}

  exception Checksum_Bounds

(*
		2.	constants
*)

  val max_int = 0xffff
  val max_word = Word.fromInt 0xffff
  val max16 = Word16.fromInt max_int
  val max32 = Word32.fromInt max_int
  val limit = 0x10000
  val limit32 = Word32.fromInt limit
  val byte_mask = 0xff
  val byte_mask32 = Word32.fromInt byte_mask
  val one16 = Word16.fromInt 1
  val one32 = Word32.fromInt 1
  val zero8 = Word8.fromInt 0
  val zero16 = Word16.fromInt 0
  val zero32 = Word32.fromInt 0

(*
		3.	one_s functions
*)

  val one_s_complement = Word16.notb

  fun one_s_add (a, b) =
       let val sum = Word16.toInt a + Word16.toInt b
	   fun andb (i,j) =
	     Word32.toInt (Word32.andb (Word32.fromInt i, Word32.fromInt j))
	   val masked = andb (sum, max_int)
	   val corrected = if masked = sum then sum else masked + 1
       in Word16.fromInt corrected
       end

(*
		4.	function check_partial

	RFC 1071 gives various techniques for computing fast checksums.
	Of the four, this loop implements 2:  Deferred carries,
	where we add the 16-bit values in 32-bit accumulators and
	only fold the result into a 16-bit value at the end, and
	unwinding loops: we read 4 bytes at a time.
*)

  local
   fun build_bytes (left, right) =
        let val left_int = Word8.toInt left
	    val right_int = Word8.toInt right
	    fun orb (i,j) =
	      Word32.toInt (Word32.orb (Word32.fromInt i, Word32.fromInt j))
	    fun lshift (i,j) =
	      Word32.toInt (Word32.<< (Word32.fromInt i, Word31.fromInt j))
	    val result =
	          if Word16.bigEndian then
		   orb (lshift (left_int, 8), right_int)
		  else
		   orb (left_int, lshift (right_int, 8))
	in Word16.fromInt result
	end

   fun swap_bytes value =
        Word16.orb (Word16.>> (value, 0w8),
		       Word16.<< (value, 0w8))

   fun make_big_endian value =
        if Word16.bigEndian then value else swap_bytes value

   fun convert_array array =
        let val alignment_f = (0w4 - Word_Array.alignment_f array) mod 0w4
            val alignment_r = Word_Array.alignment_r array mod 0w4
	    val unaligned32 = Word_Array.to32 array
            val w8 = Word_Array.to8 array
	    val length8 = Word_Array.W8.U_Big.F.length w8
	in if length8 < 0w4 then
	     (w8,
	      Word_Array.W32.Big.F.create (Word32.fromInt 0, 0w0),
	      Word_Array.W8.Big.F.create (0w0, 0w0))
	   else
	     (Word_Array.W8.U_Big.R.seek (w8, length8 - alignment_f),
	      Word_Array.W32.align unaligned32,
	      Word_Array.W8.U_Big.F.seek (w8, length8 - alignment_r))
	end

   fun initial_odd (previous, odd, NONE) = (previous, odd)
     | initial_odd (previous, NONE, SOME (head, rest)) =
        initial_odd (previous, SOME head, Word_Array.W8.U_Big.F.next rest)
     | initial_odd (previous, SOME odd, SOME (head, rest)) =
        initial_odd (one_s_add (previous, build_bytes (odd, head)), NONE,
		     Word_Array.W8.U_Big.F.next rest)

   fun final_odd (previous, NONE, result, NONE) =
        {previous = one_s_add (previous, result), odd = NONE}
     | final_odd (previous, NONE, result, SOME new_odd) =
        {previous = one_s_add (previous, result), odd = SOME new_odd}
     | final_odd (previous, SOME old_odd, result, NONE) =
        {previous = one_s_add (previous, swap_bytes result),
	 odd = SOME old_odd}
     | final_odd (previous, SOME old_odd, result, SOME new_odd) =
        {previous = one_s_add (one_s_add (previous, swap_bytes result),
			       build_bytes (old_odd, new_odd)), odd = NONE}

   fun final_loop (NONE, accumulator, odd) = (accumulator, odd)
     | final_loop (SOME (head, rest), accumulator, NONE) =
        final_loop (Word_Array.W8.U_Big.F.next rest, accumulator, SOME head)
     | final_loop (SOME (head, rest), accumulator, SOME odd) =
        final_loop (Word_Array.W8.U_Big.F.next rest,
		    one_s_add (accumulator, build_bytes (odd, head)), NONE)

   fun fold_loop n =
        if Word32.<= (n, max32) then n
	else
	 fold_loop (Word32.+ (Word32.>> (n, 0w16),
				 Word32.andb (n, max32)))

   fun check_partial_loop (array8, {previous, odd}, fold_check) =
        let val (initial_bytes, array32, final_bytes) = convert_array array8
	    val (new_previous, new_odd) =
	          initial_odd (previous, odd,
			       Word_Array.W8.U_Big.F.next initial_bytes)
	    val result = Word_Array.W32.Native.F.fold fold_check zero32 array32
	    val folded = fold_loop result
	    val result16 = Word16.fromInt (Word32.toInt folded) 
	    val (final_result, final_odd_value) =
	          final_loop (Word_Array.W8.U_Big.F.next final_bytes,
			      result16, NONE)
	in final_odd (new_previous, new_odd, final_result, final_odd_value)
	end

   fun fold_64k_max (new, accumulator) =
        Word32.+ (Word32.+ (Word32.>> (new, 0w16),
				  Word32.andb (new, max32)),
		     accumulator)

   fun fold_unlimited (new, accumulator) =
        Word32.+ (Word32.+ (Word32.>> (new, 0w16),
				  Word32.andb (new, max32)),
		     Word32.+ (Word32.>> (accumulator, 0w16),
				  Word32.andb (accumulator, max32)))

  in
   fun check_partial (array, partial) =
        check_partial_loop
	   (array, partial,
	    if Word_Array.W8.U_Big.F.length
	          (Word_Array.to8 array) <= max_word then fold_64k_max
	    else fold_unlimited)

(*
		5.	function complete_partial
*)

   fun complete_partial {previous, odd = NONE} = make_big_endian previous
     | complete_partial {previous, odd = SOME byte} =
        make_big_endian (one_s_add (previous, build_bytes (byte, zero8)))

  end (* local *)

(*
		6.	function checksum
*)

  fun checksum array =
       complete_partial (check_partial (array, initial_state))

 end (* struct *)


