import os
import asyncio
import fnmatch
from ipaddress import ip_address
import json
import logging
from multiprocessing import Queue
from typing import Any

from zeroconf import IPVersion, ServiceStateChange, Zeroconf
from zeroconf.asyncio import (
    AsyncServiceBrowser,
    AsyncServiceInfo,
    AsyncZeroconf,
    AsyncZeroconfServiceTypes,
)

from shared.db_helpers import save_mdns_record, clear_pending_queue
from shared.networking_helpers import get_default_route

logger = logging.getLogger(__name__)

homeassistant_data_file = os.path.join(
    os.path.dirname(os.path.realpath(__file__)),
    "..",
    "data",
    "mdns.json",
)

db_queue_global: "Queue[tuple[str, dict[str, Any]]] | None" = None


def async_on_service_state_change(
    zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange
) -> None:
    # logger.debug(f"[mDNS] Service {name} of type {service_type} state changed: {state_change}")
    if state_change is not ServiceStateChange.Added:
        return
    asyncio.ensure_future(async_save_device_info(zeroconf, service_type, name))


async def async_save_device_info(
    zeroconf: Zeroconf, service_type: str, mdns_name: str
) -> None:
    if service_type.startswith("_services._dns-sd._udp.local."):
        # these services cause a BadTypeInNameException if we try to read more info
        return
    info = AsyncServiceInfo(service_type, mdns_name)
    await info.async_request(zeroconf, 3000)

    # Update the devices table
    ip4s = []
    ip6s = []
    for ip_str in info.parsed_addresses():
        ip = ip_address(ip_str)
        if ip.version == 4:
            ip4s.append(ip)
        elif ip.version == 6:
            ip6s.append(ip)
    ip4 = ip4s[0] if ip4s else None
    if len(ip4s) > 1:
        # if we have multiple ipv4 addresses, prefer the one on the same subnet
        # e.g. if the device has a tailscale and local ip, we use the local one
        subnet = get_default_route().subnet
        for ip in ip4s:
            if ip in subnet:
                ip4 = ip
                break
        logger.warning(f"Multiple IPv4 addresses for {mdns_name}: {ip4s}, using {ip4}")
    ip4_str = str(ip4) if ip4 else None

    if not (info.server or ip4):
        logger.warning(f"No server or IP address for {mdns_name}")
        return

    with open(homeassistant_data_file, "r") as f:
        homeassistant_data = json.load(f)
    integrations_for_service = homeassistant_data.get(service_type, [])

    if db_queue_global is None:
        logger.warning(f"DB_QUEUE is None, not saving {mdns_name} ({service_type})")
        return

    # go through the list of possible integrations, and see if they match
    for integration in integrations_for_service:
        # integrations either have one
        match_name = integration.get("match_device_name")
        match_properties = integration.get("match_properties")

        integration_domain = integration.get("integration_domain")
        # if they have neither a match_name or match_properties, then they apply
        if (not match_name and not match_properties) or (
            match_name and fnmatch.fnmatch(mdns_name.lower(), match_name)
        ):
            save_mdns_record(
                db_queue_global,
                ip4_str,
                [str(ip6) for ip6 in ip6s],
                mdns_name,
                service_type,
                info.server,
                info.properties,
                integration_domain,
                integration,
            )
            return  # return since we only expect one match per service name

        # if they have match_properties, we have to check the properties of the mdns record
        elif match_properties:
            for match_key, match_search in match_properties.items():
                match_key_bytes = str.encode(match_key)
                try:
                    match_key_bytes = info.properties.get(match_key_bytes)
                    if match_key_bytes is not None:
                        prop_val = match_key_bytes.decode("utf-8").lower()
                        if fnmatch.fnmatch(prop_val, match_search):
                            save_mdns_record(
                                db_queue_global,
                                ip4_str,
                                [str(ip6) for ip6 in ip6s],
                                mdns_name,
                                service_type,
                                info.server,
                                info.properties,
                                integration_domain,
                                integration,
                            )
                except AttributeError:
                    pass


async def scan(timeout: int, db_queue: "Queue[tuple[str, dict[str, Any]]]") -> None:
    logger.info("Starting mDNS scanner")
    global db_queue_global
    db_queue_global = db_queue
    primary_ip = get_default_route().host_ip
    if not primary_ip:
        logger.warning("No primary interface found")
        return
    logger.info(f"Primary interface for mDNS is {primary_ip}")

    aiozc = AsyncZeroconf(interfaces=[primary_ip], ip_version=IPVersion.All)

    # scan all services
    services = list(
        await AsyncZeroconfServiceTypes.async_find(
            aiozc=aiozc,
            interfaces=[primary_ip],
            ip_version=IPVersion.All,
        )
    )

    logger.debug(f"Browsing {len(services)} services")
    aiobrowser = AsyncServiceBrowser(
        aiozc.zeroconf, services, handlers=[async_on_service_state_change]
    )

    # let the scanner run for 30 seconds, then terminate it
    await asyncio.sleep(timeout)
    logger.info("Terminating mDNS scanner")
    await aiobrowser.async_cancel()
    await aiozc.async_close()

    # clear pending queue of ipv6 only records
    await asyncio.sleep(5)
    if not db_queue_global:
        clear_pending_queue(db_queue_global)
