functor TransSpecialFun(structure TransPrep: TRANS_PREP): TRANS_SPECIAL =

  struct

    structure TransPrem = TransPrep
    structure TransCommon = TransPrep.TransCommon

    open TransCommon
    open TransPrep

    open Hash
    open Pos
    open Evaluate
    open Abstract
    open ParserDefault
    open ParseTreeStruct
    open Interface
    open Options

    (* Check if an expression involves temporal operators *)
    fun hasTemporal e =
	let exception Quit
	    fun recur t = ptTransform loop t
	    and loop t =
		if isTemporal t then raise Quit
		else recur t
	in
	    (loop e; false) handle Quit => true
	end

    (* Customize an AsstVarsTree to a particular property under a
       particular abstraction *)
    fun customizeAsstVars0 options findObject prims cone limit specList abs atree =
	let val funName = "customizeAsstVars0"
	    val _ = pushFunStackLazy(funName,
				     fn()=>("specList=["^(ptlist2str ", " specList)^"],\n"
					    ^"Assignment Tree = {"^(avt2str atree)^"}"))
	    val evalExpr = evaluateExpr options findObject
	    val varsFrom = varsFrom options findObject
	    val debug = lazyVerbDebug options funName
	    fun listDiff lst1 lst2 = 
	          List.filter(fn x=>not(List.exists(fn y=>vtEq(x,y)) lst2)) lst1
	    fun doList {norm=norm, next=next, init=init} (h::t) =
	         let val {norm=hnorm, next=hnext, init=hinit} = getAsstVars h
		     val newvars = {norm=listDiff norm hnorm,
				    next=listDiff next hnext,
				    init=listDiff init hinit}
		 in doList newvars t
		 end
	      | doList vars [] = vars
            fun looppairs vars lst = List.map(fn(e,t)=>(e, loop vars t)) lst
	    and loop vars (t as (ListAsstTree(_,lst))) =
		 let val newList = List.map(fn t=>loop(getAsstVars t) t) lst
	         in (case doList vars lst of
			 {norm=[],next=[],init=[]} => ListAsstTree(vars, newList)
		       | remains =>
			     ListAsstTree(vars, (NopAsstTree remains)::newList))
		 end
	      | loop vars (LetAsstTree(_, defs, t)) = LetAsstTree(vars, defs, loop vars t)
	      | loop vars (CaseAsstTree(_, sel, lst)) =
		 CaseAsstTree(vars, sel, looppairs vars lst)
	      | loop vars (IfAsstTree(_, lst, last)) =
		 IfAsstTree(vars, looppairs vars lst, loop vars last)
	      | loop vars (ChooseAsstTree(_, paramsOpt, lst)) =
		 ChooseAsstTree(vars, paramsOpt, looppairs vars lst)
	      | loop vars (ForeachAsstTree(_, params, t)) =
		 ForeachAsstTree(vars, params, loop vars t)
	      | loop vars (LabeledAsstTree(_, label, t)) =
		 LabeledAsstTree(vars, label, loop vars t)
	      | loop vars (NopAsstTree _) = NopAsstTree vars
	      (* Assume that other atomic assignments don't need adjustment *)
	      | loop _ t = t
	    fun isDelayed (DelayedVar x) = true
	      | isDelayed _ = false
	    val specVars = List.map cvValue (unionC (List.map (varsFrom prims limit) specList))
	    val isSpecTemporal = List.exists hasTemporal specList
	    val _ = debug(fn()=>"customizeAsstVars: isSpecTemporal = "
			  ^(if isSpecTemporal then "true" else "false")
			  ^", specVars = ["^(ptlist2str ", " specVars)^"]\n")
	    fun coneFromVar v = 
		let val funName = "coneFromVar"
		    val _ = pushFunStackLazy(funName, fn()=> pt2string v)
		    val vType = getExprType findObject v
		    fun loop x =
			let val funName = "coneFromVar/loop"
			    val _ = pushFunStackLazy(funName, fn()=>pt2string x)
			    val res = loop' x
			    val _ = popFunStackLazy(funName,
						    fn()=>"["^(cvlist2str ", " res)^"]")
			in res
			end
		    and loop' (TupleType (_, tlist)) =
			let val items = List.map (fn n => evalExpr(ExtractTuple (n, v)))
					(List.tabulate (List.length tlist, fn i => i))
			in
			    unionC (List.map coneFromVar items)
			end
		      | loop' (RecordType (_, rlist)) =
			let fun makeParts(RecordField{name=n, ...}) = 
			         evalExpr (ExtractRecord (n, v))
			      | makeParts x = raise SympBug
				 ("trans_special.sml: RecordType contains a non-RecordField: "
				  ^(pt2string x))
			    val items = List.map makeParts rlist
			in
			    unionC (List.map coneFromVar items)
			end
		      | loop' (t as FunType (_, t1, t2)) =
			let val abst = getTypeAbstraction options abs t1
			    val dlist =
				(case abst of
				     SOME t1' => getTypeValues options limit (getAbsType t1')
				   | NONE => getTypeValues options limit t1)
			in
			    (case dlist of
				 SOME ds =>
				     unionC (List.map coneFromVar
					     (List.map (fn d => evalExpr(Appl (dp, v, d))) ds))
			       | NONE => raise ProverError
				     ("Infinite type in the function range / array index,\n"
				      ^"cannot enumerate all elements:\n   "
				      ^(pt2string t)))
			end
		      (* Exactly the same as FunType *)
		      | loop' (t as ArrayType (_, t1, t2)) =
			let val abst = getTypeAbstraction options abs t1
			    val dlist =
				(case abst of
				     SOME t1' => getTypeValues options limit (getAbsType t1')
				   | NONE => getTypeValues options limit t1)
			in
			    (case dlist of
				 SOME ds =>
				     unionC (List.map coneFromVar
					     (List.map (fn d => evalExpr(Appl (dp, v, d))) ds))
			       | NONE => raise ProverError
				     ("Infinite type in the function range / array index,\n"
				      ^"cannot enumerate all elements:\n   "
				      ^(pt2string t)))
			end
		      | loop' (EnumType (_, clist)) =
			let fun extractor (a as TypeConstr{Type=tp,...}) =
			         (case tp of
				      FunType _ => SOME(fn x=> ExtractAppl(a, x))
				    | _ => NONE)
			      | extractor a = raise SympBug
				 ("trans_special EnumType: not a TypeConstr:\n  "
				  ^(pt2string a))
			    val extracted =
				List.mapPartial (fn n => (Option.map(fn extr =>
								     evalExpr (extr v))
							  (extractor n))
						 handle EvalError _ => NONE)
				clist
			    val innerX = 
			      (case getPrimitiveIndexVar prims v of
				   SOME x => x
				 | NONE => raise SympBug 
				      ("trans_special:  EnumType index should "
				       ^"already be in the cone.")
				     (*
				      let val un = Id(dp,newName())
					val x = StateVar {name = un, uname = un, 
					  Type = AbstractType clist,
					  id = un}
				      in 
					addPrimitiveIndexVarDestructive prims v
					x;
					x
				      end
				      *)
				     )
			in 
			    unionC (List.map coneFromVar (innerX :: extracted))
			end
		      | loop' (TypeInst (_, args, TypeClosure parms)) =
			let val {name = name, 
				 uname = uname,
				 params = params, 
				 def = def, 
				 recursive = recursive,
				 parent = parent} = parms
			    val nt = TypeClosure {name = name, uname = uname, 
						  params = [],
						  def = instantiateType 
						  (args, params) def,
						  recursive = recursive,
						  parent = parent}
			in
			    loop nt
			end
		      | loop' (TypeClosure {def = newt, recursive = r,...}) =
			  if r then raise SympBug 
			      "trans_special: Recursive types not yet supported"
			  else loop newt
		      | loop' (vt as Uid _) =
			  (case findObject vt of
			       SOME newt => loop newt
			     | NONE => raise SympBug 
				 "trans_special:  Uid not defined")
		      | loop' (StaticFormalType{value=SOME t, ...}) = loop t
		      (* The var is of a primitive type.  Include it into its own cone
		         along with what `getVarCone' returns. *)
		      | loop' t =
			let val abst = getTypeAbstraction options abs t
			    fun ff f = 
				let val vt = vtWrap f v
				    val getCone = 
					if isSpecTemporal then 
					    getVarCone options findObject prims limit cone
					else getVarConeNoNext options findObject prims limit cone
				in 
				    unionC[[PVar v], getCone vt]
				end
			    val flags = if isSpecTemporal then [NormalFlag, InitFlag,NextFlag]
					else [NormalFlag, InitFlag]
			    val vCone = unionC (List.map ff flags)
			    fun doAbst(SOME t') =
				let val (delayed, nonDelayed) = List.partition isDelayed vCone
				    val forceDelayed = unionC 
					  (List.map ((varsFrom prims limit) o cvValue) delayed)
				    val resultForced = unionC 
					  (List.map (coneFromVar o cvValue) forceDelayed)
				in
				    unionC [nonDelayed, resultForced]
				end
			      | doAbst NONE = vCone
			in
			    doAbst abst
			end
		    val res = loop vType
		    val _ = popFunStackLazy(funName, fn()=>"["^(cvlist2str ", " res)^"]")
		  in
		      res
		  end (* coneFromVar *)

	    val unionPt = union ptEq
	    val varsSpec = unionPt(List.map ((List.map cvValue) o coneFromVar) specVars)
	    val _ = debug(fn()=>"\ncustomizeAsstVars: varsSpec = ["
			  ^(ptlist2strDebug ", " varsSpec)^"]\n")
	    (* Insert variables from `vars' in place of relevant AsstVars vars *)
	    fun filterRelevant vars ({norm = norm, next = next, init = init}: AsstVars) =
		let fun doVar (vts, flag) v =
		        if List.exists (fn x => svPartOf (v, vtName x)) vts then
			      SOME(vtWrap flag v)
			else NONE
		    fun doVars(vts, flag) = List.mapPartial (doVar(vts, flag)) vars
		in
		    {norm = doVars(norm, NormalFlag),
		     next = if isSpecTemporal then doVars(next, NextFlag) else [],
		     init = doVars(init, InitFlag)}
		end
	    fun isEmptyVars {norm = [], next = [], init = []} = true
	      | isEmptyVars _ = false
	    fun nonEmpty (vars: ParseTree list) lst = List.mapPartial(trim vars) lst
	    and nonEmptyNop vars t =
		(case trim vars t of
		     SOME r => r
		   | NONE => NopAsstTree{norm = [], next = [], init = []})
	    and nonEmptyPairsNop (vars: ParseTree list) lst =
	      let fun ff (e,t) = (e, nonEmptyNop vars t)
	      in
		  List.map ff lst
	      end
	    and nonEmptyPairs (vars: ParseTree list) lst =
	      let fun ff(e, t) = Option.map(fn r => (e, r))(trim vars t)
	      in
		  List.mapPartial ff lst
	      end
	    and trim vars tree = 
		let fun filterU U =
		        let val V = filterRelevant vars U
			in
			    if (isEmptyVars V) then NONE
			    else SOME V
			end
		    (* Move the normal vars to the init vars *)
		    fun norm2init{norm=norm, init=init, next=next} =
			let val names = unionPt [List.map vtName norm, List.map vtName init]
			    val newInit = List.map (vtWrap InitFlag) names
			in
			    {norm=[], init=newInit, next=next}
			end
		    val doPairs = nonEmptyPairs vars
		    fun loop (NormalAsstTree(U, name, e)) =
			  Option.map(fn V=> 
				       if isSpecTemporal then
					   NormalAsstTree(V, name, e)
				       else InitAsstTree(norm2init V, name, e)) (filterU U)
		      | loop (NextAsstTree(U, name, e)) =
			  if isSpecTemporal then 
			      Option.map(fn V=> (NextAsstTree(V, name, e))) (filterU U)
			  else NONE
		      | loop (InitAsstTree(U, name, e)) =
			  Option.map(fn V=> (InitAsstTree(V, name, e))) (filterU U)
		      | loop (NopAsstTree U) =
			  Option.map(fn V=> (NopAsstTree V)) (filterU U)
		      | loop (ListAsstTree(U, atreel)) =
			  Option.map(fn V=> (ListAsstTree(V, nonEmpty vars atreel))) (filterU U)
		      | loop (CaseAsstTree(U, sel, atreel)) =
			  let fun ff V =
			      let val pairs = nonEmptyPairsNop vars atreel
			      in CaseAsstTree(V, sel, pairs)
			      end
			  in
			      Option.map ff (filterU U)
			  end
		      | loop (IfAsstTree(U, atreel, last)) =
			  let fun ff V = IfAsstTree(V, doPairs atreel, nonEmptyNop vars last)
			  in
			      Option.map ff (filterU U)
			  end
		      | loop (ChooseAsstTree(U, paramsOpt, atreel)) =
			  Option.map(fn V=> (ChooseAsstTree(V, paramsOpt, doPairs atreel)))
			            (filterU U)
		      | loop (ForeachAsstTree(U, params, atree)) =
			  Option.map(fn V=> 
				     (ForeachAsstTree(V, params,
						      nonEmptyNop vars atree))) (filterU U)
		      | loop (LabeledAsstTree(U, label, atree)) =
			  Option.map(fn V=> 
				     (LabeledAsstTree(V, label,
						  nonEmptyNop vars atree))) (filterU U)
		      | loop (LetAsstTree(U, defs, atree)) =
			  Option.map(fn V=> 
				     (LetAsstTree(V, defs,
						  nonEmptyNop vars atree))) (filterU U)
		in
		    loop tree
		end (* trim *)

	    val trimmed = (case trim varsSpec atree of
			       SOME tree => tree
			     | NONE => NopAsstTree {norm=[], next=[], init=[]})
	    val opt = if isSpecTemporal then
		        rebuildAsstOptionsBalance rebuildAsstOptionsDefault
		      else rebuildAsstOptionsDefault
	    val (newTree, cone, pvars) = 
		  rebuildAsstVars options findObject prims limit (NONE, trimmed, opt)
	    val _ = debug(fn()=>"\ncustomizeAsstVars: trimmed tree =\n"
			  ^(avt2str trimmed))
	    val result = loop (getAsstVars trimmed) trimmed
	    val _ = popFunStackLazy(funName,
				    fn()=>("varsSpec = ["^(ptlist2str ", " varsSpec)^"],\n"
					   ^"Assignment Tree = {"^(avt2str result)^"}"))
	in 
	  (varsSpec, cone, pvars, result)
	end

    fun customizeAsstVars options findObject prims cone limit specList abs atree =
	let val (_, _, _, res) = (customizeAsstVars0 options findObject prims
			    cone limit specList abs atree)
	in
	    res
	end

    (* Perform COI reduction on the model, along with other necessary
       required transformations before generating the transition
       relation *)
    fun specializeTransRel options findObject abs trans specs =
	let val { limit=lim,...} = options
	    val stateVars = ref[]
	    val pvarsGlobal = ref(makePrimitiveVars())
	    fun loop (TransAtomic(AtomicModel m)) =
		let val { name=name,
			  uname=uname,
			  assts=assts,
			  cone=cone,
			  pvars=pvars,
			  absModules=absModules } = m
		    val (vars, newCone, newPvars, newAssts) = 
			  customizeAsstVars0 options findObject pvars cone lim specs abs assts
		    val _ = stateVars := union ptEq [vars, !stateVars]
		    val _ = pvarsGlobal := mergePrimitiveVars(!pvarsGlobal, newPvars)
		in
		    TransAtomic(AtomicModel{ name=name,
					     uname=uname,
					     assts=newAssts,
					     cone=newCone,
					     pvars=newPvars,
					     absModules=absModules })
		end
	      | loop (TransSync2(t1, t2)) = TransSync2(loop t1, loop t2)
	      | loop (TransAsync2(t1, t2)) = TransAsync2(loop t1, loop t2)
	      | loop _ = raise SympBug
		  ("TransSpecial/specializeTransRel:\n  "
		   ^"sorry, closed form parallel composition is not implemented yet")
	    val newTrans = loop trans
	in
	    (!stateVars, !pvarsGlobal, newTrans)
	end

    fun specializeModel options (model: Model) specs =
	let val { trans=trans,
		  findObject=findObject,
		  stateVars=stateVars,
		  abs=abs, ... } = model
	    val (stateVars, pvars, newTrans) =
		   specializeTransRel options findObject abs trans specs
	in 
	    { trans=newTrans,
	      findObject=findObject,
	      stateVars=stateVars,
	      pvars=pvars,
              abs=abs }: Model
	end
	

  end
