(* Proof checker for (purely) linear logic
 * 15-836 Fall 2023
 * Frank Pfenning, based on
 * Jeff Polakow, Linear Logic Programming with an Ordered Context
 * PPDP 200, pp68-79
 *)

signature LINEAR =
sig

datatype prop = Atom of string        (* P *)
              | Lolli of prop * prop  (* A -o B *)
              | Tensor of prop * prop (* A * B *)
              | With of prop * prop   (* A & B *)
              | One                   (* 1 *)
              | Plus of prop * prop   (* A + B *)

val pp_prop : prop -> string

type var = string

datatype proof = Id of var                    (* Id x *)
               | LolliR of bind               (* -oR (x.M) *)
               | LolliL of var * proof * bind (* -oL x M (y.P) *)
               | TensorR of proof * proof     (* *R M N *)
               | TensorL of var * bind2       (* *L x (y.z.P) *)
               | WithR of proof * proof       (* &R M N *)
               | WithL1 of var * bind         (* &L1 M (y.P) *)
               | WithL2 of var * bind         (* &L2 M (z.P) *)
               | OneR                         (* 1R *)
               | OneL of var * proof          (* 1L x M *)
               | PlusR1 of proof              (* +R1 M *)
               | PlusR2 of proof              (* +R2 N *)
               | PlusL of var * bind * bind   (* +L x (y.P) (z.Q) *)
     and bind = Bind of var * proof           (* x.M *)
     and bind2 = Bind2 of var * var * proof   (* x.y.M *)

val pp_proof : proof -> string

datatype antecedent = Avail of var * prop
                    | Used of var * prop

val pp_ante : antecedent -> string

type state = antecedent list

val check : state -> proof -> prop -> state (* uses I/O states, may raise Fail *)
val entails : state -> proof -> prop -> bool  (* prints, may raise Fail *)

end (* signature LINEAR *)

structure Linear :> LINEAR =
struct

datatype prop = Atom of string        (* P *)
              | Lolli of prop * prop  (* A -o B *)
              | Tensor of prop * prop (* A * B *)
              | With of prop * prop   (* A & B *)
              | One                   (* 1 *)
              | Plus of prop * prop   (* A + B *)

fun parens s = "(" ^ s ^ ")"
fun brackets s = "[" ^ s ^ "]"

fun pp_prop (Atom(P)) = P
  | pp_prop (Lolli(A,B)) = pp_pprop A ^ " -o " ^ pp_pprop B
  | pp_prop (Tensor(A,B)) = pp_pprop A ^ " * " ^ pp_pprop B
  | pp_prop (With(A,B)) = pp_pprop A ^ " & " ^ pp_pprop B
  | pp_prop (One) = "1"
  | pp_prop (Plus(A,B)) = pp_pprop A ^ " + " ^ pp_pprop B
and pp_pprop (P as Atom _) = pp_prop P
  | pp_pprop (P as One) = pp_prop P
  | pp_pprop A = parens (pp_prop A)

type var = string

datatype proof = Id of var                    (* Id x *)
               | LolliR of bind               (* -oR (x.M) *)
               | LolliL of var * proof * bind (* -oL x M (y.P) *)
               | TensorR of proof * proof     (* *R M N *)
               | TensorL of var * bind2       (* *L x (y.z.P) *)
               | WithR of proof * proof       (* &R M N *)
               | WithL1 of var * bind         (* &L1 M (y.P) *)
               | WithL2 of var * bind         (* &L2 M (z.P) *)
               | OneR                         (* 1R *)
               | OneL of var * proof          (* 1L x M *)
               | PlusR1 of proof              (* +R1 M *)
               | PlusR2 of proof              (* +R2 N *)
               | PlusL of var * bind * bind   (* +L x (y.P) (z.Q) *)
     and bind = Bind of var * proof           (* x.M *)
     and bind2 = Bind2 of var * var * proof   (* x.y.M *)

fun pp_proof (Id(x)) = "Id " ^ x
  | pp_proof (LolliR(xM)) = "-oR " ^ pp_bind xM
  | pp_proof (LolliL(x,M,yP)) = "-oL " ^ x ^ " " ^ pp_pproof M ^ " " ^ pp_bind yP
  | pp_proof (TensorR(M,N)) = "*R " ^ pp_pproof M ^ " " ^ pp_pproof N
  | pp_proof (TensorL(x,xyP)) = "*L " ^ x ^ " " ^ pp_bind2 xyP
  | pp_proof (WithR(M,N)) = "&R " ^ pp_pproof M ^ " " ^ pp_pproof N
  | pp_proof (WithL1(x,yP)) = "&L1 " ^ x ^ " " ^ pp_bind yP
  | pp_proof (WithL2(x,yP)) = "&L2 " ^ x ^ " " ^ pp_bind yP
  | pp_proof (OneR) = "1R"
  | pp_proof (OneL(x,P)) = "1L " ^ x ^ " " ^ pp_pproof P
  | pp_proof (PlusR1(M)) = "+R1 " ^ pp_pproof M
  | pp_proof (PlusR2(M)) = "+R2 " ^ pp_pproof M
  | pp_proof (PlusL(x,yP,zQ)) = "+L " ^ x ^ " " ^ pp_bind yP ^ " " ^ pp_bind zQ
and pp_pproof M = parens (pp_proof M)

and pp_bind (Bind(x,M)) = parens (x ^ ". " ^ pp_proof M)
and pp_bind2 (Bind2(x,y,M)) = parens (x ^ "." ^ y ^ ". " ^ pp_proof M)

datatype antecedent = Avail of var * prop
                    | Used of var * prop

fun pp_ante (Avail(x,A)) = parens (x ^ " : " ^ pp_prop A)
  | pp_ante (Used(x,A)) = brackets (x ^ " : " ^ pp_prop A)

type state = antecedent list

fun subset Delta1 Delta2 = 
    List.all (fn J1:antecedent => List.exists (fn J2 => J1 = J2) Delta2) Delta1

fun eq Delta1 Delta2 = subset Delta1 Delta2 andalso subset Delta2 Delta1

fun all_used Delta = List.all (fn Used _ => true | _ => false) Delta

(* we write '+' for disjoint union on states *)

(* pull x DeltaL (DeltaR + Avail(x:A)) = (Avail(x:A)::(DeltaL + DeltaR))
 * raises Fail is there is now such J(x:A)
 *)
fun pull x DeltaL (Avail(y,B)::DeltaR) =
    if x = y then Avail(y,B)::DeltaL @ DeltaR
    else pull x (DeltaL @ [Avail(y,B)]) DeltaR
  | pull x DeltaL (Used(y,B)::DeltaR) =
    if x = y then Used(y,B)::DeltaL @ DeltaR
    else pull x (DeltaL @ [Used(y,B)]) DeltaR
  | pull x DeltaL nil = raise Fail ("variable " ^ x ^ " unknown or out of scope")

(* pull_avail x (Delta + Avail(x,A)) = Avail(x,A)::Delta
 * raises Fail if there is no such Avail(x,A)
 *)
fun pull_avail x Delta =
    (case pull x [] Delta
      of Avail(x,A)::Delta' => Avail(x,A)::Delta'
       | Used(x,_)::_ => raise Fail ("variable " ^ x ^ " unavailable"))

(* pull_used x (Delta + Used(x,A)) = (Used(x,A)::Delta)
 * raises Fail if there is not such Used(x,A)
 *)
fun pull_used x Delta =
    (case pull x [] Delta
      of Used(x,A)::Delta' => Used(x,A)::Delta'
       | Avail(x,_)::_ => raise Fail ("variable " ^ x ^ " unused"))

(* pull2_used x y (Delta + Used(x,A) + Used(y,B)) = Used(x,A)::Used(y,B)::Delta
 * raises Fail if there are no such Used(x,A) and Used(y,B)
 *)
fun pull2_used x y Delta =
    (case pull_used x Delta
      of Used(x,A)::Delta' =>
         (case pull_used y Delta'
           of Used(y,B)::Delta'' => Used(x,A)::Used(y,B)::Delta''))

(* check_fresh x Delta = () if there is no Avail(x,A) in Delta
 * raises Fail otherwise
 *)
fun check_fresh x (Avail(y,B)::Delta) =
    if x = y then raise Fail ("illegal shadowing of variable " ^ x)
    else check_fresh x Delta
  | check_fresh x (Used(y,B)::Delta) =
    if x = y then raise Fail ("illegal shadowing of variable " ^ x)
    else check_fresh x Delta
  | check_fresh x nil = ()

(* assume (Avail(x,A)) Delta = Avail(x,A)::Delta
 * raises Fail if x not fresh in Delta
 *)
fun assume (Avail(x,A)) Delta =
    ( check_fresh x Delta
    ; Avail(x,A)::Delta )

(* check Delta M A = Delta' if Delta - Delta' |- M : A
 * where Delta - Delta' is set difference
 * (variables make antecedents unique)
 * unlike ordered checking, Delta may contain Used _ antecedents
 *)
fun check Delta (Id(x)) A = 
    (case pull_avail x Delta
      of Avail(x,B)::Delta' =>
         if A <> B then raise Fail ("mismatch on type of variable " ^ x)
         else Used(x,B)::Delta')

  | check Delta (LolliR(Bind(x,M))) (Lolli(A,B)) =
    (case pull_used x (check (assume (Avail(x,A)) Delta) M B)
      of (Used(x,A)::Delta') => Delta')
  | check Delta (LolliR(Bind(x,M))) D = raise Fail ("term not of type _ -o _")

  | check Delta (LolliL(x,M,Bind(y,P))) C =
    (case pull_avail x Delta
      of Avail(x,Lolli(A,B))::Delta' =>
         (case check Delta' M A (* or with Used(x,_)? *)
           of Delta1 =>
              (case pull_used y (check (assume (Avail(y,B)) Delta1) P C)
                of Used(y,B)::Delta2 => Used(x,Lolli(A,B))::Delta2))
       | (Avail(x,D)::_) => raise Fail ("variable " ^ y ^ " not of type _ -o _"))

  | check Delta (TensorR(M,N)) (Tensor(A,B)) =
    let val Delta1 = check Delta M A
        val Delta2 = check Delta1 N B
    in Delta2 end
  | check Delta (TensorR(M,N)) D = raise Fail ("term not of type _ * _ ")

  | check Delta (TensorL(x,Bind2(y,z,P))) C =
    (case pull_avail x Delta
      of Avail(x,Tensor(A,B))::Delta' =>
         (case pull2_used y z (check (assume (Avail(y,A)) (assume (Avail(z,B)) Delta')) P C)
           of Used(y,A)::Used(z,B)::Delta'' => Used(x,Tensor(A,B))::Delta'')
       | Avail(x,D)::_ => raise Fail ("variable " ^ x ^ " not of type _ * _"))

  | check Delta (WithR(M,N)) (With(A,B)) =
    let val Delta1 = check Delta M A
        val Delta2 = check Delta N B
        val () = if eq Delta1 Delta2 then ()
                 else raise Fail ("branches do not use the same variables")
    in Delta1 end
  | check Delta (WithR _) D = raise Fail ("term not of type _ & _")

  | check Delta (WithL1(x,Bind(y,P))) C =
    (case pull_avail x Delta
      of Avail(x, With(A,B))::Delta' =>
         (case pull_used y (check (assume (Avail(y,A)) Delta') P C)
           of (Used(y,A)::Delta'') => (Used(x,With(A,B))::Delta''))
       | Avail(x,D)::_ => raise Fail ("variable " ^ x ^ " not of type _ & _"))

  | check Delta (WithL2(x,Bind(y,P))) C =
    (case pull_avail x Delta
      of Avail(x, With(A,B))::Delta' =>
         (case pull_used y (check (assume (Avail(y,B)) Delta') P C)
           of (Used(y,B)::Delta'') => Used(x, With(A,B))::Delta')
       | Avail(x, D)::_ => raise Fail ("variable " ^ x ^ " not of type _ & _"))

  | check Delta (OneR) One = Delta
  | check Delta (OneR) D = raise Fail ("term not of type 1")
                                 
  | check Delta (OneL(x,P)) C =
    (case pull_avail x Delta
      of Avail(x, One)::Delta' => check (Used(x, One)::Delta') P C
       | Avail(x,_)::_ => raise Fail ("variable " ^ x ^ " not of type 1"))

  | check Delta (PlusR1(M)) (Plus(A,B)) = check Delta M A
  | check Delta (PlusR1(M)) D = raise Fail ("term not of type _ + _")

  | check Delta (PlusR2(N)) (Plus(A,B)) = check Delta N B
  | check Delta (PlusR2(N)) D = raise Fail ("term not of type _ + _")
                                      
  | check Delta (PlusL(x,Bind(y,P),Bind(z,Q))) C =
    (case pull_avail x Delta
      of Avail(x, Plus(A,B))::Delta' =>
         (case pull_used y (check (assume (Avail(y,A)) Delta') P C)
           of Used(y,A)::Delta1' =>
              (case pull_used z (check (assume (Avail(z,B)) Delta') Q C)
                of (Used(z,B)::Delta2') =>
                   if eq Delta1' Delta2'
                   then Used(x, Plus(A,B))::Delta1'
                   else raise Fail ("branches for " ^ x ^ " do not use the same variables")))
       | Avail(x, D)::_ => raise Fail ("variable " ^ x ^ " not of type _ +"))


fun entails Delta M A =
    let val ante = String.concatWith ", " (List.map pp_ante Delta)
        val succ = pp_prop A
        val () = print (ante ^ " |- " ^ succ ^ "\n")
        val () = print (pp_proof M ^ "\n")
        val Delta' = check Delta M A
        val () = if all_used Delta' then ()
                 else raise Fail "not all variables used"
    in true end

end (* structure Linear *)
