Files
haproxy-mcp/haproxy_mcp/file_ops.py
kappa 12fd3b5e8f Store SQLite DB on remote host via SCP for persistence
Instead of syncing JSON files back, the SQLite DB itself is now
the persistent store on the remote HAProxy host:
- Startup: download remote DB via SCP (skip migration if exists)
- After writes: upload local DB via SCP (WAL checkpoint first)
- JSON sync removed (sync_servers_json, sync_certs_json deleted)

New functions:
- ssh_ops: remote_download_file(), remote_upload_file() via SCP
- db: sync_db_to_remote(), _try_download_remote_db()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 11:46:36 +09:00

419 lines
11 KiB
Python

"""File I/O operations for HAProxy MCP Server.
Most data access is now delegated to db.py (SQLite).
This module retains atomic file writes, map file I/O for HAProxy,
and provides backward-compatible function signatures.
"""
import fcntl
import os
import tempfile
from contextlib import contextmanager
from typing import Any, Generator, Optional
from .config import (
MAP_FILE,
WILDCARDS_MAP_FILE,
REMOTE_MODE,
logger,
)
from .validation import domain_to_backend
@contextmanager
def file_lock(lock_path: str) -> Generator[None, None, None]:
"""Acquire exclusive file lock for atomic operations.
In REMOTE_MODE, locking is skipped (single-writer assumption
with atomic writes on the remote host).
Args:
lock_path: Path to the lock file (typically config_file.lock)
Yields:
None - the lock is held for the duration of the context
"""
if REMOTE_MODE:
yield
return
with open(lock_path, 'w') as lock_file:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
yield
finally:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
def atomic_write_file(file_path: str, content: str) -> None:
"""Write content to file atomically using temp file + rename.
Args:
file_path: Target file path
content: Content to write
Raises:
IOError: If write fails
"""
if REMOTE_MODE:
from .ssh_ops import remote_write_file
remote_write_file(file_path, content)
return
dir_path = os.path.dirname(file_path)
fd = None
temp_path = None
try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.tmp.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
f.write(content)
os.rename(temp_path, file_path)
temp_path = None # Rename succeeded
except OSError as e:
raise IOError(f"Failed to write {file_path}: {e}") from e
finally:
if fd is not None:
try:
os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError:
pass
def _read_map_file(file_path: str) -> list[tuple[str, str]]:
"""Read a single map file and return list of (domain, backend) tuples.
Args:
file_path: Path to the map file
Returns:
List of (domain, backend) tuples from the map file
"""
entries = []
try:
content = _read_file(file_path)
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
entries.append((parts[0], parts[1]))
except FileNotFoundError:
logger.debug("Map file not found: %s", file_path)
return entries
def _read_file(file_path: str) -> str:
"""Read a file locally or remotely based on REMOTE_MODE.
Args:
file_path: Path to the file
Returns:
File contents as string
Raises:
FileNotFoundError: If file doesn't exist
"""
if REMOTE_MODE:
from .ssh_ops import remote_read_file
return remote_read_file(file_path)
with open(file_path, "r", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError as e:
logger.debug("File locking not supported for %s: %s", file_path, e)
try:
return f.read()
finally:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError:
pass
def get_map_contents() -> list[tuple[str, str]]:
"""Get all domain-to-backend mappings from SQLite.
Returns:
List of (domain, backend) tuples including wildcards.
"""
from .db import db_get_map_contents
return db_get_map_contents()
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
"""Split entries into exact domains and wildcards.
Args:
entries: List of (domain, backend) tuples
Returns:
Tuple of (exact_entries, wildcard_entries)
"""
exact = []
wildcards = []
for domain, backend in entries:
if domain.startswith("."):
wildcards.append((domain, backend))
else:
exact.append((domain, backend))
return exact, wildcards
def save_map_file(entries: list[tuple[str, str]]) -> None:
"""Sync map files from the database.
Regenerates domains.map and wildcards.map from the current
database state. The entries parameter is ignored (kept for
backward compatibility during transition).
Raises:
IOError: If map files cannot be written.
"""
from .db import sync_map_files
sync_map_files()
def get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a domain from SQLite (O(1)).
Args:
domain: The domain to look up
Returns:
Backend name if found, None otherwise
"""
from .db import db_get_domain_backend
return db_get_domain_backend(domain)
def is_legacy_backend(backend: str) -> bool:
"""Check if backend is a legacy static backend (not a dynamic pool).
Pool backends: pool_1, pool_2, ..., pool_100 (dynamic, zero-reload)
Legacy backends: {domain}_backend (static, requires reload)
Args:
backend: Backend name to check.
Returns:
True if legacy backend, False if pool backend.
"""
return not backend.startswith("pool_")
def get_legacy_backend_name(domain: str) -> str:
"""Convert domain to legacy backend name format.
Args:
domain: Domain name
Returns:
Legacy backend name (e.g., 'api_example_com_backend')
"""
return f"{domain_to_backend(domain)}_backend"
def get_backend_and_prefix(domain: str) -> tuple[str, str]:
"""Look up backend and determine server name prefix for a domain.
Args:
domain: The domain name to look up
Returns:
Tuple of (backend_name, server_prefix)
Raises:
ValueError: If domain cannot be mapped to a valid backend
"""
backend = get_domain_backend(domain)
if not backend:
backend = get_legacy_backend_name(domain)
if backend.startswith("pool_"):
server_prefix = backend
else:
server_prefix = domain_to_backend(domain)
return backend, server_prefix
def load_servers_config() -> dict[str, Any]:
"""Load servers configuration from SQLite.
Returns:
Dictionary with server configurations (legacy format compatible).
"""
from .db import db_load_servers_config
return db_load_servers_config()
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add server configuration to persistent storage.
Args:
domain: Domain name
slot: Server slot (1 to MAX_SLOTS)
ip: Server IP address
http_port: HTTP port
"""
from .db import db_add_server, sync_db_to_remote
db_add_server(domain, slot, ip, http_port)
sync_db_to_remote()
def remove_server_from_config(domain: str, slot: int) -> None:
"""Remove server configuration from persistent storage.
Args:
domain: Domain name
slot: Server slot to remove
"""
from .db import db_remove_server, sync_db_to_remote
db_remove_server(domain, slot)
sync_db_to_remote()
def remove_domain_from_config(domain: str) -> None:
"""Remove domain from persistent config (servers + domain entry).
Args:
domain: Domain name to remove
"""
from .db import db_remove_domain_servers, sync_db_to_remote
db_remove_domain_servers(domain)
sync_db_to_remote()
def get_shared_domain(domain: str) -> Optional[str]:
"""Get the parent domain that this domain shares a pool with.
Args:
domain: Domain name to check
Returns:
Parent domain name if sharing, None otherwise
"""
from .db import db_get_shared_domain
return db_get_shared_domain(domain)
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
"""Add a domain that shares a pool with another domain.
Args:
domain: New domain name
shares_with: Existing domain to share pool with
"""
from .db import db_add_shared_domain, sync_db_to_remote
db_add_shared_domain(domain, shares_with)
sync_db_to_remote()
def get_domains_sharing_pool(pool: str) -> list[str]:
"""Get all domains that use a specific pool.
Args:
pool: Pool name (e.g., 'pool_5')
Returns:
List of domain names using this pool
"""
from .db import db_get_domains_sharing_pool
return db_get_domains_sharing_pool(pool)
def is_shared_domain(domain: str) -> bool:
"""Check if a domain is sharing another domain's pool.
Args:
domain: Domain name to check
Returns:
True if domain has _shares reference, False otherwise
"""
from .db import db_is_shared_domain
return db_is_shared_domain(domain)
# Certificate configuration functions
def load_certs_config() -> list[str]:
"""Load certificate domain list from SQLite.
Returns:
Sorted list of domain names.
"""
from .db import db_load_certs
return db_load_certs()
def add_cert_to_config(domain: str) -> None:
"""Add a domain to the certificate config.
Args:
domain: Domain name to add
"""
from .db import db_add_cert, sync_db_to_remote
db_add_cert(domain)
sync_db_to_remote()
def remove_cert_from_config(domain: str) -> None:
"""Remove a domain from the certificate config.
Args:
domain: Domain name to remove
"""
from .db import db_remove_cert, sync_db_to_remote
db_remove_cert(domain)
sync_db_to_remote()
# Domain map helper functions (used by domains.py)
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to SQLite and sync map files.
Args:
domain: Domain name (e.g., "example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
from .db import db_add_domain, sync_map_files
db_add_domain(domain, backend, is_wildcard, shares_with)
sync_map_files()
def remove_domain_from_map(domain: str) -> None:
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
Args:
domain: Base domain name (without leading dot).
"""
from .db import db_remove_domain, sync_map_files
db_remove_domain(domain)
sync_map_files()
def find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
from .db import db_find_available_pool
return db_find_available_pool()