Files
haproxy-mcp/haproxy_mcp/file_ops.py
kappa e40d69a1b1 feat: Add SSH remote execution for HAProxy on remote host
MCP server can now manage HAProxy running on a remote host via SSH.
When SSH_HOST env var is set, all file I/O and subprocess commands
(podman, acme.sh, openssl) are routed through SSH instead of local exec.

- Add ssh_ops.py module with remote_exec, run_command, file I/O helpers
- Modify file_ops.py to support remote reads/writes via SSH
- Update all tools (domains, certificates, health, configuration) for SSH
- Fix domains.py: replace direct fcntl usage with file_lock context manager
- Add openssh-client to Docker image for SSH connectivity
- Update k8s deployment with SSH env vars and SSH key secret mount

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:56:54 +09:00

457 lines
12 KiB
Python

"""File I/O operations for HAProxy MCP Server."""
import fcntl
import json
import os
import tempfile
from contextlib import contextmanager
from typing import Any, Generator, Optional
from .config import (
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
REMOTE_MODE,
logger,
)
from .validation import domain_to_backend
@contextmanager
def file_lock(lock_path: str) -> Generator[None, None, None]:
"""Acquire exclusive file lock for atomic operations.
In REMOTE_MODE, locking is skipped (single-writer assumption
with atomic writes on the remote host).
Args:
lock_path: Path to the lock file (typically config_file.lock)
Yields:
None - the lock is held for the duration of the context
"""
if REMOTE_MODE:
yield
return
with open(lock_path, 'w') as lock_file:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
yield
finally:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
def atomic_write_file(file_path: str, content: str) -> None:
"""Write content to file atomically using temp file + rename.
Args:
file_path: Target file path
content: Content to write
Raises:
IOError: If write fails
"""
if REMOTE_MODE:
from .ssh_ops import remote_write_file
remote_write_file(file_path, content)
return
dir_path = os.path.dirname(file_path)
fd = None
temp_path = None
try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.tmp.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
f.write(content)
os.rename(temp_path, file_path)
temp_path = None # Rename succeeded
except OSError as e:
raise IOError(f"Failed to write {file_path}: {e}") from e
finally:
if fd is not None:
try:
os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError:
pass
def _read_map_file(file_path: str) -> list[tuple[str, str]]:
"""Read a single map file and return list of (domain, backend) tuples.
Args:
file_path: Path to the map file
Returns:
List of (domain, backend) tuples from the map file
"""
entries = []
try:
content = _read_file(file_path)
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
entries.append((parts[0], parts[1]))
except FileNotFoundError:
logger.debug("Map file not found: %s", file_path)
return entries
def _read_file(file_path: str) -> str:
"""Read a file locally or remotely based on REMOTE_MODE.
Args:
file_path: Path to the file
Returns:
File contents as string
Raises:
FileNotFoundError: If file doesn't exist
"""
if REMOTE_MODE:
from .ssh_ops import remote_read_file
return remote_read_file(file_path)
with open(file_path, "r", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError as e:
logger.debug("File locking not supported for %s: %s", file_path, e)
try:
return f.read()
finally:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError:
pass
def get_map_contents() -> list[tuple[str, str]]:
"""Read both domains.map and wildcards.map and return combined entries.
Returns:
List of (domain, backend) tuples from both map files
"""
# Read exact domains
entries = _read_map_file(MAP_FILE)
# Read wildcards and append
entries.extend(_read_map_file(WILDCARDS_MAP_FILE))
return entries
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
"""Split entries into exact domains and wildcards.
Args:
entries: List of (domain, backend) tuples
Returns:
Tuple of (exact_entries, wildcard_entries)
"""
exact = []
wildcards = []
for domain, backend in entries:
if domain.startswith("."):
wildcards.append((domain, backend))
else:
exact.append((domain, backend))
return exact, wildcards
def save_map_file(entries: list[tuple[str, str]]) -> None:
"""Save domain-to-backend entries to map files.
Splits entries into two files for 2-stage routing:
- domains.map: Exact matches (map_str, O(log n))
- wildcards.map: Wildcard entries starting with "." (map_dom, O(n))
Args:
entries: List of (domain, backend) tuples.
Raises:
IOError: If map files cannot be written.
"""
# Split into exact and wildcard entries
exact_entries, wildcard_entries = split_domain_entries(entries)
# Save exact domains (for map_str - fast O(log n) lookup)
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for domain, backend in sorted(exact_entries):
exact_lines.append(f"{domain} {backend}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Save wildcards (for map_dom - O(n) but small set)
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for domain, backend in sorted(wildcard_entries):
wildcard_lines.append(f"{domain} {backend}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
def get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a domain from domains.map.
Args:
domain: The domain to look up
Returns:
Backend name if found, None otherwise
"""
for map_domain, backend in get_map_contents():
if map_domain == domain:
return backend
return None
def is_legacy_backend(backend: str) -> bool:
"""Check if backend is a legacy static backend (not a dynamic pool).
Pool backends: pool_1, pool_2, ..., pool_100 (dynamic, zero-reload)
Legacy backends: {domain}_backend (static, requires reload)
Args:
backend: Backend name to check.
Returns:
True if legacy backend, False if pool backend.
"""
return not backend.startswith("pool_")
def get_legacy_backend_name(domain: str) -> str:
"""Convert domain to legacy backend name format.
Args:
domain: Domain name
Returns:
Legacy backend name (e.g., 'api_example_com_backend')
"""
return f"{domain_to_backend(domain)}_backend"
def get_backend_and_prefix(domain: str) -> tuple[str, str]:
"""Look up backend and determine server name prefix for a domain.
Args:
domain: The domain name to look up
Returns:
Tuple of (backend_name, server_prefix)
Raises:
ValueError: If domain cannot be mapped to a valid backend
"""
backend = get_domain_backend(domain)
if not backend:
backend = get_legacy_backend_name(domain)
if backend.startswith("pool_"):
server_prefix = backend
else:
server_prefix = domain_to_backend(domain)
return backend, server_prefix
def load_servers_config() -> dict[str, Any]:
"""Load servers configuration from JSON file.
Returns:
Dictionary with server configurations
"""
try:
content = _read_file(SERVERS_FILE)
return json.loads(content)
except FileNotFoundError:
return {}
except json.JSONDecodeError as e:
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
return {}
def save_servers_config(config: dict[str, Any]) -> None:
"""Save servers configuration to JSON file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args:
config: Dictionary with server configurations
"""
atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2))
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add server configuration to persistent storage with file locking.
Args:
domain: Domain name
slot: Server slot (1 to MAX_SLOTS)
ip: Server IP address
http_port: HTTP port
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain not in config:
config[domain] = {}
config[domain][str(slot)] = {"ip": ip, "http_port": http_port}
save_servers_config(config)
def remove_server_from_config(domain: str, slot: int) -> None:
"""Remove server configuration from persistent storage with file locking.
Args:
domain: Domain name
slot: Server slot to remove
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain in config and str(slot) in config[domain]:
del config[domain][str(slot)]
if not config[domain]:
del config[domain]
save_servers_config(config)
def remove_domain_from_config(domain: str) -> None:
"""Remove domain from persistent config with file locking.
Args:
domain: Domain name to remove
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain in config:
del config[domain]
save_servers_config(config)
def get_shared_domain(domain: str) -> Optional[str]:
"""Get the parent domain that this domain shares a pool with.
Args:
domain: Domain name to check
Returns:
Parent domain name if sharing, None otherwise
"""
config = load_servers_config()
domain_config = config.get(domain, {})
return domain_config.get("_shares")
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
"""Add a domain that shares a pool with another domain.
Args:
domain: New domain name
shares_with: Existing domain to share pool with
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
config[domain] = {"_shares": shares_with}
save_servers_config(config)
def get_domains_sharing_pool(pool: str) -> list[str]:
"""Get all domains that use a specific pool.
Args:
pool: Pool name (e.g., 'pool_5')
Returns:
List of domain names using this pool
"""
domains = []
for domain, backend in get_map_contents():
if backend == pool and not domain.startswith("."):
domains.append(domain)
return domains
def is_shared_domain(domain: str) -> bool:
"""Check if a domain is sharing another domain's pool.
Args:
domain: Domain name to check
Returns:
True if domain has _shares reference, False otherwise
"""
config = load_servers_config()
domain_config = config.get(domain, {})
return "_shares" in domain_config
# Certificate configuration functions
def load_certs_config() -> list[str]:
"""Load certificate domain list from JSON file.
Returns:
List of domain names
"""
try:
content = _read_file(CERTS_FILE)
data = json.loads(content)
return data.get("domains", [])
except FileNotFoundError:
return []
except json.JSONDecodeError as e:
logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e)
return []
def save_certs_config(domains: list[str]) -> None:
"""Save certificate domain list to JSON file atomically.
Args:
domains: List of domain names
"""
atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2))
def add_cert_to_config(domain: str) -> None:
"""Add a domain to the certificate config.
Args:
domain: Domain name to add
"""
with file_lock(f"{CERTS_FILE}.lock"):
domains = load_certs_config()
if domain not in domains:
domains.append(domain)
save_certs_config(domains)
def remove_cert_from_config(domain: str) -> None:
"""Remove a domain from the certificate config.
Args:
domain: Domain name to remove
"""
with file_lock(f"{CERTS_FILE}.lock"):
domains = load_certs_config()
if domain in domains:
domains.remove(domain)
save_certs_config(domains)