import logging
import datetime
from typing import Any
from peewee import fn, DoesNotExist
from playhouse.shortcuts import model_to_dict
from model.model import (
    db,
    Device,
    Flow,
    IpNode,
    IpHostEdge,
    EntityNode,
    CompanyNode,
    HostNode,
    CnameEdge,
)
import model.friendly_organizer as friendly_organizer

logger = logging.getLogger(__name__)


def get_devices() -> list[dict[str, Any]]:
    with db.atomic():
        # Use EXISTS subquery for better performance when we only need a boolean
        # Sort in the database query to avoid Python postprocessing
        # Use .dicts() to get dictionaries directly, avoiding model_to_dict overhead
        query = (
            Device.select(
                Device,
                fn.EXISTS(
                    Flow.select().where(
                        (Flow.src_device == Device.id) | (Flow.dst_device == Device.id)
                    )
                ).alias("has_flows"),
                fn.MAX(Flow.end_ts).alias("latest_flow_activity"),
            )
            .left_outer_join(
                Flow,
                on=((Flow.src_device == Device.id) | (Flow.dst_device == Device.id)),
            )
            .group_by(Device)
            .order_by(
                Device.preferred_mfg.asc(),
                Device.preferred_mfg.is_null(),
                Device.preferred_name.asc(),
                Device.preferred_name.is_null(),
            )
            .dicts()
        )

    return list(query)


def get_devices_for_metrics() -> list[dict[str, Any]]:
    with db.atomic():
        query = (
            Device.select(
                Device.id,
                Device.preferred_name,
                Device.preferred_model,
                Device.preferred_mfg,
                Device.user_name,
                Device.user_model,
                Device.user_mfg,
                Device.is_iot,
                Device.is_iot_user_override,
                Device.is_arp_spoofable,
                Device.is_arp_spoofed,
                Device.arp_spoofing_error,
                Device.arp_spoofing_start_time,
                Device.mac_vendor,
                Device.dhcp_hostname,
                Device.mdns_hostname,
                Device.ssdp_name,
                Device.ssdp_mfg,
                Device.ssdp_model,
                Device.ha_name,
                Device.ha_mfg,
                Device.ha_model,
                Device.first_seen,
                Device.last_seen,
                fn.COUNT(Flow.id).alias("flow_count"),
                fn.SUM(Flow.byte_count).alias("bytes"),
                fn.SUM(Flow.packet_count).alias("packets"),
                fn.SUM(Flow.tcp_retransmit).alias("tcp_retransmits"),
                fn.SUM(Flow.tcp_rst).alias("tcp_rsts"),
            )
            .left_outer_join(
                Flow,
                on=((Flow.src_device == Device.id) | (Flow.dst_device == Device.id)),
            )
            .group_by(Device)
            .dicts()
        )
    devices = []
    for device in query:
        device["first_seen"] = device["first_seen"].isoformat()
        device["last_seen"] = device["last_seen"].isoformat()
        if device["arp_spoofing_start_time"] is not None:
            device["arp_spoofing_start_time"] = device[
                "arp_spoofing_start_time"
            ].isoformat()
        devices.append(device)
    return devices


# def unknown_count():
#     return PendingDevice.select(PendingDevice.hostname).distinct().count()


def get_device_details(
    device_id: int,
) -> dict[str, Any]:
    """
    Get device details including friendly organizer default values.
    Returns a dictionary with device data and default values for editing.
    """
    with db.atomic():
        device = (
            Device.select(
                Device,
                fn.EXISTS(
                    Flow.select().where(
                        (Flow.src_device == Device.id) | (Flow.dst_device == Device.id)
                    )
                ).alias("has_flows"),
            )
            .left_outer_join(
                Flow,
                on=((Flow.src_device == Device.id) | (Flow.dst_device == Device.id)),
            )
            .where(Device.id == device_id)
            .group_by(Device)
            .get()
        )

        # Calculate default values using friendly_organizer
        default_name = friendly_organizer.determine_preferred_name(
            device, use_user_name=False
        )
        default_model = friendly_organizer.determine_preferred_model(
            device, use_user_model=False
        )
        default_mfg, _ = friendly_organizer.determine_preferred_manufacturer(
            device, use_user_mfg=False
        )
        device_dict = model_to_dict(device, backrefs=False, recurse=False)
        device_dict["has_flows"] = (
            device.has_flows
        )  # needed as model_to_dict drops the has_flows field
        device_dict["default_name"] = default_name
        device_dict["default_model"] = default_model
        device_dict["default_mfg"] = default_mfg
        return device_dict


def _get_flow_stats_for_device(
    device_id: int, start_ts: datetime.datetime | None = None
) -> dict[int, dict[str, Any]]:
    """
    Get aggregated flow statistics for all hosts that communicated with a device.

    Returns:
        Dictionary mapping host_id -> flow statistics
    """
    # Get destination hosts for flows where this device is the source
    src_flows_conditions = (Flow.src_device == device_id) & (
        Flow.dst_host.is_null(False)
    )  # pyright: ignore[reportUnknownVariableType]
    if start_ts is not None:
        src_flows_conditions = src_flows_conditions & (Flow.end_ts > start_ts)  # pyright: ignore[reportOperatorIssue, reportUnknownVariableType]

    src_flows_query = (
        Flow.select(
            Flow.dst_host,
            fn.SUM(Flow.byte_count).alias("src_bytes"),
            fn.SUM(Flow.packet_count).alias("src_packets"),
            fn.MAX(Flow.end_ts).alias("last_activity"),
        )
        .where(src_flows_conditions)
        .group_by(Flow.dst_host)
        .dicts()
    )

    # Get source hosts for flows where this device is the destination
    dst_flows_conditions = (Flow.dst_device == device_id) & (
        Flow.src_host.is_null(False)
    )
    if start_ts is not None:
        dst_flows_conditions = dst_flows_conditions & (Flow.end_ts > start_ts)  # pyright: ignore[reportOperatorIssue, reportUnknownVariableType]

    dst_flows_query = (
        Flow.select(
            Flow.src_host,
            fn.SUM(Flow.byte_count).alias("dst_bytes"),
            fn.SUM(Flow.packet_count).alias("dst_packets"),
            fn.MAX(Flow.end_ts).alias("last_activity"),
        )
        .where(dst_flows_conditions)
        .group_by(Flow.src_host)
        .dicts()
    )

    # Build a dictionary of host_id -> flow stats
    host_flow_stats = {}

    # Add stats from flows where device is source (talking to dst_host)
    for flow_dict in src_flows_query:
        host_id = flow_dict.get("dst_host")
        if host_id:
            if host_id not in host_flow_stats:
                host_flow_stats[host_id] = {
                    "src_bytes": 0,
                    "src_packets": 0,
                    "dst_bytes": 0,
                    "dst_packets": 0,
                    "last_activity": None,
                }

            host_flow_stats[host_id]["src_bytes"] += flow_dict.get("src_bytes", 0)
            host_flow_stats[host_id]["src_packets"] += flow_dict.get("src_packets", 0)

            # Take the most recent timestamp
            last_activity = flow_dict.get("last_activity")
            if last_activity and (
                host_flow_stats[host_id]["last_activity"] is None
                or last_activity > host_flow_stats[host_id]["last_activity"]
            ):
                host_flow_stats[host_id]["last_activity"] = last_activity

    # Add stats from flows where device is destination (talking to src_host)
    for flow_dict in dst_flows_query:
        host_id = flow_dict.get("src_host")
        if host_id:
            if host_id not in host_flow_stats:
                host_flow_stats[host_id] = {
                    "dst_bytes": 0,
                    "dst_packets": 0,
                    "src_bytes": 0,
                    "src_packets": 0,
                    "last_activity": None,
                }

            host_flow_stats[host_id]["dst_bytes"] += flow_dict.get("dst_bytes", 0)
            host_flow_stats[host_id]["dst_packets"] += flow_dict.get("dst_packets", 0)

            # Take the most recent timestamp
            last_activity = flow_dict.get("last_activity")
            if last_activity and (
                host_flow_stats[host_id]["last_activity"] is None
                or last_activity > host_flow_stats[host_id]["last_activity"]
            ):
                host_flow_stats[host_id]["last_activity"] = last_activity

    return host_flow_stats


def _find_preferred_host_edge(host_edges: list[Any], prefer_dns_type: list[str]) -> Any:
    """Find the preferred host edge based on DNS type preference list."""
    preferred_ip_host_edge = None
    # Try each preference in order
    for dns_type in prefer_dns_type:
        for edge in host_edges:
            if edge.edge_type == dns_type:
                preferred_ip_host_edge = edge
                break
        if preferred_ip_host_edge:
            break
    if not preferred_ip_host_edge:
        # pick the first edge
        preferred_ip_host_edge = host_edges[0]
    return preferred_ip_host_edge


def _follow_cname_chain(host_id: int) -> int:
    """Follow CNAME chain to get final hostname."""
    final_host_id = host_id
    cname_count = 0
    while True:
        cname_edge = CnameEdge.get_or_none(dst=final_host_id)
        if not cname_edge:
            break
        final_host_id = cname_edge.src
        cname_count += 1
        if cname_count > 10:  # Prevent infinite loops
            logger.warning(f"Too many CNAME redirects for host {host_id}")
            break
    return final_host_id


def _resolve_host_to_entities(
    final_host_id: int,
) -> tuple[str, str, EntityNode | None, CompanyNode | None, HostNode | None]:
    """
    Resolve a host ID to its entity and company information.

    Returns:
        Tuple of (entity_name, company_name, entity, company, host)
    """
    entity_name = "IP addresses without a known domain"
    company_name = "Other companies"
    entity: EntityNode | None = None
    company: CompanyNode | None = None
    host: HostNode | None = None

    try:
        host = HostNode.get_by_id(final_host_id)
        entity_name = (
            host.domain or "IP addresses without a known domain"
        )  # fallback if no entity
        entity = EntityNode.get_by_id(host.entity_id)
        entity_name = entity.entity_name
        company = CompanyNode.get_by_id(entity.company_id)
        company_name = company.company_name

        # if entity and company name are the same, use the domain name instead
        if entity_name == company_name:
            entity_name = host.domain or "IP addresses without a known domain"
    except DoesNotExist:
        pass

    return entity_name, company_name, entity, company, host


def _initialize_company_in_result(
    result: dict, company_name: str, company: CompanyNode | None
) -> None:
    """Initialize a company entry in the result dictionary if it doesn't exist."""
    if company and company_name not in result:
        result[company_name] = {
            "id": company.id,
            "name": company.company_name,
            "description": company.description,
            "website": company.website,
            "privacy_url": company.privacy_url,
            "country": company.country,
            "category": [],
            "entities": {},
            "src_bytes": 0,
            "dst_bytes": 0,
            "src_packets": 0,
            "dst_packets": 0,
            "last_activity": datetime.datetime.min,
        }
    elif company_name not in result:
        result[company_name] = {
            "id": None,
            "name": "Other companies",
            "entities": {},
            "src_bytes": 0,
            "dst_bytes": 0,
            "src_packets": 0,
            "dst_packets": 0,
            "last_activity": datetime.datetime.min,
        }


def _initialize_entity_in_result(
    result: dict, company_name: str, entity_name: str, entity: EntityNode | None
) -> None:
    """Initialize an entity entry in the result dictionary if it doesn't exist."""
    if entity and entity_name not in result[company_name]["entities"]:
        result[company_name]["entities"][entity_name] = {
            "name": entity_name,
            "website": entity.website,
            "category": entity.category,
            "hosts": [],
            "src_bytes": 0,
            "dst_bytes": 0,
            "src_packets": 0,
            "dst_packets": 0,
            "last_activity": datetime.datetime.min,
        }
        if entity.category not in result[company_name]["category"]:
            result[company_name]["category"].append(entity.category)
    elif entity_name not in result[company_name]["entities"]:
        result[company_name]["entities"][entity_name] = {
            "name": entity_name,
            "hosts": [],
            "src_bytes": 0,
            "dst_bytes": 0,
            "src_packets": 0,
            "dst_packets": 0,
            "last_activity": datetime.datetime.min,
        }


def _add_or_update_host_in_result(
    result: dict, company_name: str, entity_name: str, host_entry: dict
) -> None:
    """Add a new host or update an existing host in the result structure."""
    # Check if host already exists (by hostname) and update if needed
    existing_host_index = None
    for i, existing_host in enumerate(
        result[company_name]["entities"][entity_name]["hosts"]
    ):
        if existing_host["hostname"] == host_entry["hostname"]:
            existing_host_index = i
            break

    if existing_host_index is not None:
        # Update existing host stats
        existing_host = result[company_name]["entities"][entity_name]["hosts"][
            existing_host_index
        ]

        existing_host["src_bytes"] += host_entry.get("src_bytes", 0)
        existing_host["src_packets"] += host_entry.get("src_packets", 0)
        existing_host["dst_bytes"] += host_entry.get("dst_bytes", 0)
        existing_host["dst_packets"] += host_entry.get("dst_packets", 0)

        # Take the most recent timestamp
        if host_entry["last_activity"] > existing_host["last_activity"]:
            existing_host["last_activity"] = host_entry["last_activity"]
    else:
        # Add new host
        result[company_name]["entities"][entity_name]["hosts"].append(host_entry)


def _aggregate_stats_to_higher_levels(
    result: dict, company_name: str, entity_name: str, flow_stats: dict
) -> None:
    """Aggregate flow statistics from host level up to entity and company levels."""
    # Aggregate stats to entity level
    result[company_name]["entities"][entity_name]["src_bytes"] += flow_stats.get(
        "src_bytes", 0
    )
    result[company_name]["entities"][entity_name]["src_packets"] += flow_stats.get(
        "src_packets", 0
    )
    result[company_name]["entities"][entity_name]["dst_bytes"] += flow_stats.get(
        "dst_bytes", 0
    )
    result[company_name]["entities"][entity_name]["dst_packets"] += flow_stats.get(
        "dst_packets", 0
    )

    # Propagate most recent timestamp to entity level
    entity_last_activity = result[company_name]["entities"][entity_name][
        "last_activity"
    ]
    if (
        flow_stats["last_activity"]
        and flow_stats["last_activity"] > entity_last_activity
    ):
        result[company_name]["entities"][entity_name]["last_activity"] = flow_stats[
            "last_activity"
        ]

    # Aggregate stats to company level
    result[company_name]["src_bytes"] += flow_stats.get("src_bytes", 0)
    result[company_name]["dst_bytes"] += flow_stats.get("dst_bytes", 0)
    result[company_name]["src_packets"] += flow_stats.get("src_packets", 0)
    result[company_name]["dst_packets"] += flow_stats.get("dst_packets", 0)

    # Propagate most recent timestamp to company level
    company_last_activity = result[company_name]["last_activity"]
    if (
        flow_stats["last_activity"]
        and flow_stats["last_activity"] > company_last_activity
    ):
        result[company_name]["last_activity"] = flow_stats["last_activity"]


def _sort_result_structure(result: dict) -> dict:
    """Sort the result structure by company and entity names, with special ordering."""
    # sort result by company name, with "Other companies" at the end
    result = dict(
        sorted(result.items(), key=lambda x: (x[0] == "Other companies", x[0] or ""))
    )

    # sort result by entity name, with "IP addresses without a known domain" at the end
    for company in result.values():
        company["entities"] = dict(
            sorted(
                company["entities"].items(),
                key=lambda x: (
                    x[0] == "IP addresses without a known domain",
                    x[0] or "",
                ),
            )
        )

    return result


def get_entities_by_device(
    device_id: int,
    prefer_dns_type: list[str] | None = None,
    start_ts: datetime.datetime | None = None,
) -> dict[str, Any]:
    """
    Get all entities and companies associated with a device's network flows.

    Args:
        device_id: The ID of the device to analyze
        prefer_dns_type: List of DNS types to prefer in order when multiple paths exist (['sni', 'forward_dns', 'reverse_dns'])
        start_ts: Optional timestamp filter - only include flows after this timestamp

    Returns:
        A nested dictionary with companies as top-level keys, containing entities as second-level keys,
        and hosts as third-level keys, with flow statistics aggregated at each level.
    """
    if not prefer_dns_type:
        prefer_dns_type = ["sni", "forward_dns", "reverse_dns"]

    logger.debug(f"Getting entities for device_id: {device_id}")

    with db.atomic():
        # Get flow statistics for all hosts that communicated with this device
        host_flow_stats = _get_flow_stats_for_device(device_id, start_ts)

        unique_ip_ids_list = list(host_flow_stats.keys())

        # Get all IP nodes for these unique IDs
        ip_nodes = IpNode.select().where(IpNode.id.in_(unique_ip_ids_list))
        ip_nodes_list = list(ip_nodes)

        # Build the result structure
        result = {}

        for ip_node in ip_nodes_list:
            # Get flow stats for this IP
            flow_stats = host_flow_stats.get(
                ip_node.id,
                {
                    "src_bytes": 0,
                    "src_packets": 0,
                    "dst_bytes": 0,
                    "dst_packets": 0,
                    "last_activity": None,
                },
            )

            # Get all host edges for this IP
            host_edges = list(IpHostEdge.select().where(IpHostEdge.ip == ip_node))

            if not host_edges:
                continue

            # Find the preferred host edge
            preferred_ip_host_edge = _find_preferred_host_edge(
                host_edges, prefer_dns_type
            )
            host_id = preferred_ip_host_edge.host

            # Follow CNAME chain to get final hostname
            final_host_id = _follow_cname_chain(host_id)

            # Resolve host to entity and company information
            entity_name, company_name, entity, company, host = (
                _resolve_host_to_entities(final_host_id)
            )

            # Initialize company and entity in result structure
            _initialize_company_in_result(result, company_name, company)
            _initialize_entity_in_result(result, company_name, entity_name, entity)

            # Create host entry with flow stats
            host_entry = {
                "hostname": host.hostname if host else ip_node.ip_addr,
                "src_bytes": flow_stats.get("src_bytes", 0),
                "src_packets": flow_stats.get("src_packets", 0),
                "dst_bytes": flow_stats.get("dst_bytes", 0),
                "dst_packets": flow_stats.get("dst_packets", 0),
                "last_activity": flow_stats.get("last_activity", datetime.datetime.min),
            }

            # Add or update host in result structure
            _add_or_update_host_in_result(result, company_name, entity_name, host_entry)

            # Aggregate statistics to entity and company levels
            _aggregate_stats_to_higher_levels(
                result, company_name, entity_name, flow_stats
            )

        # Sort and return the result
        return _sort_result_structure(result)
