import datetime
import logging
from typing import Any

from model.model import db, Device, Flow, IpNode
from model.hosts import record_ip
from model.ghostery import GhosteryDB

logger = logging.getLogger(__name__)


def write_pending_flows_to_db(
    flow_dict: dict[str, Any], gateway_mac_addr: str, ghostery_db: GhosteryDB
) -> None:
    """Write flows in the flow_dict into the database (Flow table) using bulk operations"""

    # Collect all unique MAC addresses and IP addresses for bulk lookup
    all_mac_addrs = set()
    all_ip_addrs = set()
    flow_data = []

    for flow_key, flow_stat_dict in flow_dict.items():
        # Unpack the flow key
        (
            src_mac_addr,
            dst_mac_addr,
            src_ip_addr,
            dst_ip_addr,
            src_port,
            dst_port,
            protocol,
        ) = flow_key

        # Collect unique addresses for bulk lookup
        if src_mac_addr:
            all_mac_addrs.add(src_mac_addr)
        if dst_mac_addr:
            all_mac_addrs.add(dst_mac_addr)
        if src_ip_addr:
            all_ip_addrs.add(src_ip_addr)
        if dst_ip_addr:
            all_ip_addrs.add(dst_ip_addr)

        # Store flow data for later processing
        flow_data.append(
            {
                "flow_key": flow_key,
                "flow_stat_dict": flow_stat_dict,
                "src_mac_addr": src_mac_addr,
                "dst_mac_addr": dst_mac_addr,
                "src_ip_addr": src_ip_addr,
                "dst_ip_addr": dst_ip_addr,
                "src_port": src_port,
                "dst_port": dst_port,
                "protocol": protocol,
            }
        )

    with db.atomic():
        # Bulk lookup devices by MAC address
        devices_by_mac = {}
        if all_mac_addrs:
            devices = Device.select().where(
                Device.mac_addr.in_(all_mac_addrs) & Device.is_arp_spoofed
            )
            for device in devices:
                devices_by_mac[device.mac_addr] = device

        # Bulk lookup devices by IP address
        devices_by_ip = {}
        if all_ip_addrs:
            devices = Device.select().where(
                Device.ip_addr.in_(all_ip_addrs) & Device.is_arp_spoofed
            )
            for device in devices:
                devices_by_ip[device.ip_addr] = device

        # Bulk lookup hosts by IP address
        hosts_by_ip = {}
        if all_ip_addrs:
            hosts = IpNode.select().where(IpNode.ip_addr.in_(all_ip_addrs))
            for host in hosts:
                hosts_by_ip[host.ip_addr] = host

        # Process flows and prepare bulk operations
        flows_to_create = []
        devices_to_update = set()
        hosts_to_create = []
        current_time = datetime.datetime.now()

        for flow_info in flow_data:
            src_mac_addr = flow_info["src_mac_addr"]
            dst_mac_addr = flow_info["dst_mac_addr"]
            src_ip_addr = flow_info["src_ip_addr"]
            dst_ip_addr = flow_info["dst_ip_addr"]
            flow_stat_dict = flow_info["flow_stat_dict"]

            # Determine matching entry in DB
            src_device = None
            dst_device = None
            src_host = None
            dst_host = None

            # Look up source device
            if src_mac_addr and src_mac_addr in devices_by_mac:
                src_device = devices_by_mac[src_mac_addr]
            elif src_ip_addr and src_ip_addr in devices_by_ip:
                src_device = devices_by_ip[src_ip_addr]

            # Look up destination device
            if dst_mac_addr and dst_mac_addr in devices_by_mac:
                dst_device = devices_by_mac[dst_mac_addr]
            elif dst_ip_addr and dst_ip_addr in devices_by_ip:
                dst_device = devices_by_ip[dst_ip_addr]

            # Update device last_seen timestamps
            if src_device:
                src_device.last_seen = current_time
                devices_to_update.add(src_device)
            elif src_ip_addr:
                # Create host if it doesn't exist
                if src_ip_addr not in hosts_by_ip:
                    hosts_to_create.append(src_ip_addr)
                    src_host = None  # Will be created below
                else:
                    src_host = hosts_by_ip[src_ip_addr]

            if dst_device:
                dst_device.last_seen = current_time
                devices_to_update.add(dst_device)
            elif dst_ip_addr:
                # Create host if it doesn't exist
                if dst_ip_addr not in hosts_by_ip:
                    hosts_to_create.append(dst_ip_addr)
                    dst_host = None  # Will be created below
                else:
                    dst_host = hosts_by_ip[dst_ip_addr]

            # Prepare flow data for bulk creation
            flows_to_create.append(
                {
                    "start_ts": flow_stat_dict["start_ts"],
                    "end_ts": flow_stat_dict["end_ts"],
                    "src_device": src_device,
                    "dst_device": dst_device,
                    "src_host": src_host,
                    "dst_host": dst_host,
                    "src_ip_addr": src_ip_addr,
                    "dst_ip_addr": dst_ip_addr,
                    "src_device_mac_addr": src_mac_addr,
                    "dst_device_mac_addr": dst_mac_addr,
                    "src_port": flow_info["src_port"],
                    "dst_port": flow_info["dst_port"],
                    "protocol": flow_info["protocol"],
                    "byte_count": flow_stat_dict["byte_count"],
                    "packet_count": flow_stat_dict["pkt_count"],
                    "tcp_retransmit": flow_stat_dict.get("tcp_retransmit", 0),
                    "tcp_rst": flow_stat_dict.get("tcp_rst", 0),
                }
            )

        # Bulk create hosts that don't exist
        if hosts_to_create:
            for ip_addr in hosts_to_create:
                host = record_ip(ip_addr, ghostery_db)
                hosts_by_ip[ip_addr] = host

        # Bulk update devices
        if devices_to_update:
            Device.bulk_update(list(devices_to_update), fields=[Device.last_seen])

        # Update src_host and dst_host references in flows_to_create
        for flow in flows_to_create:
            if flow["src_host"] is None and flow["src_ip_addr"]:
                flow["src_host"] = hosts_by_ip.get(flow["src_ip_addr"])
            if flow["dst_host"] is None and flow["dst_ip_addr"]:
                flow["dst_host"] = hosts_by_ip.get(flow["dst_ip_addr"])

        # Bulk create flows
        if flows_to_create:
            Flow.insert_many(flows_to_create).execute()

    logger.debug("Wrote {} flows to database.".format(len(flow_dict)))
