import logging
import os
import json
import re
import functools

import tldextract

from model.model import (
    db,
    IpNode,
    HostNode,
    CnameEdge,
    IpHostEdge,
    EntityNode,
    CompanyNode,
)
from shared.networking_helpers import get_country_from_ip_addr, get_reverse_dns
from model.ghostery import GhosteryDB

logger = logging.getLogger(__name__)

# Global cache for domain_map.json
domain_map_cache = None


@functools.lru_cache(maxsize=8192)
def get_ddg_data(domain: str) -> dict:
    global domain_map_cache
    if domain_map_cache is None:
        with open(os.path.join("data", "domains", "domain_map.json"), "r") as f:
            domain_map_cache = json.load(f)
    return domain_map_cache.get(domain)


def record_hostname(hostname: str, ghostery_db: GhosteryDB) -> HostNode:
    """Record a hostname in the database"""

    if not hostname:
        logger.error("record_hostname: hostname is None")
        return None

    # first see if domain is in Ghostery DB (that list is highest quality)
    domain, entity_dict, company_dict = ghostery_db.get_data_for_domain(hostname)
    if not domain:
        # Extract the domain from the hostname
        ext = tldextract.extract(hostname)
        try:
            domain = ext.top_domain_under_public_suffix

            # then try to find domain in DuckDuckGo tracker radar
            ddg_data = get_ddg_data(domain)
            if ddg_data:
                entity_dict = {
                    "id": ext.domain,  # excludes TLD
                    "name": domain,
                }
                company_dict = {
                    "id": domain,
                    "name": ddg_data.get("displayName"),
                }

            # special case for Apple (17.0.0.0/8)
            elif re.match(r"^17\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ext.domain):
                _, entity_dict, company_dict = ghostery_db.get_data_for_domain(
                    "apple.com"
                )
                domain = "Other network addresses"

        except AttributeError:
            logger.error(f"Failed to extract domain from hostname: {hostname}")
            domain = None
            entity_dict = None
            company_dict = None

    with db.atomic():
        if entity_dict and company_dict:
            entity, company = record_entity_and_company(entity_dict, company_dict)
        else:
            entity = None

        hn = HostNode.get_or_create(hostname=hostname)[0]
        if domain:
            hn.domain = domain
        if entity:
            hn.entity_id = entity.id
        hn.save()

    return hn


def record_entity_and_company(
    entity: dict, company: dict
) -> tuple[EntityNode, CompanyNode]:
    """Record an entity and company in the database"""
    with db.atomic():
        company_node = CompanyNode.get_or_create(
            id=company.get("id"),
            defaults={
                "company_name": company.get("name"),
                "description": company.get("description"),
                "website": company.get("website"),
                "privacy_url": company.get("privacy_url"),
                "country": company.get("country"),
            },
        )[0]
        entity_node = EntityNode.get_or_create(
            id=entity.get("id"),
            defaults={
                "entity_name": entity.get("name"),
                "company_id": company.get("id"),
                "category": entity.get("category"),
                "website": entity.get("website"),
            },
        )[0]
        return entity_node, company_node


def record_ip(ip_address: str, ghostery_db: GhosteryDB) -> IpNode:
    """Record an IP address in the database"""
    with db.atomic():
        # First check if node already exists
        node = IpNode.get_or_none(ip_addr=ip_address)

        if node is None:
            # Create node, doing country lookup and reverse DNS lookup
            country = get_country_from_ip_addr(ip_address)
            node = IpNode.create(ip_addr=ip_address, country=country)

            reverse_lookup = get_reverse_dns(ip_address)
            if reverse_lookup:
                reverse_node = record_hostname(reverse_lookup, ghostery_db)
            else:
                logger.error(f"record_ip: reverse lookup failed for {ip_address}")
                reverse_node = None
            if reverse_node:
                IpHostEdge.create(ip=node, host=reverse_node, edge_type="reverse_dns")

        return node


def record_dns_cname(
    source_hostname: str, target_hostname: str, ghostery_db: GhosteryDB
) -> CnameEdge | None:
    """Add a CNAME record between two hostnames"""
    with db.atomic():
        # Create or get the source hostname node
        source_node = record_hostname(source_hostname, ghostery_db)

        # Create or get the target hostname node
        target_node = record_hostname(target_hostname, ghostery_db)

        if not source_node or not target_node:
            logger.error("record_dns_cname: source_node or target_node is None")
            return None

        # Check if edge already exists
        edge = CnameEdge.get_or_none(
            src=source_node,
            dst=target_node,
        )
        if edge is None:
            return CnameEdge.create(
                src=source_node,
                dst=target_node,
            )
        return edge


def record_dns_a(
    hostname: str, ip_address: str, edge_type: str, ghostery_db: GhosteryDB
) -> IpHostEdge | None:
    """Add a DNS A edge between a hostname and an IP address"""
    with db.atomic():
        # Create or get the hostname node
        host_node = record_hostname(hostname, ghostery_db)

        # Create or get the IP node
        ip_node = record_ip(ip_address, ghostery_db)

        if not host_node or not ip_node:
            logger.error("record_dns_a: host_node or ip_node is None")
            return None

        # Check if edge already exists
        edge = IpHostEdge.get_or_none(
            ip=ip_node,
            host=host_node,
            edge_type=edge_type,
        )
        if edge is None:
            return IpHostEdge.create(
                ip=ip_node,
                host=host_node,
                edge_type=edge_type,
            )
        return edge
