(* -------------------------------------------------------------*)
(*		Tools.ml: 					*)
(* ------------------------------------------------------------ *)


signature TOOLS =
  sig
  val is_schvar : term -> bool
  val dest_term : term -> string * term list
  val get_prop : term -> term
  val strip_conj_concl : thm -> thm list
  val mk_intr_rule : thm -> thm list
  val mk_intr_rules : thm list -> thm list
  val mk_elim_rule  : thm -> thm
  val mk_elim_rules  : thm list -> thm list 
  val get_def :
    {def: thm,elim: thm,intr: thm list,name: string} list -> string -> thm 
  val get_intr :
    {def: thm,elim: thm,intr: thm list,name: string} list -> string -> thm list
  val get_elim :
    {def: thm,elim: thm,intr: thm list,name: string} list -> string -> thm 
  val show : {def:thm,elim:thm,intr:thm list,name:string} list -> string -> unit
  val show_rules : {def:thm,elim:thm,intr:thm list,name:string} -> unit
  val extract : theory -> (string * string) list -> thm list
  val mk_drules : theory -> string -> 
       {def: thm,elim: thm,intr: thm list,name: string} 
  val merge_theory_list : theory list -> theory
  val strip_equal : thm -> thm
end;

functor ToolsFun () : TOOLS =
struct
local open Syntax 
in

(* -------------------------------------------------------------------- *)
(*	Discriminators for terms - used to make				*)
(*	sure that only safe resolution rules are used			*)
(* -------------------------------------------------------------------- *)

fun is_schvar (Var(_,_))  = true |  is_schvar _ = false;

(* -------------------------------------------------------------------- *)
(*									*)
(*		Destructor for terms:	 				*)
(*	dest_term (Const (s,_) $ a1 $ ... $ an) = (s, [a1, ..., an])	*)
(*									*)
(* -------------------------------------------------------------------- *)

local fun term_list (t $ t') = t' :: (term_list t) |
          term_list t = [t]
in
fun dest_term t = 
  let val (Const(s,_) :: rest) = rev (term_list t) in (s, rest) end 
end;

fun get_prop prem = let val ("Trueprop", [t]) =   
  dest_term (Logic.strip_assums_concl prem) in t end;

(* -------------------------------------------------------------------- *)
(*									*)
(*	ML functions for forwards proof					*)
(*									*)
(*	strip_exists:	  strips off all existential quantifiers	*)
(*									*)
(*	strip_conj:	  finds all assumptions of the form A & B	*)
(*			  and replaces them by [| A ; B |]		*)
(*									*)
(*	strip_disj:	  finds all assumptions of the form A | B	*)
(*			  and produces two new theorems (one which 	*)
(*			  assumes A, and one which assumes B).          *)
(*			  It always returns a list of theorems.		*)
(*									*)
(*      strip_conj_concl: breaks apart a conjunctive conclusion, and 	*)
(*			  returns a list of theorems 			*)
(*									*)
(*      strip_not_disj_concl: breaks apart a negated disjunction, and 	*)
(*			  returns a list of theorems 			*)
(*									*)
(*      strip_imp_concl:  strip implications from the conclusion	*)
(*									*)
(*      strip_not_concl:  simplifies a negated conclusion ~Q by 	*)
(*			  putting Q into the assumption list and	*)
(*			  making the conclusion anything.		*)
(*									*)
(* -------------------------------------------------------------------- *)

fun strip_exists_asm i th = strip_exists_asm i (exI RSN (i, th)) 
                        handle THM _ => th;

fun strip_exists th = let val i = ref 1;
                          val x = ref th 
                    in 
                       while !i <= length(prems_of (th)) do
                      (x := strip_exists_asm (!i) (!x);
                       i := !i+1); !x
                    end;

fun strip_conj_asm i th = conjI RSN (i, th) 
                        handle THM _ => th;

fun strip_conj th = let val i = ref 1;
                        val x = ref th 
                    in 
                       while !i <= length(prems_of (!x)) do
                      (x := strip_conj_asm (!i) (!x);
                       i := !i+1); !x
                    end;

(* this is an awful hack *)


fun strip_disj_asm i th = 
 let  fun resolution (state, i, rules) =
         biresolution false (map (fn rl => (false,rl)) rules) i state;
       val [th1, th2] = 
           Sequence.list_of_s (resolution (th,i,[disjI1,disjI2]))
 in (th1 ::  strip_disj_asm i th2)
 end
  handle _ => [th];
                   
fun strip_disj th = let val i = ref 1;
                        val x = ref [] 
                    in 
                       while !i <= length(prems_of (th)) do
                      (x := (!x) @ (strip_disj_asm (!i) th);
                       i := !i+1); (if null (!x) then [th] else (!x))
                    end;


fun strip_equal_asm i th = 
  let val x = (refl RSN (i,th))
  in (strip_equal_asm i x)
  end
 handle THM _ => th; 

(* not needed at the moment?  *)

fun strip_equal th = let val i = ref 1;
                         val x = ref th 
                    in 
                       while !i <= length(prems_of (th)) do
                      (x := strip_equal_asm (!i) (!x);
                       i := !i+1); !x
                    end;
                
val strip = (map ((* strip_equal o *)strip_conj o strip_exists)) o
                    strip_disj o strip_exists;

fun strip_conj_concl th = 
   (th RS conjunct1)::(strip_conj_concl (th RS conjunct2))
                          handle THM _ => [th];
                          
fun strip_imp_concl th = strip_imp_concl (th RS mp) handle THM _ => th;

fun strip_not_concl th =  (th RS notE) handle THM _ => th;


fun gen_sym_asm i th = [sym RSN (i,th), th] handle THM _ => [];

fun gen_sym th = let val i = ref 1;
                     val x = ref [] 
                 in 
                     while !i <= length(prems_of (th)) do
                    (x := (!x) @ (gen_sym_asm (!i) th);
                     i := !i+1); !x
                 end;
 

(* -------------------------------------------------------------------- *)
(*		Introduction Rules					*)
(* -------------------------------------------------------------------- *)

fun mk_intr_rule def = 
     strip ((rewrite_rule [iff_def] def) RS conjunct1 RS mp);

fun mk_intr_rules thmlist = flat (map mk_intr_rule thmlist);

(* -------------------------------------------------------------------- *)
(*		Elimination Rules					*)
(* -------------------------------------------------------------------- *)

val ex2E = prove_goal thy
 "[| EX x y. P(x,y); !! x y. P(x,y) ==> R |] ==> R"
 (fn prems =>
  [ (cut_facts_tac prems 1),
    (REPEAT (eresolve_tac [exE] 1 ORELSE ares_tac prems 1)) ]);


fun mk_elim_rule def = 
   let val th = ((rewrite_rule [iff_def] def) RS conjunct2 RS mp)
   in
      th RS exE
   handle _ => th RS ex2E 
   handle _ => th
end;

fun mk_elim_rules thmlist = map mk_elim_rule thmlist;


(* -------------------------------------------------------------------- *)
(*		Selectors and Printing Functions			*)
(* -------------------------------------------------------------------- *)

fun mk_drules thy s = let val def = get_axiom thy s;
              val intr = mk_intr_rule def;
              val elim = mk_elim_rule def
          in {name = s, def = def, intr = intr, elim = elim }
          end;

fun extract thy x = map (get_axiom thy) (map fst x); 

fun get_def  
({name = s',intr = intr:thm list,elim = elim:thm,def = def:thm} :: reclist) 
 (s:string) = 
   if s=s' then def else get_def reclist s; 

fun get_intr
({name = s',intr = intr:thm list,elim = elim:thm,def = def:thm} :: reclist)  (s:string) = 
   if s=s' then intr else get_intr reclist s; 

fun get_elim  
({name = s',intr = intr:thm list,elim = elim:thm,def = def:thm} :: reclist)  (s:string) = 
   if s=s' then elim else get_elim reclist s; 

fun show
  ({name = s', intr = intr, elim = elim, def = def} :: reclist) s = 
    if s = s' then
    (print "-----------------------------------------------------\n";
     print "\nRule Name:  "; 
     print (s:string);
     print "\n\nDefinition:\n"; 
     prth def; 
     print "\nIntroduction Rules:\n"; 
     prths intr; 
     print "\nElimination Rules:\n"; 
     prth elim;
     print "-----------------------------------------------------\n\n\n")
    else
    show reclist s; 
 
fun show_rules {name = s, def = def, intr = intr, elim = elim} = 
    (print "-----------------------------------------------------\n";
     print "\nRule Name:  "; 
     print (s:string);
     print "\n\nDefinition:\n"; 
     prth def; 
     print "\nIntroduction Rules:\n"; 
     prths intr; 
     print "\nElimination Rules:\n"; 
     prth elim;
     print "-----------------------------------------------------\n\n\n");


fun merge_theory_list [th] = th |
    merge_theory_list [th1, th2] = merge_theories (th1, th2) |
    merge_theory_list (th :: thlist) = 
       merge_theories (th, merge_theory_list thlist);

end;
end;
