#open "prelude";;
#open "terms";;
#open "equations";;

(****************** Critical pairs *********************)

(* All (u,sig) such that N/u (&var) unifies with M,
   with principal unifier sig *)

let super M = suprec where rec suprec = function
    Term(_,sons) as N ->
      let collate (pairs,n) son =
       (pairs @ map (fun (u,sig) -> (n::u,sig)) (suprec son), n+1) in
      let (insides,_) = it_list collate ([],1) sons in
       (try
          ([], unify M N) :: insides
        with Failure _ -> insides)
  | _ -> []
;;

(* All (u,sig), u&[], such that N/u unifies with M *)

let super_strict M = function
      Term(_,sons) ->
        let collate (pairs,n) son =
          (pairs @ map (fun (u,sig) -> (n::u,sig)) (super M son), n+1) in
        fst (it_list collate ([],1) sons)
    | _ -> []
;;

(* Critical pairs of L1=R1 with L2=R2 *)
(* critical_pairs : term_pair -> term_pair -> term_pair list *)
let critical_pairs (Equation(L1,R1)) (Equation(L2,R2)) =
  let mk_pair (u,sig) =
     Equation(substitute sig (replace L2 u R1), substitute sig R2) in
  map mk_pair (super L1 L2);;

(* Strict critical pairs of L1=R1 with L2=R2 *)
(* strict_critical_pairs : term_pair -> term_pair -> term_pair list *)
let strict_critical_pairs (Equation(L1,R1)) (Equation(L2,R2)) =
  let mk_pair (u,sig) =
    Equation(substitute sig (replace L2 u R1), substitute sig R2) in
  map mk_pair (super_strict L1 L2)
;;

(* All critical pairs of eq1 with eq2 *)
let mutual_critical_pairs eq1 eq2 =
  (strict_critical_pairs eq1 eq2) @ (critical_pairs eq2 eq1);;

(* Renaming of variables *)

let rename n (Equation(t1,t2)) =
  let rec ren_rec = function
    Var k -> Var(k+n)
  | Term(op,sons) -> Term(op, map ren_rec sons) in
  Equation(ren_rec t1, ren_rec t2)
;;

(************************ Completion ******************************)

let deletion_message (Rule(k,_,_,_)) =
  print_string "Rule ";print_int k; message " deleted"
;;

(* Generate failure message *)
let non_orientable (Equation(M,N)) =
    pretty_term M; print_string " = "; pretty_term N; print_newline()
;;

let fetch_rule k =
  fetch_rec where rec fetch_rec = function
    [] -> (* print_string "fetch_rec"; print_newline (); *)
      failwith "rule deleted"
  | Rule(n,v,L,R)::rest ->
      if k==n then v, Equation(L,R) else fetch_rec rest
;;

(* Improved Knuth-Bendix completion procedure *)

let kb_completion greater = kbrec where rec kbrec n rules =
  let normal_form = mrewrite_all rules in
  process where rec process failures k l eqs =
(*******
     print_string "***kb_completion "; print_int n; print_newline();
     pretty_rules rules;
     do_list non_orientable failures;
     print_int k; print_string " "; print_int l; print_newline();
     do_list non_orientable eqs;
********)
     match eqs with
     [] ->
      if k<l then next_criticals failures (k+1) l else
      if l<n then next_criticals failures 1 (l+1) else
       (match failures with
          [] -> rules (* successful completion *)
        | _  -> message "Non-orientable equations :";
                do_list non_orientable failures;
                failwith "kb_completion")
   | Equation(M,N)::eqs ->
      let M' = normal_form M
      and N' = normal_form N
      and enter_rule left right =
        let new_rule = mk_rule (n+1) left right in
          pretty_rule new_rule;
          let left_reducible (Rule(_,_,L,_)) = reducible left L in
          let redl,irredl = partition left_reducible rules in
            do_list deletion_message redl;
            let irreds = (map right_reduce irredl
                 where right_reduce (Rule(m,_,L,R)) =
                       mk_rule m L (mrewrite_all (new_rule::rules) R))
            and eqs' = map (fun (Rule(_,_,L,R)) -> Equation(L,R)) redl in
              kbrec (n+1) (new_rule::irreds) [] k l (eqs @ eqs' @ failures) in
      if eq_term M' N' then process failures k l eqs else
      if greater M' N' then enter_rule M' N' else
      if greater N' M' then enter_rule N' M' else
        process (Equation(M',N')::failures) k l eqs
  and next_criticals failures k l =
(*****
    print_string "***next_criticals ";
    print_int k; print_string " "; print_int l ; print_newline();
*****)
    try
      let (v, el) = fetch_rule l rules in
        if k==l then
          process failures k l (strict_critical_pairs el (rename v el))
        else
          try
            let (_, ek) = fetch_rule k rules in
              process failures k l (mutual_critical_pairs el (rename v ek))
	  with Failure "rule deleted" ->
            next_criticals failures (k+1) l
    with Failure "rule deleted" ->
      next_criticals failures 1 (l+1)
;;

(* complete_rules is assumed locally confluent, and checked Noetherian with
  ordering greater, rules is any list of rules *)

let kb_complete greater complete_rules rules =
    let n = check_rules complete_rules
    and eqs = map (fun (Rule(_,_,L,R)) -> Equation(L,R)) rules in
    let completed_rules =
      kb_completion greater n complete_rules [] n n eqs in
    message "Canonical set found :";
    pretty_rules (rev completed_rules);()
;;


