(* Proof checker for ordered logic
 * 15-836 Fall 2023
 * Frank Pfenning, based on:
 * Jeff Polakow, Linear Logic Programming with an Ordered Context
 * PPDP 2000, pp 68-79
 *)

signature ORDERED =
sig

datatype prop = Atom of string       (* P *)
              | Under of prop * prop (* A \ B *)
              | Over of prop * prop  (* B / A *)
              | Fuse of prop * prop  (* A . B *)
              | Twist of prop * prop (* A o B *)
              | With of prop * prop  (* A & B *)
              | One                  (* 1 *)
              | Plus of prop * prop  (* A + B *)

val pp_prop : prop -> string

type var = string

datatype proof = Id of var                    (* Id x *)
               | UnderR of bind               (* \R (x.M) *)
               | UnderL of var * proof * bind (* \L x M (y.P) *)
               | OverR of bind                (* /R (x.M) *)
               | OverL of var * proof * bind  (* /L x M (y.P) *)
               | FuseR of proof * proof       (* .R M N *)
               | FuseL of var * bind2         (* .L x (y.z.P) *)
               | TwistR of proof * proof      (* oR N M *)
               | TwistL of var * bind2        (* oL x (z.y.P) *)
               | WithR of proof * proof       (* &R M N *)
               | WithL1 of var * bind         (* &L1 x (y.P) *)
               | WithL2 of var * bind         (* &L2 x (z.P) *)
               | OneR                         (* 1R *)
               | OneL of var * proof          (* 1L x M *)
               | PlusR1 of proof              (* +R M *)
               | PlusR2 of proof              (* +R N *)
               | PlusL of var * bind * bind   (* +L x (y.P) (z.Q) *)
     and bind = Bind of var * proof           (* x.M *)
     and bind2 = Bind2 of var * var * proof   (* x.y.M *)

val pp_proof : proof -> string

datatype antecedent = Avail of var * prop
                    | Used of var * prop

val pp_ante : antecedent -> string

type state = antecedent list

type seq = state * proof * prop

val check : state -> proof -> prop -> state (* uses I/O states; may raise Fail *)
val entails : state -> proof -> prop -> bool (* may raise Fail *)

end (* signature ORDERED *)

structure Ordered :> ORDERED =
struct

datatype prop = Atom of string       (* P *)
              | Under of prop * prop (* A \ B *)
              | Over of prop * prop  (* B / A *)
              | Fuse of prop * prop  (* A . B *)
              | Twist of prop * prop (* A o B *)
              | With of prop * prop  (* A & B *)
              | One                  (* 1 *)
              | Plus of prop * prop  (* A + B *)

fun parens s = "(" ^ s ^ ")"
fun brackets s = "[" ^ s ^ "]"

fun pp_prop (Atom(P)) = P
  | pp_prop (Under(A,B)) = pp_pprop A ^ " \\ " ^ pp_pprop B
  | pp_prop (Over(A,B)) = pp_pprop A ^ " / " ^ pp_pprop B
  | pp_prop (Fuse(A,B)) = pp_pprop A ^ " . " ^ pp_pprop B
  | pp_prop (Twist(A,B)) = pp_pprop A ^ " o " ^ pp_pprop B
  | pp_prop (With(A,B)) = pp_pprop A ^ " & " ^ pp_pprop B
  | pp_prop (One) = "1"
  | pp_prop (Plus(A,B)) = pp_pprop A ^ " + " ^ pp_pprop B
and pp_pprop (P as Atom _) = pp_prop P
  | pp_pprop (P as One) = pp_prop P
  | pp_pprop A = parens (pp_prop A)

type var = string

datatype proof = Id of var                    (* Id x *)
               | UnderR of bind               (* \R (x.M) *)
               | UnderL of var * proof * bind (* \L x M (y.P) *)
               | OverR of bind                (* /R (x.M) *)
               | OverL of var * proof * bind  (* /L x M (y.P) *)
               | FuseR of proof * proof       (* .R M N *)
               | FuseL of var * bind2         (* .L x (y.z.P) *)
               | TwistR of proof * proof      (* oR N M *)
               | TwistL of var * bind2        (* oL x (z.y.P) *)
               | WithR of proof * proof       (* &R M N *)
               | WithL1 of var * bind         (* &L1 x (y.P) *)
               | WithL2 of var * bind         (* &L2 x (z.P) *)
               | OneR                         (* 1R *)
               | OneL of var * proof          (* 1L x M *)
               | PlusR1 of proof              (* +R M *)
               | PlusR2 of proof              (* +R N *)
               | PlusL of var * bind * bind   (* +L x (y.P) (z.Q) *)
     and bind = Bind of var * proof           (* x.M *)
     and bind2 = Bind2 of var * var * proof   (* x.y.M *)

fun pp_proof (Id(x)) = "Id " ^ x
  | pp_proof (UnderR(xM)) = "\\R " ^ pp_bind xM
  | pp_proof (UnderL(x,M,yP)) = "\\L " ^ x ^ " " ^ pp_pproof M ^ " " ^ pp_bind yP
  | pp_proof (OverR(xM)) = "/R " ^ pp_bind xM
  | pp_proof (OverL(x,M,yP)) = "/L " ^ x ^ " " ^ pp_pproof M ^ " " ^ pp_bind yP
  | pp_proof (FuseR(M,N)) = ".R " ^ pp_pproof M ^ " " ^ pp_pproof N
  | pp_proof (FuseL(x,yzP)) = ".L " ^ x ^ " " ^ pp_bind2 yzP
  | pp_proof (TwistR(M,N)) = "oR " ^ pp_pproof M ^ " " ^ pp_pproof N
  | pp_proof (TwistL(x,zyP)) = "oL " ^ x ^ " " ^ pp_bind2 zyP
  | pp_proof (WithR(M,N)) = "&R " ^ pp_pproof M ^ " " ^ pp_pproof N
  | pp_proof (WithL1(x,yP)) = "&L1 " ^ x ^ " " ^ pp_bind yP
  | pp_proof (WithL2(x,yP)) = "&L2 " ^ x ^ " " ^ pp_bind yP
  | pp_proof (OneR) = "1R"
  | pp_proof (OneL(x,P)) = "1L " ^ x ^ pp_pproof P
  | pp_proof (PlusR1(M)) = "+R1 " ^ pp_pproof M
  | pp_proof (PlusR2(M)) = "+R2 " ^ pp_pproof M
  | pp_proof (PlusL(x,yP,zQ)) = "+L " ^ x ^ " " ^ pp_bind yP ^ " " ^ pp_bind zQ
and pp_pproof M = parens (pp_proof M)

and pp_bind (Bind(x,M)) = parens (x ^ ". " ^ pp_proof M)
and pp_bind2 (Bind2(x,y,M)) = parens (x ^ "." ^ y ^ ". " ^ pp_proof M)

datatype antecedent = Avail of var * prop
                    | Used of var * prop

fun pp_ante (Avail(x,A)) = parens (x ^ " : " ^ pp_prop A)
  | pp_ante (Used(x,A)) = brackets (x ^ " : " ^ pp_prop A)

type state = antecedent list

type seq = state * proof * prop

fun eq (OmegaL:state) OmegaR = (OmegaL = OmegaR)

fun all_avail Omega = List.all (fn Avail _ => true | _ => false) Omega

fun all_used Omega = List.all (fn Used _ => true | _ => false) Omega

(* initial_avail' OmegaL (OmegaRL @ OmegaRR) = (OmegaL @ OmegaRL, OmegaRR)
 * where OmegaL @ OmegaRL = Avail*, OmegaRR = [Used _] @ _ or []
 *)
fun initial_avail' OmegaL (Avail(x,A)::OmegaR) =
    initial_avail' (OmegaL @ [Avail(x,A)]) OmegaR
  | initial_avail' OmegaL (Used(x,A)::OmegaR) = (OmegaL, Used(x,A)::OmegaR)
  | initial_avail' OmegaL nil = (OmegaL, nil)

(* initial_avail (OmegaL @ OmegaR) = (OmegaL, OmegaR)
 * where OmegaL = Avail*, OmegaR = [Used _] @ _ or []
 *)
fun initial_avail Omega = initial_avail' [] Omega

(* final_avail (OmegaL @ OmegaR) = (OmegaL, OmegaR)
 * where OmegaL = (_ @ [Used _]) or [] and OmegaR = Avail*
 *)
fun final_avail Omega =
    let
        val (OmegaRrev, OmegaLrev) = initial_avail (List.rev Omega)
    in
        (List.rev OmegaLrev, List.rev OmegaRrev)
    end

(* split_used_avail' OmegaL (OmegaRL @ OmegaRR) = (OmegaL @ OmegaRL, OmegaRR)
 * where OmegaL = Used*, OmegaRL = Used*, OmegaRR = Avail*
 * raises Fail if there are no such OmegaRL, OmegaRR
 *)
fun split_used_avail' OmegaL (Used(x,A)::OmegaR) =
    split_used_avail' (OmegaL @ [Used(x,A)]) OmegaR
  | split_used_avail' OmegaL (Avail(x,A)::OmegaR) =
    if all_avail OmegaR then (OmegaL, Avail(x,A)::OmegaR)
    else raise Fail "not consecutively consumed"
  | split_used_avail' OmegaL nil = (OmegaL, nil)

(* split_used_avail (OmegaL @ OmegaR) = (OmegaL, OmegaR)
 * where OmegaL = Used*, OmegaR = Avail*
 * raises Fail if there are not such OmegaL, OmegaR
 *)
fun split_used_avail Omega = split_used_avail' [] Omega

(* split_middle' Omega OmegaL (OmegaRL @ OmegaRR) = (OmegaL @ OmegaRL, OmegaRRR)
 * where |Omega| = |OmegaRL|,
 *)
fun split_middle' nil OmegaL OmegaR = (OmegaL, OmegaR)
  | split_middle' (_::Omega) OmegaL (ante::OmegaR) =
    split_middle' Omega (OmegaL @ [ante]) OmegaR

(* split_middle Omega (OmegaL @ OmegaR) = (OmegaL, OmegaR)
 * where |Omega| = |OmegaL|
 *)
fun split_middle Omega Omega' = split_middle' Omega [] Omega'

(* split' x OmegaL (OmegaRL @ [J(x,A)] @ OmegaRR) = (OmegaL @ OmegaRL, [J(x,A)], OmegaRR)
 * raises Fail if there is no such J(x,A)
 *)
fun split' x OmegaL (Avail(y,B)::OmegaR) =
    if x = y then (OmegaL, Avail(y,B), OmegaR)
    else split' x (OmegaL @ [Avail(y,B)]) OmegaR
  | split' x OmegaL (Used(y,B)::OmegaR) =
    if x = y then (OmegaL, Used(y,B), OmegaR)
    else split' x (OmegaL @ [Used(y,B)]) OmegaR
  | split' x OmegaL nil = raise Fail ("variable " ^ x ^ " unknown or out of scope")

(* split_avail x (OmegaL @ [Avail(x,A)] @ OmegaR) = (OmegaL, Avail(x,A), OmegaR)
 * raises Fail if there is not such Avail(x,A)
 * split x Omega assumes Omega = Avail*
 *)
fun split_avail x Omega =
    (case split' x [] Omega
      of (OmegaL, Avail(x,A), OmegaR) => (OmegaL, Avail(x,A), OmegaR)
       | (_, Used(x,A), _) => raise Match) (* internal error *)

(* split_used x (OmegaL @ [Used(x,A)] @ OmegaR) = (OmegaL, Used(x,A), OmegaR)
 * raises Fail if there is no such Used(x,A)
 *)
fun split_used x Omega =
    (case split' x [] Omega
      of (OmegaL, Used(x,A), OmegaR) => (OmegaL, Used(x,A), OmegaR)
       | (_, Avail(x,A), _) => raise Fail ("variable " ^ x ^ " unused"))

(* check_unavail x Omega = () if Omega contains no Avail(x,A)
 * raises Fail otherwise
 *)
fun check_unavail x (Avail(y,B)::Omega) =
    if x = y then raise Fail ("illegal shadowing of variable " ^ x)
    else check_unavail x Omega
  | check_unavail x (Used(y,B)::Omega) = check_unavail x Omega
  | check_unavail x nil = ()

(* diff Omega Omega' = Omega - Omega' see (Polakow 2000)
 * assume Omega = Avail*
 * not used in the code, just in assertions
 *)
fun diff (Avail(x,A)::Omega) (Avail(x',A')::Omega') = diff Omega Omega'
  | diff (Avail(x,A)::Omega) (Used(x',A')::Omega') =
    Avail(x,A)::diff Omega Omega'
  | diff nil nil = nil
  | diff _ _ = raise Match (* internal error *)

(* check Omega M A = Omega' if Omega - Omega' |- M : A
 * see above for Omega - Omega' = diff Omega Omega'
 * assumes Omega = Avail*
 * raises Fail if there is not such Omega'
 * raises Fail if there is variable shadowing
 *)
fun check Omega (Id(x)) A = 
    (case split_avail x Omega
      of (OmegaL, Avail(x,B), OmegaR) =>
         if A <> B then raise Fail ("mismatch on type of variable " ^ x)
         else (OmegaL @ [Used(x,B)] @ OmegaR))

  | check Omega (UnderR(Bind(x,M))) (Under(A,B)) =
    ( check_unavail x Omega
    ; case split_used x (check ([Avail(x,A)] @ Omega) M B)
       of (nil, Used(x,A), OmegaR') => OmegaR' )
  | check Omega (UnderR(Bind(x,M))) D = raise Fail ("term not of type _ \\ _")

  | check Omega (UnderL(x,M,Bind(y,P))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x,Under(A,B)), OmegaR) =>
         let val () = ( check_unavail y OmegaL ; check_unavail y OmegaR)
         in case split_used y (check (OmegaL @ [Avail(y,B)] @ OmegaR) P C) (* check P : C first *)
             of (OmegaL', Used(y,B), OmegaR') =>
                let val (OmegaLL', OmegaLR) = final_avail OmegaL' (* OmegaLR = Avail* *)
                    val OmegaLR' = check OmegaLR M A
                in OmegaLL' @ OmegaLR' @ [Used(x,Under(A,B))] @ OmegaR' end
         end
       | (_, Avail(x,D), _) => raise Fail ("variable " ^ x ^ " not of type _ \\ _"))

  | check Omega (OverR(Bind(x,M))) (Over(B,A)) =
    ( check_unavail x Omega
    ; case split_used x (check (Omega @ [Avail(x,A)]) M B)
        of (OmegaL', Used(x,A), nil) => OmegaL' )
  | check Omega (OverR(Bind(x,M))) D = raise Fail ("term not of type _ / _ ")

  | check Omega (OverL(x,M,Bind(y,P))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x,Over(B,A)), OmegaR) =>
         let val () = ( check_unavail y OmegaL ; check_unavail y OmegaR )
         in case split_used y (check (OmegaL @ [Avail(y,B)] @ OmegaR) P C) (* check P : C first *)
             of (OmegaL', Used(y,B), OmegaR') =>
                let
                    val (OmegaRL, OmegaRR') = initial_avail OmegaR' (* OmegaRL = Avail* *)
                    val OmegaRL' = check OmegaRL M A
                in OmegaL' @ [Used(x,Over(B,A))] @ OmegaRL' @ OmegaRR' end
         end)

  | check Omega (FuseR(M,N)) (Fuse(A,B)) =
    let val Omega' = check Omega M A
        val (OmegaL', OmegaR) = split_used_avail Omega' (* OmegaR = Avail* *)
        val OmegaR' = check OmegaR N B
    in OmegaL' @ OmegaR' end
  | check Omega (FuseR(M,N)) D = raise Fail ("term not of type _ * _ ")

  | check Omega (FuseL(x,Bind2(y,z,P))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x,Fuse(A,B)), OmegaR) =>
         let val () = ( check_unavail z OmegaL ; check_unavail z OmegaR )
             val () = ( check_unavail y OmegaL ; check_unavail y ([Avail(z,B)] @ OmegaR) )
         in case split_used y (check (OmegaL @ [Avail(y,A), Avail(z,B)] @ OmegaR) P C)
             of (OmegaL', Used(y,A), OmegaR') =>
                (case split_used z OmegaR'
                  of (nil, Used(z,B), OmegaRR') =>
                     OmegaL' @ [Used(x,Fuse(A,B))] @ OmegaRR')
         end
       | (_, Avail(x,D), _) => raise Fail ("variable " ^ x ^ " not of type _ . _"))

  | check Omega (TwistR(M,N)) (Twist(A,B)) =
    let val Omega' = check Omega N B                    (* check N : B first *)
        val (OmegaL', OmegaR) = split_used_avail Omega' (* OmegaR = Avail* *)
        val OmegaR' = check OmegaR M A
    in OmegaL' @ OmegaR' end
  | check Omega (TwistR(M,N)) D = raise Fail ("term not of type _ o _ ")

  | check Omega (TwistL(x,Bind2(z,y,P))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x,Twist(A,B)), OmegaR) =>
         let val () = ( check_unavail y OmegaL ; check_unavail y OmegaR )
             val () = ( check_unavail z OmegaL ; check_unavail z (Avail(y,A)::OmegaR) )
         in case split_used z (check (OmegaL @ [Avail(z,B), Avail(y,A)] @ OmegaR) P C)
             of (OmegaL', Used(z,B), OmegaR') =>
                (case split_used y OmegaR'
                  of (nil, Used(y,A), OmegaRR') =>
                     OmegaL' @ [Used(x,Twist(A,B))] @ OmegaRR')
         end
       | (_, Avail(x,D), _) => raise Fail ("variable " ^ x ^ " not of type _ o _"))

  | check Omega (WithR(M,N)) (With(A,B)) =
    let val Omega1 = check Omega M A
        val Omega2 = check Omega N B
        val () = if eq Omega1 Omega2 then ()
                 else raise Fail ("branches do not use the same variables")
    in Omega1 end
  | check Omega (WithR _) D = raise Fail ("term not of type _ & _")

  | check Omega (WithL1(x,Bind(y,P))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x, With(A,B)), OmegaR) =>
         let val () = ( check_unavail y OmegaL ; check_unavail y OmegaR )
         in case split_used y (check (OmegaL @ [Avail(y,A)] @ OmegaR) P C)
             of (OmegaL', Used(y,A), OmegaR') => OmegaL' @ [Used(x,With(A,B))] @ OmegaR'
         end
       | (_, Avail(x, D), _) => raise Fail ("variable " ^ x ^ " not of type _ & _"))

  | check Omega (WithL2(x,Bind(y,P))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x, With(A,B)), OmegaR) =>
         let val () = ( check_unavail y OmegaL ; check_unavail y OmegaR )
         in case split_used y (check (OmegaL @ [Avail(y,B)] @ OmegaR) P C)
             of (OmegaL', Used(y,B), OmegaR') => OmegaL' @ [Used(x,With(A,B))] @ OmegaR'
         end
       | (_, Avail(x, D), _) => raise Fail ("variable " ^ x ^ " not of type _ & _"))

  | check Omega (OneR) One = Omega
  | check Omega (OneR) D = raise Fail ("term not of type 1")

  | check Omega (OneL(x,P)) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x, One), OmegaR) =>
         (case split_middle OmegaL (check (OmegaL @ OmegaR) P C)
           of (OmegaL', OmegaR') => (OmegaL' @ [Used(x, One)] @ OmegaR'))
       | (_, Avail(x,_), _) => raise Fail ("variable " ^ x ^ " not of type 1"))

  | check Omega (PlusR1(M)) (Plus(A,B)) = check Omega M A
  | check Omega (PlusR1(M)) D = raise Fail ("term not of type _ + _")

  | check Omega (PlusR2(N)) (Plus(A,B)) = check Omega N B
  | check Omega (PlusR2(N)) D = raise Fail ("term not of type _ + _")

  | check Omega (PlusL(x,Bind(y,P),Bind(z,Q))) C =
    (case split_avail x Omega
      of (OmegaL, Avail(x, Plus(A,B)), OmegaR) =>
         let val () = ( check_unavail y OmegaL ; check_unavail y OmegaR )
             val () = ( check_unavail z OmegaL ; check_unavail z OmegaR )
         in case split_used y (check (OmegaL @ [Avail(y,A)] @ OmegaR) P C)
             of (OmegaL1', Used(y,A), OmegaR1') =>
                (case split_used z (check (OmegaL @ [Avail(z,B)] @ OmegaR) Q C)
                  of (OmegaL2', Used(z,B), OmegaR2') =>
                     if eq OmegaL1' OmegaL2' andalso eq OmegaR1' OmegaR2'
                     then (OmegaL1' @ [Used(x, Plus(A,B))] @ OmegaR1')
                     else raise Fail ("branches for " ^ x ^ " do not use the same variables"))
         end
       | (_, Avail(x, D), _) => raise Fail ("variable " ^ x ^ " not of type _ + _"))

fun entails Omega M A =
    let val ante = String.concatWith " " (List.map pp_ante Omega)
        val succ = pp_prop A
        val () = print (ante ^ " |- " ^ succ ^ "\n")
        val () = print (pp_proof M ^ "\n")
        val Omega' = check Omega M A
        val () = if all_used Omega' then ()
                 else raise Fail "not all variables used"
    in true end

end (* structure Ordered *)
