(* Proof checker for (purely) structural logic
 * 15-836 Fall 2023
 * Frank Pfenning
 *)

signature STRUCTURAL =
sig

datatype prop = Atom of string         (* P *)
              | Implies of prop * prop (* A -> B *)
              | And of prop * prop     (* A /\ B *)
              | True                   (* T *)
              | Or of prop * prop      (* A \/ B *)

val pp_prop : prop -> string

type var = string

datatype proof = Id of var                      (* Id x *)
               | ImpliesR of bind               (* ->R (x.M) *)
               | ImpliesL of var * proof * bind (* ->L x (y.P) *)
               | AndR of proof * proof          (* /\R M N *)
               | AndL of var * bind2            (* /\L x (y.z.P) *)
               | AndL1 of var * bind            (* /\L1 x (y.P) *)
               | AndL2 of var * bind            (* /\L2 x (y.P) *)
               | TrueR                          (* TR *)
               | TrueL of var * proof           (* TL x M *)
               | OrR1 of proof                  (* \/R1 M *)
               | OrR2 of proof                  (* \/R2 N *)
               | OrL of var * bind * bind       (* \/L x (y.P) (z.Q) *)
     and bind = Bind of var * proof             (* x.M *)
     and bind2 = Bind2 of var * var * proof     (* x.y.M *)

val pp_proof : proof -> string

datatype antecedent = Ante of var * prop

val pp_ante : antecedent -> string

type state = antecedent list

val check : state -> proof -> prop -> unit (* may raise Fail *)

val entails : state -> proof -> prop -> bool (* prints; may raise Fail *)

end (* signature STRUCTURAL *)

structure Structural :> STRUCTURAL =
struct

datatype prop = Atom of string         (* P *)
              | Implies of prop * prop (* A -> B *)
              | And of prop * prop     (* A /\ B *)
              | True                   (* T *)
              | Or of prop * prop      (* A \/ B *)

fun parens s = "(" ^ s ^ ")"

fun pp_prop (Atom(P)) = P
  | pp_prop (Implies(A,B)) = pp_pprop A ^ " -> " ^ pp_pprop B
  | pp_prop (And(A,B)) = pp_pprop A ^ " /\\ " ^ pp_pprop B
  | pp_prop (True) = "T"
  | pp_prop (Or(A,B)) = pp_pprop A ^ " \\/ " ^ pp_pprop B
and pp_pprop (P as Atom _) = pp_prop P
  | pp_pprop (P as True) = pp_prop P
  | pp_pprop A = parens (pp_prop A)

type var = string

datatype proof = Id of var                      (* Id x *)
               | ImpliesR of bind               (* ->R (x.M) *)
               | ImpliesL of var * proof * bind (* ->L x (y.P) *)
               | AndR of proof * proof          (* /\R M N *)
               | AndL of var * bind2            (* /\L x (y.z.P) *)
               | AndL1 of var * bind            (* /\L1 x (y.P) *)
               | AndL2 of var * bind            (* /\L2 x (y.P) *)
               | TrueR                          (* TR *)
               | TrueL of var * proof           (* TL x M *)
               | OrR1 of proof                  (* \/R1 M *)
               | OrR2 of proof                  (* \/R2 N *)
               | OrL of var * bind * bind       (* \/L x (y.P) (z.Q) *)
     and bind = Bind of var * proof             (* x.M *)
     and bind2 = Bind2 of var * var * proof     (* x.y.M *)

fun pp_proof (Id(x)) = "Id " ^ x
  | pp_proof (ImpliesR(xM)) = "->R " ^ pp_bind xM
  | pp_proof (ImpliesL(x,M,yP)) = "->L " ^ x ^ " " ^ pp_pproof M ^ " " ^ pp_bind yP
  | pp_proof (AndR(M,N)) = "/\\R " ^ pp_pproof M ^ " " ^ pp_pproof N
  | pp_proof (AndL(x,xyP)) = "/\\L " ^ x ^ " " ^ pp_bind2 xyP
  | pp_proof (AndL1(x,yP)) = "/\\L1 " ^ x ^ " " ^ pp_bind yP
  | pp_proof (AndL2(x,yP)) = "/\\L2 " ^ x ^ " " ^ pp_bind yP
  | pp_proof (TrueR) = "TR"
  | pp_proof (TrueL(x,P)) = "TL " ^ x ^ pp_pproof P
  | pp_proof (OrR1(M)) = "\\/R1 " ^ pp_pproof M
  | pp_proof (OrR2(M)) = "\\/R2 " ^ pp_pproof M
  | pp_proof (OrL(x,yP,zQ)) = "\\/L " ^ x ^ " " ^ pp_bind yP ^ " " ^ pp_bind zQ
and pp_pproof M = parens (pp_proof M)

and pp_bind (Bind(x,M)) = parens (x ^ ". " ^ pp_proof M)
and pp_bind2 (Bind2(x,y,M)) = parens (x ^ "." ^ y ^ ". " ^ pp_proof M)

datatype antecedent = Ante of var * prop

fun pp_ante (Ante(x,A)) = parens (x ^ " : " ^ pp_prop A)

type state = antecedent list

(* we write Gamma, x:A for adding x:A to Gamma *)
(* requires x not to be in Gamma *)

(* check_fresh Gamma x = () if Ante(x:B) not in Gamma *)
fun check_fresh (Ante(y,B)::Gamma) x =
    if x = y then raise Fail ("illegal shadowing of variable " ^ x)
    else check_fresh Gamma x
  | check_fresh nil x = ()

(* tpof Gamma x = A if x:A in Gamma *)
fun tpof (Ante(y,B)::Gamma) x = if x = y then B else tpof Gamma x
  | tpof nil x = raise Fail ("variable " ^ x ^ " unknown or out of scope")

(* assume Gamma (x:A) = Gamma, x:A *)
fun assume Gamma (Ante(x,A)) =
    ( check_fresh Gamma x
    ; Ante(x,A)::Gamma )

(* check Gamma M A = () if Gamma |- M : A
 * raises Fail otherwise
 *)
fun check Gamma (Id(x)) A = 
    if A <> tpof Gamma x then raise Fail ("mismatch on type of variable " ^ x)
    else ()

  | check Gamma (ImpliesR(Bind(x,M))) (Implies(A,B)) =
    check (assume Gamma (Ante(x,A))) M B
  | check Gamma (ImpliesR(Bind(x,M))) D = raise Fail ("term not of type _ -> _")

  | check Gamma (ImpliesL(x,M,Bind(y,P))) C =
    (case tpof Gamma x
      of Implies(A,B) =>
         ( check Gamma M A
         ; check (assume Gamma (Ante(y,B))) P C )
       | _ => raise Fail ("variable " ^ x ^ " not of type _ -> _"))

  | check Gamma (AndR(M,N)) (And(A,B)) =
    ( check Gamma M A
    ; check Gamma N B )
  | check Gamma (AndR(M,N)) D = raise Fail ("term not of type _ /\\ _ ")

  | check Gamma (AndL(x,Bind2(y,z,P))) C =
    (case tpof Gamma x
      of And(A,B) =>
         check (assume (assume Gamma (Ante(y,A))) (Ante(z,B))) P C
       | _ => raise Fail ("variable " ^ x ^ " not of type _ /\\ _"))

  | check Gamma (AndL1(x,Bind(y,P))) C =
    (case tpof Gamma x
      of And(A,B) => check (assume Gamma (Ante(y,A))) P C
       | _ => raise Fail ("variable " ^ x ^ " not of type _ /\\ _"))

  | check Gamma (AndL2(x,Bind(y,P))) C =
    (case tpof Gamma x
      of And(A,B) => check (assume Gamma (Ante(y,B))) P C
       | _ => raise Fail ("variable " ^ x ^ " not of type _ /\\ _"))

  | check Gamma (TrueR) True = ()
  | check Gamma (TrueR) D = raise Fail ("term not of type T")

  | check Gamma (TrueL(x,P)) C =
    (case tpof Gamma x
      of True => check Gamma P C
       | _ => raise Fail ("variable " ^ x ^ "not of type T"))

  | check Gamma (OrR1(M)) (Or(A,B)) = check Gamma M A
  | check Gamma (OrR1(M)) D = raise Fail ("term not of type _ \\/ _")

  | check Gamma (OrR2(N)) (Or(A,B)) = check Gamma N B
  | check Gamma (OrR2(N)) D = raise Fail ("term not of type _ \\/ _")
                                      
  | check Gamma (OrL(x,Bind(y,P),Bind(z,Q))) C =
    (case tpof Gamma x
      of Or(A,B) => ( check (assume Gamma (Ante(y,A))) P C
                    ; check (assume Gamma (Ante(z,B))) Q C )
       | _ => raise Fail ("variable " ^ x ^ " not of type _ \\/ _"))

fun entails Gamma M A =
    let val ante = String.concatWith ", " (List.map pp_ante Gamma)
        val succ = pp_prop A
        val () = print (ante ^ " |- " ^ succ ^ "\n")
        val () = print (pp_proof M ^ "\n")
        val () = check Gamma M A
    in true end

end (* structure Structural *)

