import datetime
from ipaddress import ip_address
import logging
import re
from typing import Any

from model.model import db, Device, Flow, MDNS, PendingDevice
import model.friendly_organizer as friendly_organizer
import shared.config as config
from shared.networking_helpers import (
    get_default_route,
)

logger = logging.getLogger(__name__)


def record_device(
    mac_addr: str | None = None,
    ip_addr: str | None = None,
    dhcp_hostname: str | None = None,
) -> None:
    with db.atomic():
        default_route = get_default_route()
        if not ip_addr or not mac_addr:
            # logger.debug(
            #     "attempting to record device with no IP or MAC address. refusing"
            # )
            return
        if ip_addr and ip_address(ip_addr) not in default_route.subnet:
            # logger.info(
            #     f"attempting to record device {ip_addr} not in ip range. refusing"
            # )
            return
        if mac_addr == default_route.host_mac:
            # logger.debug(
            #     f"attempting to record device {mac_addr} which is the host. refusing"
            # )
            return
        if mac_addr == default_route.gateway_mac:
            # logger.debug(
            #     f"attempting to record device {mac_addr} which is the gateway. refusing"
            # )
            return

        # try finding an existing device by MAC, then by IP
        device = Device.get_or_none(mac_addr=mac_addr)
        if device is None:
            device = Device.get_or_none(ip_addr=ip_addr)

        # if still not found, create a new one
        if device is None:
            # Never seen this device before, so create one
            logger.debug(f"creating new device {mac_addr} {ip_addr}")
            device = Device.create(
                mac_addr=mac_addr, ip_addr=ip_addr, dhcp_hostname=dhcp_hostname
            )

        # Update the IP address if different
        if device.ip_addr != ip_addr:
            device.ip_addr = ip_addr

        # Update the hostname only if it is undefined
        # We prefer the hostname retrieved by HA from the one retrieved by packet sniffing
        if not device.dhcp_hostname and dhcp_hostname:
            device.dhcp_hostname = dhcp_hostname

        assert device is not None
        device.last_seen = datetime.datetime.now()
        device.save()


# def record_host(
#     ip_addr: IPv4Address | IPv6Address | None = None,
#     hostname: str | None = None,
#     source: str | None = None,
# ) -> Any:
#     if ip_addr and is_private_ip_addr(ip_addr):
#         return

#     with db.atomic():
#         # see if we have record for this IP or hostname
#         record = Host.get_or_none(ip_addr=ip_addr)
#         if not record:
#             record = Host.get_or_none(hostname=hostname)

#         # otherwise make a new record
#         if not record:
#             record = Host.create(ip_addr=ip_addr, ip_src=source, hostname=hostname)

#         # add missing fields
#         if ip_addr:
#             if not record.ip_addr:
#                 record.ip_src = source
#         # forward DNS lookup if hostname is provided but IP is not
#         elif hostname and not record.ip_addr:
#             try:
#                 # Get IP from hostname
#                 resolved_ip = ip_address(socket.gethostbyname(hostname))
#                 if not is_private_ip_addr(resolved_ip):
#                     ip_addr = resolved_ip
#                     record.ip_src = "forward-dns"
#             except socket.gaierror:
#                 pass

#         # if hostname is provided, it came from dns or a high-quality source
#         # so we blindly overwrite the existing value
#         if hostname:
#             record.hostname_src = source
#         # otherwise we try a reverse dns lookup
#         # but we check that fields are undefined before writing
#         elif ip_addr and not record.hostname:
#             try:
#                 hostname = socket.getnameinfo((str(ip_addr), 0), 0)[0]
#                 record.hostname_src = "reverse-dns"
#             except socket.gaierror:
#                 pass

#         reg_domain = get_reg_domain(hostname)
#         if not record.ip_addr:
#             record.ip_addr = ip_addr
#         if not record.hostname:
#             record.hostname = hostname
#         if not record.country:
#             record.country = get_country_from_ip_addr(ip_addr)
#         if not record.reg_domain:
#             record.reg_domain = get_reg_domain(hostname)
#         if not record.owner:
#             record.owner = get_tracker_company(reg_domain)
#         record.save()

#     return record


def save_mdns_record(
    ip4: str | None,
    ip6s: list[str],
    mdns_name: str | None,
    service: str,
    server: str | None,
    properties: dict[bytes, bytes | None],
    integration: str,
    integration_data: dict[str, Any],
) -> None:
    device_manufacturer = None
    device_model = None
    friendly_name = None
    if "device_manufacturer" in integration_data:
        if "static_value" in integration_data["device_manufacturer"]:
            device_manufacturer = integration_data["device_manufacturer"][
                "static_value"
            ]
        elif "properties_key" in integration_data["device_manufacturer"]:
            key = integration_data["device_manufacturer"]["properties_key"]
            device_manufacturer = properties.get(key.encode())
    if "device_model" in integration_data:
        if "static_value" in integration_data["device_model"]:
            device_model = integration_data["device_model"]["static_value"]
        elif "properties_key" in integration_data["device_model"]:
            key = integration_data["device_model"]["properties_key"]
            device_model = properties.get(key.encode())
    if "friendly_name" in integration_data:
        if "static_value" in integration_data["friendly_name"]:
            friendly_name = integration_data["friendly_name"]["static_value"]
        elif "properties_key" in integration_data["friendly_name"]:
            key = integration_data["friendly_name"]["properties_key"]
            friendly_name = properties.get(key.encode())

    with db.atomic():
        device: Device | None = None
        # first try to find device by hostname, then by IP address
        if server:
            device = Device.get_or_none(mdns_hostname=server)
        if device is None and ip4:
            device = Device.get_or_none(ip_addr=ip4)
        # note that we do not check ipv6 addresses here
        # if a record doesn't match it on hostname or ipv4, we dump it in the pending queue
        # didn't want to write the logic to deal with the list field in the db twice

        # if we found a device in the db, update it
        if device:
            # update any outdated or missing fields
            if server and device.mdns_hostname != server:
                device.mdns_hostname = server
            if ip4 and (device.ip_addr != ip4):
                device.ip_addr = ip4
            if len(ip6s) > 0:
                if not device.ip6s:
                    device.ip6s = ip6s
                else:
                    for ip6 in ip6s:
                        if ip6 not in device.ip6s:
                            device.ip6s.append(ip6)
            device.last_seen = datetime.datetime.now()
            device.save()

            # Update the mdns table
            MDNS.get_or_create(
                device=device,
                service=service,
                defaults={
                    "hostname": server,
                    "full_name": mdns_name,
                    "properties": properties,
                    "integration": integration,
                    "device_manufacturer": device_manufacturer,
                    "device_model": device_model,
                    "friendly_name": friendly_name,
                },
            )

        # otherwise, put it in the pending device table to try later
        # this might be because we have an IP but not a MAC
        # or we have only ipv6 addresses
        else:
            PendingDevice.get_or_create(
                ip6s=ip6s,
                service=service,
                hostname=server,
                full_name=mdns_name,
                properties=properties,
                integration=integration,
                device_manufacturer=device_manufacturer,
                friendly_name=friendly_name,
            )


def clear_pending_queue() -> None:
    ip6_device_map: dict[str, int] = {}
    hostname_device_map: dict[str, int] = {}
    with db.atomic():
        devices = Device.select(Device.id, Device.ip6s, Device.mdns_hostname)  # type: ignore
        pendings = PendingDevice.select()
        for device in devices:
            hostname_device_map[device.mdns_hostname] = device.id
            try:
                for ip6 in device.ip6s:
                    ip6_device_map[ip6] = device.id
            except TypeError:
                # ip6s is none
                continue

        for pending in pendings:
            # get matching device
            device_id = hostname_device_map.get(pending.hostname)
            if not device_id:
                for ip6 in pending.ip6s:
                    if ip6 in pending.ip6s:
                        device_id = ip6_device_map.get(ip6)
                        break

            # save to db
            if device_id:
                logger.debug(f"clearing pending device {pending} device {device.id}")
                MDNS.get_or_create(
                    device=device_id,
                    service=pending.service,
                    defaults={
                        "hostname": pending.hostname,
                        "full_name": pending.full_name,
                        "properties": pending.properties,
                        "integration": pending.integration,
                        "device_manufacturer": pending.device_manufacturer,
                        "friendly_name": pending.friendly_name,
                    },
                )
                pending.delete_instance()


def ssdp_save(
    discovered_device_dict: dict[str, Any],
) -> None:
    with db.atomic():
        dev = Device.get_or_none(ip_addr=discovered_device_dict["device_ip_addr"])
        if not dev:
            logger.info(
                f"Device {discovered_device_dict['device_ip_addr']} not found in DB, skipping"
            )
            return
        dev.ssdp_data = discovered_device_dict
        try:
            model_name = discovered_device_dict["location_contents"]["root"]["device"][
                "device"
            ]["modelName"]
            dev.ssdp_name = model_name
        except (KeyError, TypeError):
            pass
        try:
            mfg = discovered_device_dict["location_contents"]["root"]["device"][
                "device"
            ]["manufacturer"]
            dev.ssdp_mfg = mfg
        except (KeyError, TypeError):
            pass
        try:
            model = discovered_device_dict["location_contents"]["root"]["device"][
                "device"
            ]["modelName"]
            dev.ssdp_model = model
        except (KeyError, TypeError):
            pass
        dev.save()
        # logger.debug(f"Discovered device: {discovered_device_dict['device_ip_addr']}")


def ha_dhcp_save(ha_dhcp_dict: list[dict[str, str]]) -> None:
    """
    Save the Home Assistant DHCP discovery data to the database.
    """
    with db.atomic():
        for device in ha_dhcp_dict:
            if not device.get("mac_address"):
                logger.warning(
                    f"No MAC address found in device data for device {device}"
                )
                continue
            mac_addr = device.get("mac_address").lower()
            ip_addr = device.get("ip_address")
            dhcp_hostname = device.get("hostname")

            # Check if the device already exists in the database
            existing_device = Device.get_or_none(mac_addr=mac_addr)
            if existing_device:
                # Update the existing device
                existing_device.ip_addr = ip_addr
                existing_device.dhcp_hostname = dhcp_hostname
                existing_device.last_seen = datetime.datetime.now()
                existing_device.save()
            else:
                # Create a new device record
                record_device(
                    ip_addr=ip_addr, mac_addr=mac_addr, dhcp_hostname=dhcp_hostname
                )


def ha_devices_save(devices: list[dict[str, Any]]) -> None:
    with db.atomic():
        for device in devices:
            mac: str | None = None
            ip: str | None = None

            # skip music assistant devices (the IP address is of the Home Assistant host, not of the actual device)
            is_music_assistant = False
            for integration, _ in device["identifiers"]:
                if integration == "music_assistant":
                    is_music_assistant = True
                    break
            if is_music_assistant:
                continue

            # first try to find MAC address in connections by name
            for cxn_name, cxn_id in device["connections"]:
                if cxn_name == "mac":
                    mac = cxn_id
                    break
            # if not found, try to find MAC address in connections by matching regex
            if not mac:
                for _, cxn_id in device["connections"]:
                    is_mac = re.match(
                        r"^([0-9a-fA-F]{2}[:-]){5}([0-9a-fA-F]{2})$", str(cxn_id)
                    )
                    if is_mac:
                        mac = cxn_id
                        break
            # then try to find MAC address in identifiers by regex
            if not mac:
                for _, identifier in device["identifiers"]:
                    is_mac = re.match(
                        r"^([0-9a-fA-F]{2}[:-]){5}([0-9a-fA-F]{2})$", str(identifier)
                    )
                    if is_mac:
                        mac = identifier
                        break

            # try to get the IP from configuration_url
            is_ip = re.search(
                r"(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})", device["configuration_url"]
            )
            if is_ip:
                ip = is_ip.group(0)

            # We only save home assistant info to the database if we can definitely match it to a device already in the database
            db_device = None
            if mac:
                db_device = Device.get_or_none(mac_addr=mac)
            elif ip:
                db_device = Device.get_or_none(ip_addr=ip)

            if db_device:
                # Update the device with the Home Assistant data
                db_device.ha_name = device["name"]
                db_device.ha_mfg = device["manufacturer"]
                db_device.ha_model = device["model"]
                db_device.save()


def retention_cleanup(hours: int = 168) -> None:
    """
    Cleanup old records from the database.
    This function is called periodically to remove old records from the database.
    """
    logger.info("Starting retention cleanup...")
    oldest_retained_time = datetime.datetime.now() - datetime.timedelta(hours=hours)

    # delete old flows
    Flow.delete().where(Flow.end_ts < oldest_retained_time).execute()

    # Find devices older than retention period
    old_devices = Device.select().where(Device.last_seen < oldest_retained_time)
    old_device_ids = [device.id for device in old_devices]

    if old_device_ids:
        logger.info(f"Found {len(old_device_ids)} old devices to delete")

        # Delete flows referencing old devices (both src and dst)
        flows_deleted = (
            Flow.delete()
            .where(
                (Flow.src_device.in_(old_device_ids))
                | (Flow.dst_device.in_(old_device_ids))
            )
            .execute()
        )
        logger.info(f"Deleted {flows_deleted} flows referencing old devices")

        # Delete MDNS records referencing old devices
        mdns_deleted = MDNS.delete().where(MDNS.device.in_(old_device_ids)).execute()
        logger.info(f"Deleted {mdns_deleted} MDNS records for old devices")

        # Delete old devices
        devices_deleted = (
            Device.delete().where(Device.last_seen < oldest_retained_time).execute()
        )
        logger.info(f"Deleted {devices_deleted} old devices")

    # Clean up PendingDevice table based on last wipe time
    last_pending_wipe = config.get("last_pending_device_wipe")
    should_wipe_pending = False

    if last_pending_wipe is None:
        should_wipe_pending = True
    else:
        # Parse the stored timestamp and check if it's older than retention period
        try:
            last_wipe_time = datetime.datetime.fromisoformat(last_pending_wipe)
            if last_wipe_time < oldest_retained_time:
                should_wipe_pending = True
        except (ValueError, TypeError):
            # Invalid timestamp format, wipe the table
            should_wipe_pending = True

    if should_wipe_pending:
        pending_deleted = PendingDevice.delete().execute()
        logger.info(f"Deleted {pending_deleted} pending device records")
        config.set("last_pending_device_wipe", datetime.datetime.now().isoformat())

    # TODO add time storage to other models and delete old records
    # Host.delete().where(Host.last_seen < oldest_retained_time).execute()

    logger.info("Retention cleanup completed.")


def update_device(
    device_id: int,
    user_name: str | None = None,
    user_model: str | None = None,
    user_mfg: str | None = None,
    is_iot_form_status: str | None = None,
) -> Device:
    """
    Update device information and recalculate preferred values.
    """
    with db.atomic():
        device = Device.get_by_id(device_id)

        # Update user fields and set override flag if any change
        if user_name is None or user_name == "":
            device.user_name = None
        else:
            device.user_name = user_name
            device.is_iot_user_override = True

        if user_model is None or user_model == "":
            device.user_model = None
        else:
            device.user_model = user_model
            device.is_iot_user_override = True

        if user_mfg is None or user_mfg == "":
            device.user_mfg = None
        else:
            device.user_mfg = user_mfg
            device.is_iot_user_override = True

        if is_iot_form_status == "on":
            device.is_iot = is_iot_form_status
            device.is_iot_user_override = True
        elif is_iot_form_status == "off":
            device.is_iot = False
            device.is_iot_user_override = True

        device.save()

    # Use a second transaction to avoid race conditions
    with db.atomic():
        preferred_name = friendly_organizer.determine_preferred_name(device)
        preferred_model = friendly_organizer.determine_preferred_model(device)
        preferred_mfg, _ = friendly_organizer.determine_preferred_manufacturer(
            device, preferred_name=preferred_name
        )
        device.preferred_name = preferred_name
        device.preferred_model = preferred_model
        device.preferred_mfg = preferred_mfg
        device.save()

    return device


def toggle_device_arp_spoof(device_id: int, desired_state: bool) -> dict[str, Any]:
    """
    Toggle ARP spoofing for a device and return the result.
    """
    with db.atomic():
        device = Device.get_by_id(device_id)

        # Only change if different from current state
        if device.is_arp_spoofed != desired_state:
            device.is_arp_spoofed = desired_state
            if desired_state:
                device.arp_spoofing_start_time = datetime.datetime.now()

                # set user override if error was previously set
                if device.arp_spoofing_error == 1:
                    device.arp_spoofing_error = 2
            device.save()

        enabled_count = Device.select().where(Device.is_arp_spoofed).count()

    return {
        "device": device,
        "enabled_count": enabled_count,
        "changed": device.is_arp_spoofed != desired_state,
    }
