import functools
import ipaddress
import logging
import os
import socket
import struct
from typing import NamedTuple

import ifaddr
import tldextract
import geoip2.database

from shared.ttl_cache import ttl_cache

logger = logging.getLogger(__name__)


class DefaultRoute(NamedTuple):
    gateway_ip: str
    gateway_mac: str | None
    interface: str
    host_ip: str
    host_mac: str | None
    subnet: ipaddress.IPv4Network | None


@ttl_cache(maxsize=2, ttl=60)
def get_default_route() -> DefaultRoute | None:
    """Returns a DefaultRoute object containing the default gateway IP, MAC, interface, host IP, host MAC, and subnet."""

    # Note that this function uses Linux-specific APIs (/proc and /sys).
    # The third-party libraries to read these values were either unmaintained, overly complex, or had other issues, so I decided to parse the files directly.
    # Other mechanisms would need to be used for other operating systems.

    # Read the default gateway directly from /proc.
    # ref: https://stackoverflow.com/a/6556951
    gateway_ip = None
    iface = None
    with open("/proc/net/route") as fh:
        for line in fh:
            fields = line.strip().split()
            if fields[1] != "00000000" or not int(fields[3], 16) & 2:
                # If not default route or not RTF_GATEWAY, skip it
                continue

            iface = fields[0]
            gateway_ip = socket.inet_ntoa(struct.pack("<L", int(fields[2], 16)))

    if not gateway_ip or not iface:
        logger.error("No default gateway found in /proc/net/route")
        return None

    # Get the gateway MAC address from /proc/net/arp
    gateway_mac = None
    with open("/proc/net/arp") as fh:
        for line in fh:
            fields = line.strip().split()
            if fields[0] == gateway_ip:
                gateway_mac = fields[3]
                break

    # Get the host MAC address
    host_mac = None
    try:
        with open("/sys/class/net/" + iface + "/address") as fh:
            host_mac = fh.readline().strip()
    except FileNotFoundError:
        logger.warning(f"Could not read MAC address for interface {iface}")

    # then get the host IP and subnet using ifaddr library
    adapters = ifaddr.get_adapters()
    subnet = None
    for adapter in adapters:
        if adapter.name == iface:
            host_ip = adapter.ips[0].ip
            subnet = ipaddress.ip_network(
                f"{host_ip}/{adapter.ips[0].network_prefix}", strict=False
            )
            break
    if not subnet:
        logger.error("No subnet found for interface %s", iface)
        return None
    return DefaultRoute(gateway_ip, gateway_mac, iface, host_ip, host_mac, subnet)


@functools.lru_cache(maxsize=8192)
def is_private_ip_addr(
    ip_addr: ipaddress.IPv4Address | ipaddress.IPv6Address | str,
) -> bool:
    """Returns True if the given IP address is a private local IP address."""

    if isinstance(ip_addr, str):
        ip_addr = ipaddress.ip_address(ip_addr)
    return ip_addr.is_private


@functools.lru_cache(maxsize=8192)
def is_ipv4_addr(ip_string: str) -> bool:
    """Checks if ip_string is a valid IPv4 address."""

    try:
        _ = socket.inet_aton(ip_string)
        return True
    except socket.error:
        return False


data_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data")
ip_country_parser = geoip2.database.Reader(
    os.path.join(data_directory, "GeoLite2-Country_20191224", "GeoLite2-Country.mmdb")
)


@functools.lru_cache(maxsize=8192)
def get_country_from_ip_addr(
    remote_ip_addr: ipaddress.IPv4Address | ipaddress.IPv6Address | None,
) -> str | None:
    """Returns country for IP."""

    if not remote_ip_addr:
        return None

    if is_private_ip_addr(remote_ip_addr):
        return "(local network)"

    try:
        country = ip_country_parser.country(remote_ip_addr).country.name
        if country:
            return country
    except Exception:
        pass

    return None


@functools.lru_cache(maxsize=8192)
def get_reg_domain(full_domain: str) -> str:
    if not full_domain:
        return ""

    if full_domain == "(local network)":
        return full_domain

    reg_domain = tldextract.extract(full_domain.replace("?", "")).registered_domain

    if reg_domain:
        if "?" in full_domain:
            reg_domain += "?"
        return reg_domain

    return full_domain


@ttl_cache(maxsize=2, ttl=5)
def get_arp_table() -> dict[str, str]:
    """Returns a dictionary of IP addresses to MAC addresses."""

    arp_table: dict[str, str] = {}
    with open("/proc/net/arp") as fh:
        for line in fh:
            fields = line.strip().split()
            arp_table[fields[0]] = fields[3]
    return arp_table


@functools.lru_cache(maxsize=8192)
def get_reverse_dns(ip_addr: str) -> str:
    """Returns the reverse DNS for a given IP address. Cached to avoid repeated DNS lookups."""
    return socket.getfqdn(ip_addr)
