diff --git a/Dockerfile b/Dockerfile index f624164..596699d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,8 @@ COPY haproxy_mcp/ ./haproxy_mcp/ FROM python:3.11-slim +RUN apt-get update && apt-get install -y --no-install-recommends openssh-client && rm -rf /var/lib/apt/lists/* + WORKDIR /app COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages diff --git a/haproxy_mcp/config.py b/haproxy_mcp/config.py index ffa3900..db199e3 100644 --- a/haproxy_mcp/config.py +++ b/haproxy_mcp/config.py @@ -44,6 +44,13 @@ MAX_SLOTS: int = int(os.getenv("HAPROXY_MAX_SLOTS", "10")) # Container configuration HAPROXY_CONTAINER: str = os.getenv("HAPROXY_CONTAINER", "haproxy") +# SSH remote execution (when MCP runs on a different host from HAProxy) +SSH_HOST: str = os.getenv("SSH_HOST", "") # Empty = local mode +SSH_USER: str = os.getenv("SSH_USER", "root") +SSH_KEY: str = os.getenv("SSH_KEY", "") # Path to SSH private key +SSH_PORT: int = int(os.getenv("SSH_PORT", "22")) +REMOTE_MODE: bool = bool(SSH_HOST) + # Validation patterns - compiled once for performance DOMAIN_PATTERN = re.compile( r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?' diff --git a/haproxy_mcp/file_ops.py b/haproxy_mcp/file_ops.py index ba06697..847eed7 100644 --- a/haproxy_mcp/file_ops.py +++ b/haproxy_mcp/file_ops.py @@ -12,6 +12,7 @@ from .config import ( WILDCARDS_MAP_FILE, SERVERS_FILE, CERTS_FILE, + REMOTE_MODE, logger, ) from .validation import domain_to_backend @@ -21,22 +22,19 @@ from .validation import domain_to_backend def file_lock(lock_path: str) -> Generator[None, None, None]: """Acquire exclusive file lock for atomic operations. - This context manager provides a consistent locking mechanism for - read-modify-write operations on configuration files to prevent - race conditions during concurrent access. + In REMOTE_MODE, locking is skipped (single-writer assumption + with atomic writes on the remote host). Args: lock_path: Path to the lock file (typically config_file.lock) Yields: None - the lock is held for the duration of the context - - Example: - with file_lock("/path/to/config.json.lock"): - config = load_config() - config["key"] = "value" - save_config(config) """ + if REMOTE_MODE: + yield + return + with open(lock_path, 'w') as lock_file: fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) try: @@ -55,6 +53,11 @@ def atomic_write_file(file_path: str, content: str) -> None: Raises: IOError: If write fails """ + if REMOTE_MODE: + from .ssh_ops import remote_write_file + remote_write_file(file_path, content) + return + dir_path = os.path.dirname(file_path) fd = None temp_path = None @@ -91,29 +94,49 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]: """ entries = [] try: - with open(file_path, "r", encoding="utf-8") as f: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_SH) - except OSError as e: - logger.debug("File locking not supported for %s: %s", file_path, e) - try: - for line in f: - line = line.strip() - if not line or line.startswith("#"): - continue - parts = line.split() - if len(parts) >= 2: - entries.append((parts[0], parts[1])) - finally: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) - except OSError as e: - logger.debug("File unlock failed for %s: %s", file_path, e) + content = _read_file(file_path) + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) >= 2: + entries.append((parts[0], parts[1])) except FileNotFoundError: logger.debug("Map file not found: %s", file_path) return entries +def _read_file(file_path: str) -> str: + """Read a file locally or remotely based on REMOTE_MODE. + + Args: + file_path: Path to the file + + Returns: + File contents as string + + Raises: + FileNotFoundError: If file doesn't exist + """ + if REMOTE_MODE: + from .ssh_ops import remote_read_file + return remote_read_file(file_path) + + with open(file_path, "r", encoding="utf-8") as f: + try: + fcntl.flock(f.fileno(), fcntl.LOCK_SH) + except OSError as e: + logger.debug("File locking not supported for %s: %s", file_path, e) + try: + return f.read() + finally: + try: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + except OSError: + pass + + def get_map_contents() -> list[tuple[str, str]]: """Read both domains.map and wildcards.map and return combined entries. @@ -250,24 +273,14 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]: def load_servers_config() -> dict[str, Any]: - """Load servers configuration from JSON file with file locking. + """Load servers configuration from JSON file. Returns: Dictionary with server configurations """ try: - with open(SERVERS_FILE, "r", encoding="utf-8") as f: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_SH) - except OSError: - logger.debug("File locking not supported for %s", SERVERS_FILE) - try: - return json.load(f) - finally: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) - except OSError: - pass + content = _read_file(SERVERS_FILE) + return json.loads(content) except FileNotFoundError: return {} except json.JSONDecodeError as e: @@ -398,19 +411,9 @@ def load_certs_config() -> list[str]: List of domain names """ try: - with open(CERTS_FILE, "r", encoding="utf-8") as f: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_SH) - except OSError as e: - logger.debug("File locking not supported for %s: %s", CERTS_FILE, e) - try: - data = json.load(f) - return data.get("domains", []) - finally: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) - except OSError as e: - logger.debug("File unlock failed for %s: %s", CERTS_FILE, e) + content = _read_file(CERTS_FILE) + data = json.loads(content) + return data.get("domains", []) except FileNotFoundError: return [] except json.JSONDecodeError as e: diff --git a/haproxy_mcp/haproxy_client.py b/haproxy_mcp/haproxy_client.py index 5411ea4..de74758 100644 --- a/haproxy_mcp/haproxy_client.py +++ b/haproxy_mcp/haproxy_client.py @@ -1,7 +1,6 @@ """HAProxy Runtime API client functions.""" import socket -import subprocess import select import time @@ -14,6 +13,7 @@ from .config import ( SUBPROCESS_TIMEOUT, ) from .exceptions import HaproxyError +from .ssh_ops import run_command def haproxy_cmd(command: str) -> str: @@ -147,23 +147,23 @@ def reload_haproxy() -> tuple[bool, str]: Tuple of (success, message) """ try: - validate = subprocess.run( + validate = run_command( ["podman", "exec", HAPROXY_CONTAINER, "haproxy", "-c", "-f", "/usr/local/etc/haproxy/haproxy.cfg"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT + timeout=SUBPROCESS_TIMEOUT, ) if validate.returncode != 0: return False, f"Config validation failed:\n{validate.stderr}" - result = subprocess.run( + result = run_command( ["podman", "kill", "--signal", "USR2", HAPROXY_CONTAINER], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT + timeout=SUBPROCESS_TIMEOUT, ) if result.returncode != 0: return False, f"Reload failed: {result.stderr}" return True, "OK" - except subprocess.TimeoutExpired: + except TimeoutError: return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds" except FileNotFoundError: - return False, "podman command not found" + return False, "ssh/podman command not found" except OSError as e: return False, f"OS error: {e}" diff --git a/haproxy_mcp/ssh_ops.py b/haproxy_mcp/ssh_ops.py new file mode 100644 index 0000000..a8f66ec --- /dev/null +++ b/haproxy_mcp/ssh_ops.py @@ -0,0 +1,149 @@ +"""SSH remote execution for HAProxy MCP Server. + +When REMOTE_MODE is enabled (SSH_HOST is set), file I/O and subprocess +commands are executed on the remote HAProxy host via SSH. +""" + +import subprocess +from .config import ( + SSH_HOST, + SSH_USER, + SSH_KEY, + SSH_PORT, + REMOTE_MODE, + SUBPROCESS_TIMEOUT, + logger, +) + + +def _ssh_base_cmd() -> list[str]: + """Build base SSH command with options.""" + cmd = [ + "ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "BatchMode=yes", + "-o", "ConnectTimeout=10", + "-p", str(SSH_PORT), + ] + if SSH_KEY: + cmd.extend(["-i", SSH_KEY]) + cmd.append(f"{SSH_USER}@{SSH_HOST}") + return cmd + + +def remote_exec(command: str, timeout: int = SUBPROCESS_TIMEOUT) -> subprocess.CompletedProcess: + """Execute a command on the remote host via SSH. + + Args: + command: Shell command to execute remotely + timeout: Command timeout in seconds + + Returns: + CompletedProcess with stdout/stderr + + Raises: + subprocess.TimeoutExpired: If command times out + OSError: If SSH command fails to execute + """ + ssh_cmd = _ssh_base_cmd() + [command] + logger.debug("SSH exec: %s", command) + return subprocess.run( + ssh_cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + + +def remote_read_file(path: str) -> str: + """Read a file from the remote host. + + Args: + path: Absolute file path on remote host + + Returns: + File contents as string + + Raises: + FileNotFoundError: If file doesn't exist on remote + IOError: If read fails + """ + result = remote_exec(f"cat {path}") + if result.returncode != 0: + stderr = result.stderr.strip() + if "No such file" in stderr: + raise FileNotFoundError(f"Remote file not found: {path}") + raise IOError(f"Failed to read remote file {path}: {stderr}") + return result.stdout + + +def remote_write_file(path: str, content: str) -> None: + """Write content to a file on the remote host atomically. + + Uses temp file + mv for atomic write, matching local behavior. + + Args: + path: Absolute file path on remote host + content: Content to write + + Raises: + IOError: If write fails + """ + # Escape content for shell, use heredoc via stdin + ssh_cmd = _ssh_base_cmd() + # Atomic write: write to temp file, then rename + remote_script = f"tmpf=$(mktemp {path}.tmp.XXXXXX) && cat > \"$tmpf\" && mv \"$tmpf\" {path}" + ssh_cmd.append(remote_script) + + logger.debug("SSH write: %s (%d bytes)", path, len(content)) + result = subprocess.run( + ssh_cmd, + input=content, + capture_output=True, + text=True, + timeout=SUBPROCESS_TIMEOUT, + ) + if result.returncode != 0: + raise IOError(f"Failed to write remote file {path}: {result.stderr.strip()}") + + +def remote_file_exists(path: str) -> bool: + """Check if a file exists on the remote host. + + Args: + path: Absolute file path on remote host + + Returns: + True if file exists + """ + result = remote_exec(f"test -f {path} && echo yes || echo no") + return result.stdout.strip() == "yes" + + +def run_command(args: list[str], timeout: int = SUBPROCESS_TIMEOUT) -> subprocess.CompletedProcess: + """Execute a command locally or remotely based on REMOTE_MODE. + + Args: + args: Command and arguments as list + timeout: Command timeout in seconds + + Returns: + CompletedProcess with stdout/stderr + """ + if REMOTE_MODE: + # Join args into a shell command for SSH + # Quote arguments that contain spaces + quoted = [] + for a in args: + if " " in a or "'" in a or '"' in a: + quoted.append(f"'{a}'") + else: + quoted.append(a) + return remote_exec(" ".join(quoted), timeout=timeout) + else: + return subprocess.run( + args, + capture_output=True, + text=True, + timeout=timeout, + ) diff --git a/haproxy_mcp/tools/certificates.py b/haproxy_mcp/tools/certificates.py index 4ad170b..6f04484 100644 --- a/haproxy_mcp/tools/certificates.py +++ b/haproxy_mcp/tools/certificates.py @@ -1,7 +1,6 @@ """Certificate management tools for HAProxy MCP Server.""" import os -import subprocess from datetime import datetime from typing import Annotated @@ -13,6 +12,7 @@ from ..config import ( CERTS_DIR, CERTS_DIR_CONTAINER, ACME_HOME, + REMOTE_MODE, ) from ..exceptions import HaproxyError from ..validation import validate_domain @@ -21,58 +21,48 @@ from ..file_ops import ( load_certs_config, add_cert_to_config, remove_cert_from_config, + _read_file, ) +from ..ssh_ops import run_command, remote_file_exists # acme.sh script path (derived from ACME_HOME) -ACME_SH = os.path.join(ACME_HOME, "acme.sh") +ACME_SH = f"{ACME_HOME}/acme.sh" # Longer timeout for certificate operations (ACME can be slow) CERT_TIMEOUT = 120 +def _file_exists(path: str) -> bool: + """Check file existence locally or remotely.""" + if REMOTE_MODE: + return remote_file_exists(path) + return os.path.exists(path) + + def get_pem_paths(domain: str) -> tuple[str, str]: - """Get host and container PEM paths for a domain. - - Args: - domain: Domain name - - Returns: - Tuple of (host_path, container_path) - """ + """Get host and container PEM paths for a domain.""" return ( - os.path.join(CERTS_DIR, f"{domain}.pem"), - os.path.join(CERTS_DIR_CONTAINER, f"{domain}.pem") + f"{CERTS_DIR}/{domain}.pem", + f"{CERTS_DIR_CONTAINER}/{domain}.pem", ) def load_cert_to_haproxy(domain: str) -> tuple[bool, str]: - """Load a certificate into HAProxy via Runtime API (zero-downtime). - - Args: - domain: Domain name - - Returns: - Tuple of (success, message) - """ + """Load a certificate into HAProxy via Runtime API (zero-downtime).""" host_path, container_path = get_pem_paths(domain) - if not os.path.exists(host_path): + if not _file_exists(host_path): return False, f"PEM file not found: {host_path}" try: - # Read PEM content - with open(host_path, "r", encoding="utf-8") as f: - pem_content = f.read() + pem_content = _read_file(host_path) - # Check if cert already loaded result = haproxy_cmd("show ssl cert") if container_path in result: - # Update existing cert haproxy_cmd(f"set ssl cert {container_path} <<\n{pem_content}\n") haproxy_cmd(f"commit ssl cert {container_path}") return True, "updated" else: - # Add new cert haproxy_cmd(f"new ssl cert {container_path}") haproxy_cmd(f"set ssl cert {container_path} <<\n{pem_content}\n") haproxy_cmd(f"commit ssl cert {container_path}") @@ -87,40 +77,24 @@ def load_cert_to_haproxy(domain: str) -> tuple[bool, str]: def unload_cert_from_haproxy(domain: str) -> tuple[bool, str]: - """Unload a certificate from HAProxy via Runtime API. - - Args: - domain: Domain name - - Returns: - Tuple of (success, message) - """ + """Unload a certificate from HAProxy via Runtime API.""" _, container_path = get_pem_paths(domain) try: - # Check if cert is loaded result = haproxy_cmd("show ssl cert") if container_path not in result: return True, "not loaded" - - # Delete from HAProxy runtime haproxy_cmd(f"del ssl cert {container_path}") return True, "unloaded" - except HaproxyError as e: logger.error("HAProxy error unloading certificate %s: %s", domain, e) return False, str(e) def restore_certificates() -> int: - """Restore all certificates from config to HAProxy on startup. - - Returns: - Number of certificates restored - """ + """Restore all certificates from config to HAProxy on startup.""" domains = load_certs_config() restored = 0 - for domain in domains: success, msg = load_cert_to_haproxy(domain) if success: @@ -128,23 +102,13 @@ def restore_certificates() -> int: logger.debug("Certificate %s: %s", domain, msg) else: logger.warning("Failed to restore certificate %s: %s", domain, msg) - return restored -# ============================================================================= -# Implementation functions (module-level) -# ============================================================================= - - def _haproxy_list_certs_impl() -> str: """Implementation of haproxy_list_certs.""" try: - result = subprocess.run( - [ACME_SH, "--list"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) + result = run_command([ACME_SH, "--list"], timeout=SUBPROCESS_TIMEOUT) if result.returncode != 0: return f"Error: {result.stderr}" @@ -152,16 +116,14 @@ def _haproxy_list_certs_impl() -> str: if len(lines) <= 1: return "No certificates found" - # Get HAProxy loaded certs try: haproxy_certs = haproxy_cmd("show ssl cert") except HaproxyError as e: logger.debug("Could not get HAProxy certs: %s", e) haproxy_certs = "" - # Parse and format output certs = [] - for line in lines[1:]: # Skip header + for line in lines[1:]: parts = line.split() if len(parts) >= 4: domain = parts[0] @@ -178,11 +140,10 @@ def _haproxy_list_certs_impl() -> str: else: renew = part - # Check deployment status host_path, container_path = get_pem_paths(domain) if container_path in haproxy_certs: status = "loaded" - elif os.path.exists(host_path): + elif _file_exists(host_path): status = "file exists (not loaded)" else: status = "not deployed" @@ -190,15 +151,12 @@ def _haproxy_list_certs_impl() -> str: certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}") return "\n\n".join(certs) if certs else "No certificates found" - except subprocess.TimeoutExpired: + except TimeoutError: return "Error: Command timed out" except FileNotFoundError: return "Error: acme.sh not found" - except subprocess.SubprocessError as e: - logger.error("Subprocess error listing certificates: %s", e) - return f"Error: {e}" except OSError as e: - logger.error("OS error listing certificates: %s", e) + logger.error("Error listing certificates: %s", e) return f"Error: {e}" @@ -208,24 +166,26 @@ def _haproxy_cert_info_impl(domain: str) -> str: return "Error: Invalid domain format" host_path, container_path = get_pem_paths(domain) - if not os.path.exists(host_path): + if not _file_exists(host_path): return f"Error: Certificate not found for {domain}" try: - # Use openssl to get certificate info - result = subprocess.run( + result = run_command( ["openssl", "x509", "-in", host_path, "-noout", "-subject", "-issuer", "-dates", "-ext", "subjectAltName"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT + timeout=SUBPROCESS_TIMEOUT, ) if result.returncode != 0: return f"Error reading certificate: {result.stderr}" - # Get file info - stat = os.stat(host_path) - modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S") + # Get file modification time + stat_result = run_command(["stat", "-c", "%Y", host_path]) + if stat_result.returncode == 0: + ts = int(stat_result.stdout.strip()) + modified = datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%M:%S") + else: + modified = "unknown" - # Check HAProxy status try: haproxy_certs = haproxy_cmd("show ssl cert") loaded = "Yes" if container_path in haproxy_certs else "No" @@ -242,9 +202,9 @@ def _haproxy_cert_info_impl(domain: str) -> str: result.stdout.strip() ] return "\n".join(info) - except subprocess.TimeoutExpired: + except TimeoutError: return "Error: Command timed out" - except (subprocess.SubprocessError, OSError) as e: + except OSError as e: logger.error("Error getting certificate info for %s: %s", domain, e) return f"Error: {e}" @@ -254,62 +214,29 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str: if not validate_domain(domain): return "Error: Invalid domain format" - # Check if CF_Token is available - if not os.environ.get("CF_Token"): - secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini") - if os.path.exists(secrets_file): - try: - with open(secrets_file) as f: - for line in f: - if "=" in line and "token" in line.lower(): - token = line.split("=", 1)[1].strip().strip('"').strip("'") - os.environ["CF_Token"] = token - break - except (IOError, OSError) as e: - logger.warning("Failed to read Cloudflare token: %s", e) - - if not os.environ.get("CF_Token"): - return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini" - - # Check if certificate already exists - cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") - if os.path.exists(cert_dir): + cert_dir = f"{ACME_HOME}/{domain}_ecc" + if _file_exists(cert_dir): return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew." - # Build acme.sh command (without reload - we'll do it via Runtime API) host_path, _ = get_pem_paths(domain) - - # Create PEM after issuance install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}" - cmd = [ - ACME_SH, "--issue", - "--dns", "dns_cf", - "-d", domain - ] - + cmd = [ACME_SH, "--issue", "--dns", "dns_cf", "-d", domain] if wildcard: cmd.extend(["-d", f"*.{domain}"]) - cmd.extend(["--reloadcmd", install_cmd]) try: logger.info("Issuing certificate for %s", domain) - result = subprocess.run( - cmd, - capture_output=True, text=True, timeout=CERT_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) + result = run_command(cmd, timeout=CERT_TIMEOUT) if result.returncode != 0: error_msg = result.stderr or result.stdout return f"Error issuing certificate:\n{error_msg}" - # Load into HAProxy via Runtime API (zero-downtime) - if os.path.exists(host_path): + if _file_exists(host_path): success, msg = load_cert_to_haproxy(domain) if success: - # Save to config for persistence add_cert_to_config(domain) return f"Certificate issued and loaded for {domain} ({msg})" else: @@ -317,9 +244,9 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str: else: return f"Certificate issued but PEM file not created. Check {host_path}" - except subprocess.TimeoutExpired: + except TimeoutError: return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s" - except (subprocess.SubprocessError, OSError) as e: + except OSError as e: logger.error("Error issuing certificate for %s: %s", domain, e) return f"Error: {e}" @@ -329,8 +256,8 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str: if not validate_domain(domain): return "Error: Invalid domain format" - cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") - if not os.path.exists(cert_dir): + cert_dir = f"{ACME_HOME}/{domain}_ecc" + if not _file_exists(cert_dir): return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first." cmd = [ACME_SH, "--renew", "-d", domain] @@ -339,11 +266,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str: try: logger.info("Renewing certificate for %s", domain) - result = subprocess.run( - cmd, - capture_output=True, text=True, timeout=CERT_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) + result = run_command(cmd, timeout=CERT_TIMEOUT) output = result.stdout + result.stderr @@ -351,10 +274,8 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str: return f"Certificate for {domain} not due for renewal. Use force=True to force renewal." if "Cert success" in output or result.returncode == 0: - # Reload into HAProxy via Runtime API success, msg = load_cert_to_haproxy(domain) if success: - # Ensure in config add_cert_to_config(domain) return f"Certificate renewed and reloaded for {domain} ({msg})" else: @@ -362,11 +283,11 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str: else: return f"Error renewing certificate:\n{output}" - except subprocess.TimeoutExpired: + except TimeoutError: return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s" except FileNotFoundError: return "Error: acme.sh not found" - except (subprocess.SubprocessError, OSError) as e: + except OSError as e: logger.error("Error renewing certificate for %s: %s", domain, e) return f"Error: {e}" @@ -375,19 +296,12 @@ def _haproxy_renew_all_certs_impl() -> str: """Implementation of haproxy_renew_all_certs.""" try: logger.info("Running certificate renewal cron") - result = subprocess.run( - [ACME_SH, "--cron"], - capture_output=True, text=True, timeout=CERT_TIMEOUT * 3, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) + result = run_command([ACME_SH, "--cron"], timeout=CERT_TIMEOUT * 3) output = result.stdout + result.stderr - - # Count renewals renewed = output.count("Cert success") skipped = output.count("Skip") - # Reload any renewed certs into HAProxy if renewed > 0: domains = load_certs_config() reloaded = 0 @@ -403,11 +317,11 @@ def _haproxy_renew_all_certs_impl() -> str: else: return "Renewal check completed" - except subprocess.TimeoutExpired: + except TimeoutError: return "Error: Renewal cron timed out" except FileNotFoundError: return "Error: acme.sh not found" - except (subprocess.SubprocessError, OSError) as e: + except OSError as e: logger.error("Error running certificate renewal cron: %s", e) return f"Error: {e}" @@ -417,46 +331,44 @@ def _haproxy_delete_cert_impl(domain: str) -> str: if not validate_domain(domain): return "Error: Invalid domain format" - cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") + cert_dir = f"{ACME_HOME}/{domain}_ecc" host_path, _ = get_pem_paths(domain) - if not os.path.exists(cert_dir) and not os.path.exists(host_path): + if not _file_exists(cert_dir) and not _file_exists(host_path): return f"Error: No certificate found for {domain}" errors = [] deleted = [] - # Unload from HAProxy first (zero-downtime) success, msg = unload_cert_from_haproxy(domain) if success: deleted.append(f"HAProxy ({msg})") else: errors.append(f"HAProxy unload: {msg}") - # Remove from acme.sh - if os.path.exists(cert_dir): + if _file_exists(cert_dir): try: - result = subprocess.run( + result = run_command( [ACME_SH, "--remove", "-d", domain], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} + timeout=SUBPROCESS_TIMEOUT, ) if result.returncode == 0: deleted.append("acme.sh") else: errors.append(f"acme.sh: {result.stderr}") - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e: + except (TimeoutError, OSError) as e: errors.append(f"acme.sh: {e}") - # Remove PEM file - if os.path.exists(host_path): + if _file_exists(host_path): try: - os.remove(host_path) - deleted.append("PEM file") + result = run_command(["rm", "-f", host_path]) + if result.returncode == 0: + deleted.append("PEM file") + else: + errors.append(f"PEM file: {result.stderr}") except OSError as e: errors.append(f"PEM file: {e}") - # Remove from config remove_cert_from_config(domain) result_parts = [] @@ -474,7 +386,7 @@ def _haproxy_load_cert_impl(domain: str) -> str: return "Error: Invalid domain format" host_path, _ = get_pem_paths(domain) - if not os.path.exists(host_path): + if not _file_exists(host_path): return f"Error: PEM file not found: {host_path}" success, msg = load_cert_to_haproxy(domain) @@ -485,11 +397,6 @@ def _haproxy_load_cert_impl(domain: str) -> str: return f"Error loading certificate: {msg}" -# ============================================================================= -# MCP Tool Registration -# ============================================================================= - - def register_certificate_tools(mcp): """Register certificate management tools with MCP server.""" diff --git a/haproxy_mcp/tools/configuration.py b/haproxy_mcp/tools/configuration.py index 78330bb..611a0a3 100644 --- a/haproxy_mcp/tools/configuration.py +++ b/haproxy_mcp/tools/configuration.py @@ -1,6 +1,5 @@ """Configuration management tools for HAProxy MCP Server.""" -import subprocess import time from ..config import ( @@ -18,6 +17,7 @@ from ..file_ops import ( get_domain_backend, get_backend_and_prefix, ) +from ..ssh_ops import run_command def restore_servers_from_config() -> int: @@ -167,17 +167,17 @@ def register_config_tools(mcp): Validation result or error details """ try: - result = subprocess.run( + result = run_command( ["podman", "exec", HAPROXY_CONTAINER, "haproxy", "-c", "-f", "/usr/local/etc/haproxy/haproxy.cfg"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT + timeout=SUBPROCESS_TIMEOUT, ) if result.returncode == 0: return "Configuration is valid" return f"Configuration errors:\n{result.stderr}" - except subprocess.TimeoutExpired: + except TimeoutError: return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds" except FileNotFoundError: - return "Error: podman command not found" + return "Error: ssh/podman command not found" except OSError as e: return f"Error: OS error: {e}" diff --git a/haproxy_mcp/tools/domains.py b/haproxy_mcp/tools/domains.py index c473b59..59b5256 100644 --- a/haproxy_mcp/tools/domains.py +++ b/haproxy_mcp/tools/domains.py @@ -1,8 +1,6 @@ """Domain management tools for HAProxy MCP Server.""" -import fcntl import os -import subprocess from typing import Annotated, Optional from pydantic import Field @@ -15,8 +13,10 @@ from ..config import ( MAX_SLOTS, SUBPROCESS_TIMEOUT, CERTS_DIR, + REMOTE_MODE, logger, ) +from ..ssh_ops import run_command, remote_file_exists from ..exceptions import HaproxyError from ..validation import validate_domain, validate_ip, validate_port_int from ..haproxy_client import haproxy_cmd @@ -115,6 +115,13 @@ def _rollback_domain_addition( logger.error("Failed to rollback map file after HAProxy error") +def _file_exists(path: str) -> bool: + """Check file existence locally or remotely.""" + if REMOTE_MODE: + return remote_file_exists(path) + return os.path.exists(path) + + def check_certificate_coverage(domain: str) -> tuple[bool, str]: """Check if a domain is covered by an existing certificate. @@ -124,34 +131,35 @@ def check_certificate_coverage(domain: str) -> tuple[bool, str]: Returns: Tuple of (is_covered, certificate_name or message) """ - if not os.path.isdir(CERTS_DIR): + if REMOTE_MODE: + dir_check = run_command(["test", "-d", CERTS_DIR]) + if dir_check.returncode != 0: + return False, "Certificate directory not found" + elif not os.path.isdir(CERTS_DIR): return False, "Certificate directory not found" # Check for exact match first - exact_pem = os.path.join(CERTS_DIR, f"{domain}.pem") - if os.path.exists(exact_pem): + exact_pem = f"{CERTS_DIR}/{domain}.pem" + if _file_exists(exact_pem): return True, domain # Check for wildcard coverage (e.g., api.example.com covered by *.example.com) parts = domain.split(".") if len(parts) >= 2: - # Try parent domain (example.com for api.example.com) parent_domain = ".".join(parts[1:]) - parent_pem = os.path.join(CERTS_DIR, f"{parent_domain}.pem") + parent_pem = f"{CERTS_DIR}/{parent_domain}.pem" - if os.path.exists(parent_pem): - # Verify the certificate has wildcard SAN + if _file_exists(parent_pem): try: - result = subprocess.run( + result = run_command( ["openssl", "x509", "-in", parent_pem, "-noout", "-ext", "subjectAltName"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT + timeout=SUBPROCESS_TIMEOUT, ) if result.returncode == 0: - # Check if wildcard covers this domain wildcard = f"*.{parent_domain}" if wildcard in result.stdout: return True, f"{parent_domain} (wildcard)" - except (subprocess.TimeoutExpired, OSError): + except (TimeoutError, OSError): pass return False, "No matching certificate" @@ -235,96 +243,92 @@ def register_domain_tools(mcp): return "Error: Cannot specify both ip and share_with (shared domains use existing servers)" # Use file locking for the entire pool allocation operation - lock_path = f"{MAP_FILE}.lock" - with open(lock_path, 'w') as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + from ..file_ops import file_lock + with file_lock(f"{MAP_FILE}.lock"): + # Read map contents once for both existence check and pool lookup + entries = get_map_contents() + + # Check if domain already exists (using cached entries) + for domain_entry, backend in entries: + if domain_entry == domain: + return f"Error: Domain {domain} already exists (mapped to {backend})" + + # Build used pools and registered domains sets + used_pools: set[str] = set() + registered_domains: set[str] = set() + for entry_domain, backend in entries: + if backend.startswith("pool_"): + used_pools.add(backend) + if not entry_domain.startswith("."): + registered_domains.add(entry_domain) + + # Handle share_with: reuse existing domain's pool + if share_with: + share_backend = get_domain_backend(share_with) + if not share_backend: + return f"Error: Domain {share_with} not found" + if not share_backend.startswith("pool_"): + return f"Error: Cannot share with legacy backend {share_backend}" + pool = share_backend + else: + # Find available pool + pool = _find_available_pool(entries, used_pools) + if not pool: + return f"Error: All {POOL_COUNT} pool backends are in use" + + # Check if this is a subdomain of an existing domain + is_subdomain, parent_domain = _check_subdomain(domain, registered_domains) + try: - # Read map contents once for both existence check and pool lookup - entries = get_map_contents() - - # Check if domain already exists (using cached entries) - for domain_entry, backend in entries: - if domain_entry == domain: - return f"Error: Domain {domain} already exists (mapped to {backend})" - - # Build used pools and registered domains sets - used_pools: set[str] = set() - registered_domains: set[str] = set() - for entry_domain, backend in entries: - if backend.startswith("pool_"): - used_pools.add(backend) - if not entry_domain.startswith("."): - registered_domains.add(entry_domain) - - # Handle share_with: reuse existing domain's pool - if share_with: - share_backend = get_domain_backend(share_with) - if not share_backend: - return f"Error: Domain {share_with} not found" - if not share_backend.startswith("pool_"): - return f"Error: Cannot share with legacy backend {share_backend}" - pool = share_backend - else: - # Find available pool - pool = _find_available_pool(entries, used_pools) - if not pool: - return f"Error: All {POOL_COUNT} pool backends are in use" - - # Check if this is a subdomain of an existing domain - is_subdomain, parent_domain = _check_subdomain(domain, registered_domains) - + # Save to disk first (atomic write for persistence) + entries.append((domain, pool)) + if not is_subdomain: + entries.append((f".{domain}", pool)) try: - # Save to disk first (atomic write for persistence) - entries.append((domain, pool)) - if not is_subdomain: - entries.append((f".{domain}", pool)) - try: - save_map_file(entries) - except IOError as e: - return f"Error: Failed to save map file: {e}" - - # Update HAProxy maps via Runtime API - try: - _update_haproxy_maps(domain, pool, is_subdomain) - except HaproxyError as e: - _rollback_domain_addition(domain, entries) - return f"Error: Failed to update HAProxy map: {e}" - - # Handle server configuration based on mode - if share_with: - # Save shared domain reference - add_shared_domain_to_config(domain, share_with) - result = f"Domain {domain} added, sharing pool {pool} with {share_with}" - elif ip: - # Add server to slot 1 - add_server_to_config(domain, 1, ip, http_port) - try: - server = f"{pool}_1" - haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}") - haproxy_cmd(f"set server {pool}/{server} state ready") - except HaproxyError as e: - remove_server_from_config(domain, 1) - return f"Domain {domain} added to {pool} but server config failed: {e}" - result = f"Domain {domain} added to {pool} with server {ip}:{http_port}" - else: - result = f"Domain {domain} added to {pool} (no servers configured)" - - if is_subdomain: - result += f" (subdomain of {parent_domain}, no wildcard)" - - # Check certificate coverage - cert_covered, cert_info = check_certificate_coverage(domain) - if cert_covered: - result += f"\nSSL: Using certificate {cert_info}" - else: - result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one." - - return result + save_map_file(entries) + except IOError as e: + return f"Error: Failed to save map file: {e}" + # Update HAProxy maps via Runtime API + try: + _update_haproxy_maps(domain, pool, is_subdomain) except HaproxyError as e: - return f"Error: {e}" - finally: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + _rollback_domain_addition(domain, entries) + return f"Error: Failed to update HAProxy map: {e}" + + # Handle server configuration based on mode + if share_with: + # Save shared domain reference + add_shared_domain_to_config(domain, share_with) + result = f"Domain {domain} added, sharing pool {pool} with {share_with}" + elif ip: + # Add server to slot 1 + add_server_to_config(domain, 1, ip, http_port) + try: + server = f"{pool}_1" + haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}") + haproxy_cmd(f"set server {pool}/{server} state ready") + except HaproxyError as e: + remove_server_from_config(domain, 1) + return f"Domain {domain} added to {pool} but server config failed: {e}" + result = f"Domain {domain} added to {pool} with server {ip}:{http_port}" + else: + result = f"Domain {domain} added to {pool} (no servers configured)" + + if is_subdomain: + result += f" (subdomain of {parent_domain}, no wildcard)" + + # Check certificate coverage + cert_covered, cert_info = check_certificate_coverage(domain) + if cert_covered: + result += f"\nSSL: Using certificate {cert_info}" + else: + result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one." + + return result + + except HaproxyError as e: + return f"Error: {e}" @mcp.tool() def haproxy_remove_domain( diff --git a/haproxy_mcp/tools/health.py b/haproxy_mcp/tools/health.py index d904c47..39208d9 100644 --- a/haproxy_mcp/tools/health.py +++ b/haproxy_mcp/tools/health.py @@ -1,7 +1,6 @@ """Health check tools for HAProxy MCP Server.""" import json -import os import subprocess import time from typing import Annotated, Any @@ -18,6 +17,8 @@ from ..validation import validate_domain, validate_backend_name from ..haproxy_client import haproxy_cmd from ..file_ops import get_backend_and_prefix from ..utils import parse_stat_csv, parse_servers_state +from ..ssh_ops import run_command, remote_file_exists +from ..config import REMOTE_MODE def register_health_tools(mcp): @@ -65,9 +66,9 @@ def register_health_tools(mcp): # Check container status try: - container_result = subprocess.run( + container_result = run_command( ["podman", "inspect", "--format", "{{.State.Status}}", HAPROXY_CONTAINER], - capture_output=True, text=True, timeout=5 + timeout=5, ) if container_result.returncode == 0: container_status = container_result.stdout.strip() @@ -88,7 +89,8 @@ def register_health_tools(mcp): files_ok = True file_status: dict[str, str] = {} for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]: - if os.path.exists(path): + exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path) + if exists: file_status[name] = "ok" else: file_status[name] = "missing" diff --git a/k8s/deployment.yaml b/k8s/deployment.yaml index e54d52e..453d70b 100644 --- a/k8s/deployment.yaml +++ b/k8s/deployment.yaml @@ -32,8 +32,20 @@ spec: value: "10.253.100.107" - name: HAPROXY_PORT value: "9999" + - name: SSH_HOST + value: "10.253.100.107" + - name: SSH_USER + value: "root" + - name: SSH_KEY + value: "/root/.ssh/id_rsa" + - name: SSH_PORT + value: "22" - name: LOG_LEVEL value: "INFO" + volumeMounts: + - name: ssh-key + mountPath: /root/.ssh + readOnly: true readinessProbe: tcpSocket: port: 8000 @@ -51,3 +63,8 @@ spec: limits: memory: "256Mi" cpu: "500m" + volumes: + - name: ssh-key + secret: + secretName: haproxy-ssh-key + defaultMode: 0600