import asyncio
from datetime import datetime
import logging
from random import randrange
import time
from typing import Any
from pathlib import Path
import subprocess

from aiohttp import ClientSession, FormData
import toml

import shared.config as config
import shared.system_stats as system_stats
from model.database_manager import DatabaseClient

logger = logging.getLogger(__name__)

METRICS_DOMAIN = "https://iot-transparency.lan.cmu.edu"


def get_version() -> str:
    version = "unknown"
    pyproject_toml_file = Path(__file__).parent.parent / "pyproject.toml"
    if pyproject_toml_file.exists() and pyproject_toml_file.is_file():
        data = toml.load(pyproject_toml_file)
        # check project.version
        if "project" in data and "version" in data["project"]:
            version = data["project"]["version"]

    return version


class MetricsClient:
    """Client for sending metrics events to the server."""

    def __init__(self, db_client: DatabaseClient | None = None) -> None:
        self.session = ClientSession()
        self.db_client: DatabaseClient | None = db_client

    async def close(self) -> None:
        if not self.session.closed:
            await self.session.close()
        self.session = None

    async def send(self, event: str, payload: dict[str, Any]) -> dict[str, str] | None:
        """
        Send a metrics event to the server. Adds a few common fields to the event.
        """
        # if session is closed, open a new one
        if self.session.closed:
            self.session = ClientSession()

        installation_id: str = config.get("installation_id")
        version = get_version()

        metrics_url = f"{METRICS_DOMAIN}/post"

        logger.info(f"sending metrics event {event}")
        try:
            async with self.session.post(
                metrics_url,
                json={
                    "user_id": installation_id,
                    "event": event,
                    "version": version,
                    "client_ts": time.time(),
                    **payload,
                },
            ) as response:
                if response.status == 200 or response.status == 201:
                    logger.debug(f"metrics post received {event}")
                    res = await response.json()
                    return res
                elif response.status == 403:  # forbidden
                    res = await response.json()
                    logger.error(f"metrics post refused {res}")
                elif response.status == 401:
                    # our server returns 401 unauthorized if the user has completed study
                    res = await response.json()
                    logger.info(f"metrics post refused {res}")
                    self.complete_study()
                else:
                    error = await response.json()
                    logger.error(
                        f"metrics post {event} failed with status code: {response.status} {error}"
                    )
        except Exception as e:
            logger.error(e)

    async def metrics_ping(
        self, first_run: bool, data_dir: str
    ) -> dict[str, str] | None:
        # do not send data if onboarding is not complete, or if user has completed study
        logger.info("preparing metrics ping event")
        config_items = config.items()
        installation_id = config_items.get("installation_id")
        if not installation_id:
            return
        study_complete = config_items.get("study_complete")
        if study_complete:
            return

        # get devices
        metrics_query_start = datetime.now()
        if not self.db_client:
            logger.error("No database client found")
            return
        devices_req = self.db_client.send_request("get_devices_for_metrics")
        devices = self.db_client.wait_for_response(devices_req)
        if not devices.success:
            logger.error(f"Error getting devices: {devices.error}")
            return
        if devices is None:
            logger.error("No devices found")
            return

        # Get network statistics and add to payload
        netstat = system_stats.netstat()
        # save the number of packets forwarded at the time of db initialization
        # use get to only set if it doesn't exist
        packets_forwarded_total = int(netstat["Ip"]["ForwDatagrams"])
        packets_forwarded_at_db_init = config_items.get(
            "packets_forwarded_at_db_init", packets_forwarded_total
        )
        packets_forwarded = packets_forwarded_total - packets_forwarded_at_db_init

        memstat = system_stats.memstat()
        loadavg = system_stats.loadavg()
        diskstat = system_stats.diskstat(data_dir)
        packet_counts = config_items.get("packet_counts", {})

        # Check IPv6 connectivity
        ipv6_connectivity = False
        try:
            result = subprocess.run(
                ["ping6", "-c", "1", "-W", "3", "2606:4700:4700::1111"],
                capture_output=True,
                timeout=5,
            )
            ipv6_connectivity = result.returncode == 0
        except (subprocess.TimeoutExpired, FileNotFoundError):
            ipv6_connectivity = False

        metrics_query_duration = (datetime.now() - metrics_query_start).total_seconds()

        # Create payload
        payload = {
            "first_event_since_start": first_run,
            "devices": devices.data,
            "query_duration": metrics_query_duration,
            # "packets_in_db": total_packet_count,
            "packets_forwarded": packets_forwarded,
            "packet_counts": packet_counts,
            "netstat": netstat,
            "memstat": memstat,
            "diskstat": diskstat,
            "loadavg": loadavg,
            "ipv6_connectivity": ipv6_connectivity,
        }

        logger.debug(
            f"metrics query took {metrics_query_duration} seconds, sending to server now"
        )
        return await self.send("ping", payload)

    async def metrics_loop(
        self, wait_time: int = 3600, data_dir: str = "/data"
    ) -> None:
        first_run = True
        logger.info("starting metrics_loop")
        while True:
            if config.get("study_complete") and not config.get(
                "study_complete_reenable"
            ):
                logger.info("Study complete, terminating metrics loop")
                return

            logger.info("sending metrics event")
            _ = await self.metrics_ping(first_run, data_dir)
            first_run = False
            await asyncio.sleep(wait_time + randrange(600))

    def complete_study(self) -> None:
        config.set("study_complete", True)
        config.set("study_complete_reenable", False)
        # TODO: disable ARP spoofing for all devices
        # with db.atomic():
        #     Device.update(is_arp_spoofed=False).execute()


async def pageload_send(page: str) -> None:
    """Send a metrics event with a new aiohttp session.
    This is used in cases where passing the session is not possible,
    """
    # do not send data if user has completed study
    installation_id = config.get("installation_id")
    if not installation_id:
        return
    study_complete = config.get("study_complete")
    if study_complete:
        return

    payload = {
        "page": page,
    }
    client = MetricsClient()
    _ = await client.send("page_load", payload)
    await client.close()


async def enroll(activation_code: str) -> int:
    logger.info(f"enrolling in study with user ID {activation_code}")
    config.set("installation_id", activation_code)
    config.set("study_complete", False)
    config.set("study_complete_reenable", False)
    version = get_version()

    metrics_url = f"{METRICS_DOMAIN}/enroll"

    logger.info("sending enrollment event")
    async with ClientSession() as session:
        try:
            async with session.post(
                metrics_url,
                json={
                    "user_id": activation_code,
                    "version": version,
                    "client_ts": time.time(),
                },
            ) as response:
                logger.info(f"activation response status: {response.status}")
                return response.status
        except Exception as e:
            logger.error(e)
    return 999


async def log_send(filename: str, ext: str = "log") -> tuple[bool, str]:
    """Send a log file to the metrics server.

    Args:
        filename: Path to the log file to send

    Returns:
        The server response if successful, None otherwise
    """
    installation_id = config.get("installation_id")
    upload_url = f"{METRICS_DOMAIN}/upload_log"

    try:
        async with ClientSession() as session:
            with open(filename, "rb") as f:
                data = FormData()
                data.add_field("user_id", installation_id)
                data.add_field("file", f, filename=Path(filename).name)
                data.add_field("ext", ext)

                async with session.post(upload_url, data=data) as response:
                    if response.status == 200 or response.status == 201:
                        logger.info(f"log file upload successful: {filename}")
                        return True, "Log file uploaded successfully"
                    else:
                        error = await response.json()
                        logger.error(f"log file rejected: {error}")
                        return False, f"Log file rejected by server: {error}"
    except Exception as e:
        logger.error(f"Error uploading log file: {e}")
        return False, f"Error uploading log file: {e}"


async def send_tcpdump(ip: str, path: str) -> dict[str, str] | None:
    """Capture packets using tcpdump and save to a file.
    Args:
        ip: The IP address to capture packets for
        filename: The name of the file to save the capture to
    """

    logger.info(f"Starting tcpdump packet capture for {ip}")
    filename = f"{path}/tcpdump-{ip}-{datetime.now().strftime('%Y%m%d-%H%M%S')}.pcap"
    try:
        _ = subprocess.run(
            f"tcpdump -n -e -tttt -w {filename} 'ip and not icmp and host {ip} or (src {ip} and not dst {ip}) or (dst {ip} and not src {ip})'",
            timeout=60,
            shell=True,
        )
    except subprocess.TimeoutExpired:
        pass

    logger.info(f"Sending packet capture for {ip} to server")
    return await log_send(filename, "pcap")
