"""
Captures and analyzes packets from the network.

"""

from multiprocessing import Queue
from typing import Any

import scapy.all as sc

import shared.config as config
from shared.networking_helpers import get_default_route
import shared.system_stats as system_stats


def start_packet_collector(packet_queue: "Queue[Any]", data_dir: str) -> None:
    def add_packet_to_queue(pkt: Any) -> None:
        """
        Adds a packet to the packet queue.

        """
        packet_queue.put(pkt)

    t = get_default_route()
    iface = t.interface
    host_ip_addr = t.host_ip

    # save the number of packets forwarded at the time of db initialization
    netstat = system_stats.netstat()
    packets_forwarded = int(netstat["Ip"]["ForwDatagrams"])
    # use get to only set if it doesn't exist
    config.initialize(data_dir)
    config.get("packets_forwarded_at_db_init", packets_forwarded)

    sc.load_layer("tls")

    while True:
        # Continuously sniff packets for 30 second intervals (as sniff might crash).
        # Also, avoid capturing packets to/from the host itself, except ARP, which
        # we need for discovery.
        sc.sniff(
            prn=add_packet_to_queue,
            iface=iface,
            filter=f"(not arp and host not {host_ip_addr}) or arp",
            timeout=30,
        )
