import logging
from typing import Any
from peewee import fn
from model.model import db, Device, Flow

logger = logging.getLogger(__name__)


def get_device_retransmission_ratios() -> list[dict[str, Any]]:
    """
    Query all eligible devices and calculate their TCP retransmission ratios.

    Returns:
        List of dictionaries containing device information and retransmission ratios
    """
    logger.debug("Querying device retransmission ratios")

    device_ratios = []

    try:
        with db.atomic():
            # Query devices with their packet and retransmission counts
            query = (
                Device.select(
                    Device.id,
                    Device.preferred_name,
                    Device.preferred_mfg,
                    Device.mac_addr,
                    Device.ip_addr,
                    Device.is_arp_spoofed,
                    fn.SUM(Flow.packet_count).alias("total_packets"),
                    fn.SUM(Flow.tcp_retransmit).alias("total_retransmits"),
                )
                .left_outer_join(
                    Flow,
                    on=(
                        (Flow.src_device == Device.id) | (Flow.dst_device == Device.id)
                    ),
                )
                .where(
                    Device.is_arp_spoofed,  # Only check currently ARP spoofed devices
                    Flow.protocol == "tcp",
                    Device.arp_spoofing_start_time.is_null(False),  # has start time
                )
                .group_by(Device)
            )

            for device_data in query:
                tcp_packets = device_data.total_packets or 1  # Avoid division by zero
                total_retransmits = device_data.total_retransmits or 0

                # Calculate retransmission ratio
                retransmission_ratio = total_retransmits / (
                    tcp_packets + total_retransmits
                )

                device_info = {
                    "device_id": device_data.id,
                    "name": device_data.preferred_name,
                    "manufacturer": device_data.preferred_mfg,
                    "mac_addr": device_data.mac_addr,
                    "ip_addr": device_data.ip_addr,
                    "is_arp_spoofed": device_data.is_arp_spoofed,
                    "arp_spoofing_error": device_data.arp_spoofing_error,
                    "tcp_packets": tcp_packets,
                    "total_retransmits": total_retransmits,
                    "retransmission_ratio": retransmission_ratio,
                    "exceeds_threshold": _exceeds_retransmission_threshold(
                        tcp_packets, retransmission_ratio
                    ),
                }

                device_ratios.append(device_info)

                logger.debug(
                    f"Device {device_data.id} ({device_data.preferred_name or 'Unknown'}): "
                    f"{total_retransmits}/{tcp_packets} = {retransmission_ratio:.3f}"
                )

    except Exception as e:
        logger.error(f"Error querying device retransmission ratios: {e}")
        raise

    return device_ratios


def _exceeds_retransmission_threshold(
    tcp_packets: int, retransmission_ratio: float
) -> bool:
    """
    Check if retransmission ratio exceeds threshold based on packet count.

    Args:
        tcp_packets: Total number of TCP packets
        retransmission_ratio: Ratio of retransmits to total packets

    Returns:
        True if threshold is exceeded
    """
    return (tcp_packets > 1000 and retransmission_ratio > 0.5) or (
        tcp_packets > 20 and retransmission_ratio > 0.7
    )


def disable_arp_spoofing_for_devices(device_ids: list[int]) -> int:
    """
    Disable ARP spoofing for specified devices.

    Args:
        device_ids: List of device IDs to disable ARP spoofing for

    Returns:
        Number of devices successfully disabled
    """
    if not device_ids:
        return 0

    logger.info(f"Disabling ARP spoofing for {len(device_ids)} devices")

    try:
        with db.atomic():
            # Disable ARP spoofing and set error flag for these devices
            devices_updated = (
                Device.update(
                    is_arp_spoofed=False,
                    arp_spoofing_error=1,  # 0: no error, 1: error, 2: error, user has overridden
                )
                .where(Device.id.in_(device_ids))
                .execute()
            )

            logger.info(
                f"Successfully disabled ARP spoofing for {devices_updated} devices"
            )
            return devices_updated

    except Exception as e:
        logger.error(f"Error disabling ARP spoofing for devices: {e}")
        raise


def check_and_disable_arp_spoofing_for_high_retransmission() -> dict[str, Any]:
    """
    Check all devices for TCP retransmission ratios and disable ARP spoofing
    for devices with high retransmission ratios.

    Returns:
        Dictionary with operation results including:
        - devices_checked: number of devices evaluated
        - devices_disabled: number of devices with ARP spoofing disabled
        - disabled_device_details: list of devices that had ARP spoofing disabled
        - all_device_ratios: list of all device retransmission data
    """
    logger.info("Starting TCP retransmission ratio check for ARP spoofing")

    # Get retransmission ratios for all eligible devices
    device_ratios = get_device_retransmission_ratios()
    devices_checked = len(device_ratios)

    # Find devices that exceed threshold
    devices_to_disable = [
        device for device in device_ratios if device["exceeds_threshold"]
    ]

    # Disable ARP spoofing for devices that exceed threshold
    device_ids_to_disable = [
        device["device_id"]
        for device in devices_to_disable
        if device["arp_spoofing_error"] == 0
    ]
    devices_disabled = disable_arp_spoofing_for_devices(device_ids_to_disable)

    # Log details for disabled devices
    for device in devices_to_disable:
        logger.info(
            f"Disabled ARP spoofing for device {device['device_id']} "
            f"({device['name'] or 'Unknown'}) - "
            f"retransmission ratio: {device['retransmission_ratio']:.3f}"
        )

    result = {
        "devices_checked": devices_checked,
        "devices_disabled": devices_disabled,
        "disabled_device_details": devices_to_disable,
        "all_device_ratios": device_ratios,
    }

    logger.info(
        f"TCP retransmission check complete. "
        f"Checked {devices_checked} devices, disabled ARP spoofing for {devices_disabled} devices"
    )

    return result
