import logging
import threading

import model.model as model

logger = logging.getLogger(__name__)


class ARPCache(object):
    """Stores a mapping between IP and MAC addresses"""

    def __init__(self) -> None:
        # Initialize the variables
        self._ip_mac_cache: dict[str, str] = {}
        self._mac_ip_cache: dict[str, str] = {}
        self._lock = threading.Lock()
        # Load previous ARP cache from the database
        with model.db.atomic():
            for device in model.Device.select():
                self._ip_mac_cache[device.ip_addr] = device.mac_addr
                self._mac_ip_cache[device.mac_addr] = device.ip_addr

    def update(self, ip_addr: str, mac_addr: str) -> None:
        """Updates the cache with the given IP and MAC addresses."""
        with self._lock:
            self._ip_mac_cache[ip_addr] = mac_addr
            self._mac_ip_cache[mac_addr] = ip_addr

    def get_mac_addr(self, ip_addr: str) -> str | None:
        """Returns the MAC address for the given IP address."""
        with self._lock:
            return self._ip_mac_cache.get(ip_addr)

    def get_ip_addr(self, mac_addr: str) -> str | None:
        """Returns the IP address for the given MAC address."""
        with self._lock:
            return self._mac_ip_cache.get(mac_addr)
