(* Statics *)
(* implements a context-passing linear type-checker *)

signature STATICS =
sig

    val check : IntSyn.env -> IntSyn.ctx -> IntSyn.proc -> IntSyn.typing -> IntSyn.ext -> IntSyn.ctx
    (* may raise ErrorMsg.Error *)

end (* signature STATICS *)

structure Statics :> STATICS =
struct

fun ERROR(ext, msg) = ErrorMsg.ERROR ext msg

structure I = IntSyn

(* lookup that may raise ErrorMsg.Error *)

fun select k alts ext =
    (case I.select k alts
      of SOME(elem) => elem
       | NONE => ERROR(ext, "missing alternative " ^ I.pp_tag k))

fun lookup_procname env p ext =
    (case I.lookup_procname env p
      of SOME(d) => d
       | NONE => ERROR(ext, "process " ^ p ^ " undefined"))

fun lookup_ntp x Delta ext =
    (case I.lookup_tpn x Delta
      of SOME(I.TpN(x, nA, I.Used)) => ERROR(ext, "variable " ^ I.pp_chan x ^ " already used")
       | SOME(I.TpN(x, nA, _)) => nA
       | NONE => ERROR(ext, "variable " ^ I.pp_chan x ^ " unknown or out of scope"))

fun lookup_status x Delta ext =
    (case I.lookup_tpn x Delta
      of SOME(I.TpN(x, _, s)) => s
       | NONE => raise Match)

fun lookup_tp env x Delta ext =
    I.expand env (lookup_ntp x Delta ext)

fun pp_xR env z nC = I.pp_parm env (I.Tp(z, nC))
fun pp_xL env x Delta =
    (case I.lookup_tpn x Delta
      of SOME(I.TpN(x, nA, _)) =>
         I.pp_parm env (I.Tp(x, nA))
       | _ => raise Match) (* must be defined *)

(*******************)
(* type comparison *)
(*******************)

(* 
 * either type equality (default) or subtyping,
 * based on the Flags.subtp setting.  Call
 * mpass with -s or --subtyping to enable subtyping
 *)

fun compare_tp env nA nB =
    if !Flags.subtp then I.subtp env nA nB
    else I.eqtp env nA nB

(************************)
(* context manipulation *)
(************************)

fun update (I.TpN(x, nA, s)) (I.TpN(y, nB, t)::Delta) =
    if x = y then I.TpN(x, nA, s)::Delta
    else I.TpN(y, nB, t)::update (I.TpN(x, nA, s)) Delta
  | update _ nil = raise Match

fun add (I.TpN(x, nA, s)) Delta ext =
    if List.exists (fn I.TpN(y, _, _) => x = y) Delta
    then ERROR(ext, "shadowing on " ^ I.pp_chan x)
    else I.TpN(x, nA, s)::Delta

fun remove x (I.TpN(y, nB, I.Used)::Delta) =
    if x = y then Delta else (I.TpN(y, nB, I.Used)::remove x Delta)
  | remove x (I.TpN(y, nB, s)::Delta) =
    I.TpN(y, nB, s)::remove x Delta
  | remove x nil = raise Match
    
fun mark_used xs (I.TpN(y, nB, s)::Delta) =
    if List.exists (fn x => x = y) xs
    then I.TpN(y, nB, I.Used)::mark_used xs Delta
    else I.TpN(y, nB, s)::mark_used xs Delta
  | mark_used xs nil = nil

fun all_used (I.TpN(x, nA, I.Must)::Delta) ext =
    ERROR(ext, "variable " ^ I.pp_chan x ^ " unused")
  | all_used (I.TpN(x, nA, s)::Delta) ext = (* May or Used *)
    I.TpN(x, nA, s)::all_used Delta ext
  | all_used nil ext = nil

fun check_status x Delta s ext =
    if lookup_status x Delta ext = s then ()
    else ERROR(ext, "inconsistent status of variable " ^ I.pp_chan x ^ " across branches")

fun check_subset (I.TpN(x, _, s)::Delta1) Delta2 ext =
    ( check_status x Delta2 s ext
    ; check_subset Delta1 Delta2 ext )
  | check_subset nil Delta2 ext = ()

fun check_equal Delta1 Delta2 ext =
    ( check_subset Delta1 Delta2 ext
    ; check_subset Delta2 Delta1 ext )

fun all_equal (Delta1::Delta2::Deltas) ext =
    ( check_equal Delta1 Delta2 ext
    ; all_equal (Delta2::Deltas) ext )
  | all_equal (Delta::nil) ext = Delta
  | all_equal nil ext = ERROR(ext, "at least one branch required")

fun must2may (I.TpN(x, nA, I.Must)::Delta) = I.TpN(x, nA, I.May)::must2may Delta
  | must2may (I.TpN(x, nA, s)::Delta) = I.TpN(x, nA, s)::must2may Delta (* s = Used or May *)
  | must2may nil = nil

fun may2must (I.TpN(x, nA, I.Must)::Delta) Delta' = (* was: Must, became May *)
    (case lookup_status x Delta' NONE               (* status should always be defined *)
      of I.May => may2must Delta (update (I.TpN(x, nA, I.Must)) Delta') (* revert to Must if unused *)
       | I.Used => may2must Delta Delta' (* remains Used *)
       | I.Must => raise Match) (* impossible *)
  | may2must (I.TpN(x, nA, s)::Delta) Delta' = (* s = Used or May *)
    may2must Delta Delta'
  | may2must nil Delta' = Delta'

(* minimal inference for cut *)

fun infer_cut env Delta (I.Cut(x, NONE, I.Call(p, x', ys), Q)) zC ext =
    let val () = if x = x' then ()
                 else ERROR(ext, "variable mismatch: " ^ I.pp_chan x' ^ " not equal provided channel " ^ I.pp_chan x)
        val I.ProcDef(p, I.Tp(_,nA), uBs, _, ext') = lookup_procname env p ext
    in nA end
  | infer_cut env Delta (I.Cut(x, NONE, I.Marked(marked_P), Q)) zC ext =
    infer_cut env Delta (I.Cut(x, NONE, Mark.data marked_P, Q)) zC ext (* don't update ext *)
  | infer_cut env Delta (I.Cut(x, NONE, P, Q)) zC ext =
    ERROR(ext, "ommitted type of " ^ I.pp_chan x ^ " not inferrable")

(* right_rule P z = true if P communicates along z, the provided channel
 * requires P is a send or receive
 *)
fun right_rule (I.Send(x,_)) z = (x = z)
  | right_rule (I.Recv(x,_)) z = (x = z)

fun check_channel (I.Channel _) ext = ()
  | check_channel (I.MarkedValue(marked_V)) ext =
    check_channel (Mark.data marked_V) (Mark.ext marked_V)
  | check_channel V ext = ERROR(ext, "message " ^ I.pp_value V ^ " must be a channel")

fun get_channel (I.Channel(x)) ext = x
  | get_channel (I.MarkedValue(marked_V)) ext =
    get_channel (Mark.data marked_V) (Mark.ext marked_V)
  | get_channel V ext = ERROR(ext, "message " ^ I.pp_value V ^ " must be a channel")

fun get_channel_opt (I.Channel(x)) = SOME(x)
  | get_channel_opt (I.MarkedValue(marked_V)) = get_channel_opt (Mark.data marked_V)
  | get_channel_opt V = NONE                      

fun tpcomp () = if !Flags.subtp then " not subtype of " else " not equal to "

fun project_pair ((I.Pair(V1,V2), P)::K) NONE nil ext =
    let val y = get_channel V1 ext
    in project_pair K (SOME(y)) [(V2,P)] ext end
  | project_pair ((I.Pair(V1,V2), P)::K) (SOME(y)) Kpairs ext =
    let val z = get_channel V1 ext
        val P' = I.subst_proc [(y,z)] P
    in project_pair K (SOME(y)) (Kpairs @ [(V2,P')]) ext end
  | project_pair ((I.MarkedValue(marked_V), P)::K) yOpt Kpairs ext =
    project_pair ((Mark.data marked_V, P)::K) yOpt Kpairs (Mark.ext marked_V)
  | project_pair ((I.Channel(x),P)::K) yOpt Kpairs ext =
    ERROR(ext, "overlapping variable pattern " ^ I.pp_value (I.Channel(x)))
  | project_pair ((V,P)::branches) yOpt Kpairs ext =
    ERROR(ext, "pattern " ^ I.pp_value V ^ " not a pair (y,_)")
  | project_pair nil (SOME(y)) Kpairs ext = (y, Kpairs)

fun project_unit ((I.Unit, P)::K) ext =
    P::project_unit K ext
  | project_unit ((I.MarkedValue(marked_V), P)::K) ext =
    project_unit ((Mark.data marked_V, P)::K) (Mark.ext marked_V)
  | project_unit ((I.Channel(x),P)::K) ext =
    ERROR(ext, "overlapping variable pattern " ^ I.pp_value (I.Channel(x)))
  | project_unit ((V,P)::K) ext =
    ERROR(ext, "pattern " ^ I.pp_value V ^ " not unit ()")
  | project_unit nil ext = nil

fun project_label ((I.Label(k,Vk), P)::K) l Kl Kother ext =
    if k = l
    then project_label K l (Kl @ [(Vk,P)]) Kother ext
    else project_label K l Kl (Kother @ [(I.Label(k,Vk),P)]) ext
  | project_label ((I.MarkedValue(marked_V), P)::K) l Kl Kother ext =
    project_label ((Mark.data marked_V, P)::K) l Kl Kother (Mark.ext marked_V)
  | project_label ((I.Channel(x),P)::K) l Kl Kother ext =
    ERROR(ext, "overlapping variable pattern " ^ I.pp_value (I.Channel(x)))
  | project_label ((V,P)::K) l Kl Kother ext =
    ERROR(ext, "pattern " ^ I.pp_value V ^ " not tagged 'k(_)")
  | project_label nil l Kl Kother ext = (Kl, Kother)

(* check env Delta P zC ext = Delta'
 * checkR env Delta P zC ext = Delta' if P ends in a right rule
 * checkL env Delta P zC ext = Delta' if P ends in a right rule
 * where Delta - Delta' |- P :: (z : C)
 * raises ErrorMsg.Error if not such Delta' exists
 *
 * variables in Delta are marked as
 * Must (must be used, never in Delta')
 * May (may be used, could be May or Used in Delta'
 * Used (was used, so no longer available
 * Right (succedent/offered channel)
 *)
fun check env Delta (I.Fwd(x,y)) (I.TpN(x', nA, I.Right)) ext =
    (* y : B |- fwd x y :: (x : A) if B <: A *)
    ( if x = x' then () else ERROR(ext, "variable mismatch: " ^ I.pp_chan x ^ " not equal to provided channel " ^ I.pp_chan x') ;
      case lookup_ntp y Delta ext
       of nB => if compare_tp env nB nA
                then all_used (mark_used [y] Delta) ext
                else ERROR(ext, "type mismatch: " ^ pp_xL env y Delta ^ tpcomp() ^ pp_xR env x' nA))
                                      
  | check env Delta (I.Cut(x,SOME(nA),P,Q)) (I.TpN(z, nC, I.Right)) ext =
    (* 
     * Delta1, Delta2 |- x:A <- P(x) ; Q(x) :: (z : C)
     * if Delta1 |- P(x) :: (x : A)
     * and Delta2, x:A |- Q(x) :: (z : C)
     *)
    let val () = if x <> z then () else ERROR(ext, "channel " ^ I.pp_chan x ^ " shadowing provided channel")
        val Delta1 = check env (must2may Delta) P (I.TpN(x, nA, I.Right)) ext
        val Delta2 = may2must Delta Delta1 (* revert May to Must *)
        val Delta' = check env (add (I.TpN(x, nA, I.Must)) Delta2 ext) Q (I.TpN(z, nC, I.Right)) ext
    in remove x Delta' end (* end of scope for x *)

  | check env Delta (I.Cut(x,NONE,P,Q)) (I.TpN(z, nC, I.Right)) ext =
    let val () = if x <> z then () else ERROR(ext, "channel " ^ I.pp_chan x ^ " shadowing provided channel")
        val nA = infer_cut env Delta (I.Cut(x, NONE, P, Q)) (I.TpN(z, nC, I.Right)) ext
    in check env Delta (I.Cut(x, SOME(nA), P, Q)) (I.TpN(z, nC, I.Right)) ext end

  | check env Delta (I.Call(p, x, ys)) (I.TpN(z, nC, I.Right)) ext =
    let val I.ProcDef(p, I.Tp(x',nA), uBs, _, ext') = lookup_procname env p ext
        val () = if x = z then ()
                 else ERROR(ext, "variable mismatch: " ^ I.pp_chan x ^ " does not match provided channel " ^ I.pp_chan z)
        val () = if compare_tp env nA nC then ()
                 else ERROR(ext, "type mismatch: " ^ I.pp_parm env (I.Tp(x,nA)) ^ tpcomp () ^ pp_xR env z nC)
        val () = if List.length ys = List.length uBs then ()
                 else ERROR(ext, "process " ^ p ^ " requires " ^ Int.toString (List.length uBs) ^ " arguments "
                                 ^ " but given " ^ Int.toString (List.length ys))
        val yBs = List.map (fn y => I.Tp(y, lookup_ntp y Delta ext)) ys
        val () = ListPair.app (fn (I.Tp(y,nB'), I.Tp(u,nB)) =>
                                  if compare_tp env nB' nB then ()
                                  else ERROR(ext, "type mismatch: " ^ I.pp_parm env (I.Tp(y,nB'))
                                                  ^ tpcomp () ^ I.pp_parm env (I.Tp(u,nB))))
                              (yBs, uBs)
    in all_used (mark_used ys Delta) ext end

  | check env Delta (I.Marked(marked_P)) zC ext =
    check env Delta (Mark.data marked_P) zC (Mark.ext marked_P)

  | check env Delta P (I.TpN(z, nC, I.Right)) ext =
    if right_rule P z
    then checkR env Delta P (I.TpN(z, nC, I.Right)) ext
    else checkL env Delta P (I.TpN(z, nC, I.Right)) ext

(* P not a mark *)
and checkR env Delta (I.Send(x, V)) (I.TpN(z, nC, I.Right)) ext = (* x = z *)
    all_used (checkRvalue env Delta V nC ext) ext
  | checkR env Delta (I.Recv(x, I.Cont(K))) (I.TpN(z, nC, I.Right)) ext = (* x = z *)
    checkRcont env Delta K nC ext

and checkRvalue env Delta (I.Channel(x')) nA ext =
    (* x' : B |- x' : [A] if B <: A *)
    let val nB = lookup_ntp x' Delta ext
    in if compare_tp env nB nA (* nB <: nA *)
       then mark_used [x'] Delta
       else ERROR(ext, "type mismatch: " ^ pp_xL env x' Delta ^ tpcomp() ^ I.pp_tpn env nA)
    end
  | checkRvalue env Delta (I.Label(k,V)) nA ext =
    (* 
     * Delta |- k(V) : [+{l:Al}]
     * if Delta |- V : [Ak]
     *)
    (case I.expand env nA
      of I.Plus(alts) =>
         (case select k alts ext
           of nAk => checkRvalue env Delta V nAk ext)
       | _ => ERROR(ext, "type mismatch: " ^ I.pp_tpn env nA ^ " not an internal choice +{...}"))
  | checkRvalue env Delta (I.Pair(V1,V2)) nA ext =
    (*
     * Delta, y:B |- (y,V) : [A1 * A2]
     * if y:B |- y : [A1]   (iff B <: A1)
     * and Delta |- V : [A2]
     *)
    ( check_channel V1 ext      (* V1 = y *)
    ; case I.expand env nA
       of I.Tensor(nA1,nA2) =>
          let val Delta1 = checkRvalue env Delta V1 nA1 ext (* V1 = y *)
              val Delta2 = checkRvalue env Delta1 V2 nA2 ext
          in Delta2 end
        | _ => ERROR(ext, "type mismatch: " ^ I.pp_tpn env nA ^ " is not a tensor (_ * _)"))
  | checkRvalue env Delta (I.Unit) nA ext =
    (* . |- () : [1] *)
    (case I.expand env nA
      of I.One => Delta
       | _ => ERROR(ext, "type mismatch: " ^ I.pp_tpn env nA ^ " not equal 1"))

  | checkRvalue env Delta (I.MarkedValue(marked_V)) nA ext =
    checkRvalue env Delta (Mark.data marked_V) nA (Mark.ext marked_V)

and checkRcont env Delta [(V, P)] nA ext =
    (*
     * Delta |- (x' => P(x')) : <A> (inversion)
     * if Delta |- P(x') :: (x' : A)
     *)
    (case get_channel_opt V
      of SOME(x') => check env Delta P (I.TpN(x', nA, I.Right)) ext
       | NONE => checkRcont_tp env Delta [(V,P)] (I.expand env nA) ext)
  | checkRcont env Delta K nA ext =
    checkRcont_tp env Delta K (I.expand env nA) ext

and checkRcont_tp env Delta K (I.With(alts)) ext =
    (*
     * Delta |- K : <&{l:Al}>
     * if Delta |- K @ l(_) : <Al> for all l
     *)
    let val Deltas = checkRalts env Delta K alts ext
    in all_equal Deltas ext end
    
  | checkRcont_tp env Delta K (I.Lolli(nA1,nA2)) ext =
    (*
     * Delta |- K : <A1 -o A2>
     * if Delta, y:A1 |- K @ (y,_) : <A2>
     *)
    let val (y, Kpairs) = project_pair K NONE nil ext (* y substituted into all branches *)
        val Delta' = checkRcont env (add (I.TpN(y, nA1, I.Must)) Delta ext) Kpairs nA2 ext
    in remove y Delta' end

  | checkRcont_tp env Delta branches A ext =
    ERROR(ext, "type mismatch: " ^ I.pp_tp env A ^ " is not external choice or linear function type")

and checkRalts env Delta K ((l,Al)::alts) ext =
    (case project_label K l nil nil ext
      of (nil, _) => ERROR(ext, "missing branch for " ^ I.pp_tag l)
       | (Kl, Kother) => checkRcont env Delta Kl Al ext
                         :: checkRalts env Delta Kother alts ext)
  | checkRalts env Delta nil nil ext = nil
  | checkRalts env Delta ((V,P)::_) nil ext =
    ERROR(ext, "extraneous branch with pattern " ^ I.pp_value V)

and checkL env Delta (I.Send(x, V)) zC ext =
    all_used (checkLvalue env (mark_used [x] Delta) (lookup_ntp x Delta ext) V zC ext) ext

  | checkL env Delta (I.Recv(x, I.Cont(K))) zC ext =
    checkLcont env (mark_used [x] Delta) (lookup_ntp x Delta ext) K zC ext

and checkLvalue env Delta nA (I.Channel(x')) (I.TpN(z, nB, I.Right)) ext =
    (* [A] |- x' :: (x' : B) if A <: B *)
    let val () = if x' = z then ()
                 else ERROR (ext, "variable mismatch: " ^ I.pp_chan x' ^ " does not match provided channel " ^ I.pp_chan z)
    in if compare_tp env nA nB (* nA <: nB *)
       then Delta
       else ERROR(ext, "type mismatch: " ^ I.pp_tpn env nA ^ tpcomp() ^ I.pp_tpn env nB)
    end
  | checkLvalue env Delta nA (I.Label(k,V)) zC ext =
    (*
     * Delta, [&{l:Al}] |- k(V) :: (z : C)
     * if Delta, [Ak] |- V :: (z : C)
     *)
    (case I.expand env nA
      of I.With(alts) =>
         (case select k alts ext
           of nAk => checkLvalue env Delta nAk V zC ext)
       | _ => ERROR(ext, "type mismatch: " ^ I.pp_tpn env nA ^ " not an external choice &{...}"))
  | checkLvalue env Delta nA (I.Pair(V1,V2)) zC ext =
    (*
     * Delta, y:A1', [A1 -o A2] |- (y, V2) :: (z : C)
     * if y:A1' |- y : [A1]  (iff A1' <: A1)
     * and Delta, [A2] |- V2 :: (z : C)
     *)
    ( check_channel V1 ext      (* V1 = y *)
    ; case I.expand env nA
       of I.Lolli(nA1,nA2) =>
          let val Delta1 = checkRvalue env Delta V1 nA1 ext
              val Delta2 = checkLvalue env Delta1 nA2 V2 zC ext
          in Delta2 end
        | _ => ERROR(ext, "type mismatch: " ^ I.pp_tpn env nA ^ " not a linear function (_ -o _)"))
  | checkLvalue env Delta nA (I.Unit) zC ext =
    ERROR(ext, "sending '()' to the provider is not permitted (would require type 'bot')")

  | checkLvalue env Delta nA (I.MarkedValue(marked_V)) zC ext =
    checkLvalue env Delta nA (Mark.data marked_V) zC (Mark.ext marked_V)

and checkLcont env Delta nA [] zC ext =
    ERROR(ext, "no matching branch for type " ^ I.pp_tpn env nA)
  | checkLcont env Delta nA [(V, P)] zC ext =
    (*
     * Delta, <A> |- (x' => P(x')) :: (z : C)
     * if Delta, x':A |- P(x') :: (z : C)
     *)
    (case get_channel_opt V
      of SOME(x') => let val Delta' = check env (add (I.TpN(x', nA, I.Must)) Delta ext) P zC ext
                     in remove x' Delta' end
       | NONE => checkLcont_tp env Delta (I.expand env nA) [(V, P)] zC ext)
  | checkLcont env Delta nA K zC ext =
    checkLcont_tp env Delta (I.expand env nA) K zC ext

and checkLcont_tp env Delta (I.Plus(alts)) K zC ext =
    (*
     * Delta, <+{l:Al}> |- K :: (z : C)
     * if Delta, <Al> |- K @ l(_) :: (z : C) for all l
     *)
    let val Deltas = checkLalts env Delta alts K zC ext
    in all_equal Deltas ext end
  | checkLcont_tp env Delta (I.Tensor(nA1,nA2)) K zC ext =
    (*
     * Delta, <A1 * A2> |- K :: (z : C)
     * if Delta, y : A1, <A2> |- K @ (y,_) :: (z : C)
     *)
    let val (y,Kpairs) = project_pair K NONE nil ext (* y substituted into all branches *)
        val Delta' = checkLcont env (add (I.TpN(y, nA1, I.Must)) Delta ext) nA2 Kpairs zC ext
    in remove y Delta' end
  | checkLcont_tp env Delta (I.One) K zC ext =
    (*
     * Delta, <1> |- K :: (z : C)
     * if Delta |- K @ () :: (z : C)
     *)
    let val Punits = project_unit K ext
        val P = (case Punits
                  of nil => ERROR(ext, "no matching branch for unit type 1")
                   | P::nil => P
                   | _ => ERROR(ext, "more than one matching branch for unit type 1"))
    in check env Delta P zC ext end
  | checkLcont_tp env Delta nA K zC ext =
    ERROR(ext, "type mismatch")

and checkLalts env Delta ((l,Al)::alts) K zC ext =
    (case project_label K l nil nil ext
      of (nil, _) => ERROR(ext, "missing branch for " ^ I.pp_tag l)
       | (Kl, Kother) => checkLcont env Delta Al Kl zC ext
                         :: checkLalts env Delta alts Kother zC ext)
  | checkLalts env Delta nil nil zC ext = nil
  | checkLalts env Delta nil ((V,P)::_) zC ext =
    ERROR(ext, "extraneous branch with pattern " ^ I.pp_value V)

(* execute directly with queues *)
(*
fun trans env Delta (I.Fwd(x,y)) zC = I.Fwd(x,y)
  | trans env Delta (I.Cut(x,SOME(nA),P,Q)) zC =
    I.Cut(x, SOME(nA), trans env Delta P (I.TpN(x,nA,I.Right)),
          trans env (add (I.TpN(x,nA,I.Must)) Delta NONE) Q zC)
  | trans env Delta (I.Call(p, x, ys)) zC = I.Call(p, x, ys)
  | trans env Delta (I.Marked(marked_P)) zC =
    trans env Delta (Mark.data marked_P) zC
  | trans env Delta P (I.TpN(z, nC, _)) =
    if right_rule P z
    then transR env Delta P (I.TpN(z, nC, _))
    else transL env Delta P (I.TpN(z, nC, _))
and transR env Delta (I.Send(x,V)) (I.TpN(_, nC, _)) = I.Send(x, transRvalue env Delta V zC)
  | transR env Delta (I.Recv(x,I.Cont(K))) zC = I.Recv(x, transRcont env Delta K zC)
and transRvalue env Delta (I.Channel(x'))
*)

end
