(*

	FoxNet: The Fox Project's Communication Protocol Implementation Effort
	Edoardo Biagioni (Edoardo.Biagioni@cs.cmu.edu)
	Brian Milnes (Brian.Milnes@cs.cmu.edu)
	Ken Cline (Kenneth.Cline@cs.cmu.edu)
        Nick Haines (Nick.Haines@cs.cmu.edu)
	Fox Project
	School of Computer Science
	Carnegie Mellon University
	Pittsburgh, Pa 15139-3891

		i.	Abstract

	Install DNS for host_id name lookup (parsing).


		ii.	Table of Contents

	i.	Abstract
	ii.	Table of Contents
	1.	signature Dns_Cache
	2.	functor Dns_Cache



	1.	signature Dns_Cache

*)

  signature DNS_CACHE =
    sig
      type message
      type question
      val add: message -> unit
      val lookup: question -> message option
    end

(*
	2.	functor Dns_Cache

*)

functor Dns_Cache (structure B: FOX_BASIS
		   structure DnsM: DNS_EXTERN
		   val debug_level: int ref option): DNS_CACHE =
struct
  structure Trace = Trace (structure V = B.V
			   val debug_level = debug_level
			   val module_name = "dnscache.fun"
			   val makestring = fn _ => NONE)

  type message = DnsM.message
  type question = DnsM.question
  type domain_name = DnsM.Domain_Name.T
  type ip_number = DnsM.internet_address

  (* maximum size of cache - for simplicity's sake, we clear the entire
     cache and start fresh when the store reaches this size. *)
  val max_size = 128

  fun ttl_to_time (DnsM.RR {ttl, ...}) =
        B.V.Time.+ (B.V.Time.now (), B.V.Time.fromSeconds (Word32.toInt ttl))
  fun time_max (t,t') = if B.V.Time.> (t, t') then t else t'
  fun merge (rr, []) = [(rr, ttl_to_time rr)]
    | merge (rr, (first_rr, expiration)::rest) =
        if DnsM.equal_rr (rr, first_rr)
	  then [(first_rr, time_max (expiration, ttl_to_time first_rr))]
	else (first_rr, expiration)::merge (rr, rest)
      
  local
    fun purge ([], _) = []
      | purge ((first as (rr,expiration))::rest, now) =
          if B.V.Time.< (expiration, now)
	    then (Trace.trace_print
		    (fn () => "Outdating:\n" ^ DnsM.makestring_rr rr);
		  purge (rest, now))
	  else first::(purge (rest, now))
    fun outdate ((_, key), (c, rr_list)) =
          (case purge (rr_list, B.V.Time.now ()) of
	     [] => (B.Store.remove (c, key), [])
	   | l => (B.Store.add (c, key, l), l))
  in
    fun look args =
          (case B.Store.look args of
	     SOME result => SOME (outdate (args, result))
	   | NONE => NONE)
  end

  fun add_rr (rr as DnsM.RR {ttl = 0w0:word32, ...}, cache) = cache
    | add_rr (rr as DnsM.RR {name, ...}, cache) =
        case B.Store.look (cache, name) of
	    SOME (c, rrs) => B.Store.add (c, name, merge (rr, rrs))
	  | NONE => B.Store.add (cache, name, [(rr, ttl_to_time rr)])
  fun add_rrs ([], cache) = cache
    | add_rrs (rr::rest, cache) = add_rr (rr, add_rrs (rest, cache))
    
  fun new_cache () =
        B.Store.new (DnsM.Domain_Name.hash, DnsM.Domain_Name.equal)
  val cache : (domain_name, (DnsM.rr * B.V.Time.time) list) B.Store.T ref
            = ref (new_cache ())

  fun dump () =
    let
      fun dump_rrs [] = ""
	| dump_rrs ((rr,_)::rest) = ("    " ^ DnsM.makestring_rr rr ^ "\n" ^
				     dump_rrs rest)
      fun dump_item (key, rrs) =
	"  [" ^ DnsM.Domain_Name.makestring key ^"]" ^ "\n" ^
	dump_rrs rrs
    in
      print ("Current cache contents:\n" ^
	     B.Store.makestring (!cache, dump_item, ""))
    end

  fun add (m as DnsM.Message {answer, authority, additional, ...}) =
        (Trace.trace_print (fn () => "adding message:\n" ^ DnsM.makestring m);
	 if B.Store.size (!cache) > max_size
	   then cache := new_cache ()
	 else ();
	 cache := add_rrs (answer,
			   (add_rrs (authority,
				     add_rrs (additional, !cache)))))
    | add (DnsM.Parse_Error _) = ()

  (* Allow recursion in case a cannonical name record is found, in
     which case the original name is replaced with the canonical name.
     Only one level of recursion is allowed. *)
  fun recursive_lookup (q as DnsM.Question {name, rr_qtype, rr_class}, depth) =
    let
      fun type_to_qtype (DnsM.A _)     = DnsM.A_Q
	| type_to_qtype (DnsM.NS _)    = DnsM.NS_Q
	| type_to_qtype (DnsM.MD _)    = DnsM.MD_Q
	| type_to_qtype (DnsM.MF _ )   = DnsM.MF_Q
	| type_to_qtype (DnsM.CNAME _) = DnsM.CNAME_Q
	| type_to_qtype (DnsM.SOA)     = DnsM.SOA_Q
	| type_to_qtype (DnsM.MB _)    = DnsM.MB_Q
	| type_to_qtype (DnsM.MG _)    = DnsM.MG_Q
	| type_to_qtype (DnsM.MR _)    = DnsM.MR_Q
	| type_to_qtype (DnsM.NULL)    = DnsM.NULL_Q
	| type_to_qtype (DnsM.WKS)     = DnsM.WKS_Q
	| type_to_qtype (DnsM.PTR _)   = DnsM.PTR_Q
	| type_to_qtype (DnsM.HINFO _) = DnsM.HINFO_Q
	| type_to_qtype (DnsM.MINFO)   = DnsM.MINFO_Q
	| type_to_qtype (DnsM.MX _)    = DnsM.MX_Q
	| type_to_qtype (DnsM.TXT _)   = DnsM.TXT_Q
      fun update_ttl (DnsM.RR {name, rr_type, rr_class, ...}, expiration) =
	    let

	      val ttl = Word32.fromInt
		          (B.V.Time.toSeconds (B.V.Time.- (expiration,
							   B.V.Time.now ())))
	    in
	      DnsM.RR {name=name, rr_type=rr_type, rr_class=rr_class, ttl=ttl}
	    end
      fun matching_rrs [] = []
	| matching_rrs ((h as DnsM.RR {rr_type = DnsM.CNAME cname, ...},
			 expiration)::tail) =
	    if depth=0 then
	      case (recursive_lookup (DnsM.Question
				      {name=cname, rr_qtype=rr_qtype,
				       rr_class=rr_class}, 1)) of
		SOME (DnsM.Message {answer, ...}) => answer
	      | SOME (DnsM.Parse_Error _) => matching_rrs tail
	      | NONE => matching_rrs tail
	    else matching_rrs tail
	| matching_rrs ((h as DnsM.RR {rr_type, ...}, expiration)::tail) =
	    if (type_to_qtype rr_type = rr_qtype)
	      then update_ttl (h, expiration) :: matching_rrs tail
	    else matching_rrs tail
      val rrs = (case B.Store.look (!cache, name) of
		   NONE => []
		 | SOME (c, l) => (cache := c; matching_rrs l))
      val header = DnsM.Header
	{query = false, opcode = DnsM.Query, aa = false,
	 tc = false, rd = false, ra = false,
	 rcode = DnsM.No_Error}
    in
      if rrs <> [] then
	let
	  val m = DnsM.Message
	            {header= header, question= [q],
		     answer= rrs, authority= [], additional= []}
	in
	  Trace.trace_print
	    (fn () => "Cache hit, found:\n" ^ DnsM.makestring m);
	    SOME m
	end
      else
	NONE
    end

  fun lookup arg = recursive_lookup (arg, 0)

end
