open Printf

module type CC_SIG = sig
  type cc
  type eqn
  val cc_new : Z.t -> cc
  val merge  : cc -> eqn -> unit
  val check_eq : cc -> Z.t -> Z.t -> bool
  val mk_defn : Z.t -> Z.t -> Z.t -> eqn
  val mk_eqn  : Z.t -> Z.t -> eqn
end

module Bare : CC_SIG = struct
  type cc = Congbare.cc
  type eqn = Congbare.eqn
  let cc_new = Congbare.cc_new
  let merge  = Congbare.merge
  let check_eq = Congbare.check_eq
  let mk_defn c a b = Congbare.Defn (c, a, b)
  let mk_eqn a b = Congbare.Eqn (a, b)
end

module Path : CC_SIG = struct
  type cc = Congpath.cc
  type eqn = Congpath.eqn
  let cc_new = Congpath.cc_new
  let merge  = Congpath.merge
  let check_eq cc a b = Congpath.check_eq cc a b
  let mk_defn c a b = Congpath.Defn (c, a, b)
  let mk_eqn a b = Congpath.Eqn (a, b)
end

type op =
  | Eq of int * int
  | Def of int * int * int   (* c = app a b *)

type check = { a:int; b:int; expect:bool }

type case = {
  name: string;
  size: int;
  ops: op list;
  checks: check list;
}

let read_file path =
  let ic = open_in path in
  let b = Buffer.create 1024 in
  (try while true do Buffer.add_string b (input_line ic); Buffer.add_char b '\n' done with End_of_file -> ());
  close_in ic; Buffer.contents b

let trim s =
  let is_space = function ' '| '\t' | '\r' | '\n' -> true | _ -> false in
  let a = ref 0 and b = ref (String.length s - 1) in
  while !a <= !b && is_space s.[!a] do incr a done;
  while !b >= !a && is_space s.[!b] do decr b done;
  if !a > !b then "" else String.sub s !a (!b - !a + 1)

let bool_of_string_exn s =
  match String.lowercase_ascii (trim s) with
  | "true" -> true | "false" -> false | x -> failwith ("bad bool: "^x)

let parse_case path : case =
  let lines = read_file path |> String.split_on_char '\n' in
  let rec take_until tag acc = function
    | [] -> (List.rev acc, [])
    | l::ls when trim l = tag -> (List.rev acc, ls)
    | l::ls -> take_until tag (l::acc) ls
  in
  let header, rest = take_until "--ops--" [] lines in
  let ops_lines, rest2 = take_until "--checks--" [] rest in
  let check_lines, _rest3 = take_until "--end--" [] rest2 in
  let tbl = Hashtbl.create 5 in
  List.iter (fun l -> match String.split_on_char ':' l with
    | [k;v] -> Hashtbl.replace tbl (trim k) (trim v)
    | _ -> ()) header;
  let name = try Hashtbl.find tbl "name" with Not_found -> Filename.basename path in
  let size = try int_of_string (Hashtbl.find tbl "size") with _ -> failwith ("missing/invalid size in "^path) in
  let parse_op l =
    let l = trim l in
    if l = "" || (String.length l > 0 && l.[0] = '#') then None else
    match String.split_on_char ' ' l |> List.filter ((<>) "") with
    | ["eq"; a; b] -> Some (Eq (int_of_string a, int_of_string b))
    | ["def"; c; a; b] -> Some (Def (int_of_string c, int_of_string a, int_of_string b))
    | _ -> failwith ("bad op line: "^l)
  in
  let ops = List.filter_map parse_op ops_lines in
  let parse_check l =
    let l = trim l in
    if l = "" || (String.length l > 0 && l.[0] = '#') then None else
    (* forms: "? a b = true" or "check a b true" *)
    match String.split_on_char ' ' l |> List.filter ((<>) "") with
    | ("?"|"check") :: a :: b :: rest ->
        let expect =
          (match rest with
           | ["="; b1] -> bool_of_string_exn b1
           | [b1] -> bool_of_string_exn b1
           | _ -> failwith ("bad check line: "^l))
        in Some { a = int_of_string a; b = int_of_string b; expect }
    | _ -> failwith ("bad check line: "^l)
  in
  let checks = List.filter_map parse_check check_lines in
  { name; size; ops; checks }

let load_cases base_dir : (string * case) list =
  let dir = Filename.concat base_dir "tests-cong" in
  Sys.readdir dir
  |> Array.to_list |> List.sort compare
  |> List.filter (fun f -> Filename.check_suffix f ".tcase")
  |> List.map (fun f -> let c = parse_case (Filename.concat dir f) in (c.name, c))

let as_z i = Z.of_int i

let exec_case (module C: CC_SIG) (c:case) : bool =
  let cc = C.cc_new (as_z c.size) in
  let z = as_z in
  (* Basic bounds validation to avoid UF array OOB from malformed tests. *)
  let max_id =
    let m = ref (-1) in
    let upd x = if x > !m then m := x in
    List.iter (function Eq(a,b)->upd a;upd b | Def(c,a,b)->upd c;upd a;upd b) c.ops;
    List.iter (fun {a;b;_} -> upd a; upd b) c.checks; !m in
  if max_id >= c.size then (
    eprintf "Error: test '%s' has size %d but references id %d.\n" c.name c.size max_id;
    exit 2
  );
  (* Single unified mapping for all domain elements: a..z, then a1,b1,... *)
  let const_name i =
    let alpha = "abcdefghijklmnopqrstuvwxyz" in
    let n = String.length alpha in
    let ch = alpha.[i mod n] in
    if i < n then String.make 1 ch else (String.make 1 ch) ^ string_of_int (i / n)
  in
  let fn_name id = const_name id in
  let apply_op = function
    | Eq (a,b) -> C.merge cc (C.mk_eqn (z a) (z b))
    | Def (c,a,b) -> C.merge cc (C.mk_defn (z c) (z a) (z b))
  in
  List.iter apply_op c.ops;
  let results = List.map (fun {a;b;expect} -> let got = C.check_eq cc (z a) (z b) in (a,b,expect,got)) c.checks in
  let all_ok = List.for_all (fun (_,_,e,g) -> e = g) results in
  printf "- %s\n  size: %d\n" c.name c.size;
  printf "  constraints:\n";
  let pr_eq a b = printf "    %s = %s\n" (const_name a) (const_name b) in
  let pr_def c a b = printf "    %s = %s(%s)\n" (const_name c) (fn_name a) (const_name b) in
  List.iter (function | Eq (a,b) -> pr_eq a b | Def (c,a,b) -> pr_def c a b) c.ops;
  (* Nested single-line conjunction view (drop intermediate defs not used as heads). *)
  let defs : (int, (int * int)) Hashtbl.t = Hashtbl.create 32 in
  let used_as_head : (int, unit) Hashtbl.t = Hashtbl.create 32 in
  let eqs = ref [] in
  List.iter (function
    | Def (c,a,b) -> Hashtbl.replace defs c (a,b); Hashtbl.replace used_as_head a ()
    | Eq (a,b) -> eqs := (a,b) :: !eqs
  ) c.ops;
  (* term normal form rendering without local type defs *)
  let memo_nf : (int, (int * string list)) Hashtbl.t = Hashtbl.create 64 in
  let rec render ?(depth=0) ?(seen=[]) (t:int) : string =
    if depth > 1000 || List.mem t seen then const_name t
    else
      let (h,args) = nf ~depth ~seen t in
      match args with
      | [] -> const_name h
      | _  -> Printf.sprintf "%s(%s)" (const_name h) (String.concat ", " args)
  and nf ?(depth=0) ?(seen=[]) (t:int) : (int * string list) =
    if depth > 1000 || List.mem t seen then (t, [])
    else match Hashtbl.find_opt memo_nf t with
    | Some r -> r
    | None ->
        let r = match Hashtbl.find_opt defs t with
        | None -> (t, [])
        | Some (f,b) ->
            let (h, argsf) = nf ~depth:(depth+1) ~seen:(t::seen) f in
            let arg_s = render ~depth:(depth+1) ~seen:(t::seen) b in
            (h, argsf @ [arg_s])
        in Hashtbl.replace memo_nf t r; r
  in
  (* Keep only defs whose LHS is not used as a head elsewhere. *)
  let to_show_defs =
    Hashtbl.fold (fun c _ acc -> if Option.is_none (Hashtbl.find_opt used_as_head c) then c::acc else acc) defs []
    |> List.sort compare
  in
  (* Keep eqs only for symbols that have no def on either side. *)
  let to_show_eqs =
    List.filter (fun (a,b) -> not (Hashtbl.mem defs a) && not (Hashtbl.mem defs b)) (List.rev !eqs)
  in
  let lit_strings =
    (List.map (fun c -> Printf.sprintf "%s = %s" (const_name c) (render c)) to_show_defs)
    @ (List.map (fun (a,b) -> Printf.sprintf "%s = %s" (const_name a) (const_name b)) to_show_eqs)
  in
  if lit_strings <> [] then printf "  formula: %s\n" (String.concat " ∧ " lit_strings);
  List.iter (fun (a,b,e,g) ->
    printf "  check: %s = %s  expect: %-5s  got: %-5s%s\n"
      (const_name a) (const_name b)
      (string_of_bool e) (string_of_bool g)
      (if e = g then "" else "  <-- mismatch")
  ) results;
  printf "  status: %s\n\n" (if all_ok then "PASS" else "FAIL");
  all_ok

let run_all (module C: CC_SIG) base_dir =
  let cases = load_cases base_dir in
  let (passed,total) = List.fold_left (fun (p,t) (_,c) -> let ok = exec_case (module C) c in ((if ok then p+1 else p), t+1)) (0,0) cases in
  printf "Summary: %d passed, %d failed, %d total\n" passed (total-passed) total

let run_one (module C: CC_SIG) base_dir name =
  let cases = load_cases base_dir in
  match List.assoc_opt name cases with
  | Some c -> ignore (exec_case (module C) c)
  | None ->
      (* Fallback: support legacy 'testNN' indexing by order. *)
      if String.length name >= 5 && String.sub name 0 4 = "test" then (
        match int_of_string_opt (String.sub name 4 (String.length name - 4)) with
        | Some idx ->
            let arr = Array.of_list (List.map snd cases) in
            if 0 <= idx && idx < Array.length arr then ignore (exec_case (module C) arr.(idx))
            else eprintf "Unknown test index '%s' (have %d cases)\n" name (Array.length arr)
        | None -> eprintf "Unknown test '%s'\n" name
      ) else eprintf "Unknown test '%s'\n" name

let () =
  let exe = Sys.argv.(0) in
  let build_dir = Filename.dirname exe in
  let base_dir = Filename.dirname build_dir in
  (* Which implementation? Identify by executable name: run_cong_bare vs run_cong_path *)
  let ends_with s suf =
    let ls = String.length s and lf = String.length suf in
    ls >= lf && String.sub s (ls - lf) lf = suf
  in
  let impl = if Filename.basename exe |> fun b -> ends_with b "_path" then `Path else `Bare in
  let module Impl = (val (match impl with `Bare -> (module Bare : CC_SIG) | `Path -> (module Path : CC_SIG)) ) in
  let args = Array.to_list Sys.argv |> List.tl in
  match args with
  | [] -> run_all (module Impl) base_dir
  | ["all"] -> run_all (module Impl) base_dir
  | [arg] ->
      if String.length arg >= 5 && String.sub arg 0 4 = "test" then (
        let suf = String.sub arg 4 (String.length arg - 4) in
        match int_of_string_opt suf with
        | Some n when n < 10 -> run_one (module Impl) base_dir (Printf.sprintf "test0%d" n)
        | _ -> run_one (module Impl) base_dir arg
      ) else if arg = "all" then run_all (module Impl) base_dir
      else run_one (module Impl) base_dir arg
  | arg :: _ ->
      if arg = "all" then run_all (module Impl) base_dir
      else if String.length arg >= 5 && String.sub arg 0 4 = "test" then (
        let suf = String.sub arg 4 (String.length arg - 4) in
        match int_of_string_opt suf with
        | Some n when n < 10 -> run_one (module Impl) base_dir (Printf.sprintf "test0%d" n)
        | _ -> run_one (module Impl) base_dir arg
      ) else run_one (module Impl) base_dir arg
