import multiprocessing as mp
import queue
from threading import Thread
from typing import Any
from dataclasses import dataclass
from enum import IntEnum
import logging
import uuid
import traceback
from peewee import OperationalError, fn

from model.model import (
    db,
    db_pragmas,
    db_initialize_tables,
    Device,
    Flow,
    MDNS,
    PendingDevice,
)
import model.devices
import model.devices_write as devices_write
import model.flow as flow
import model.friendly_organizer as friendly_organizer
import model.ghostery as ghostery
import model.hosts as hosts
import model.tcp_monitor as tcp_monitor

logger = logging.getLogger(__name__)


class Priority(IntEnum):
    HIGH = 1
    NORMAL = 2
    LOW = 3


@dataclass
class DBRequest:
    request_id: str
    operation: str
    data: dict[str, Any]
    priority: Priority = Priority.NORMAL
    request_response: bool = True

    def __lt__(self, other: "DBRequest") -> bool:
        return self.priority < other.priority


@dataclass
class DBResponse:
    request_id: str
    success: bool
    data: Any = None
    error: str | None = None


class DatabaseManager:
    def __init__(
        self,
        db_path: str,
        ghostery_db_path: str,
    ):
        # Use PriorityQueue for requests (handles priority)
        self.request_queue: queue.PriorityQueue[DBRequest | None] = (
            queue.PriorityQueue()
        )
        # Regular queue for responses
        self.response_queue: queue.Queue[DBResponse | None] = queue.Queue()
        # Multiprocess queue for requests from other processes
        self.db_ipc_queue: mp.Queue[tuple[str, dict[str, Any]]] = mp.Queue()
        self.db_thread: Thread | None = None
        self.ipc_listener_thread: Thread | None = None
        self.db_path: str = db_path
        self.ghostery_db_path: str = ghostery_db_path

    def start(self):
        self.db_thread = Thread(target=self._db_worker)
        self.db_thread.start()
        self.ipc_listener_thread = Thread(target=self._ipc_listener)
        self.ipc_listener_thread.start()

    def stop(self):
        if self.db_thread:
            self.request_queue.put(None)  # Shutdown signal
            self.db_thread.join()
        if self.ipc_listener_thread:
            self.db_ipc_queue.put(("TERMINATE", {}))
            self.ipc_listener_thread.join()

    def _db_worker(self):
        """Database worker thread. Reads from priority queue and handles requests."""

        # Initalize database
        logger.info("Initializing database")
        db.init(self.db_path, db_pragmas)  # pyright: ignore[reportUnknownMemberType]
        db_initialize_tables()

        ghostery_db = ghostery.GhosteryDB(self.ghostery_db_path)

        # Process all items by priority
        while True:
            request = self.request_queue.get()
            if request is None:  # Shutdown signal
                break
            response = self._handle_request(request, ghostery_db)
            if response is not None:
                self.response_queue.put(response)

    def _handle_request(
        self, request: DBRequest, ghostery_db: ghostery.GhosteryDB
    ) -> DBResponse | None:
        """Handle individual database requests"""
        operation = request.operation
        data = request.data
        request_response = request.request_response
        response = None

        try:
            if operation == "get_devices":
                response = model.devices.get_devices()
            elif operation == "get_devices_for_metrics":
                response = model.devices.get_devices_for_metrics()
            elif operation == "get_entities_by_device":
                response = model.devices.get_entities_by_device(
                    device_id=data["device_id"],
                    prefer_dns_type=data.get("prefer_dns_type", None),
                    start_ts=data.get("start_ts"),
                )
            elif operation == "get_device_by_id":
                response = Device.get_by_id(data["id"])
            elif operation == "get_device_details":
                response = model.devices.get_device_details(data["device_id"])
            elif operation == "get_debug_data":
                devices_data = model.devices.get_devices()
                mdns_data = list(MDNS.select().dicts())
                pendingdevice_data = list(PendingDevice.select().dicts())
                total_packet_count = (
                    Flow.select(fn.SUM(Flow.packet_count)).scalar() or 0
                )
                response = {
                    "devices_data": devices_data,
                    "mdns_data": mdns_data,
                    "pendingdevice_data": pendingdevice_data,
                    "total_packet_count": total_packet_count,
                }
            elif operation == "update_device":
                response = devices_write.update_device(
                    device_id=data["device_id"],
                    user_name=data.get("user_name"),
                    user_model=data.get("user_model"),
                    user_mfg=data.get("user_mfg"),
                    is_iot_form_status=data.get("is_iot_form_status"),
                )
            elif operation == "toggle_device_arp_spoof":
                response = devices_write.toggle_device_arp_spoof(
                    device_id=data["device_id"],
                    desired_state=data["desired_state"],
                )
            elif operation == "record_device":
                _ = devices_write.record_device(
                    mac_addr=data["mac_addr"],
                    ip_addr=data["ip_addr"],
                    dhcp_hostname=data["dhcp_hostname"],
                )
            elif operation == "record_dns_a":
                _ = hosts.record_dns_a(
                    hostname=data["hostname"],
                    ip_address=data["ip_address"],
                    edge_type=data["edge_type"],
                    ghostery_db=ghostery_db,
                )
            elif operation == "record_dns_cname":
                _ = hosts.record_dns_cname(
                    source_hostname=data["source_hostname"],
                    target_hostname=data["target_hostname"],
                    ghostery_db=ghostery_db,
                )
            elif operation == "write_pending_flows":
                flow.write_pending_flows_to_db(
                    data["flow_dict"], data["gateway_mac_addr"], ghostery_db
                )
            elif operation == "save_mdns_record":
                _ = devices_write.save_mdns_record(
                    ip4=data["ip4"],
                    ip6s=data["ip6s"],
                    mdns_name=data["mdns_name"],
                    service=data["service"],
                    server=data["server"],
                    properties=data["properties"],
                    integration=data["integration"],
                    integration_data=data["integration_data"],
                )
                if request_response:
                    return DBResponse(request.request_id, True)
            elif operation == "clear_pending_queue":
                _ = devices_write.clear_pending_queue()
            elif operation == "ssdp_save":
                _ = devices_write.ssdp_save(data)
            elif operation == "update_device_metadata":
                _ = friendly_organizer.update_device_metadata()
            elif operation == "ha_dhcp":
                _ = devices_write.ha_dhcp_save(data)
            elif operation == "ha_devices":
                _ = devices_write.ha_devices_save(data)
            elif operation == "retention_cleanup":
                hours = data.get("hours", 168)
                _ = devices_write.retention_cleanup(hours)
            elif operation == "tcp_retransmission_check":
                response = (
                    tcp_monitor.check_and_disable_arp_spoofing_for_high_retransmission()
                )
            elif operation == "get_retransmission_ratios":
                response = tcp_monitor.get_device_retransmission_ratios()
            else:
                logger.warning(f"Unknown operation: {operation}")
                return DBResponse(
                    request.request_id, False, error=f"Unknown operation: {operation}"
                )
        except OperationalError as e:
            logger.error(f"Database error: {e}. Putting request back in queue.")
            self.request_queue.put(request)
        except Exception as e:
            logger.error(f"Error handling request: {e}")
            logger.debug(traceback.format_exc())
            return DBResponse(request.request_id, False, error=str(e))

        if request_response:
            logger.debug(
                f"Returning response for request {request.request_id} {operation}"
            )
            return DBResponse(request.request_id, True, data=response)
        else:
            return None

    def _ipc_listener(self) -> None:
        """Listener for requests from other processes.
        Listens on a multiprocessing queue and forwards requests to the database manager. Used to support communication from other processes (not threads) to the database manager."""

        db_client = DatabaseClient(self.request_queue, self.response_queue)

        while True:
            message_type, data = self.db_ipc_queue.get()
            try:
                if message_type == "TERMINATE":
                    logger.info(
                        "Received message to terminate. Exiting database worker."
                    )
                    return
                else:
                    _ = db_client.send_request(
                        message_type,
                        data,
                        priority=Priority.LOW,
                        request_response=False,
                    )
            except Exception as e:
                logger.error(f"Error processing message: {e}")
                logger.debug(traceback.format_exc())


class DatabaseClient:
    def __init__(
        self,
        request_queue: "queue.Queue[DBRequest | None]",
        response_queue: "queue.Queue[DBResponse | None]",
    ):
        self.request_queue: queue.Queue[DBRequest | None] = request_queue
        self.response_queue: queue.Queue[DBResponse | None] = response_queue
        self.pending_requests: dict[str, DBRequest] = {}

    def send_request(
        self,
        operation: str,
        data: dict[str, Any] | None = None,
        priority: Priority = Priority.NORMAL,
        request_response: bool = True,
    ) -> str:
        """Send request and return request ID"""
        request_id = str(uuid.uuid4())
        request = DBRequest(
            request_id, operation, data or {}, priority, request_response
        )
        self.pending_requests[request_id] = request
        self.request_queue.put(request)
        return request_id

    def wait_for_response(self, request_id: str) -> DBResponse:
        """Wait for specific response by request ID"""
        responses_buffer: list[DBResponse] = []

        while True:
            try:
                response = self.response_queue.get(timeout=60)
                if not response:
                    # None was in the queue, skip over it
                    continue
                if response.request_id == request_id:
                    # Put back any buffered responses
                    for buffered in responses_buffer:
                        self.response_queue.put(buffered)
                    logger.debug(
                        f"Received response for request {request_id}, queue size: {self.response_queue.qsize()}"
                    )
                    return response
                else:
                    # Add response to buffer, not the one we're waiting for
                    logger.debug(
                        f"Waiting for {request_id}, got {response.request_id}, adding to buffer"
                    )
                    responses_buffer.append(response)
            except queue.Empty:
                # No response in alloted time, break for cleanup
                break

        # Error case: no response in alloted time
        # Put back buffered responses and return None
        for buffered in responses_buffer:
            self.response_queue.put(buffered)
        return DBResponse(request_id, False, error="No response received for query.")

        # TODO prune items from response queue if they are old/abandoned requests


# Usage example
# if __name__ == "__main__":
#     # Setup
#     db_manager = DatabaseManager()
#     db_manager.start()

#     client = DatabaseClient(db_manager)

#     # Send requests with different priorities
#     high_priority_id = client.send_request(
#         "read", {"table": "users", "id": 123}, Priority.HIGH
#     )
#     normal_id = client.send_request(
#         "write", {"table": "logs", "data": "info"}, Priority.NORMAL
#     )

#     # Get responses
#     response = client.wait_for_response(high_priority_id)
#     if response and response.success:
#         print(f"High priority result: {response.data}")

#     # Cleanup
#     db_manager.stop()
