(* intset2.sml
 *
 * COPYRIGHT (c) 1996 Bell Laboratories.
 *)

(* Name(s): Stephen Adams; modfied by A. Appel
   Department, Institution: Electronics & Computer Science,
      University of Southampton
   Address:  Electronics & Computer Science
             University of Southampton
	     Southampton  SO9 5NH
	     Great Britian
   E-mail:   sra@ecs.soton.ac.uk

   Comments:

     1.  The implementation is based on Binary search trees of Bounded
         Balance, similar to Nievergelt & Reingold, SIAM J. Computing
         2(1), March 1973.  The main advantage of these trees is that
         they keep the size of the tree in the node, giving a constant
         time size operation.

     2.  The bounded balance criterion is simpler than N&R's alpha.
         Simply, one subtree must not have more than `weight' times as
         many elements as the opposite subtree.  Rebalancing is
         guaranteed to reinstate the criterion for weight>2.23, but
         the occasional incorrect behaviour for weight=2 is not
         detrimental to performance.

     3.  There are two implementations of union.  The default,
         hedge_union, is much more complex and usually 20% faster.  I
         am not sure that the performance increase warrants the
         complexity (and time it took to write), but I am leaving it
         in for the competition.  It is derived from the original
         union by replacing the split_lt(gt) operations with a lazy
         version. The `obvious' version is called old_union.

     4.  Most time is spent in T', the rebalancing constructor.  If my
         understanding of the output of *<file> in the sml batch
         compiler is correct then the code produced by NJSML 0.75
         (sparc) for the final case is very disappointing.  Most
         invocations fall through to this case and most of these cases
         fall to the else part, i.e. the plain contructor,
         T(v,ln+rn+1,l,r).  The poor code allocates a 16 word vector
         and saves lots of registers into it.  In the common case it
         then retrieves a few of the registers and allocates the 5
         word T node.  The values that it retrieves were live in
         registers before the massive save.
*)

signature INTSET = sig type intset
                       val empty : intset
		       val minimum : intset -> int
		       val singleton : int -> intset
		       val union : intset * intset -> intset
		       val add : intset * int -> intset
		       val addList : intset * int list -> intset
		       val intersection : intset * intset -> intset
		       val member : int * intset -> bool
		       val members : intset -> int list
		       val cardinality : intset -> int
                       val difference : intset * intset -> intset
                       val delete : intset * int -> intset
		       val deleteList : intset * int list -> intset
                   end

structure IntSet : INTSET =
    struct

	local

	    val weight = 3

	    datatype  Set = E | T of int * int * Set * Set

	    fun size E = 0
	      | size (T(_,n,_,_)) = n
	    
	    (*fun N(v,l,r) = T(v,1+size(l)+size(r),l,r)*)
	    fun N(v,E,              E)               = T(v,1,E,E)
	      | N(v,E,              r as T(_,n,_,_)) = T(v,n+1,E,r)
	      | N(v,l as T(_,n,_,_),E)               = T(v,n+1,l,E)
	      | N(v,l as T(_,n,_,_),r as T(_,m,_,_)) = T(v,n+m+1,l,r)

	    fun single_L (v,l,T(rv,_,rl,rr)) = N(rv,N(v,l,rl),rr)
	    fun single_R (v,T(lv,_,ll,lr),r) = N(lv,ll,N(v,lr,r))

	    fun double_L (v,l,T(rv,_,T(rlv,_,rll,rlr),rr)) = 
					N(rlv,N(v,l,rll),N(rv,rlr,rr))

	    fun double_R (v,T(lv,_,ll,T(lrv,_,lrl,lrr)),r) = 
		N(lrv,N(lv,ll,lrl),N(v,lrr,r))

	    fun T' (v,E,E) = T(v,1,E,E)
	      | T' (v,E,r as T(_,_,E,E))     = T(v,2,E,r)
	      | T' (v,l as T(_,_,E,E),E)     = T(v,2,l,E)

	      | T' (v,l as E,r as T(rv,_,rl as T(rlv,_,rll,rlr),rr as E)) = 
					N(rlv,N(v,l,rll),N(rv,rlr,rr))
	      | T' (v,l as T(lv,_,ll as E,lr as T(lrv,_,lrl,lrr)),r as E) = 
					N(lrv,N(lv,ll,lrl),N(v,lrr,r))

	      (* these cases almost never happen with small weight*)
	      | T' (v,l as E,
		      r as T(rv,_,rl as T(rlv,rln,rll,rlr),rr as T(rrv,rrn,rrl,rrr))) =
		if rln<rrn then N(rv,N(v,l,rl),rr) 
			   else N(rlv,N(v,l,rll),N(rv,rlr,rr))
	      | T' (v,l as T(lv,_,ll as T(_,lln,_,_),lr as T(lrv,lrn,lrl,lrr)),
		      r as E) =
		if lln>lrn then N(lv,ll,N(v,lr,r))
			   else N(lrv,N(lv,ll,lrl),N(v,lrr,r))

	      | T' (v,l as E,r as T(rv,_,rl as E,rr))  = N(rv,N(v,l,rl),rr)
	      | T' (v,l as T(lv,_,ll,lr as E),r as E)  = N(lv,ll,N(v,lr,r))

	      | T' (v,l as T(lv,ln,ll,lr), r as T(rv,rn,rl,rr)) =
                if rn >= weight*ln 
		then case rl
		      of T(rlv,rln,rll,rlr) =>
					if rln < size rr
					then  N(rv,N(v,l,rl),rr)  
		 			else  N(rlv,N(v,l,rll),N(rv,rlr,rr))
		       | E => N(rv,N(v,l,rl),rr)
	        else if ln>=weight*rn
		then case lr
		      of T(lrv,lrn,lrl,lrr) =>
					if lrn < size ll 
					then  N(lv,ll,N(v,lr,r))
					else  N(lrv,N(lv,ll,lrl),N(v,lrr,r))
		       | E => N(lv,ll,N(v,lr,r))
		else T(v,ln+rn+1,l,r)

	    fun add (E,x) = T(x,1,E,E)
	      | add (set as T(v,_,l,r),x) =
	        if x<v then T'(v,add(l,x),r)
		else if x>v then T'(v,l,add(r,x))
		     else set

	    fun concat3 (E,v,r) = add(r,v)
	      | concat3 (l,v,E) = add(l,v)
	      | concat3 (l as T(v1,n1,l1,r1), v, r as T(v2,n2,l2,r2)) =
		if weight*n1 < n2 then T'(v2,concat3(l,v,l2),r2)
		else if weight*n2 < n1 then T'(v1,l1,concat3(r1,v,r))
		     else N(v,l,r)

	    fun split(E,x) = (E,E)
              | split(t as T(v,_,E,E),x) =
	           if v>x then (E,t)
	           else if v<x then (t,E)
                   else (E,E)
              | split(t as T(v,_,E,r  as T(rv,_,rl,rr)),x) =
		   if rv>x
                   then if v<x
		        then let val (a,b) = split(rl,x)
			      in (add(a,v),concat3(b,rv,rr))
		             end
                        else if v>x
			     then (E,t)
			     else (E,r)
                   else if rv<x
			then let val (a,b) = split(rr,x)
			      in (concat3(T(v,1,E,E),rv,a), b)
			     end
                        else (add(rl,v),rr)
              | split(t as T(v,_,l as T(lv,_,ll,lr),E),x) =
		   if lv<x
                   then if v>x
		        then let val (a,b) = split(lr,x)
			      in (concat3(ll,lv,a),add(b,v))
		             end
                        else if v<x
			     then (t,E)
			     else (l,E)
                   else if lv>x
			then let val (a,b) = split(ll,x)
			      in (a, concat3(b,lv,T(v,1,E,E)))
			     end
                        else (ll, add(lr,v))
              | split(T(v,_,l as T(lv,_,ll,lr), r as T(rv,_,rl,rr)),x) =
		   if v>x
		   then if lv>x
			then let val (a,b) = split(ll,x)
			      in (a,concat3(concat3(b,lv,lr),v,r))
                             end
                        else if lv<x
                             then let val (a,b) = split(lr,x)
				   in (concat3(ll,lv,a),concat3(b,v,r))
                                  end
                             else (ll,concat3(lr,v,r))
                   else if v<x
                        then if rv>x
                             then let val (a,b) = split(rl,x)
                                   in (concat3(l,v,a),concat3(b,rv,rr))
                                  end
                             else if rv<x
                                  then let val (a,b) = split(rr,x)
                                        in (concat3(l,v,concat3(rl,rv,a)),b)
                                       end
				  else (concat3(l,v,rl),rr)
                        else (l,r)
(*              | split(T(v,_,l,r),x) =
	           if v>x 
		   then let val (a,b) = split(l,x)
		         in (a,concat3(b,v,r))
		        end
                   else if v<x then 
                        let val (a,b) = split(r,x)
		         in (concat3(l,v,a),b)
			end
		   else (l,r)
*)
	    fun min (T(v,_,E,_)) = v
	      | min (T(v,_,l,_)) = min l
	      | min _            = raise Match
		
	    and delete' (E,r) = r
	      | delete' (l,E) = l
	      | delete' (l,r) = let val min_elt = min r in
		T'(min_elt,l,delmin r)
				end
	    and delmin (T(_,_,E,r)) = r
	      | delmin (T(v,_,l,r)) = T'(v,delmin l,r)
	      | delmin _ = raise Match

	    fun concat (E,  s2) = s2
	      | concat (s1, E)  = s1
	      | concat (t1 as T(v1,n1,l1,r1), t2 as T(v2,n2,l2,r2)) =
		if weight*n1 < n2 then T'(v2,concat(t1,l2),r2)
		else if weight*n2 < n1 then T'(v1,l1,concat(r1,t2))
		     else T'(min t2,t1, delmin t2)

	    fun fold(f,base,set) =
		let fun fold'(base,E) = base
		      | fold'(base,T(v,_,l,r)) = fold'(f(v,fold'(base,r)),l)
		in 
		    fold'(base,set)
		end

	in

	    type  intset = Set

	    val empty = E
		
            val minimum = min

	    fun singleton x = T(x,1,E,E)

	    fun union (E,s2)  = s2
	      | union (s1,E)  = s1
	      | union (T(v,_,l,r),s2) = 
		let val (l2,r2) = split(s2,v)
		in
		    concat3(union(l,l2),v,union(r,r2))
		end

	    val add = add

	    fun difference (E,s)  = E
	      | difference (s,E)  = s
	      | difference (s, T(v,_,l,r)) =
		let val (l2,r2) = split(s,v)
		in
		    concat(difference(l2,l),difference(r2,r))
		end

	    fun member (x,set) =
		let fun mem E = false
		      | mem (T(v,_,l,r)) =
			if x<v then mem l else if x>v then mem r else true
		in mem set end

	    (*fun intersection (a,b) = difference(a,difference(a,b))*)

	    fun intersection (E,_) = E
	      | intersection (_,E) = E
	      | intersection (s, T(v,_,l,r)) =
		let val (l2,r2) = split(s,v)
		in
		    if member(v,s) then
			concat3(intersection(l2,l),v,intersection(r2,r))
		    else
			concat(intersection(l2,l),intersection(r2,r))
		end

	    fun members set = fold(op::,[],set)

	    fun cardinality E = 0
	      | cardinality (T(_,n,_,_)) = n
	    
	    fun delete (E,x) = E
	      | delete (set as T(v,_,l,r),x) =
		if x<v then T'(v,delete(l,x),r)
		else if x>v then T'(v,l,delete(r,x))
		     else delete'(l,r)

	    fun fromList l = List.fold (fn(x,y)=>add(y,x)) l E

	    fun addList(set,[]) = set
	      | addList(set,x::xs) = addList(add(set,x),xs)

	    fun deleteList(set,[]) = set
	      | deleteList(set,x::xs) = deleteList(delete(set,x),xs)
	end
    end
