(* Copyright 1989 by AT&T Bell Laboratories *)
(* equal.sml --- constructing generic equality functions *)

signature EQUAL = 
  sig val equal : {stringequal : Lambda.lexp, polyequal : Lambda.lexp}
                  -> Modules.env -> Types.ty -> Lambda.lexp
  end

structure Equal : EQUAL = struct

 open ErrorMsg Types Lambda Access BasicTypes TypesUtil 
 open EqTypes Transtypes PrettyPrint
 structure LT = LambdaType

fun transDcon(DATACON{name,rep,typ,orig=NONE,...}) = (name,rep,transTyLty typ)
  | transDcon(DATACON{name,rep,orig=SOME t,...}) = (name,rep,transTyLty t)
fun transBoolDcon(DATACON{name,rep,typ,...}) = (name,rep, LT.inj LT.BOOL)
val trueDcon' = transBoolDcon trueDcon
val falseDcon' = transBoolDcon falseDcon
val trueLexp = CON(trueDcon',RECORD[])
val falseLexp = CON(falseDcon',RECORD[])

val error = ErrorMsg.impossible

fun argType(POLYty{tyfun=TYFUN{arity,body=CONty(_,[domain,_])},...}, args) =
      applyTyfun(TYFUN{arity=arity,body=domain},args)
  | argType(CONty(_,[argty,_]), []) = argty
  | argType _ = error "Equal.argType"

fun polyArgType(POLYty{tyfun=TYFUN{arity,body},...}) = polyArgType body
  | polyArgType(CONty(_,[argty,_])) = argty
  | polyArgType _ = error "Equal.polyArgType"

exception Poly

val boolty = LT.inj LT.BOOL
fun eqty t = LT.injARROW(LT.injRECORD[t,t],boolty)
val inteqty = eqty LT.injINT
val int32eqty = eqty LT.injINT32
val booleqty = eqty boolty
val realeqty = eqty LT.injREAL
val boxedeqty = eqty LT.injBOXED

exception Notfound

fun eqType(ty,ty') =
    let fun eq(ty as CONty(tycon, args), ty' as CONty(tycon', args')) =
	    (case tycon
	      of RELtyc _ => raise Poly
	       | _ =>
		 (case tycon'
		    of RELtyc _ => raise Poly
	             | _ =>
		       if eqTycon(tycon, tycon')
		       then List2.all2 eqType(args,args') 
		       else (eqType(reduceType ty, ty')
			     handle ReduceType =>
				 (eqType(ty,reduceType ty')
				  handle ReduceType => false))))
	  | eq(VARty _, _) = raise Poly
	  | eq(_, VARty _) = raise Poly
	  | eq(POLYty _, _) = raise Poly
	  | eq(_, POLYty _) = raise Poly
	  | eq _ = false
     in eq(prune ty, prune ty')
    end

fun equal {stringequal,polyequal} (env:Modules.env) (concreteType : ty) =
    let val cache : (ty * lexp * lexp ref) list ref = ref nil
	fun enter ty =
	    let val v = VAR(mkLvar())
		val r = ref v
	     in if !Control.debugging 
		then with_pp (ErrorMsg.defaultConsumer())
		      (fn ppstrm =>
		       (add_string ppstrm "enter: ";
			PPType.resetPPType();
			PPType.ppType env ppstrm ty))
		else ();
		cache := (ty, v, r) :: !cache; (v,r)
	    end
	fun find ty =
	    let fun f ((t,v,e)::r) =
		      if eqType(ty,t)
		      then v
		      else f r
		  | f nil = (if !Control.debugging
			      then Control.Print.say "find-notfound\n"
			      else ();
			     raise Notfound)
	     in if !Control.debugging 
		then with_pp (ErrorMsg.defaultConsumer())
		      (fn ppstrm =>
		       (add_string ppstrm "find: ";
			PPType.resetPPType();
			PPType.ppType env ppstrm ty))
		else ();
		f (!cache)
	    end

        fun atomeq tyc =
 	  if equalTycon(tyc,intTycon) then PRIM(P.IEQL,inteqty)
	  else if equalTycon(tyc,wordTycon)  then PRIM(P.IEQL,inteqty)
	  else if equalTycon(tyc,word8Tycon)  then PRIM(P.IEQL,inteqty)
	  else if equalTycon(tyc,word32Tycon) then PRIM(P.IEQL,int32eqty)
 	  else if equalTycon(tyc,boolTycon) then PRIM(P.IEQL,booleqty) 
 	  else if equalTycon(tyc,refTycon) then PRIM(P.PTREQL,boxedeqty) 
 	  else if equalTycon(tyc,arrayTycon) then PRIM(P.PTREQL,boxedeqty)
 	  else if equalTycon(tyc,realTycon) then PRIM(P.FEQLd,realeqty)
 	  else if equalTycon(tyc,stringTycon) then stringequal
 	  else raise Poly
 
	fun test(ty,0) = raise Poly
	  | test(ty,depth) =
	(if !Control.debugging
	 then with_pp (ErrorMsg.defaultConsumer())
	       (fn ppstrm =>
		(add_string ppstrm "test: ";
		 PPType.resetPPType();
		 PPType.ppType env ppstrm ty))
	 else ();
	 case ty
	  of VARty(ref(INSTANTIATED t)) => test(t,depth)
	   | CONty(DEFtyc _, _) => test(reduceType ty,depth)
	   | CONty(GENtyc{kind=ref(FORMDEFtyc tyc),...},args) => 
	     test(CONty(tyc,args),depth)
	   | CONty(RECORDtyc _, tyl) =>
	      (find ty handle Notfound =>
	        let val v = mkLvar() and x=mkLvar() and y=mkLvar()
		    val (eqv, patch) = enter ty
		    fun loop(n,[ty]) = APP(test(ty,depth), RECORD[SELECT(n, VAR x),
					  	            SELECT(n, VAR y)])
	  	      | loop(n,ty::r) = SWITCH(loop(n,[ty]), boolsign,
			 	          [(DATAcon(trueDcon'), loop(n+1,r)),
					   (DATAcon(falseDcon'), falseLexp)],
				          NONE)
		      | loop(_,nil) = trueLexp
                    val lt = transTyLty(ty)
		 in patch := FN(v,LT.injRECORD[lt,lt],
                                APP(FN(x,LT.BOGUS,
                                       APP(FN(y,LT.BOGUS,loop(0,tyl)),
		        		   SELECT(1,VAR v))),
				    SELECT(0,VAR v)));
		    eqv
	        end)
	   | CONty(tyc as GENtyc{kind=ref(PRIMtyc),eq=ref YES,...}, tyl) =>
	       atomeq tyc
	   | CONty(GENtyc{kind=ref(ABStyc tyc),eq=ref NO,...}, tyl) =>
	       test(mkCONty(tyc,tyl),depth)
	       (* assume that an equality datatype has been converted
		  to an ABStyc in an abstype declaration *)
	   | CONty(tyc as GENtyc{kind=ref(DATAtyc 
                                      [DATACON{const=false,rep=REF,...}]),...},
		   tyl) =>
               atomeq tyc
           | CONty(tyc as GENtyc{kind=ref(DATAtyc 
                       [DATACON{const=false,rep=TRANSPARENT,
                                typ,orig,...}]), ...},
                   tyl) =>
	       (find ty handle Notfound =>
		  let val (eqv,patch) = enter ty
		      val v = mkLvar()
                      val a = (case orig of SOME x => x | NONE => typ)
                      val ty' = argType(a,tyl)
                      val lt = transTyLty(ty')
		   in patch := FN(v,LT.injRECORD[lt,lt],
                                  APP(test(ty',depth-1), VAR v));
		      eqv
		  end)
	   | CONty(tyc as GENtyc{kind=ref(DATAtyc dcons),...}, tyl) =>
	      (find ty
	       handle Notfound =>
	       let val v = mkLvar() and x=mkLvar() and y=mkLvar()
		   val (eqv, patch) = enter ty
		   fun inside (DATACON{const=true,...}) = trueLexp
                     | inside (c as DATACON{typ,const=false,orig,...}) =
                        let val a = (case orig of SOME z => z | NONE => typ)
                            val b = transTyLty(polyArgType(a))
                            val argt = argType(typ,tyl)
                            val header = unwrapOp(b,transTyLty(argt))
                         in APP(test(argt,depth-1),
			       RECORD[header(DECON(transDcon c,VAR x)),
                                      header(DECON(transDcon c,VAR y))])
                        end
                   val lt = transTyLty ty
                   val argty = LT.injRECORD[lt,lt]
                   val pty = LT.injARROW(argty,LT.inj LT.BOOL)
		   val body = 
		       case dcons
			 of [] => error "empty data types in equal.sml"
                          | [dcon] => inside dcon	
			  | DATACON{sign,...}::_ =>
			      let fun concase dcon =
				     let val dcon' = DATAcon(transDcon dcon)
				      in (dcon',
				          SWITCH(VAR y, sign, 
						 [(dcon', inside dcon)],
						 SOME(falseLexp)))
				     end
			       in SWITCH(VAR x,sign,map concase dcons,NONE)
			      end
                          
		   val body = SWITCH(APP(PRIM(P.PTREQL,pty), 
                                         RECORD[VAR x, VAR y]),
				      boolsign,
                                      [(DATAcon(trueDcon'), trueLexp),
                                       (DATAcon(falseDcon'), body)],
                                      NONE)
		in patch := FN(v,argty,APP(FN(x,LT.BOGUS,
                                                APP(FN(y,LT.BOGUS,body),
					        SELECT(1,VAR v))),
				           SELECT(0,VAR v)));
		   eqv
	       end)
	   | _ => raise Poly)

	val body = test(concreteType,10)
        val fl = !cache

     in case fl 
         of nil => body
          | _ => FIX(map (fn (_,VAR v,_) => v 
                           | _ => error "Equal #324") fl,
                     map (fn (ty,_,_) => let val lt = transTyLty(ty) 
                                          in LT.injARROW(LT.injRECORD[lt,lt],
							 LT.inj LT.BOOL)
                                         end) fl,
             	     map (fn (_,_,e) => !e) fl,
	             body)
    end
    handle Poly => polyequal

		
end (* structure Equal *)

