
structure CheckModule :> CHECK_MODULE =
   struct

      open ILModule
      open SubstModule
      open FirstModule
      open ContextModule
      open EquivModule

      exception TypeError = Misc.TypeError

      fun checkKind ctx k =
         (case k of
             Ktype => ()

           | Ksing c =>
                checkCon ctx c Ktype

           | Kpi (k1, k2) =>
                (
                checkKind ctx k1;
                checkKind (extendKind ctx k1) k2
                )

           | Ksigma (k1, k2) =>
                (
                checkKind ctx k1;
                checkKind (extendKind ctx k1) k2

                )
           | Kunit => ())

      and inferCon ctx c =
         (case c of
             Cvar (i, _) =>
                selfify c (lookupKind ctx i)

           | Clam (k, c) =>
                (
                checkKind ctx k;
                Kpi (k, inferCon (extendKind ctx k) c)
                )

           | Capp (c1, c2) =>
                (case inferCon ctx c1 of
                    Kpi (dom, cod) =>
                       (
                       checkCon ctx c2 dom;
                       substKind c2 cod
                       )
                  | _ =>
                       raise TypeError)

           | Cpair (c1, c2) =>
                Ksigma (inferCon ctx c1, liftKind 1 (inferCon ctx c2))

           | Cpi1 c' =>
                (case inferCon ctx c' of
                    Ksigma (k1, k2) =>
                       k1
                  | _ =>
                       raise TypeError)

           | Cpi2 c' =>
                (case inferCon ctx c' of
                    Ksigma (k1, k2) =>
                       substKind (Cpi1 c') k2
                  | _ =>
                       raise TypeError)

           | Cunit => Kunit

           | Carrow (c1, c2) =>
                (
                checkCon ctx c1 Ktype;
                checkCon ctx c2 Ktype;
                Ksing c
                )

           | Cprod cl =>
                (
                List.app (fn c' => checkCon ctx c' Ktype) cl;
                Ksing c
                )

           | Csum cl =>
                (
                List.app (fn c' => checkCon ctx c' Ktype) cl;
                Ksing c
                )

           | Crec c' =>
                (
                checkCon (extendKind ctx Ktype) c' Ktype;
                Ksing c
                )

           | Ctag c' =>
                (
                checkCon ctx c' Ktype;
                Ksing c
                )

           | Cref c' =>
                (
                checkCon ctx c' Ktype;
                Ksing c
                )

           | Cexn => Ksing Cexn
           | Cbool => Ksing Cbool
           | Cint => Ksing Cint
           | Cchar => Ksing Cchar
           | Cstring => Ksing Cstring)

      and checkCon ctx c k =
         subkind ctx (inferCon ctx c) k

      fun whnfAnnot ctx t =
         (
         checkCon ctx t Ktype;
         whnf ctx t
         )

      fun checkSg ctx sg =
         (case sg of
             Sval c =>
                checkCon ctx c Ktype
           | Scon k =>
                checkKind ctx k
           | Ssigma (sg1, sg2) =>
                (
                checkSg ctx sg1;
                checkSg (extendKind ctx (fstsg sg1)) sg2
                )
           | Spi (sg1, sg2) =>
                (
                checkSg ctx sg1;
                checkSg (extendKind ctx (fstsg sg1)) sg2
                )
           | Sunit => ()
           | Snamed (_, sg') =>
                checkSg ctx sg')


      structure PrimType =
         PrimTypeFun (structure Param =
                         struct
                            open ILModule
                            val Cunittype = Cprod []
                         end)

      fun inferTerm ctx e =
         (case e of
             Tvar v =>
                lookupType ctx v

           | Tlam (v, dom, e') =>
                (
                checkCon ctx dom Ktype;
                Carrow (dom, inferTerm (extendType ctx v dom) e')
                )

           | Tapp (e1, e2) =>
                (case inferTermWhnf ctx e1 of
                    Carrow (dom, cod) =>
                       (
                       checkTerm ctx e2 dom;
                       cod
                       )
                  | _ =>
                       raise TypeError)

           | Ttuple el =>
                Cprod (map (inferTerm ctx) el)

           | Tproj (e', i) =>
                (case inferTermWhnf ctx e' of
                    Cprod tl =>
                       (List.nth (tl, i)
                        handle Subscript => raise TypeError)
                  | _ =>
                       raise TypeError)

           | Tinj (e', i, annot) =>
                (case whnfAnnot ctx annot of
                    Csum tl =>
                       let
                          val t =
                             List.nth (tl, i)
                             handle Subscript => raise TypeError
                       in
                          checkTerm ctx e' t;
                          annot
                       end
                  | _ =>
                       raise TypeError)

           | Tcase (e', []) =>
                raise TypeError

           | Tcase (e', (v1, e1) :: rest) =>
                (case inferTermWhnf ctx e' of
                    Csum tl =>
                       (case tl of
                           [] =>
                              raise TypeError
                         | t1 :: trest =>
                              let
                                 val t = inferTerm (extendType ctx v1 t1) e1
                              in
                                 (ListPair.appEq
                                     (fn ((vi, ei), ti) =>
                                         checkTerm (extendType ctx vi ti) ei t)
                                     (rest, trest)
                                  handle ListPair.UnequalLengths => raise TypeError);
                                 t
                              end)
                  | _ =>
                       raise TypeError)

           | Troll (e', annot) =>
                (case whnfAnnot ctx annot of
                    Crec t =>
                       (
                       checkTerm ctx e' (substCon annot t);
                       annot
                       )
                  | _ =>
                       raise TypeError)

           | Tunroll e' =>
                (case inferTermWhnf ctx e' of
                    t as Crec t' =>
                       substCon t t'
                  | _ =>
                       raise TypeError)

           | Ttag (e1, e2) =>
                (case inferTermWhnf ctx e1 of
                    Ctag t =>
                       (
                       checkTerm ctx e2 t;
                       Cexn
                       )
                  | _ =>
                       raise TypeError)
                                 
           | Tiftag (e1, e2, v, e3, e4) =>
                (case inferTermWhnf ctx e1 of
                    Ctag t =>
                       let
                          val () = checkTerm ctx e2 Cexn
                          val t' = inferTerm (extendType ctx v t) e3
                       in
                          checkTerm ctx e4 t';
                          t'
                       end
                  | _ =>
                       raise TypeError)

           | Tnewtag t =>
                (
                checkCon ctx t Ktype;
                Ctag t
                )

           | Traise (e', t) =>
                (
                checkTerm ctx e' Cexn;
                checkCon ctx t Ktype;
                t
                )

           | Thandle (e1, v, e2) =>
                let
                   val t = inferTerm ctx e1
                in
                   checkTerm (extendType ctx v Cexn) e2 t;
                   t
                end

           | Tref e' =>
                Cref (inferTerm ctx e')

           | Tderef e' =>
                (case inferTermWhnf ctx e' of
                    Cref t =>
                       t
                  | _ =>
                       raise TypeError)

           | Tassign (e1, e2) =>
                (case inferTermWhnf ctx e1 of
                    Cref t =>
                       (
                       checkTerm ctx e2 t;
                       Cprod []
                       )
                  | _ =>
                       raise TypeError)

           | Tbool b => Cbool

           | Tif (e1, e2, e3) =>
                let
                   val () = checkTerm ctx e1 Cbool
                   val t = inferTerm ctx e2
                in
                   checkTerm ctx e3 t;
                   t
                end

           | Tint _ => Cint
           | Tchar _ => Cchar
           | Tstring _ => Cstring

           | Tlet (v, e1, e2) =>
                let
                   val t = inferTerm ctx e1
                in
                   inferTerm (extendType ctx v t) e2
                end

           | Tletm (v, m, e', t) =>
                let
                   val (sg, _) = inferModule ctx m
                in
                   checkCon ctx t Ktype;
                   checkTerm (extendSg ctx v sg) e' (liftCon 1 t);
                   t
                end

           | Tprim (prim, el) =>
                let
                   val (tl, t) = PrimType.primtype prim
                in
                   (ListPair.appEq
                       (fn (ei, ti) => checkTerm ctx ei ti)
                       (el, tl)
                    handle ListPair.UnequalLengths => raise TypeError);
                   t
                end

           | Tsnd m =>
                (case inferModule ctx m of
                    (Sval t, _) =>
                       t
                  | _ =>
                       raise TypeError))

      and inferTermWhnf ctx e =
         whnf ctx (inferTerm ctx e)

      and checkTerm ctx e t =
         equiv ctx (inferTerm ctx e) t Ktype

      and inferModule ctx m =
         (case m of
             Mvar v =>
                let
                   val (i, sg) = lookupSg ctx v
                   val c = Cvar (i, NONE)
                in
                   (selfifySg c sg, SOME c)
                end

           | Mval e =>
                (Sval (inferTerm ctx e), SOME Cunit)

           | Mcon c =>
                (Scon (inferCon ctx c), SOME c)

           | Munit =>
                (Sunit, SOME Cunit)

           | Mpair (m1, m2) =>
                let
                   val (sg1, co1) = inferModule ctx m1
                   val (sg2, co2) = inferModule ctx m2

                   val co =
                      (case (co1, co2) of
                          (SOME c1, SOME c2) =>
                             SOME (Cpair (c1, c2))
                        | _ =>
                             NONE)
                in
                   (Ssigma (sg1, liftSg 1 sg2), co)
                end

           | Mdpair (v, m1, m2) =>
                let
                   val (sg1, co1) = inferModule ctx m1
                   val (sg2, co2) = inferModule (extendSg ctx v sg1) m2

                   val co =
                      (case (co1, co2) of
                          (SOME c1, SOME c2) =>
                             SOME (Cpair (c1, substCon c1 c2))
                        | _ =>
                             NONE)
                in
                   (Ssigma (sg1, sg2), co)
                end

           | Mpi1 m' =>
                (case inferModule ctx m' of
                    (Ssigma (sg1, sg2), SOME c) =>
                       (sg1, SOME (Cpi1 c))
                  | _ =>
                       raise TypeError)

           | Mpi2 m' =>
                (case inferModule ctx m' of
                    (Ssigma (sg1, sg2), SOME c) =>
                       (substSg (Cpi1 c) sg2, SOME (Cpi2 c))
                  | _ =>
                       raise TypeError)

           | Mlam (v, dom, m') =>
                let
                   val () = checkSg ctx dom
                   val (cod, _) = inferModule (extendSg ctx v dom) m'
                in
                   (Spi (dom, cod), SOME Cunit)
                end

           | Mapp (m1, m2) =>
                (case inferModule ctx m1 of
                    (Spi (dom, cod), _) =>
                       (case checkModule ctx m2 dom of
                           SOME c =>
                              (substSg c cod, NONE)
                         | _ =>
                              raise TypeError)
                  | _ =>
                       raise TypeError)

           | Min (name, m') =>
                let
                   val (sg, co) = inferModule ctx m'
                in
                   (Snamed (name, sg), co)
                end

           | Mout m' =>
                (case inferModule ctx m' of
                    (Snamed (_, sg), co) =>
                       (sg, co)
                  | _ =>
                       raise TypeError)

           | Mlet (v, m1, m2, sg) =>
                let
                   val () = checkSg ctx sg
                   val (sg1, _) = inferModule ctx m1
                in
                   checkModule (extendSg ctx v sg1) m2 (liftSg 1 sg);
                   (sg, NONE)
                end

           | Mletd (v, m1, m2) =>
                (case inferModule ctx m1 of
                    (sg1, SOME c1) =>
                       let
                          val (sg2, co2) = inferModule (extendSg ctx v sg1) m2

                          val co =
                             (case co2 of
                                 SOME c2 =>
                                    SOME (substCon c1 c2)
                               | NONE =>
                                    NONE)
                       in
                          (substSg c1 sg2, co)
                       end
                  | _ =>
                       raise TypeError)

           | Mlete (v, e, m') =>
                let
                   val t = inferTerm ctx e
                in
                   inferModule (extendType ctx v t) m'
                end

           | Mseal (m', sg) =>
                (
                checkSg ctx sg;
                checkModule ctx m' sg;
                (sg, NONE)
                ))

      and checkModule ctx m sg =
         let
            val (sg', co) = inferModule ctx m
         in
            subsg ctx sg' sg;
            co
         end

      fun checkProgram sg m =
         (
         checkSg empty sg;
         checkModule empty m sg;
         ()
         )

   end
