(* socket.sml
 *
 * COPYRIGHT (c) 1995 AT&T Bell Laboratories.
 *)

structure Sock : SOCKET =
  struct

    structure CI = CInterface
    structure W8A = Word8Array
    structure W8V = Word8Vector

    val sockFn = CI.c_function "SMLNJ-Sockets"

    type w8vector = W8V.vector
    type w8array = W8A.array

  (* to inherit the various socket related types *)
    open PreSock

  (* bind socket C functions *)
    val netdbFun = CI.c_function "SMLNJ-Sockets"

    val dummyAddr = ADDR(W8V.fromList[])

  (* witness types for the socket parameter *)
    datatype dgram = DGRAM
    datatype 'a stream = STREAM
    datatype passive = PASSIVE
    datatype active = ACTIVE

(* I'm not sure where this should go *)
    local
      val getAddrFamily : addr -> af = sockFn "getAddrFamily"
    in
    fun familyOfAddr (ADDR a) = AF(getAddrFamily a)
    end

  (* address families *)
    structure AF =
      struct
	val listAddrFamilies : unit -> CI.system_const list
	      = sockFn "listAddrFamilies"
        val unix = AF(CI.bindSysConst ("UNIX", listAddrFamilies()))
	val inet = AF(CI.bindSysConst ("INET", listAddrFamilies()))
	fun list () =
	      List.map (fn arg => (#2 arg, AF arg)) (listAddrFamilies ())
        fun toString (AF(_, name)) = name
	fun fromString name = (
	      case CI.findSysConst(name, listAddrFamilies ())
	       of NONE => NONE
		| (SOME af) => SOME(AF af)
	      (* end case *))
      end

  (* socket types *)
    structure SOCK =
      struct
	val listSockTypes : unit -> CI.system_const list
	      = sockFn "listSockTypes"
	val stream = SOCKTY(CI.bindSysConst ("STREAM", listSockTypes ()))
	val dgram = SOCKTY(CI.bindSysConst ("DGRAM", listSockTypes ()))
(** NOTE: the following may not be supported!! **)
	val raw = SOCKTY(CI.bindSysConst ("RAW", listSockTypes ()))
	val rdm = SOCKTY(CI.bindSysConst ("RDM", listSockTypes ()))
	val seqPacket = SOCKTY(CI.bindSysConst ("SEQPACKET", listSockTypes ()))

	fun list () =
	      List.map (fn arg => (#2 arg, SOCKTY arg)) (listSockTypes ())
	fun toString (SOCKTY(_, name)) = name
	fun fromString name = (case CI.findSysConst(name, listSockTypes ())
	       of NONE => NONE
		| (SOME ty) => SOME(SOCKTY ty)
	      (* end case *))
      end

  (* socket control operations *)
    structure Ctl =
      struct
	local
	  fun getOpt ctlFn (SOCK fd) = ctlFn(fd, NONE)
	  fun setOpt ctlFn (SOCK fd, value) = ignore(ctlFn(fd, SOME value))
	  fun mkGet name = getOpt (sockFn name)
	  fun mkSet name = setOpt (sockFn name)
	in
      (* get/set socket options *)
        val getDEBUG	 : ('a, 'b) sock -> bool = mkGet "ctlDEBUG"
        val setDEBUG	 : (('a, 'b) sock * bool) -> unit = mkSet "ctlDEBUG"
        val getREUSEADDR : ('a, 'b) sock -> bool = mkGet "ctlREUSEADDR"
        val setREUSEADDR : (('a, 'b) sock * bool) -> unit = mkSet "ctlREUSEADDR"
        val getKEEPALIVE : ('a, 'b) sock -> bool = mkGet "ctlKEEPALIVE"
        val setKEEPALIVE : (('a, 'b) sock * bool) -> unit = mkSet "ctlKEEPALIVE"
        val getDONTROUTE : ('a, 'b) sock -> bool = mkGet "ctlDONTROUTE"
        val setDONTROUTE : (('a, 'b) sock * bool) -> unit = mkSet "ctlDONTROUTE"
        val getLINGER'	 : ('a, 'b) sock -> int option = mkGet "ctlLINGER"
(* NOTE: probably shoud do some range checking on the argument *)
        val setLINGER'	 : (('a, 'b) sock * int option) -> unit
	      = mkSet "ctlLINGER"
	fun getLINGER sock = (case (getLINGER' sock)
	       of NONE => NONE
		| (SOME t) => SOME(Time.fromSeconds t)
(** FIX ME **
		| (SOME t) => SOME(PreBasis.TIME{sec=t, usec=0})
**)
	      (* end case *))
	fun setLINGER (sock, NONE) = setLINGER'(sock, NONE)
	  | setLINGER (sock, SOME t) = setLINGER'(sock, SOME(Time.toSeconds t))
(** FIX ME **
	  | setLINGER (SOME(PreBasis.TIME{sec, ...})) = setLINGER'(SOME sec)
**)
        val getBROADCAST : ('a, 'b) sock -> bool = mkGet "ctlBROADCAST"
        val setBROADCAST : (('a, 'b) sock * bool) -> unit = mkSet "ctlBROADCAST"
        val getOOBINLINE : ('a, 'b) sock -> bool = mkGet "ctlOOBINLINE"
        val setOOBINLINE : (('a, 'b) sock * bool) -> unit = mkSet "ctlOOBINLINE"
(* NOTE: probably shoud do some range checking on the argument *)
        val getSNDBUF	 : ('a, 'b) sock -> int = mkGet "ctlSNDBUF"
        val setSNDBUF	 : (('a, 'b) sock * int) -> unit = mkSet "ctlSNDBUF"
(* NOTE: probably shoud do some range checking on the argument *)
        val getRCVBUF	 : ('a, 'b) sock -> int = mkGet "ctlRCVBUF"
        val setRCVBUF	 : (('a, 'b) sock * int) -> unit = mkSet "ctlRCVBUF"
	local
	  val getTYPE'  : int -> CI.system_const = sockFn "getTYPE"
	  val getERROR' : int -> bool = sockFn "getERROR"
	in
        fun getTYPE (SOCK fd) = SOCKTY(getTYPE' fd)
        fun getERROR (SOCK fd) = getERROR' fd
	end (* local *)

	local
	  fun getName (f : int -> addr) (SOCK fd) = ADDR(f fd)
	in
	val getPeerName	= getName (sockFn "getPeerName")
	val getSockName	= getName (sockFn "getSockName")
	end

	fun getSockAF sock = familyOfAddr(getSockName sock)
	val getSockType = getTYPE	(** REDUNDANT **)

	fun getSockProto sock = raise Fail "getSockProto"

	fun setNBIO (sock, flg) = raise Fail "Sock.Ctl.setNBIO unimplemented"
	fun getNREAD sock = raise Fail "Sock.Ctl.getNREAD unimplemented"
	fun getATMARK sock = raise Fail "Sock.Ctl.getATMARK unimplemented"

	end (* local *)
      end (* Ctl *)

  (* socket address operations *)
    fun sameAddr (ADDR a1, ADDR a2) = (a1 = a2)
(** NOTE: do we want this function, or functions for each specific socket
 ** type?
 **)
    fun addrToString addr = let
	  val af = familyOfAddr addr
	  in
	    raise Fail "addrToString"
	  end

  (* socket management *)
    local
      val accept'	: int -> (int * addr)	= sockFn "accept"
      val bind'		: (int * addr) -> unit	= sockFn "bind"
      val connect'	: (int * addr) -> unit	= sockFn "connect"
      val listen'	: (int * int) -> unit	= sockFn "listen"
      val close'	: int -> unit		= sockFn "close"
    in
    fun accept (SOCK fd) = let
	  val (newFD, addr) = accept' fd
	  in
	    (SOCK newFD, ADDR addr)
	  end
    fun bind (SOCK fd, ADDR addr) = bind' (fd, addr)
    fun connect (SOCK fd, ADDR addr) = connect' (fd, addr)
(** Should do some range checking on backLog *)
    fun listen (SOCK fd, backLog) = listen' (fd, backLog)
    fun close (SOCK fd) = close' fd
    end

    datatype shutdown_mode = NO_RECVS | NO_SENDS | NO_RECVS_OR_SENDS
    local
      val shutdown' : (int * int) -> unit = sockFn "shutdown"
      fun how NO_RECVS = 0
	| how NO_SENDS = 1
	| how NO_RECVS_OR_SENDS = 2
    in
    fun shutdown (SOCK fd, mode) = shutdown' (fd, how mode)
    end

    fun pollDesc (SOCK fd) = raise Fail "Sock.pollDesc unimplemented"

  (* Sock I/O option types *)
    type out_flags = {don't_route : bool, oob : bool}
    type in_flags = {peek : bool, oob : bool}

    local
      fun chk (len, buf, i, NONE) =
	    if ((i < 0) orelse (len < i))
	      then raise Subscript
	      else (buf, i, len - i)
	| chk (len, buf, i, SOME sz) =
	    if ((i < 0) orelse (sz < 0) orelse (len-i < sz))
	      then raise Subscript
	      else (buf, i, sz - i)
    in
    fun vbuf {buf, i, sz} = chk (W8V.length buf, buf, i, sz)
    fun abuf {buf, i, sz} = chk (W8A.length buf, buf, i, sz)
    end (* local *)

  (* default flags *)
    val dfltDon'tRoute = false
    val dfltOOB = false
    val dfltPeek = false

  (* Sock output operations *)
    local
      val sendV : (int * w8vector * int * int * bool * bool) -> int
	    = sockFn "sendBuf"
      val sendA : (int * w8array * int * int * bool * bool) -> int
	    = sockFn "sendBuf"
    in
    fun sendVec (SOCK fd, buffer) = let
	  val (vec, i, len) = vbuf buffer
	  in
	    if (len > 0) then sendV (fd, vec, i, len, dfltDon'tRoute, dfltOOB) else 0
	  end
    fun sendArr (SOCK fd, buffer) = let
	  val (arr, i, len) = abuf buffer
	  in
	    if (len > 0) then sendA (fd, arr, i, len, dfltDon'tRoute, dfltOOB) else 0
	  end
    fun sendVec' (SOCK fd, buffer, {don't_route, oob}) = let
	  val (vec, i, len) = vbuf buffer
	  in
	    if (len > 0) then sendV (fd, vec, i, len, don't_route, oob) else 0
	  end
    fun sendArr' (SOCK fd, buffer, {don't_route, oob}) = let
	  val (arr, i, len) = abuf buffer
	  in
	    if (len > 0) then sendA (fd, arr, i, len, don't_route, oob) else 0
	  end
    end (* local *)

    local
      val sendToV : (int * w8vector * int * int * bool * bool * addr) -> int
	    = sockFn "sendBufTo"
      val sendToA : (int * w8array * int * int * bool * bool * addr) -> int
	    = sockFn "sendBufTo"
    in
    fun sendVecTo (SOCK fd, ADDR addr, buffer) = let
	  val (vec, i, len) = vbuf buffer
	  in
	    if (len > 0)
	      then sendToV(fd, vec, i, len, dfltDon'tRoute, dfltOOB, addr)
	      else 0
	  end
    fun sendArrTo (SOCK fd, ADDR addr, buffer) = let
	  val (arr, i, len) = abuf buffer
	  in
	    if (len > 0)
	      then sendToA(fd, arr, i, len, dfltDon'tRoute, dfltOOB, addr)
	      else 0
	  end
    fun sendVecTo' (SOCK fd, ADDR addr, buffer, {don't_route, oob}) = let
	  val (vec, i, len) = vbuf buffer
	  in
	    if (len > 0)
	      then sendToV(fd, vec, i, len, don't_route, oob, addr)
	      else 0
	  end
    fun sendArrTo' (SOCK fd, ADDR addr, buffer, {don't_route, oob}) = let
	  val (arr, i, len) = abuf buffer
	  in
	    if (len > 0)
	      then sendToA(fd, arr, i, len, don't_route, oob, addr)
	      else 0
	  end
    end (* local *)

  (* Sock input operations *)
    local
      val recvV' : (int * int * bool * bool) -> w8vector
	    = sockFn "recv"
      fun recvV (_, 0, _, _) = W8V.fromList[]
	| recvV (SOCK fd, nbytes, peek, oob) = if (nbytes < 0)
	    then raise Subscript
	    else recvV' (fd, nbytes, peek, oob)
      val recvA : (int * w8array * int * int * bool * bool) -> int
	    = sockFn "recvBuf"
    in
    fun recvVec (sock, sz) = recvV (sock, sz, dfltPeek, dfltOOB)
    fun recvArr (SOCK fd, buffer) = let
	  val (buf, i, sz) = abuf buffer
	  in
	    if (sz > 0)
	      then recvA(fd, buf, i, sz, dfltPeek, dfltOOB)
	      else 0
	  end
    fun recvVec' (sock, sz, {peek, oob}) = recvV (sock, sz, peek, oob)
    fun recvArr' (SOCK fd, buffer, {peek, oob}) = let
	  val (buf, i, sz) = abuf buffer
	  in
	    if (sz > 0) then recvA(fd, buf, i, sz, peek, oob) else 0
	  end
    end (* local *)

    local
      val recvFromV' : (int * int * bool * bool) -> (w8vector * addr)
	    = sockFn "recvFrom"
      fun recvFromV (_, 0, _, _) = (W8V.fromList[], dummyAddr)
	| recvFromV (SOCK fd, sz, peek, oob) = if (sz < 0)
	    then raise Size
	    else let
	      val (data, addr) = recvFromV' (fd, sz, peek, oob)
	      in
		(data, ADDR addr)
	      end
      val recvFromA : (int * w8array * int * int * bool * bool) -> (int * addr)
	    = sockFn "recvBufFrom"
    in
    fun recvVecFrom (sock, sz) = recvFromV (sock, sz, dfltPeek, dfltOOB)
    fun recvArrFrom (SOCK fd, {buf, i}) = let
	  val (buf, i, sz) = abuf{buf=buf, i=i, sz=NONE}
	  in
	    if (sz > 0)
	      then let
		val (n, addr) = recvFromA(fd, buf, i, sz, dfltPeek, dfltOOB)
	        in
		  (n, ADDR addr)
		end
	      else (0, dummyAddr)
	  end
    fun recvVecFrom' (sock, sz, {peek, oob}) = recvFromV (sock, sz, peek, oob)
    fun recvArrFrom' (SOCK fd, {buf, i}, {peek, oob}) = let
	  val (buf, i, sz) = abuf{buf=buf, i=i, sz=NONE}
	  in
	    if (sz > 0)
	      then let val (n, addr) = recvFromA(fd, buf, i, sz, peek, oob)
	        in
		  (n, ADDR addr)
		end
	      else (0, dummyAddr)
	  end
    end (* local *)

  end (* Socket *)
