import datetime
import logging

from peewee import (
    Model,
    BooleanField,
    DateTimeField,
    ForeignKeyField,
    IntegerField,
    TextField,
)
from playhouse.sqlite_ext import SqliteExtDatabase, JSONField  # pyright: ignore[reportMissingImports, reportUnknownVariableType]
from playhouse.shortcuts import ThreadSafeDatabaseMetadata  # pyright: ignore[reportMissingImports, reportUnknownVariableType]
# from playhouse.migrate import SqliteMigrator, migrate  # pyright: ignore[reportMissingImports, reportUnknownVariableType]

import shared.config as config
import shared.system_stats as system_stats

logger = logging.getLogger(__name__)

# Database version
# If you change the schema, you must increment this version number
DB_VERSION: int = 101

# Create the database, initialization is deferred
# https://docs.peewee-orm.com/en/latest/peewee/database.html#deferring-initialization
db_pragmas = {
    # https://docs.peewee-orm.com/en/latest/peewee/database.html#recommended-settings
    "journal_mode": "wal",
    "cache_size": -1 * 64000,  # 64MB
    "foreign_keys": 1,
    "ignore_check_constraints": 0,
    "synchronous": 0,
}

db: SqliteExtDatabase = SqliteExtDatabase(None, pragmas=db_pragmas)  # pyright: ignore[reportUnknownVariableType]


class BaseModel(Model):
    class Meta:
        database: SqliteExtDatabase = db  # pyright: ignore[reportUnknownVariableType]
        # https://docs.peewee-orm.com/en/latest/peewee/database.html#thread-safety-and-multiple-databases
        model_metadata_class: ThreadSafeDatabaseMetadata = ThreadSafeDatabaseMetadata  # pyright: ignore[reportUnknownVariableType]


class Device(BaseModel):
    # Auto populated by Inspector
    mac_addr: TextField | None = TextField(index=True, null=True)
    ip_addr: TextField | None = TextField(index=True, null=True)
    ip6s: JSONField | None = JSONField(null=True)  # pyright: ignore[reportUnknownVariableType]

    # Preferred device info
    preferred_name: TextField | None = TextField(null=True)
    preferred_model: TextField | None = TextField(null=True)
    preferred_mfg: TextField | None = TextField(null=True)

    # User-set names
    user_name: TextField | None = TextField(null=True)
    user_model: TextField | None = TextField(null=True)
    user_mfg: TextField | None = TextField(null=True)

    # arp spoofing
    is_iot: BooleanField | None = BooleanField(default=False)
    is_iot_user_override: BooleanField | None = BooleanField(default=False)
    is_arp_spoofable: BooleanField | None = BooleanField(null=True)
    is_arp_spoofed: BooleanField | None = BooleanField(null=True)
    arp_spoofing_error: IntegerField = IntegerField(
        default=0
    )  # 0: no error, 1: error, 2: error, user has overridden
    arp_spoofing_start_time: DateTimeField | None = DateTimeField(null=True)

    # time seen
    first_seen: DateTimeField | None = DateTimeField(default=datetime.datetime.now())
    last_seen: DateTimeField | None = DateTimeField(default=datetime.datetime.now())

    # discovery data
    # all_names: JSONField | None = JSONField(null=True)  # pyright: ignore[reportUnknownVariableType]
    # all_manufacturers: JSONField | None = JSONField(null=True)  # pyright: ignore[reportUnknownVariableType]
    mac_vendor: TextField | None = TextField(null=True)
    dhcp_hostname: TextField | None = TextField(null=True)
    mdns_hostname: TextField | None = TextField(index=True, null=True)

    # SSDP data
    ssdp_name: TextField | None = TextField(null=True)
    ssdp_model: TextField | None = TextField(null=True)
    ssdp_mfg: TextField | None = TextField(null=True)
    ssdp_data: JSONField | None = JSONField(null=True)  # pyright: ignore[reportUnknownVariableType]

    # home assistant data
    ha_name: TextField | None = TextField(null=True)
    ha_model: TextField | None = TextField(null=True)
    ha_mfg: TextField | None = TextField(null=True)


# IP/hostname/company/entity graph
# IP address/hostname mapping is many-to-many, so we model IpNode and IpHostEdge
# CNAMEs may also be a many-to-many relationship, so we model CnameEdge
# Each hostname can then have a matching entity, EntityNode (many-to-one relationship)
# And each entity can have a matching company, CompanyNode (many-to-one relationship)


class CompanyNode(BaseModel):
    """Represents a larger company, e.g. Google"""

    id = TextField(index=True, primary_key=True)
    company_name = TextField(null=True)
    description = TextField(null=True)
    website = TextField(null=True)
    privacy_url = TextField(null=True)
    country = TextField(null=True)
    created_at = DateTimeField(default=datetime.datetime.now)


class EntityNode(BaseModel):
    """Represents an entity, e.g. Google Play Services (corresponds to trackers in Ghosery dataset)"""

    id = TextField(index=True, primary_key=True)
    # many-to-one relationship with CompanyNode, so we use ForeignKeyField
    company_id = ForeignKeyField(CompanyNode, backref="entities", index=True)
    entity_name = TextField(null=True, index=True)
    category = TextField(null=True, index=True)
    website = TextField(null=True)
    created_at = DateTimeField(default=datetime.datetime.now)


# IP addresses map to one or more hostnames
class HostNode(BaseModel):
    """Represents a node in the DNS/ownership graph"""

    hostname = TextField(index=True)
    # TLD/public suffix/matching domain on entity list
    # e.g. play.google.com, reddit-image.s3.amazonaws.com
    # many-to-one relationship with EntityNode, so we use ForeignKeyField
    domain = TextField(null=True)
    entity_id = ForeignKeyField(EntityNode, backref="hosts", index=True, null=True)
    created_at = DateTimeField(default=datetime.datetime.now)


class IpNode(BaseModel):
    ip_addr = TextField(index=True)
    country = TextField(null=True)
    categories = JSONField(null=True)
    created_at = DateTimeField(default=datetime.datetime.now)


class IpHostEdge(BaseModel):
    """Represents an edge in the DNS/ownership graph"""

    ip = ForeignKeyField(IpNode, backref="ips", index=True)
    host = ForeignKeyField(HostNode, backref="hosts", index=True)
    # Edge type: 'forward_dns', 'reverse_dns', 'cname'
    edge_type = TextField(index=True)
    created_at = DateTimeField(default=datetime.datetime.now)


class CnameEdge(BaseModel):
    """Represents a CNAME edge in the DNS/ownership graph"""

    src = ForeignKeyField(HostNode, backref="cname_src", index=True)
    dst = ForeignKeyField(HostNode, backref="cname_dst", index=True)
    created_at = DateTimeField(default=datetime.datetime.now)


# Legacy Host model - deprecated, use HostNode/HostEdge instead
# class Host(BaseModel):
#     ip_addr = TextField(null=True, index=True)
#     hostname = TextField(null=True)
#     reg_domain = TextField(null=True, index=True)
#     owner = TextField(null=True)
#     owner_url = TextField(null=True)
#     owner_privacy = TextField(null=True)
#     categories = JSONField(null=True)
#     country = TextField(null=True)
#     ip_src = TextField(null=True)
#     hostname_src = TextField(null=True)


class Flow(BaseModel):
    src_device = ForeignKeyField(Device, backref="flow_src", index=True, null=True)
    dst_device = ForeignKeyField(Device, backref="flow_dst", index=True, null=True)
    src_host = ForeignKeyField(IpNode, backref="flow_src_host", index=True, null=True)
    dst_host = ForeignKeyField(IpNode, backref="flow_dst_host", index=True, null=True)

    start_ts = DateTimeField(null=True, index=True)
    end_ts = DateTimeField(null=True, index=True)
    src_device_mac_addr = TextField(null=True)
    dst_device_mac_addr = TextField(null=True)
    src_ip_addr = TextField(null=True)
    dst_ip_addr = TextField(null=True)
    src_port = IntegerField(null=True)
    dst_port = IntegerField(null=True)
    protocol = TextField(null=True)
    byte_count = IntegerField(default=0)
    packet_count = IntegerField(default=0)
    tcp_retransmit = IntegerField(default=0)
    tcp_rst = IntegerField(default=0)


class MDNS(BaseModel):
    device = ForeignKeyField(Device, backref="mdns")
    device_manufacturer = TextField(null=True)
    device_model = TextField(null=True)
    friendly_name = TextField(null=True)
    hostname = TextField(null=True)
    service = TextField(null=True)
    full_name = TextField(null=True)
    properties = TextField(null=True)
    integration = TextField(null=True)


# Used to store devices where we have an mDNS record with an IPv6 address, but no IPv4 address to match it to
class PendingDevice(BaseModel):
    ip6s = JSONField(index=True, null=True)
    device_manufacturer = TextField(null=True)
    friendly_name = TextField(null=True)
    hostname = TextField(index=True, null=True)
    full_name = TextField(null=True)
    service = TextField(null=True)
    properties = TextField(null=True)
    integration = TextField(null=True)


def db_initialize_tables() -> None:
    """Creates the tables if they don't exist yet, and creates initial data."""

    existing_db_version: int | None = config.get("db_version")

    with db.atomic():
        if existing_db_version is not None and existing_db_version < DB_VERSION:
            # new db version, drop tables
            logger.info(
                f"Database version {existing_db_version} less than new version {DB_VERSION}. Running migrations."
            )

            # always reset scanner on version update
            # forces the "scanning for devices" page to be shown on first launch
            config.set("first_scan_complete", False)

            if existing_db_version < 15:
                # fixed issue #100, cleaning up bad data
                netstat = system_stats.netstat()
                packets_forwarded_total = int(netstat["Ip"]["ForwDatagrams"])
                config.set("packets_forwarded_at_db_init", packets_forwarded_total)
                config.set("packet_counts", None)
            if existing_db_version < 100:
                db.drop_tables(  # pyright: ignore[reportUnknownMemberType]
                    [
                        Device,
                        Flow,
                        MDNS,
                        PendingDevice,
                        IpNode,
                        HostNode,
                        EntityNode,
                        CompanyNode,
                        IpHostEdge,
                        CnameEdge,
                    ]
                )
            if existing_db_version < 101:
                # wipe flow table as tcp retransmits were counted incorrectly
                db.drop_tables(  # pyright: ignore[reportUnknownMemberType]
                    [
                        Flow,
                    ]
                )
            # future migrations go here
            # if existing_db_version > 100 and existing_db_version < 101:
            #     migrator: SqliteMigrator = SqliteMigrator(db) # pyright: ignore[reportUnknownVariableType]
            #     with db.atomic(): # pyright: ignore[reportUnknownMemberType]
            #         migrate(
            #             migrator.add_column("Device", "arp_spoofing_start_time", DateTimeField(null=True)), # pyright: ignore[reportUnknownMemberType]
            #             migrator.add_column("Device", "arp_spoofing_error", IntegerField(default=0)), # pyright: ignore[reportUnknownMemberType]
            #         )

        # Create tables
        db.create_tables(  # pyright: ignore[reportUnknownMemberType]
            [
                Device,
                Flow,
                MDNS,
                PendingDevice,
                IpNode,
                HostNode,
                EntityNode,
                CompanyNode,
                IpHostEdge,
                CnameEdge,
            ]
        )
        config.set("db_version", DB_VERSION)
