#!/usr/bin/env python3

from __future__ import annotations

import argparse
import asyncio
import datetime
import logging
import multiprocessing
import os
import signal
import sys
from typing import Any
from threading import Thread
from logging.handlers import RotatingFileHandler

from aiohttp import ClientSession

import shared.config as config

import app.main
import inspector.arp_spoofer
import inspector.packet_processor
import inspector.packet_collector
from model.database_manager import (
    DatabaseManager,
    DatabaseClient,
)
import scan.main
from shared.metrics import MetricsClient

# Set up basic logging configuration immediately
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(name)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger(__name__)

# set custom log levels for libraries
# these can be changed for debugging as needed
logging.getLogger("peewee").setLevel(logging.ERROR)
logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
logging.getLogger("filelock").setLevel(logging.ERROR)


# Add file handler once data directory is known
def log_add_file_handler(data_dir: str) -> None:
    log_file = os.path.join(data_dir, "iot-transparency.log")
    file_handler = RotatingFileHandler(
        log_file, maxBytes=5000000, backupCount=1
    )  # limit to 5MB per file, keep 1 backup
    file_handler.setFormatter(
        logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
    )
    logging.getLogger().addHandler(file_handler)


async def main(
    arp_spoof_allowed: bool,
    data_dir: str,
    listen_ip: str,
    port: str,
    app_root: str,
    ha_url: str,
    ha_addon: bool,
    debug_mode: bool,
) -> None:
    logger.info("IoT Transparency starting")

    # initalize config
    config.initialize(data_dir)
    installation_id: str | None = config.get("installation_id")
    config_items = config.items()
    config.get("installation_time", datetime.datetime.now().isoformat())
    logger.info(f"Installation ID {installation_id}")

    if arp_spoof_allowed is not True:
        logger.info("ARP spoofing not allowed")
        config.set("arp_spoof", False)
    else:
        logger.info("ARP spoofing allowed")
        config.set("arp_spoof", True)

    study_complete = config_items.get("study_complete")
    study_complete_reenable = config_items.get("study_complete_reenable")
    if study_complete_reenable == "pending_restart":
        study_complete_reenable = True
        config.set("study_complete_reenable", True)

    # initialize sqlite database and database manager thread
    # https://docs.peewee-orm.com/en/latest/peewee/database.html#deferring-initialization
    db_path: str = os.path.join(data_dir, "data.sqlite3")
    ghostery_db_path: str = os.path.join(
        "data", "domains", "trackerdb.db"
    )  # in extension data, not data_dir
    db_manager = DatabaseManager(db_path, ghostery_db_path)
    db_manager.start()
    db_client: DatabaseClient = DatabaseClient(
        db_manager.request_queue, db_manager.response_queue
    )
    db_ipc_queue: multiprocessing.Queue[tuple[str, dict[str, Any]]] = (
        db_manager.db_ipc_queue
    )

    # start new processes
    background_tasks: set[asyncio.Task[Any]] = set()

    # IPC queues
    packet_queue: "multiprocessing.Queue[Any]" = multiprocessing.Queue()

    metrics_client: MetricsClient = MetricsClient(db_client)
    scan_aiohttp_session: ClientSession | None = ClientSession()

    # Start async tasks
    logger.info("Starting async tasks")

    # scan loop (async)
    scan_loop_task: asyncio.Task[Any] | None = None
    if not study_complete or study_complete_reenable:
        scan_loop_task = asyncio.create_task(
            scan.main.scan_loop(
                wait_time=600,
                db_queue=db_ipc_queue,
                aiohttp_session=scan_aiohttp_session,
                ha_addon=ha_addon,
                ha_url=ha_url,
                packet_queue=packet_queue,
            )
        )
        background_tasks.add(scan_loop_task)
        scan_loop_task.add_done_callback(background_tasks.discard)

    # metrics loop (async)
    metrics_loop_task: asyncio.Task[Any] | None = None
    if not study_complete or study_complete_reenable:
        metrics_loop_task = asyncio.create_task(
            metrics_client.metrics_loop(wait_time=3600, data_dir=data_dir)
        )
        background_tasks.add(metrics_loop_task)
        metrics_loop_task.add_done_callback(background_tasks.discard)

    # arp spoofing loop (async)
    arp_spoofing_loop_task: asyncio.Task[Any] | None = None
    if arp_spoof_allowed and not study_complete or study_complete_reenable:
        arp_spoofing_loop_task = asyncio.create_task(
            inspector.arp_spoofer.spoof_internet_traffic_loop(db_client)
        )
        background_tasks.add(arp_spoofing_loop_task)
        arp_spoofing_loop_task.add_done_callback(background_tasks.discard)

    # Start processes
    logger.info("Starting processes")

    # web server (thread)
    app_thread = Thread(
        target=app.main.run_main_async,
        args=(
            data_dir,
            listen_ip,
            port,
            app_root,
            db_ipc_queue,
            debug_mode,
            ha_addon,
            ha_url,
            db_client,
        ),
    )
    app_thread.start()

    # packet processor (process)
    packet_processor_process: multiprocessing.Process | None = None
    if not study_complete or study_complete_reenable:
        packet_processor_process = multiprocessing.Process(
            target=inspector.packet_processor.process_packet,
            args=(packet_queue, db_ipc_queue, data_dir),
        )
        packet_processor_process.start()

    # packet collector (process)
    collector_process: multiprocessing.Process | None = None
    if not study_complete or study_complete_reenable:
        logger.info("Starting packet collector")
        collector_process = multiprocessing.Process(
            target=inspector.packet_collector.start_packet_collector,
            args=(packet_queue, data_dir),
        )
        collector_process.start()

    # define signal handler
    def signal_handler(sig: int, _: Any) -> None:
        logger.info(f"Process received interrupt signal {sig}")
        if packet_processor_process:
            packet_processor_process.terminate()
        if packet_processor_process:
            packet_processor_process.join()
        if collector_process:
            collector_process.terminate()
        if collector_process:
            collector_process.join()
        sys.exit(0)

    _ = signal.signal(signal.SIGINT, signal_handler)
    _ = signal.signal(signal.SIGTERM, signal_handler)

    try:
        # Wait for KeyboardInterrupt instead of joining processes here
        # This allows the asyncio event loop to continue running
        logger.info("Main process waiting. Press Ctrl+C to terminate.")
        completion_teardown = False
        while True:
            # If study is complete, tear down unneeded tasks
            if not completion_teardown and (
                config.get("study_complete")
                and not config.get("study_complete_reenable")
            ):
                logger.info("Study complete, tearing down unneeded tasks")
                packet_queue.put("TERMINATE")
                if metrics_loop_task:
                    _ = metrics_loop_task.cancel()
                if scan_loop_task:
                    _ = scan_loop_task.cancel()
                if arp_spoofing_loop_task:
                    _ = arp_spoofing_loop_task.cancel()
                completion_teardown = True
            # TODO find a strategy that does not use busy waiting
            await asyncio.sleep(10)  # Sleep briefly to avoid CPU spinning

    # handle quitting
    except KeyboardInterrupt:
        logger.info("Main process interrupted")
    finally:
        # Cancel all background tasks
        logger.info("Cancelling background tasks")
        await metrics_client.close()
        if scan_aiohttp_session and not scan_aiohttp_session.closed:
            await scan_aiohttp_session.close()

        # Terminate packet processor
        packet_queue.put("TERMINATE")
        if packet_processor_process:
            packet_processor_process.join()
        logger.info("Packet processor terminated")

        if collector_process:
            collector_process.terminate()
        if collector_process:
            collector_process.join()
        logger.info("Packet collector terminated")

        # Terminate database manager
        if db_manager:
            db_manager.stop()
        logger.info("Database manager terminated")

        # Wait for http connections to close, and for threads to finish
        await asyncio.sleep(0.5)

        for task in background_tasks:
            _ = task.cancel()
        if background_tasks:
            _ = await asyncio.gather(*background_tasks, return_exceptions=True)
            logger.info("All background tasks cancelled")

        if app_thread:
            app_thread.join()
        logger.info("App thread terminated")


if __name__ == "__main__":
    logger.debug("Starting IoT transparency server")

    # spawn new processes instead of fork
    # forking a multithreaded process is problematic
    multiprocessing.set_start_method("spawn")

    parser: argparse.ArgumentParser = argparse.ArgumentParser(
        description="IoT Transparency server"
    )
    _ = parser.add_argument(
        "--debug",
        help="Debug mode",
        nargs="?",  # Makes the value optional
        const=True,  # Used when flag present without value
        default=False,  # Used when flag absent
        type=lambda x: x.lower() == "true",
    )
    _ = parser.add_argument(
        "--allow-arp-spoofing",
        help="Allow ARP spoofing",
        nargs="?",  # Makes the value optional
        const=True,  # Used when flag present without value
        default=False,  # Used when flag absent
        type=lambda x: x.lower() == "true",
    )
    _ = parser.add_argument(
        "--data-dir", help="User data directory", default="user-data"
    )
    _ = parser.add_argument(
        "--listen-ip", help="IP for server to be accessible from", default="*"
    )
    _ = parser.add_argument("--port", help="Port for server", default="8080")
    _ = parser.add_argument("--app-root", help="Application root URL", default="")
    _ = parser.add_argument(
        "--ha-addon",
        help="Home Assistant add-on",
        nargs="?",  # Makes the value optional
        const=True,  # Used when flag present without value
        default=False,  # Used when flag absent
        type=lambda x: x.lower() == "true",
    )
    _ = parser.add_argument(
        "--ha-url", help="Home Assistant URL", default="http://homeassistant.local:8123"
    )

    args = parser.parse_args()
    logger.info(f"args are {args}")
    logger.info(f"debug mode is {args.debug}")

    if args.debug:
        logger.info("Debug mode enabled")
        # enable debug logging
        logging.getLogger().setLevel(logging.DEBUG)

        # save logs to file
        log_add_file_handler(args.data_dir)

    asyncio.run(
        main(
            arp_spoof_allowed=args.allow_arp_spoofing,
            data_dir=args.data_dir,
            listen_ip=args.listen_ip,
            port=args.port,
            app_root=args.app_root,
            ha_addon=args.ha_addon,
            ha_url=args.ha_url,
            debug_mode=args.debug,
        )
    )
