(*
import "../COMPARE/COMPARE";
import "../SET/SET";
*)

(* this is an implementation of SET based on 2-3-4 trees 
   (c.f. Sedgewick: "Algorithms", Chapter 15. "Balanced Trees") 
   Alexander Horz (03-07-89) 
*)

functor TreeSet(C :COMPARE)
    : sig
         include SET
         exception Compare
         val disjoint : set -> set -> bool
	 val get  : E.element -> set -> E.element
      end =
struct
    exception Compare
    exception Stop

    structure E = C
    open E
	
    structure Eq =
	struct
	    type element = element
	    val put = put
	    val format = format
	    fun eq e1 e2 = compare e1 e2 = Equal
	end

    datatype relation =
	Child
      | Sibling

    datatype set =
	EMPTY
      | TREE of (relation * set * element * set)

    exception Choose

    val empty_set = EMPTY

    fun add key t =
	let
	    fun f EMPTY =
		TREE(Child,EMPTY,key,EMPTY)
	      | f (TREE(Child,l,k,r)) =
		  (case compare key k of
		    Greater =>
			TREE(Child,l,k, f r)
		  | Less => 
			TREE(Child,f l, k, r)
		  | Equal => 
			TREE(Child,l,key,r)
		  | None =>
			raise Compare)
	      | f (TREE(Sibling,l,k,r)) =
		case compare key k of
		    Greater =>
		    ( case f r
			of r as TREE(Child, rl as TREE(Child,rll,rlk,rlr) ,rk,rr) =>
			    (case l
				 of TREE(Child,ll,lk,lr) =>
				     TREE(Child,TREE(Sibling,ll,lk,lr),k,
					  TREE(Sibling,rl,rk,rr))
			       | _ => TREE(Sibling,TREE(Child,l,k,rll),rlk,
					    TREE(Child,rlr,rk,rr)))
		      | r as TREE(Child,rl,rk, rr as TREE(Child,rrl,rrk,rrr)) =>
			    (case l
				 of TREE(Child,ll,lk,lr) =>
				     TREE(Child,TREE(Sibling,ll,lk,lr),k,
					  TREE(Sibling,rl,rk,rr))
			       | _ => TREE(Sibling,TREE(Child,l,k,rl),rk,rr))
		      | r => TREE(Sibling,l,k,r) )
		 |  Less => 
		    (  case f l
			     of l as TREE(Child,ll,lk, lr as TREE(Child,lrl,lrk,lrr)) =>
				 (case r
				      of TREE(Child,rl,rk,rr) =>
					  TREE(Child,TREE(Sibling,ll,lk,lr),k,
					       TREE(Sibling,rl,rk,rr))
				    | _ => TREE(Sibling,TREE(Child,ll,lk,lrl),lrk,
						TREE(Child,lrr,k,r)))
			   | l as TREE(Child,ll as TREE(Child,lll,llk,llr),lk,lr) =>
				 (case r
				      of TREE(Child,rl,rk,rr) =>
					  TREE(Child,TREE(Sibling,ll,lk,lr),k,
					       TREE(Sibling,rl,rk,rr))
				    | _ => TREE(Sibling,ll,lk,TREE(Child,lr,k,r)))
			   | l => TREE(Sibling,l,k,r))
		 | Equal =>
		       TREE(Sibling,l,key,r)
		 | None =>
		       raise Compare
	in
	    case f t of
		TREE(Child, l as TREE(Child,_,_,_), k, r) => TREE(Sibling,l,k,r)
	      | TREE(Child,l, k, r as TREE(Child,_,_,_)) => TREE(Sibling,l,k,r)
	      | t => t
	end

    fun member key s =
	let
	    fun look EMPTY =
		false
	      | look (TREE(_,l,k,r)) =
		case compare k key of
		    Greater =>
			look l
		  | Less => 
			look r
		  | Equal => 
			 true
		  | None =>
			raise Compare
	in
	    look s
	end

      fun get key s =
	let
	    fun look EMPTY =
		raise Compare
	      | look (TREE(_,l,k,r)) =
		case compare k key of
		    Greater =>
			look l
		  | Less => 
			look r
		  | Equal => 
			 k
		  | None =>
			raise Compare
	in
	    look s
	end

    

    fun reduce f start s  =
	let fun scan EMPTY value =
	        value
	      | scan (TREE(_,l,k,r)) value =
		scan r (f k (scan l value))
	in
	    scan s start
	end

    fun revreduce f start s  =
	let fun scan EMPTY value =
	        value
	      | scan (TREE(_,l,k,r)) value =
		scan l (f k (scan r value))
	in
	    scan s start
	end

    (* equal_tree : test if two trees are equal.  Two trees are equal if
     the set of leaves are equal *)

    fun eqsets EMPTY EMPTY =
	true
      | eqsets (tree1 as (TREE _)) (tree2 as (TREE _)) =
	let datatype pos = L | R | M
	    exception Done
	    fun getvalue(stack as ((a,position)::b)) =
		(case a
		     of (TREE(_,l,k,r)) =>
			 (case position
			      of L => getvalue ((l,L)::(a,M)::b)
			    | M => (k,case r of  EMPTY => b | _ => (a,R)::b)
			    | R => getvalue ((r,L)::b)
				  )
		   | EMPTY => getvalue b
			 )
	      | getvalue(nil) = raise Done
	    fun f (nil,nil) = true
	      | f (s1 as (_ :: _),s2 as (_ :: _ )) =
		let val (v1,news1) = getvalue s1
		    and (v2,news2) = getvalue s2
		in ((compare v1 v2) = Equal) andalso f(news1,news2)
		end
	      | f _ = false
	in f ((tree1,L)::nil,(tree2,L)::nil) handle Done => false
	end
      | eqsets _ _ =
	false

    fun empty EMPTY = true
      | empty _ = false

    fun list (s :set) =
	revreduce (fn (e :element) => (fn (l :element list) => e::l)) [] s

    fun set l =
	List.fold (fn (e, s) => add e s) l empty_set

    fun filter f s =
	reduce (fn a => (fn s' => 
		   if f a then 
			 add a s' 
		   else 
			 s')) 
	       empty_set 
	       s

    (* !!! there MUST be more efficient implementation !!! *)
    fun remove e s =
	reduce (fn e' => (fn s' => 
		   if (compare e e' = Equal) then 
			  s'
		   else
			  add e' s'))
	       empty_set
	       s

    fun choose (s as TREE(_,l,k,r)) = (k, remove k s)
      | choose EMPTY = raise Choose
   
    (* !!! there MUST be a factor O(log n) more efficient implementation !!! *)
    fun difference s1 s2 =
        filter (fn e => not (member e s2)) 
	       s1

    fun subset s1 s2 =
	reduce (fn e => (fn true  => member e s2 
                          | false => raise Stop))
	       true
	       s1
	handle Stop => false
	
    fun disjoint s1 s2 =
	not (reduce (fn e => (fn false  => member e s2 
                          | true => raise Stop))
		    false
		    s1    )
	   handle Stop => false 

    fun singleton e = add e EMPTY

    fun size s = reduce (fn e => (fn c => c + 1)) 0 s

    fun union s1 s2 =
	if (size s1 :int) > (size s2) then
	    reduce (fn e => (fn s' => add e s')) s1 s2
	else
	    reduce (fn e => (fn s' => add e s')) s2 s1

    fun intersection s1 s2 = 
	if (size s1) > (size s2) then
	    filter (fn e => member e s2) s1
	else
	    filter (fn e => member e s1) s2

    fun put_set os s =
	(outputc os "\n{ ";
	 reduce (fn e => (fn () => (put os e; outputc os " "))) {} s;
	 outputc os "}\n")

    local open Pretty 
    in
	fun format s =
	    let val n = size s
	    in
		case n of
		    0 => string "()"
		  | _ => 
                         let val (e,s') = choose s
			     fun f s = 
			     if empty s then 
				[string ")"]
			     else
				 let val (e,s') = choose s
				 in
				     (break 1)::(E.format e)::(f s')
				 end
			 in
			     block(1, (string "(")::(E.format e)::(f s'))
			 end
	    end
    end
end

