Instead of syncing JSON files back, the SQLite DB itself is now the persistent store on the remote HAProxy host: - Startup: download remote DB via SCP (skip migration if exists) - After writes: upload local DB via SCP (WAL checkpoint first) - JSON sync removed (sync_servers_json, sync_certs_json deleted) New functions: - ssh_ops: remote_download_file(), remote_upload_file() via SCP - db: sync_db_to_remote(), _try_download_remote_db() Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
419 lines
11 KiB
Python
419 lines
11 KiB
Python
"""File I/O operations for HAProxy MCP Server.
|
|
|
|
Most data access is now delegated to db.py (SQLite).
|
|
This module retains atomic file writes, map file I/O for HAProxy,
|
|
and provides backward-compatible function signatures.
|
|
"""
|
|
|
|
import fcntl
|
|
import os
|
|
import tempfile
|
|
from contextlib import contextmanager
|
|
from typing import Any, Generator, Optional
|
|
|
|
from .config import (
|
|
MAP_FILE,
|
|
WILDCARDS_MAP_FILE,
|
|
REMOTE_MODE,
|
|
logger,
|
|
)
|
|
from .validation import domain_to_backend
|
|
|
|
|
|
@contextmanager
|
|
def file_lock(lock_path: str) -> Generator[None, None, None]:
|
|
"""Acquire exclusive file lock for atomic operations.
|
|
|
|
In REMOTE_MODE, locking is skipped (single-writer assumption
|
|
with atomic writes on the remote host).
|
|
|
|
Args:
|
|
lock_path: Path to the lock file (typically config_file.lock)
|
|
|
|
Yields:
|
|
None - the lock is held for the duration of the context
|
|
"""
|
|
if REMOTE_MODE:
|
|
yield
|
|
return
|
|
|
|
with open(lock_path, 'w') as lock_file:
|
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
|
try:
|
|
yield
|
|
finally:
|
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
|
|
|
|
|
def atomic_write_file(file_path: str, content: str) -> None:
|
|
"""Write content to file atomically using temp file + rename.
|
|
|
|
Args:
|
|
file_path: Target file path
|
|
content: Content to write
|
|
|
|
Raises:
|
|
IOError: If write fails
|
|
"""
|
|
if REMOTE_MODE:
|
|
from .ssh_ops import remote_write_file
|
|
remote_write_file(file_path, content)
|
|
return
|
|
|
|
dir_path = os.path.dirname(file_path)
|
|
fd = None
|
|
temp_path = None
|
|
try:
|
|
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.tmp.')
|
|
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
fd = None # fd is now owned by the file object
|
|
f.write(content)
|
|
os.rename(temp_path, file_path)
|
|
temp_path = None # Rename succeeded
|
|
except OSError as e:
|
|
raise IOError(f"Failed to write {file_path}: {e}") from e
|
|
finally:
|
|
if fd is not None:
|
|
try:
|
|
os.close(fd)
|
|
except OSError:
|
|
pass
|
|
if temp_path is not None:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def _read_map_file(file_path: str) -> list[tuple[str, str]]:
|
|
"""Read a single map file and return list of (domain, backend) tuples.
|
|
|
|
Args:
|
|
file_path: Path to the map file
|
|
|
|
Returns:
|
|
List of (domain, backend) tuples from the map file
|
|
"""
|
|
entries = []
|
|
try:
|
|
content = _read_file(file_path)
|
|
for line in content.splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
parts = line.split()
|
|
if len(parts) >= 2:
|
|
entries.append((parts[0], parts[1]))
|
|
except FileNotFoundError:
|
|
logger.debug("Map file not found: %s", file_path)
|
|
return entries
|
|
|
|
|
|
def _read_file(file_path: str) -> str:
|
|
"""Read a file locally or remotely based on REMOTE_MODE.
|
|
|
|
Args:
|
|
file_path: Path to the file
|
|
|
|
Returns:
|
|
File contents as string
|
|
|
|
Raises:
|
|
FileNotFoundError: If file doesn't exist
|
|
"""
|
|
if REMOTE_MODE:
|
|
from .ssh_ops import remote_read_file
|
|
return remote_read_file(file_path)
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
try:
|
|
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
|
|
except OSError as e:
|
|
logger.debug("File locking not supported for %s: %s", file_path, e)
|
|
try:
|
|
return f.read()
|
|
finally:
|
|
try:
|
|
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def get_map_contents() -> list[tuple[str, str]]:
|
|
"""Get all domain-to-backend mappings from SQLite.
|
|
|
|
Returns:
|
|
List of (domain, backend) tuples including wildcards.
|
|
"""
|
|
from .db import db_get_map_contents
|
|
return db_get_map_contents()
|
|
|
|
|
|
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
|
"""Split entries into exact domains and wildcards.
|
|
|
|
Args:
|
|
entries: List of (domain, backend) tuples
|
|
|
|
Returns:
|
|
Tuple of (exact_entries, wildcard_entries)
|
|
"""
|
|
exact = []
|
|
wildcards = []
|
|
for domain, backend in entries:
|
|
if domain.startswith("."):
|
|
wildcards.append((domain, backend))
|
|
else:
|
|
exact.append((domain, backend))
|
|
return exact, wildcards
|
|
|
|
|
|
def save_map_file(entries: list[tuple[str, str]]) -> None:
|
|
"""Sync map files from the database.
|
|
|
|
Regenerates domains.map and wildcards.map from the current
|
|
database state. The entries parameter is ignored (kept for
|
|
backward compatibility during transition).
|
|
|
|
Raises:
|
|
IOError: If map files cannot be written.
|
|
"""
|
|
from .db import sync_map_files
|
|
sync_map_files()
|
|
|
|
|
|
def get_domain_backend(domain: str) -> Optional[str]:
|
|
"""Look up the backend for a domain from SQLite (O(1)).
|
|
|
|
Args:
|
|
domain: The domain to look up
|
|
|
|
Returns:
|
|
Backend name if found, None otherwise
|
|
"""
|
|
from .db import db_get_domain_backend
|
|
return db_get_domain_backend(domain)
|
|
|
|
|
|
def is_legacy_backend(backend: str) -> bool:
|
|
"""Check if backend is a legacy static backend (not a dynamic pool).
|
|
|
|
Pool backends: pool_1, pool_2, ..., pool_100 (dynamic, zero-reload)
|
|
Legacy backends: {domain}_backend (static, requires reload)
|
|
|
|
Args:
|
|
backend: Backend name to check.
|
|
|
|
Returns:
|
|
True if legacy backend, False if pool backend.
|
|
"""
|
|
return not backend.startswith("pool_")
|
|
|
|
|
|
def get_legacy_backend_name(domain: str) -> str:
|
|
"""Convert domain to legacy backend name format.
|
|
|
|
Args:
|
|
domain: Domain name
|
|
|
|
Returns:
|
|
Legacy backend name (e.g., 'api_example_com_backend')
|
|
"""
|
|
return f"{domain_to_backend(domain)}_backend"
|
|
|
|
|
|
def get_backend_and_prefix(domain: str) -> tuple[str, str]:
|
|
"""Look up backend and determine server name prefix for a domain.
|
|
|
|
Args:
|
|
domain: The domain name to look up
|
|
|
|
Returns:
|
|
Tuple of (backend_name, server_prefix)
|
|
|
|
Raises:
|
|
ValueError: If domain cannot be mapped to a valid backend
|
|
"""
|
|
backend = get_domain_backend(domain)
|
|
if not backend:
|
|
backend = get_legacy_backend_name(domain)
|
|
|
|
if backend.startswith("pool_"):
|
|
server_prefix = backend
|
|
else:
|
|
server_prefix = domain_to_backend(domain)
|
|
|
|
return backend, server_prefix
|
|
|
|
|
|
def load_servers_config() -> dict[str, Any]:
|
|
"""Load servers configuration from SQLite.
|
|
|
|
Returns:
|
|
Dictionary with server configurations (legacy format compatible).
|
|
"""
|
|
from .db import db_load_servers_config
|
|
return db_load_servers_config()
|
|
|
|
|
|
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
|
|
"""Add server configuration to persistent storage.
|
|
|
|
Args:
|
|
domain: Domain name
|
|
slot: Server slot (1 to MAX_SLOTS)
|
|
ip: Server IP address
|
|
http_port: HTTP port
|
|
"""
|
|
from .db import db_add_server, sync_db_to_remote
|
|
db_add_server(domain, slot, ip, http_port)
|
|
sync_db_to_remote()
|
|
|
|
|
|
def remove_server_from_config(domain: str, slot: int) -> None:
|
|
"""Remove server configuration from persistent storage.
|
|
|
|
Args:
|
|
domain: Domain name
|
|
slot: Server slot to remove
|
|
"""
|
|
from .db import db_remove_server, sync_db_to_remote
|
|
db_remove_server(domain, slot)
|
|
sync_db_to_remote()
|
|
|
|
|
|
def remove_domain_from_config(domain: str) -> None:
|
|
"""Remove domain from persistent config (servers + domain entry).
|
|
|
|
Args:
|
|
domain: Domain name to remove
|
|
"""
|
|
from .db import db_remove_domain_servers, sync_db_to_remote
|
|
db_remove_domain_servers(domain)
|
|
sync_db_to_remote()
|
|
|
|
|
|
def get_shared_domain(domain: str) -> Optional[str]:
|
|
"""Get the parent domain that this domain shares a pool with.
|
|
|
|
Args:
|
|
domain: Domain name to check
|
|
|
|
Returns:
|
|
Parent domain name if sharing, None otherwise
|
|
"""
|
|
from .db import db_get_shared_domain
|
|
return db_get_shared_domain(domain)
|
|
|
|
|
|
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
|
|
"""Add a domain that shares a pool with another domain.
|
|
|
|
Args:
|
|
domain: New domain name
|
|
shares_with: Existing domain to share pool with
|
|
"""
|
|
from .db import db_add_shared_domain, sync_db_to_remote
|
|
db_add_shared_domain(domain, shares_with)
|
|
sync_db_to_remote()
|
|
|
|
|
|
def get_domains_sharing_pool(pool: str) -> list[str]:
|
|
"""Get all domains that use a specific pool.
|
|
|
|
Args:
|
|
pool: Pool name (e.g., 'pool_5')
|
|
|
|
Returns:
|
|
List of domain names using this pool
|
|
"""
|
|
from .db import db_get_domains_sharing_pool
|
|
return db_get_domains_sharing_pool(pool)
|
|
|
|
|
|
def is_shared_domain(domain: str) -> bool:
|
|
"""Check if a domain is sharing another domain's pool.
|
|
|
|
Args:
|
|
domain: Domain name to check
|
|
|
|
Returns:
|
|
True if domain has _shares reference, False otherwise
|
|
"""
|
|
from .db import db_is_shared_domain
|
|
return db_is_shared_domain(domain)
|
|
|
|
|
|
# Certificate configuration functions
|
|
|
|
def load_certs_config() -> list[str]:
|
|
"""Load certificate domain list from SQLite.
|
|
|
|
Returns:
|
|
Sorted list of domain names.
|
|
"""
|
|
from .db import db_load_certs
|
|
return db_load_certs()
|
|
|
|
|
|
def add_cert_to_config(domain: str) -> None:
|
|
"""Add a domain to the certificate config.
|
|
|
|
Args:
|
|
domain: Domain name to add
|
|
"""
|
|
from .db import db_add_cert, sync_db_to_remote
|
|
db_add_cert(domain)
|
|
sync_db_to_remote()
|
|
|
|
|
|
def remove_cert_from_config(domain: str) -> None:
|
|
"""Remove a domain from the certificate config.
|
|
|
|
Args:
|
|
domain: Domain name to remove
|
|
"""
|
|
from .db import db_remove_cert, sync_db_to_remote
|
|
db_remove_cert(domain)
|
|
sync_db_to_remote()
|
|
|
|
|
|
# Domain map helper functions (used by domains.py)
|
|
|
|
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
|
|
shares_with: Optional[str] = None) -> None:
|
|
"""Add a domain to SQLite and sync map files.
|
|
|
|
Args:
|
|
domain: Domain name (e.g., "example.com").
|
|
backend: Backend pool name (e.g., "pool_5").
|
|
is_wildcard: Whether this is a wildcard entry.
|
|
shares_with: Parent domain if sharing a pool.
|
|
"""
|
|
from .db import db_add_domain, sync_map_files
|
|
db_add_domain(domain, backend, is_wildcard, shares_with)
|
|
sync_map_files()
|
|
|
|
|
|
def remove_domain_from_map(domain: str) -> None:
|
|
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
|
|
|
|
Args:
|
|
domain: Base domain name (without leading dot).
|
|
"""
|
|
from .db import db_remove_domain, sync_map_files
|
|
db_remove_domain(domain)
|
|
sync_map_files()
|
|
|
|
|
|
def find_available_pool() -> Optional[str]:
|
|
"""Find the first available pool not assigned to any domain.
|
|
|
|
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
|
|
|
|
Returns:
|
|
Pool name (e.g., "pool_5") or None if all pools are in use.
|
|
"""
|
|
from .db import db_find_available_pool
|
|
return db_find_available_pool()
|