"""File I/O operations for HAProxy MCP Server. Most data access is now delegated to db.py (SQLite). This module retains atomic file writes, map file I/O for HAProxy, and provides backward-compatible function signatures. """ import fcntl import os import tempfile from contextlib import contextmanager from typing import Any, Generator, Optional from .config import ( MAP_FILE, WILDCARDS_MAP_FILE, REMOTE_MODE, logger, ) from .validation import domain_to_backend @contextmanager def file_lock(lock_path: str) -> Generator[None, None, None]: """Acquire exclusive file lock for atomic operations. In REMOTE_MODE, locking is skipped (single-writer assumption with atomic writes on the remote host). Args: lock_path: Path to the lock file (typically config_file.lock) Yields: None - the lock is held for the duration of the context """ if REMOTE_MODE: yield return with open(lock_path, 'w') as lock_file: fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) try: yield finally: fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) def atomic_write_file(file_path: str, content: str) -> None: """Write content to file atomically using temp file + rename. Args: file_path: Target file path content: Content to write Raises: IOError: If write fails """ if REMOTE_MODE: from .ssh_ops import remote_write_file remote_write_file(file_path, content) return dir_path = os.path.dirname(file_path) fd = None temp_path = None try: fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.tmp.') with os.fdopen(fd, 'w', encoding='utf-8') as f: fd = None # fd is now owned by the file object f.write(content) os.rename(temp_path, file_path) temp_path = None # Rename succeeded except OSError as e: raise IOError(f"Failed to write {file_path}: {e}") from e finally: if fd is not None: try: os.close(fd) except OSError: pass if temp_path is not None: try: os.unlink(temp_path) except OSError: pass def _read_map_file(file_path: str) -> list[tuple[str, str]]: """Read a single map file and return list of (domain, backend) tuples. Args: file_path: Path to the map file Returns: List of (domain, backend) tuples from the map file """ entries = [] try: content = _read_file(file_path) for line in content.splitlines(): line = line.strip() if not line or line.startswith("#"): continue parts = line.split() if len(parts) >= 2: entries.append((parts[0], parts[1])) except FileNotFoundError: logger.debug("Map file not found: %s", file_path) return entries def _read_file(file_path: str) -> str: """Read a file locally or remotely based on REMOTE_MODE. Args: file_path: Path to the file Returns: File contents as string Raises: FileNotFoundError: If file doesn't exist """ if REMOTE_MODE: from .ssh_ops import remote_read_file return remote_read_file(file_path) with open(file_path, "r", encoding="utf-8") as f: try: fcntl.flock(f.fileno(), fcntl.LOCK_SH) except OSError as e: logger.debug("File locking not supported for %s: %s", file_path, e) try: return f.read() finally: try: fcntl.flock(f.fileno(), fcntl.LOCK_UN) except OSError: pass def get_map_contents() -> list[tuple[str, str]]: """Get all domain-to-backend mappings from SQLite. Returns: List of (domain, backend) tuples including wildcards. """ from .db import db_get_map_contents return db_get_map_contents() def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]: """Split entries into exact domains and wildcards. Args: entries: List of (domain, backend) tuples Returns: Tuple of (exact_entries, wildcard_entries) """ exact = [] wildcards = [] for domain, backend in entries: if domain.startswith("."): wildcards.append((domain, backend)) else: exact.append((domain, backend)) return exact, wildcards def save_map_file(entries: list[tuple[str, str]]) -> None: """Sync map files from the database. Regenerates domains.map and wildcards.map from the current database state. The entries parameter is ignored (kept for backward compatibility during transition). Raises: IOError: If map files cannot be written. """ from .db import sync_map_files sync_map_files() def get_domain_backend(domain: str) -> Optional[str]: """Look up the backend for a domain from SQLite (O(1)). Args: domain: The domain to look up Returns: Backend name if found, None otherwise """ from .db import db_get_domain_backend return db_get_domain_backend(domain) def is_legacy_backend(backend: str) -> bool: """Check if backend is a legacy static backend (not a dynamic pool). Pool backends: pool_1, pool_2, ..., pool_100 (dynamic, zero-reload) Legacy backends: {domain}_backend (static, requires reload) Args: backend: Backend name to check. Returns: True if legacy backend, False if pool backend. """ return not backend.startswith("pool_") def get_legacy_backend_name(domain: str) -> str: """Convert domain to legacy backend name format. Args: domain: Domain name Returns: Legacy backend name (e.g., 'api_example_com_backend') """ return f"{domain_to_backend(domain)}_backend" def get_backend_and_prefix(domain: str) -> tuple[str, str]: """Look up backend and determine server name prefix for a domain. Args: domain: The domain name to look up Returns: Tuple of (backend_name, server_prefix) Raises: ValueError: If domain cannot be mapped to a valid backend """ backend = get_domain_backend(domain) if not backend: backend = get_legacy_backend_name(domain) if backend.startswith("pool_"): server_prefix = backend else: server_prefix = domain_to_backend(domain) return backend, server_prefix def load_servers_config() -> dict[str, Any]: """Load servers configuration from SQLite. Returns: Dictionary with server configurations (legacy format compatible). """ from .db import db_load_servers_config return db_load_servers_config() def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None: """Add server configuration to persistent storage. Args: domain: Domain name slot: Server slot (1 to MAX_SLOTS) ip: Server IP address http_port: HTTP port """ from .db import db_add_server, sync_servers_json db_add_server(domain, slot, ip, http_port) sync_servers_json() def remove_server_from_config(domain: str, slot: int) -> None: """Remove server configuration from persistent storage. Args: domain: Domain name slot: Server slot to remove """ from .db import db_remove_server, sync_servers_json db_remove_server(domain, slot) sync_servers_json() def remove_domain_from_config(domain: str) -> None: """Remove domain from persistent config (servers + domain entry). Args: domain: Domain name to remove """ from .db import db_remove_domain_servers, sync_servers_json db_remove_domain_servers(domain) sync_servers_json() def get_shared_domain(domain: str) -> Optional[str]: """Get the parent domain that this domain shares a pool with. Args: domain: Domain name to check Returns: Parent domain name if sharing, None otherwise """ from .db import db_get_shared_domain return db_get_shared_domain(domain) def add_shared_domain_to_config(domain: str, shares_with: str) -> None: """Add a domain that shares a pool with another domain. Args: domain: New domain name shares_with: Existing domain to share pool with """ from .db import db_add_shared_domain, sync_servers_json db_add_shared_domain(domain, shares_with) sync_servers_json() def get_domains_sharing_pool(pool: str) -> list[str]: """Get all domains that use a specific pool. Args: pool: Pool name (e.g., 'pool_5') Returns: List of domain names using this pool """ from .db import db_get_domains_sharing_pool return db_get_domains_sharing_pool(pool) def is_shared_domain(domain: str) -> bool: """Check if a domain is sharing another domain's pool. Args: domain: Domain name to check Returns: True if domain has _shares reference, False otherwise """ from .db import db_is_shared_domain return db_is_shared_domain(domain) # Certificate configuration functions def load_certs_config() -> list[str]: """Load certificate domain list from SQLite. Returns: Sorted list of domain names. """ from .db import db_load_certs return db_load_certs() def add_cert_to_config(domain: str) -> None: """Add a domain to the certificate config. Args: domain: Domain name to add """ from .db import db_add_cert, sync_certs_json db_add_cert(domain) sync_certs_json() def remove_cert_from_config(domain: str) -> None: """Remove a domain from the certificate config. Args: domain: Domain name to remove """ from .db import db_remove_cert, sync_certs_json db_remove_cert(domain) sync_certs_json() # Domain map helper functions (used by domains.py) def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False, shares_with: Optional[str] = None) -> None: """Add a domain to SQLite and sync map files. Args: domain: Domain name (e.g., "example.com"). backend: Backend pool name (e.g., "pool_5"). is_wildcard: Whether this is a wildcard entry. shares_with: Parent domain if sharing a pool. """ from .db import db_add_domain, sync_map_files db_add_domain(domain, backend, is_wildcard, shares_with) sync_map_files() def remove_domain_from_map(domain: str) -> None: """Remove a domain (exact + wildcard) from SQLite and sync map files. Args: domain: Base domain name (without leading dot). """ from .db import db_remove_domain, sync_map_files db_remove_domain(domain) sync_map_files() def find_available_pool() -> Optional[str]: """Find the first available pool not assigned to any domain. Uses SQLite query for O(1) lookup vs previous O(n) list scan. Returns: Pool name (e.g., "pool_5") or None if all pools are in use. """ from .db import db_find_available_pool return db_find_available_pool()