from __future__ import annotations

import asyncio
import concurrent.futures
import datetime
import sys
import time
import traceback
import threading
import requests
import logging
from typing import Any, TypeVar
from collections.abc import Coroutine, Callable, Awaitable

logger = logging.getLogger(__name__)

# Type variables for generic function types
T = TypeVar("T")
U = TypeVar("U")


class PeriodicTaskRunner:
    def __init__(self, max_workers: int = 3) -> None:
        self.max_workers: int = max_workers
        self.executor: concurrent.futures.ThreadPoolExecutor = (
            concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
        )
        self.periodic_tasks: list[
            tuple[Callable[..., Any], int, Callable[..., Any], Any, Any]
        ] = []
        self.running: bool = False

    def add_cpu_task(
        self,
        task_func: Callable[..., T],
        sleep_time: int,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Add a CPU-bound task to be run periodically."""
        self.periodic_tasks.append(
            (self.run_cpu_task_periodically, sleep_time, task_func, args, kwargs)
        )

    def add_io_task(
        self,
        task_func: Callable[..., Awaitable[U]],
        sleep_time: int,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Add an IO-bound task to be run periodically."""
        self.periodic_tasks.append(
            (self.run_io_task_periodically, sleep_time, task_func, args, kwargs)
        )

    async def run_cpu_task_periodically(
        self,
        task_func: Callable[..., T],
        sleep_time: int,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Runs a CPU-bound task periodically using ProcessPoolExecutor.

        Args:
            task_func: The CPU-bound function to run periodically
            sleep_time: Time in seconds between task executions
            args: Positional arguments to pass to task_func
            kwargs: Keyword arguments to pass to task_func
        """
        while self.running:
            logger.info(f"Starting CPU task {task_func}")
            loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
            await loop.run_in_executor(self.executor, task_func, *args)
            logger.info(f"Sleeping CPU task {task_func} for {sleep_time}s")
            await asyncio.sleep(sleep_time)

    async def run_io_task_periodically(
        self,
        task_func: Callable[..., Awaitable[U]],
        sleep_time: int,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Runs an IO-bound task periodically using asyncio.

        Args:
            task_func: The async function to run periodically
            sleep_time: Time in seconds between task executions
            args: Positional arguments to pass to task_func
            kwargs: Keyword arguments to pass to task_func
        """
        while self.running:
            logger.info(f"Starting IO task {task_func}")
            await task_func(*args, **kwargs)
            logger.info(f"Sleeping IO task {task_func} for {sleep_time}s")
            await asyncio.sleep(sleep_time)

    async def start(self) -> None:
        """Starts all periodic tasks."""
        logger.info("Task runner starting")
        self.running = True

        tasks: list[Coroutine[Any, Any, None]] = [
            runner_func(task_func, sleep_time, *args, **kwargs)
            for runner_func, sleep_time, task_func, args, kwargs in self.periodic_tasks
        ]

        # Run all tasks concurrently
        await asyncio.gather(*tasks)

    def stop(self) -> None:
        """Stops all periodic tasks."""
        self.running = False
        self.executor.shutdown(wait=True)


class SafeLoopThread(object):
    """
    A wrapper to repeatedly execute a function in a daemon thread; if the
    function crashes, automatically restarts the function.

    Usage:

    def my_func(a, b=1):
        pass

    SafeLoopThread(my_func, args=['a'], kwargs={'b': 2}, sleep_time=1)

    TODO: Rewrite this as a decorator.

    """

    def __init__(
        self,
        func: Callable[..., Any],
        args: list[Any] = [],
        kwargs: dict[Any, Any] = {},
        sleep_time: int = 1,
    ) -> None:
        self._func = func
        self._func_args = args
        self._func_kwargs = kwargs
        self._sleep_time = sleep_time

        th = threading.Thread(target=self._execute_repeated_func_safe)
        th.daemon = False
        th.start()

    def _repeat_func(self) -> None:
        """Repeatedly calls the function."""

        logger.info(
            "[SafeLoopThread] Starting %s %s %s"
            % (self._func, self._func_args, self._func_kwargs)
        )

        while True:
            self._func(*self._func_args, **self._func_kwargs)
            if self._sleep_time:
                time.sleep(self._sleep_time)

    def _execute_repeated_func_safe(self) -> None:
        """Safely executes the repeated function calls."""

        while True:
            try:
                self._repeat_func()

            except Exception as e:
                err_msg = "=" * 80 + "\n"
                err_msg += "Time: %s\n" % datetime.datetime.today()
                err_msg += "Function: %s %s %s\n" % (
                    self._func,
                    self._func_args,
                    self._func_kwargs,
                )
                err_msg += "Exception: %s\n" % e
                err_msg += str(traceback.format_exc()) + "\n\n\n"

                sys.stderr.write(err_msg + "\n")
                logger.warn(err_msg)

                time.sleep(self._sleep_time)


def get_os() -> str:
    """Returns 'mac', 'linux', or 'windows'. Raises RuntimeError otherwise."""

    os_platform = sys.platform

    if os_platform.startswith("darwin"):
        return "mac"

    if os_platform.startswith("linux"):
        return "linux"

    if os_platform.startswith("win"):
        raise RuntimeError("Windows is not supported.")

    raise RuntimeError("Unsupported operating system.")


def http_request(
    method: str = "get",
    field_to_extract: str = "",
    args: list[Any] = [],
    kwargs: dict[Any, Any] = {},
) -> Any:
    """
    Issues an HTTP request and parse the returned contents.

    Returns the field_to_extract from the returned JSON object. If not, returns
    ''. If the request fails, raises IOError and logs the failure.

    """
    if method not in ["get", "post"]:
        raise RuntimeError("Unsupported method: %s" % method)

    # Make the request
    try:
        if method == "get":
            r = requests.get(*args, **kwargs)
        else:
            r = requests.post(*args, **kwargs)
    except Exception as ex:
        logging.warn(
            f"[http_request] Error: request with args {args} failed to complete: {ex}"
        )
        raise IOError

    # Check the status code
    if r.status_code != 200:
        logging.warn(
            f"[http_request] Error: request with args {args} failed with status code {r.status_code}"
        )
        raise IOError

    # Parse the response as JSON
    try:
        response_dict = r.json()
    except Exception as ex:
        logging.warn(
            f"[http_request] Error: unable to parse the response as JSON: {ex} - {r.text}"
        )
        raise IOError

    # Check if success
    if not response_dict["success"]:
        err_msg = ""
        if "error" in response_dict:
            err_msg = response_dict["error"]
        # Ignore the most common error (which is not an error because it takes
        # time for the backend to analyze an IP address) so that we won't
        # overwhelm the log with this message
        if err_msg != "No data for this ip_addr":
            logging.warn(
                f"[http_request] Error: request with args {args} did not succeed with error message: {err_msg}"
            )
        raise IOError

    # Return the field
    if field_to_extract:
        if field_to_extract not in response_dict:
            logging.warn(
                f"[http_request] Error: request with args {args} did not return the field {field_to_extract}"
            )
            raise IOError
        return response_dict[field_to_extract]

    return ""
