import asyncio
import logging
from multiprocessing import Queue
from random import randrange
import traceback
from typing import Any
import socket
from scapy.packet import Packet

from aiohttp import ClientSession

import shared.config as config
from shared.db_helpers import update_device_metadata
import scan.arp_scanner
import scan.homeassistant
import scan.mdns
import scan.ssdp

logger = logging.getLogger(__name__)


class Scanner:
    def __init__(
        self,
        db_queue: "Queue[tuple[str, dict[str, Any]]]",
        ha_client: scan.homeassistant.HomeAssistantClient | None,
        aiohttp_session: ClientSession | None,
        packet_queue: "Queue[Packet | str] | None",
    ):
        self.db_queue: "Queue[tuple[str, dict[str, Any]]]" = db_queue
        self.ha_client: scan.homeassistant.HomeAssistantClient | None = ha_client
        self.aiohttp_session: ClientSession | None = aiohttp_session
        self.scan_step_lock: asyncio.Lock = asyncio.Lock()
        self.scan_step_counter: list[int] = [0]  # Use list for mutable reference
        self.packet_queue: "Queue[Packet | str] | None" = packet_queue

    def reset_scan_progress(self):
        """Reset the scan progress counter"""
        self.scan_step_counter[0] = 0
        config.set("scan_step", 0)

    async def increment_scan_step(self):
        """Thread-safe increment of scan step counter"""
        async with self.scan_step_lock:
            self.scan_step_counter[0] += 1
            config.set("scan_step", self.scan_step_counter[0])

    async def arp_scan_task(self):
        """Wrapper for ARP scan with exception handling"""
        try:
            logger.info("Scanning for ARP devices")
            await scan.arp_scanner.arp_scan(self.packet_queue)
            logger.info("ARP scan complete")
        except Exception as e:
            logger.error(f"Error in arp_scanner: {e}")
            logger.debug(traceback.format_exc())
        finally:
            await self.increment_scan_step()

    async def ha_scan_task(self):
        """Wrapper for Home Assistant operations with exception handling"""
        try:
            if not self.ha_client:
                logger.info(
                    "No Home Assistant API client available, skipping API calls"
                )
                return

            logger.info("Getting DHCP data from Home Assistant")
            await self.ha_client.dhcp()
            logger.info("Getting devices from Home Assistant REST API")
            await self.ha_client.get_and_save_devices()
        except Exception as e:
            logger.error(f"Error in Home Assistant API client: {e}")
            logger.debug(traceback.format_exc())
        finally:
            await self.increment_scan_step()

    async def ssdp_scan_task(self):
        """Wrapper for SSDP scan with exception handling"""
        try:
            if not self.aiohttp_session:
                logger.error("No aiohttp_session available, skipping SSDP scan")
                return
            logger.info("Scanning for SSDP devices")
            await scan.ssdp.scan(15, self.db_queue, self.aiohttp_session)
            logger.info("SSDP scan complete")
        except Exception as e:
            logger.error(f"Error in ssdp scan: {e}")
            logger.debug(traceback.format_exc())
        finally:
            await self.increment_scan_step()

    async def mdns_scan_task(self):
        """Wrapper for mDNS scan with exception handling"""
        try:
            logger.info("Scanning for mDNS devices")
            await scan.mdns.scan(timeout=15, db_queue=self.db_queue)
            logger.info("mDNS scan complete")
        except Exception as e:
            logger.error(f"Error in mdns scan: {e}")
            logger.debug(traceback.format_exc())
        finally:
            await self.increment_scan_step()

    async def run_parallel_scans(self):
        """Execute all scanning tasks in parallel"""
        # Reset progress counter for new scan cycle
        self.reset_scan_progress()

        # Run ARP scans first
        initial_tasks = [
            asyncio.create_task(self.arp_scan_task()),
            asyncio.create_task(self.ha_scan_task()),
        ]

        # Wait for all tasks to complete, ignoring exceptions
        _ = await asyncio.gather(*initial_tasks, return_exceptions=True)

        await asyncio.sleep(1)

        # Run mDNS and SSDP scans next, as they work better if some IP/MAC pairs are already known
        second_tasks = [
            asyncio.create_task(self.mdns_scan_task()),
            asyncio.create_task(self.ssdp_scan_task()),
        ]

        _ = await asyncio.gather(*second_tasks, return_exceptions=True)

        try:
            await asyncio.sleep(5)
            logger.info("Updating device metadata")
            update_device_metadata(self.db_queue)
            # Final step increment for metadata update
            await self.increment_scan_step()
        except Exception as e:
            logger.error(f"Error in friendly_organizer: {e}")


async def scan_loop(
    wait_time: int = 3600,
    ha_url: str = "",
    db_queue: "Queue[tuple[str, dict[str, Any]]] | None" = None,
    aiohttp_session: ClientSession | None = None,
    ha_addon: bool = False,
    packet_queue: "Queue[Packet | str] | None" = None,
):
    if db_queue is None:
        logger.warning("db_queue is None, declining to scan")
        return

    # Set up Home Assistant client
    ha_client: scan.homeassistant.HomeAssistantClient | None = None
    try:
        if aiohttp_session is None or aiohttp_session.closed:
            aiohttp_session = ClientSession()
        ha_client = scan.homeassistant.HomeAssistantClient(
            is_addon=ha_addon,
            ha_url=ha_url,
            session=aiohttp_session,
            db_queue=db_queue,
        )
        await ha_client.ws_connect()
    except socket.gaierror as e:
        ha_client = None
        logger.warning(f"Unable to connect to Home Assistant: {e}")
    except Exception as e:
        ha_client = None
        logger.error(f"Error connecting to Home Assistant: {e}")

    # Create scanner instance
    scanner = Scanner(db_queue, ha_client, aiohttp_session, packet_queue)

    while True:
        logger.info("Starting scan")

        if aiohttp_session is None or aiohttp_session.closed:
            aiohttp_session = ClientSession()

        # Run all scans in parallel
        await scanner.run_parallel_scans()

        config.set("first_scan_complete", True)

        logger.info(f"Scan complete. Sleeping for {wait_time} seconds")

        # clean up database to remove old data
        db_queue.put(("retention_cleanup", {}))

        # check TCP retransmission ratios and disable ARP spoofing if needed
        # db_queue.put(("tcp_retransmission_check", {}))

        await asyncio.sleep(wait_time + randrange(600))
