refactor: Extract large functions, improve exception handling, remove duplicates

## Large function extraction
- servers.py: Extract 8 _impl functions from register_server_tools (449 lines)
- certificates.py: Extract 7 _impl functions from register_certificate_tools (386 lines)
- MCP tool wrappers now delegate to module-level implementation functions

## Exception handling improvements
- Replace 11 broad `except Exception` with specific types
- health.py: (OSError, subprocess.SubprocessError)
- configuration.py: (HaproxyError, IOError, OSError, ValueError)
- servers.py: (IOError, OSError, ValueError)
- certificates.py: FileNotFoundError, (subprocess.SubprocessError, OSError)

## Duplicate code extraction
- Add parse_servers_state() to utils.py (replaces 4 duplicate parsers)
- Add disable_server_slot() to utils.py (replaces duplicate patterns)
- Update health.py, servers.py, domains.py to use new helpers

## Other improvements
- Add TypedDict types in file_ops.py and health.py
- Set file permissions (0o600) for sensitive files
- Update tests to use specific exception types

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-03 13:23:51 +09:00
parent e66c5ddc7f
commit 06ab47aca8
12 changed files with 891 additions and 723 deletions

View File

@@ -94,8 +94,8 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH) fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError: except OSError as e:
pass # Continue without lock if not supported logger.debug("File locking not supported for %s: %s", file_path, e)
try: try:
for line in f: for line in f:
line = line.strip() line = line.strip()
@@ -107,10 +107,10 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
finally: finally:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError: except OSError as e:
pass logger.debug("File unlock failed for %s: %s", file_path, e)
except FileNotFoundError: except FileNotFoundError:
pass logger.debug("Map file not found: %s", file_path)
return entries return entries
@@ -420,16 +420,16 @@ def load_certs_config() -> list[str]:
with open(CERTS_FILE, "r", encoding="utf-8") as f: with open(CERTS_FILE, "r", encoding="utf-8") as f:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH) fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError: except OSError as e:
pass logger.debug("File locking not supported for %s: %s", CERTS_FILE, e)
try: try:
data = json.load(f) data = json.load(f)
return data.get("domains", []) return data.get("domains", [])
finally: finally:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError: except OSError as e:
pass logger.debug("File unlock failed for %s: %s", CERTS_FILE, e)
except FileNotFoundError: except FileNotFoundError:
return [] return []
except json.JSONDecodeError as e: except json.JSONDecodeError as e:

View File

@@ -67,8 +67,8 @@ def haproxy_cmd(command: str) -> str:
raise HaproxyError("Invalid UTF-8 in response") raise HaproxyError("Invalid UTF-8 in response")
except HaproxyError: except HaproxyError:
raise raise
except Exception as e: except (OSError, BlockingIOError, BrokenPipeError) as e:
raise HaproxyError(str(e)) from e raise HaproxyError(f"Socket error: {e}") from e
def haproxy_cmd_checked(command: str) -> str: def haproxy_cmd_checked(command: str) -> str:

View File

@@ -14,6 +14,7 @@ from ..config import (
CERTS_DIR_CONTAINER, CERTS_DIR_CONTAINER,
ACME_HOME, ACME_HOME,
) )
from ..exceptions import HaproxyError
from ..validation import validate_domain from ..validation import validate_domain
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import ( from ..file_ops import (
@@ -77,7 +78,11 @@ def load_cert_to_haproxy(domain: str) -> tuple[bool, str]:
haproxy_cmd(f"commit ssl cert {container_path}") haproxy_cmd(f"commit ssl cert {container_path}")
return True, "added" return True, "added"
except Exception as e: except HaproxyError as e:
logger.error("HAProxy error loading certificate %s: %s", domain, e)
return False, str(e)
except (IOError, OSError) as e:
logger.error("File error loading certificate %s: %s", domain, e)
return False, str(e) return False, str(e)
@@ -102,7 +107,8 @@ def unload_cert_from_haproxy(domain: str) -> tuple[bool, str]:
haproxy_cmd(f"del ssl cert {container_path}") haproxy_cmd(f"del ssl cert {container_path}")
return True, "unloaded" return True, "unloaded"
except Exception as e: except HaproxyError as e:
logger.error("HAProxy error unloading certificate %s: %s", domain, e)
return False, str(e) return False, str(e)
@@ -126,16 +132,13 @@ def restore_certificates() -> int:
return restored return restored
def register_certificate_tools(mcp): # =============================================================================
"""Register certificate management tools with MCP server.""" # Implementation functions (module-level)
# =============================================================================
@mcp.tool()
def haproxy_list_certs() -> str:
"""List all SSL/TLS certificates with expiry information.
Returns: def _haproxy_list_certs_impl() -> str:
List of certificates with domain, CA, created date, and renewal date """Implementation of haproxy_list_certs."""
"""
try: try:
result = subprocess.run( result = subprocess.run(
[ACME_SH, "--list"], [ACME_SH, "--list"],
@@ -152,7 +155,8 @@ def register_certificate_tools(mcp):
# Get HAProxy loaded certs # Get HAProxy loaded certs
try: try:
haproxy_certs = haproxy_cmd("show ssl cert") haproxy_certs = haproxy_cmd("show ssl cert")
except Exception: except HaproxyError as e:
logger.debug("Could not get HAProxy certs: %s", e)
haproxy_certs = "" haproxy_certs = ""
# Parse and format output # Parse and format output
@@ -190,17 +194,16 @@ def register_certificate_tools(mcp):
return "Error: Command timed out" return "Error: Command timed out"
except FileNotFoundError: except FileNotFoundError:
return "Error: acme.sh not found" return "Error: acme.sh not found"
except Exception as e: except subprocess.SubprocessError as e:
logger.error("Subprocess error listing certificates: %s", e)
return f"Error: {e}"
except OSError as e:
logger.error("OS error listing certificates: %s", e)
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_cert_info(
domain: Annotated[str, Field(description="Domain name to check (e.g., example.com)")]
) -> str:
"""Get detailed certificate information for a domain.
Shows expiry date, issuer, SANs, and file paths. def _haproxy_cert_info_impl(domain: str) -> str:
""" """Implementation of haproxy_cert_info."""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
@@ -226,7 +229,8 @@ def register_certificate_tools(mcp):
try: try:
haproxy_certs = haproxy_cmd("show ssl cert") haproxy_certs = haproxy_cmd("show ssl cert")
loaded = "Yes" if container_path in haproxy_certs else "No" loaded = "Yes" if container_path in haproxy_certs else "No"
except Exception: except HaproxyError as e:
logger.debug("Could not check HAProxy cert status: %s", e)
loaded = "Unknown" loaded = "Unknown"
info = [ info = [
@@ -240,20 +244,13 @@ def register_certificate_tools(mcp):
return "\n".join(info) return "\n".join(info)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return "Error: Command timed out" return "Error: Command timed out"
except Exception as e: except (subprocess.SubprocessError, OSError) as e:
logger.error("Error getting certificate info for %s: %s", domain, e)
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_issue_cert(
domain: Annotated[str, Field(description="Primary domain (e.g., example.com)")],
wildcard: Annotated[bool, Field(default=True, description="Include wildcard (*.example.com). Default: true")]
) -> str:
"""Issue a new SSL/TLS certificate using acme.sh with Cloudflare DNS.
Automatically deploys to HAProxy via Runtime API (zero-downtime). def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
"""Implementation of haproxy_issue_cert."""
Example: haproxy_issue_cert("example.com", wildcard=True)
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
@@ -268,7 +265,7 @@ def register_certificate_tools(mcp):
token = line.split("=", 1)[1].strip().strip('"').strip("'") token = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["CF_Token"] = token os.environ["CF_Token"] = token
break break
except Exception as e: except (IOError, OSError) as e:
logger.warning("Failed to read Cloudflare token: %s", e) logger.warning("Failed to read Cloudflare token: %s", e)
if not os.environ.get("CF_Token"): if not os.environ.get("CF_Token"):
@@ -322,20 +319,13 @@ def register_certificate_tools(mcp):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s" return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except Exception as e: except (subprocess.SubprocessError, OSError) as e:
logger.error("Error issuing certificate for %s: %s", domain, e)
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_renew_cert(
domain: Annotated[str, Field(description="Domain name to renew (e.g., example.com)")],
force: Annotated[bool, Field(default=False, description="Force renewal even if not due. Default: false")]
) -> str:
"""Renew an existing certificate.
Uses Runtime API for zero-downtime reload. def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
"""Implementation of haproxy_renew_cert."""
Example: haproxy_renew_cert("example.com", force=True)
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
@@ -374,15 +364,15 @@ def register_certificate_tools(mcp):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s" return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except Exception as e: except FileNotFoundError:
return "Error: acme.sh not found"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error renewing certificate for %s: %s", domain, e)
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_renew_all_certs() -> str:
"""Renew all certificates that are due for renewal.
This runs the acme.sh cron job to check and renew all certificates. def _haproxy_renew_all_certs_impl() -> str:
""" """Implementation of haproxy_renew_all_certs."""
try: try:
logger.info("Running certificate renewal cron") logger.info("Running certificate renewal cron")
result = subprocess.run( result = subprocess.run(
@@ -415,19 +405,15 @@ def register_certificate_tools(mcp):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return "Error: Renewal cron timed out" return "Error: Renewal cron timed out"
except Exception as e: except FileNotFoundError:
return "Error: acme.sh not found"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error running certificate renewal cron: %s", e)
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_delete_cert(
domain: Annotated[str, Field(description="Domain name to delete certificate for")]
) -> str:
"""Delete a certificate from acme.sh and HAProxy.
WARNING: This permanently removes the certificate. The domain will lose HTTPS. def _haproxy_delete_cert_impl(domain: str) -> str:
"""Implementation of haproxy_delete_cert."""
Example: haproxy_delete_cert("example.com")
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
@@ -459,7 +445,7 @@ def register_certificate_tools(mcp):
deleted.append("acme.sh") deleted.append("acme.sh")
else: else:
errors.append(f"acme.sh: {result.stderr}") errors.append(f"acme.sh: {result.stderr}")
except Exception as e: except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e:
errors.append(f"acme.sh: {e}") errors.append(f"acme.sh: {e}")
# Remove PEM file # Remove PEM file
@@ -467,7 +453,7 @@ def register_certificate_tools(mcp):
try: try:
os.remove(host_path) os.remove(host_path)
deleted.append("PEM file") deleted.append("PEM file")
except Exception as e: except OSError as e:
errors.append(f"PEM file: {e}") errors.append(f"PEM file: {e}")
# Remove from config # Remove from config
@@ -481,16 +467,9 @@ def register_certificate_tools(mcp):
return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted" return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted"
@mcp.tool()
def haproxy_load_cert(
domain: Annotated[str, Field(description="Domain name to load certificate for")]
) -> str:
"""Load/reload a certificate into HAProxy (zero-downtime).
Use after manually updating a certificate file. def _haproxy_load_cert_impl(domain: str) -> str:
"""Implementation of haproxy_load_cert."""
Example: haproxy_load_cert("example.com")
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
@@ -504,3 +483,89 @@ def register_certificate_tools(mcp):
return f"Certificate {domain} loaded into HAProxy ({msg})" return f"Certificate {domain} loaded into HAProxy ({msg})"
else: else:
return f"Error loading certificate: {msg}" return f"Error loading certificate: {msg}"
# =============================================================================
# MCP Tool Registration
# =============================================================================
def register_certificate_tools(mcp):
"""Register certificate management tools with MCP server."""
@mcp.tool()
def haproxy_list_certs() -> str:
"""List all SSL/TLS certificates with expiry information.
Returns:
List of certificates with domain, CA, created date, and renewal date
"""
return _haproxy_list_certs_impl()
@mcp.tool()
def haproxy_cert_info(
domain: Annotated[str, Field(description="Domain name to check (e.g., example.com)")]
) -> str:
"""Get detailed certificate information for a domain.
Shows expiry date, issuer, SANs, and file paths.
"""
return _haproxy_cert_info_impl(domain)
@mcp.tool()
def haproxy_issue_cert(
domain: Annotated[str, Field(description="Primary domain (e.g., example.com)")],
wildcard: Annotated[bool, Field(default=True, description="Include wildcard (*.example.com). Default: true")]
) -> str:
"""Issue a new SSL/TLS certificate using acme.sh with Cloudflare DNS.
Automatically deploys to HAProxy via Runtime API (zero-downtime).
Example: haproxy_issue_cert("example.com", wildcard=True)
"""
return _haproxy_issue_cert_impl(domain, wildcard)
@mcp.tool()
def haproxy_renew_cert(
domain: Annotated[str, Field(description="Domain name to renew (e.g., example.com)")],
force: Annotated[bool, Field(default=False, description="Force renewal even if not due. Default: false")]
) -> str:
"""Renew an existing certificate.
Uses Runtime API for zero-downtime reload.
Example: haproxy_renew_cert("example.com", force=True)
"""
return _haproxy_renew_cert_impl(domain, force)
@mcp.tool()
def haproxy_renew_all_certs() -> str:
"""Renew all certificates that are due for renewal.
This runs the acme.sh cron job to check and renew all certificates.
"""
return _haproxy_renew_all_certs_impl()
@mcp.tool()
def haproxy_delete_cert(
domain: Annotated[str, Field(description="Domain name to delete certificate for")]
) -> str:
"""Delete a certificate from acme.sh and HAProxy.
WARNING: This permanently removes the certificate. The domain will lose HTTPS.
Example: haproxy_delete_cert("example.com")
"""
return _haproxy_delete_cert_impl(domain)
@mcp.tool()
def haproxy_load_cert(
domain: Annotated[str, Field(description="Domain name to load certificate for")]
) -> str:
"""Load/reload a certificate into HAProxy (zero-downtime).
Use after manually updating a certificate file.
Example: haproxy_load_cert("example.com")
"""
return _haproxy_load_cert_impl(domain)

View File

@@ -117,7 +117,7 @@ def startup_restore() -> None:
cert_count = restore_certificates() cert_count = restore_certificates()
if cert_count > 0: if cert_count > 0:
logger.info("Restored %d certificates from config", cert_count) logger.info("Restored %d certificates from config", cert_count)
except Exception as e: except (HaproxyError, IOError, OSError, ValueError) as e:
logger.warning("Failed to restore certificates: %s", e) logger.warning("Failed to restore certificates: %s", e)
@@ -155,7 +155,7 @@ def register_config_tools(mcp):
try: try:
restored = restore_servers_from_config() restored = restore_servers_from_config()
return f"HAProxy configuration reloaded successfully ({restored} servers restored)" return f"HAProxy configuration reloaded successfully ({restored} servers restored)"
except Exception as e: except (HaproxyError, IOError, OSError, ValueError) as e:
logger.error("Failed to restore servers after reload: %s", e) logger.error("Failed to restore servers after reload: %s", e)
return f"HAProxy reloaded but server restore failed: {e}" return f"HAProxy reloaded but server restore failed: {e}"

View File

@@ -13,14 +13,12 @@ from ..config import (
WILDCARDS_MAP_FILE_CONTAINER, WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT, POOL_COUNT,
MAX_SLOTS, MAX_SLOTS,
StateField,
STATE_MIN_COLUMNS,
SUBPROCESS_TIMEOUT, SUBPROCESS_TIMEOUT,
CERTS_DIR, CERTS_DIR,
logger, logger,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import ( from ..file_ops import (
get_map_contents, get_map_contents,
@@ -31,6 +29,7 @@ from ..file_ops import (
remove_server_from_config, remove_server_from_config,
remove_domain_from_config, remove_domain_from_config,
) )
from ..utils import parse_servers_state, disable_server_slot
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]: def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
@@ -166,17 +165,17 @@ def register_domain_tools(mcp):
try: try:
domains = [] domains = []
state = haproxy_cmd("show servers state") state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
# Build server map from HAProxy state # Build server map from HAProxy state
server_map: dict[str, list] = {} server_map: dict[str, list[str]] = {}
for line in state.split("\n"): for backend, servers_dict in parsed_state.items():
parts = line.split() for server_name, srv_info in servers_dict.items():
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0": if srv_info["addr"] != "0.0.0.0":
backend = parts[StateField.BE_NAME]
if backend not in server_map: if backend not in server_map:
server_map[backend] = [] server_map[backend] = []
server_map[backend].append( server_map[backend].append(
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}" f"{server_name}={srv_info['addr']}:{srv_info['port']}"
) )
# Read from domains.map # Read from domains.map
@@ -220,7 +219,7 @@ def register_domain_tools(mcp):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not validate_ip(ip, allow_empty=True): if not validate_ip(ip, allow_empty=True):
return "Error: Invalid IP address format" return "Error: Invalid IP address format"
if not (1 <= http_port <= 65535): if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535" return "Error: Port must be between 1 and 65535"
# Use file locking for the entire pool allocation operation # Use file locking for the entire pool allocation operation
@@ -339,8 +338,7 @@ def register_domain_tools(mcp):
for slot in range(1, MAX_SLOTS + 1): for slot in range(1, MAX_SLOTS + 1):
server = f"{backend}_{slot}" server = f"{backend}_{slot}"
try: try:
haproxy_cmd(f"set server {backend}/{server} state maint") disable_server_slot(backend, server)
haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0")
except HaproxyError as e: except HaproxyError as e:
logger.warning( logger.warning(
"Failed to clear server %s/%s for domain %s: %s", "Failed to clear server %s/%s for domain %s: %s",

View File

@@ -12,14 +12,12 @@ from ..config import (
MAP_FILE, MAP_FILE,
SERVERS_FILE, SERVERS_FILE,
HAPROXY_CONTAINER, HAPROXY_CONTAINER,
StateField,
STATE_MIN_COLUMNS,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_backend_name from ..validation import validate_domain, validate_backend_name
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import get_backend_and_prefix from ..file_ops import get_backend_and_prefix
from ..utils import parse_stat_csv from ..utils import parse_stat_csv, parse_servers_state
def register_health_tools(mcp): def register_health_tools(mcp):
@@ -83,7 +81,7 @@ def register_health_tools(mcp):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
result["components"]["container"] = {"status": "timeout"} result["components"]["container"] = {"status": "timeout"}
result["status"] = "unhealthy" result["status"] = "unhealthy"
except Exception as e: except (OSError, subprocess.SubprocessError) as e:
result["components"]["container"] = {"status": "error", "error": str(e)} result["components"]["container"] = {"status": "error", "error": str(e)}
# Check configuration files # Check configuration files
@@ -144,12 +142,12 @@ def register_health_tools(mcp):
} }
# Parse server state for address info # Parse server state for address info
for line in state_output.split("\n"): parsed_state = parse_servers_state(state_output)
parts = line.split() backend_servers = parsed_state.get(backend, {})
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME] for server_name, srv_info in backend_servers.items():
addr = parts[StateField.SRV_ADDR] addr = srv_info["addr"]
port = parts[StateField.SRV_PORT] port = srv_info["port"]
# Skip disabled servers (0.0.0.0) # Skip disabled servers (0.0.0.0)
if addr == "0.0.0.0": if addr == "0.0.0.0":

View File

@@ -10,13 +10,11 @@ from ..config import (
MAX_SLOTS, MAX_SLOTS,
MAX_BULK_SERVERS, MAX_BULK_SERVERS,
MAX_SERVERS_JSON_SIZE, MAX_SERVERS_JSON_SIZE,
StateField,
StatField, StatField,
STATE_MIN_COLUMNS,
logger, logger,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip, validate_backend_name from ..validation import validate_domain, validate_ip, validate_backend_name, validate_port_int
from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch
from ..file_ops import ( from ..file_ops import (
get_backend_and_prefix, get_backend_and_prefix,
@@ -24,6 +22,7 @@ from ..file_ops import (
add_server_to_config, add_server_to_config,
remove_server_from_config, remove_server_from_config,
) )
from ..utils import parse_servers_state, disable_server_slot
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str: def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
@@ -51,65 +50,46 @@ def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str,
return server return server
def register_server_tools(mcp): # =============================================================================
"""Register server management tools with MCP server.""" # Implementation functions (module-level)
# =============================================================================
@mcp.tool()
def haproxy_list_servers(
domain: Annotated[str, Field(description="Domain name to list servers for (e.g., api.example.com)")]
) -> str:
"""List all servers for a domain with slot numbers, addresses, and status (UP/DOWN/MAINT).
Example: def _haproxy_list_servers_impl(domain: str) -> str:
haproxy_list_servers("api.example.com") """Implementation of haproxy_list_servers."""
# Output: pool_1_1: 10.0.0.1:8080 (UP)
# pool_1_2: 10.0.0.2:8080 (UP)
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
try: try:
backend, _ = get_backend_and_prefix(domain) backend, _ = get_backend_and_prefix(domain)
servers = []
state = haproxy_cmd("show servers state") state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
backend_servers = parsed_state.get(backend, {})
for line in state.split("\n"): if not backend_servers:
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
addr = parts[StateField.SRV_ADDR]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(
f"{parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
)
if not servers:
return f"Backend {backend} not found" return f"Backend {backend} not found"
servers = []
for server_name, srv_info in backend_servers.items():
addr = srv_info["addr"]
port = srv_info["port"]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(f"{server_name}: {addr}:{port} ({status})")
return f"Servers for {domain} ({backend}):\n" + "\n".join(servers) return f"Servers for {domain} ({backend}):\n" + "\n".join(servers)
except (HaproxyError, ValueError) as e: except (HaproxyError, ValueError) as e:
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_add_server(
domain: Annotated[str, Field(description="Domain name to add server to (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number 1-10, or 0 for auto-select next available slot")],
ip: Annotated[str, Field(description="Server IP address (IPv4 like 10.0.0.1 or IPv6 like 2001:db8::1)")],
http_port: Annotated[int, Field(default=80, description="HTTP port for backend connection (default: 80)")]
) -> str:
"""Add a server to a domain's backend pool for load balancing.
Each domain can have up to 10 servers (slots 1-10). HAProxy distributes traffic def _haproxy_add_server_impl(domain: str, slot: int, ip: str, http_port: int) -> str:
across all configured servers using round-robin. """Implementation of haproxy_add_server."""
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not ip: if not ip:
return "Error: IP address is required" return "Error: IP address is required"
if not validate_ip(ip): if not validate_ip(ip):
return "Error: Invalid IP address format" return "Error: Invalid IP address format"
if not (1 <= http_port <= 65535): if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535" return "Error: Port must be between 1 and 65535"
try: try:
@@ -118,13 +98,12 @@ def register_server_tools(mcp):
# Auto-select slot if slot <= 0 # Auto-select slot if slot <= 0
if slot <= 0: if slot <= 0:
state = haproxy_cmd("show servers state") state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
backend_servers = parsed_state.get(backend, {})
used_slots: set[int] = set() used_slots: set[int] = set()
for line in state.split("\n"): for server_name, srv_info in backend_servers.items():
parts = line.split() if srv_info["addr"] != "0.0.0.0":
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
if parts[StateField.SRV_ADDR] != "0.0.0.0":
# Extract slot number from server name (e.g., pool_1_3 -> 3) # Extract slot number from server name (e.g., pool_1_3 -> 3)
server_name = parts[StateField.SRV_NAME]
try: try:
used_slots.add(int(server_name.rsplit("_", 1)[1])) used_slots.add(int(server_name.rsplit("_", 1)[1]))
except (ValueError, IndexError): except (ValueError, IndexError):
@@ -151,15 +130,9 @@ def register_server_tools(mcp):
except (ValueError, IOError) as e: except (ValueError, IOError) as e:
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_add_servers(
domain: Annotated[str, Field(description="Domain name to add servers to (e.g., api.example.com)")],
servers: Annotated[str, Field(description='JSON array of servers. Each object: {"slot": 1-10, "ip": "10.0.0.1", "http_port": 80}. Example: \'[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]\'')]
) -> str:
"""Add multiple servers to a domain's backend at once (bulk operation).
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]') def _haproxy_add_servers_impl(domain: str, servers: str) -> str:
""" """Implementation of haproxy_add_servers."""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
@@ -221,7 +194,7 @@ def register_server_tools(mcp):
except (ValueError, TypeError): except (ValueError, TypeError):
validation_errors.append(f"Server {i+1}: http_port must be an integer") validation_errors.append(f"Server {i+1}: http_port must be an integer")
continue continue
if not (1 <= http_port <= 65535): if not validate_port_int(http_port):
validation_errors.append(f"Server {i+1}: port must be between 1 and 65535") validation_errors.append(f"Server {i+1}: port must be between 1 and 65535")
continue continue
@@ -267,12 +240,13 @@ def register_server_tools(mcp):
except HaproxyError as e: except HaproxyError as e:
failed_slots.append(slot) failed_slots.append(slot)
errors.append(f"slot {slot}: {e}") errors.append(f"slot {slot}: {e}")
except Exception as e: except (IOError, OSError, ValueError) as e:
# Rollback only successfully added configs on unexpected error # Rollback only successfully added configs on unexpected error
logger.error("Unexpected error during bulk server add for %s: %s", domain, e)
for slot in successfully_added_slots: for slot in successfully_added_slots:
try: try:
remove_server_from_config(domain, slot) remove_server_from_config(domain, slot)
except Exception as rollback_error: except (IOError, OSError, ValueError) as rollback_error:
logger.error( logger.error(
"Failed to rollback server config for %s slot %d: %s", "Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error domain, slot, rollback_error
@@ -283,7 +257,7 @@ def register_server_tools(mcp):
if slot not in successfully_added_slots: if slot not in successfully_added_slots:
try: try:
remove_server_from_config(domain, slot) remove_server_from_config(domain, slot)
except Exception as rollback_error: except (IOError, OSError, ValueError) as rollback_error:
logger.error( logger.error(
"Failed to rollback server config for %s slot %d: %s", "Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error domain, slot, rollback_error
@@ -294,7 +268,7 @@ def register_server_tools(mcp):
for slot in failed_slots: for slot in failed_slots:
try: try:
remove_server_from_config(domain, slot) remove_server_from_config(domain, slot)
except Exception as rollback_error: except (IOError, OSError, ValueError) as rollback_error:
logger.error( logger.error(
"Failed to rollback server config for %s slot %d: %s", "Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error domain, slot, rollback_error
@@ -311,15 +285,9 @@ def register_server_tools(mcp):
return "\n".join(result_parts) if result_parts else "No servers added" return "\n".join(result_parts) if result_parts else "No servers added"
@mcp.tool()
def haproxy_remove_server(
domain: Annotated[str, Field(description="Domain name to remove server from (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number to remove (1-10)")]
) -> str:
"""Remove a server from a domain's backend at specified slot.
Example: haproxy_remove_server("api.example.com", slot=2) def _haproxy_remove_server_impl(domain: str, slot: int) -> str:
""" """Implementation of haproxy_remove_server."""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not (1 <= slot <= MAX_SLOTS): if not (1 <= slot <= MAX_SLOTS):
@@ -338,11 +306,7 @@ def register_server_tools(mcp):
try: try:
# HTTP only - single server per slot # HTTP only - single server per slot
server = f"{server_prefix}_{slot}" server = f"{server_prefix}_{slot}"
# Batch both commands in single TCP connection disable_server_slot(backend, server)
haproxy_cmd_batch([
f"set server {backend}/{server} state maint",
f"set server {backend}/{server} addr 0.0.0.0 port 0"
])
return f"Removed server at slot {slot} from {domain} ({backend})" return f"Removed server at slot {slot} from {domain} ({backend})"
except HaproxyError as e: except HaproxyError as e:
# Rollback: re-add config if HAProxy command failed # Rollback: re-add config if HAProxy command failed
@@ -352,22 +316,9 @@ def register_server_tools(mcp):
except (ValueError, IOError) as e: except (ValueError, IOError) as e:
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_set_domain_state(
domain: Annotated[str, Field(description="Domain name (e.g., api.example.com)")],
state: Annotated[str, Field(description="Target state: 'ready' (normal), 'drain' (stop new connections), or 'maint' (maintenance)")]
) -> str:
"""Set state for all servers of a domain at once.
Example: haproxy_set_domain_state("api.example.com", state="drain") def _haproxy_set_domain_state_impl(domain: str, state: str) -> str:
"""Implementation of haproxy_set_domain_state."""
Example:
# Put all servers in maintenance for deployment
haproxy_set_domain_state("api.example.com", "maint")
# Re-enable all servers after deployment
haproxy_set_domain_state("api.example.com", "ready")
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if state not in ["ready", "drain", "maint"]: if state not in ["ready", "drain", "maint"]:
@@ -384,16 +335,15 @@ def register_server_tools(mcp):
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
parsed_state = parse_servers_state(servers_state)
backend_servers = parsed_state.get(backend, {})
changed = [] changed = []
errors = [] errors = []
for line in servers_state.split("\n"): for server_name, srv_info in backend_servers.items():
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
# Only change state for configured servers (not 0.0.0.0) # Only change state for configured servers (not 0.0.0.0)
if addr != "0.0.0.0": if srv_info["addr"] != "0.0.0.0":
try: try:
haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}") haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}")
changed.append(server_name) changed.append(server_name)
@@ -411,17 +361,9 @@ def register_server_tools(mcp):
return result return result
@mcp.tool()
def haproxy_wait_drain(
domain: Annotated[str, Field(description="Domain name to wait for (e.g., api.example.com)")],
timeout: Annotated[int, Field(default=30, description="Maximum seconds to wait (default: 30, max: 300)")]
) -> str:
"""Wait for all active connections to drain from a domain's servers.
Use after setting servers to 'drain' state before maintenance. def _haproxy_wait_drain_impl(domain: str, timeout: int) -> str:
"""Implementation of haproxy_wait_drain."""
Example: haproxy_wait_drain("api.example.com", timeout=60)
"""
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not (1 <= timeout <= 300): if not (1 <= timeout <= 300):
@@ -456,16 +398,9 @@ def register_server_tools(mcp):
return f"Timeout: Connections still active after {timeout}s" return f"Timeout: Connections still active after {timeout}s"
@mcp.tool()
def haproxy_set_server_state(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
state: Annotated[str, Field(description="'ready' (enable), 'drain' (graceful shutdown), or 'maint' (maintenance)")]
) -> str:
"""Set server state for maintenance or traffic control.
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint") def _haproxy_set_server_state_impl(backend: str, server: str, state: str) -> str:
""" """Implementation of haproxy_set_server_state."""
if not validate_backend_name(backend): if not validate_backend_name(backend):
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)" return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server): if not validate_backend_name(server):
@@ -478,16 +413,9 @@ def register_server_tools(mcp):
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
@mcp.tool()
def haproxy_set_server_weight(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
weight: Annotated[int, Field(description="Weight 0-256 (higher = more traffic, 0 = disabled)")]
) -> str:
"""Set server weight for load balancing ratio control.
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2) def _haproxy_set_server_weight_impl(backend: str, server: str, weight: int) -> str:
""" """Implementation of haproxy_set_server_weight."""
if not validate_backend_name(backend): if not validate_backend_name(backend):
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)" return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server): if not validate_backend_name(server):
@@ -499,3 +427,118 @@ def register_server_tools(mcp):
return f"Server {backend}/{server} weight set to {weight}" return f"Server {backend}/{server} weight set to {weight}"
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
# =============================================================================
# MCP Tool Registration
# =============================================================================
def register_server_tools(mcp):
"""Register server management tools with MCP server."""
@mcp.tool()
def haproxy_list_servers(
domain: Annotated[str, Field(description="Domain name to list servers for (e.g., api.example.com)")]
) -> str:
"""List all servers for a domain with slot numbers, addresses, and status (UP/DOWN/MAINT).
Example:
haproxy_list_servers("api.example.com")
# Output: pool_1_1: 10.0.0.1:8080 (UP)
# pool_1_2: 10.0.0.2:8080 (UP)
"""
return _haproxy_list_servers_impl(domain)
@mcp.tool()
def haproxy_add_server(
domain: Annotated[str, Field(description="Domain name to add server to (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number 1-10, or 0 for auto-select next available slot")],
ip: Annotated[str, Field(description="Server IP address (IPv4 like 10.0.0.1 or IPv6 like 2001:db8::1)")],
http_port: Annotated[int, Field(default=80, description="HTTP port for backend connection (default: 80)")]
) -> str:
"""Add a server to a domain's backend pool for load balancing.
Each domain can have up to 10 servers (slots 1-10). HAProxy distributes traffic
across all configured servers using round-robin.
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
"""
return _haproxy_add_server_impl(domain, slot, ip, http_port)
@mcp.tool()
def haproxy_add_servers(
domain: Annotated[str, Field(description="Domain name to add servers to (e.g., api.example.com)")],
servers: Annotated[str, Field(description='JSON array of servers. Each object: {"slot": 1-10, "ip": "10.0.0.1", "http_port": 80}. Example: \'[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]\'')]
) -> str:
"""Add multiple servers to a domain's backend at once (bulk operation).
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]')
"""
return _haproxy_add_servers_impl(domain, servers)
@mcp.tool()
def haproxy_remove_server(
domain: Annotated[str, Field(description="Domain name to remove server from (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number to remove (1-10)")]
) -> str:
"""Remove a server from a domain's backend at specified slot.
Example: haproxy_remove_server("api.example.com", slot=2)
"""
return _haproxy_remove_server_impl(domain, slot)
@mcp.tool()
def haproxy_set_domain_state(
domain: Annotated[str, Field(description="Domain name (e.g., api.example.com)")],
state: Annotated[str, Field(description="Target state: 'ready' (normal), 'drain' (stop new connections), or 'maint' (maintenance)")]
) -> str:
"""Set state for all servers of a domain at once.
Example: haproxy_set_domain_state("api.example.com", state="drain")
Example:
# Put all servers in maintenance for deployment
haproxy_set_domain_state("api.example.com", "maint")
# Re-enable all servers after deployment
haproxy_set_domain_state("api.example.com", "ready")
"""
return _haproxy_set_domain_state_impl(domain, state)
@mcp.tool()
def haproxy_wait_drain(
domain: Annotated[str, Field(description="Domain name to wait for (e.g., api.example.com)")],
timeout: Annotated[int, Field(default=30, description="Maximum seconds to wait (default: 30, max: 300)")]
) -> str:
"""Wait for all active connections to drain from a domain's servers.
Use after setting servers to 'drain' state before maintenance.
Example: haproxy_wait_drain("api.example.com", timeout=60)
"""
return _haproxy_wait_drain_impl(domain, timeout)
@mcp.tool()
def haproxy_set_server_state(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
state: Annotated[str, Field(description="'ready' (enable), 'drain' (graceful shutdown), or 'maint' (maintenance)")]
) -> str:
"""Set server state for maintenance or traffic control.
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint")
"""
return _haproxy_set_server_state_impl(backend, server, state)
@mcp.tool()
def haproxy_set_server_weight(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
weight: Annotated[int, Field(description="Weight 0-256 (higher = more traffic, 0 = disabled)")]
) -> str:
"""Set server weight for load balancing ratio control.
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2)
"""
return _haproxy_set_server_weight_impl(backend, server, weight)

View File

@@ -1,7 +1,58 @@
"""Utility functions for HAProxy MCP Server.""" """Utility functions for HAProxy MCP Server."""
from typing import Dict, Generator from typing import Dict, Generator
from .config import StatField from .config import StatField, StateField, STATE_MIN_COLUMNS
from .haproxy_client import haproxy_cmd_batch
def parse_servers_state(state_output: str) -> dict[str, dict[str, dict[str, str]]]:
"""Parse 'show servers state' output.
Args:
state_output: Raw output from HAProxy 'show servers state' command
Returns:
Nested dict: {backend: {server: {addr: str, port: str, state: str}}}
Example:
state = haproxy_cmd("show servers state")
parsed = parse_servers_state(state)
# parsed["pool_1"]["pool_1_1"] == {"addr": "10.0.0.1", "port": "8080", "state": "2"}
"""
result: dict[str, dict[str, dict[str, str]]] = {}
for line in state_output.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS:
backend = parts[StateField.BE_NAME]
server = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
port = parts[StateField.SRV_PORT]
state = parts[StateField.SRV_OP_STATE] if len(parts) > StateField.SRV_OP_STATE else ""
if backend not in result:
result[backend] = {}
result[backend][server] = {
"addr": addr,
"port": port,
"state": state,
}
return result
def disable_server_slot(backend: str, server: str) -> None:
"""Disable a server slot (set to maint and clear address).
Args:
backend: Backend name (e.g., 'pool_1')
server: Server name (e.g., 'pool_1_1')
Raises:
HaproxyError: If HAProxy command fails
"""
haproxy_cmd_batch([
f"set server {backend}/{server} state maint",
f"set server {backend}/{server} addr 0.0.0.0 port 0"
])
def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]: def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]:

View File

@@ -52,6 +52,18 @@ def validate_port(port: str) -> bool:
return 1 <= port_num <= 65535 return 1 <= port_num <= 65535
def validate_port_int(port: int) -> bool:
"""Validate port number as integer is in valid range.
Args:
port: Port number as integer
Returns:
True if port is valid (1-65535), False otherwise
"""
return isinstance(port, int) and 1 <= port <= 65535
def validate_backend_name(name: str) -> bool: def validate_backend_name(name: str) -> bool:
"""Validate backend or server name to prevent command injection. """Validate backend or server name to prevent command injection.

View File

@@ -1078,9 +1078,10 @@ class TestLoadCertToHaproxyError:
pem_file = certs_dir / "example.com.pem" pem_file = certs_dir / "example.com.pem"
pem_file.write_text("cert content") pem_file.write_text("cert content")
# Mock haproxy_cmd to raise exception # Mock haproxy_cmd to raise HaproxyError
from haproxy_mcp.exceptions import HaproxyError
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=Exception("Connection failed")): with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=HaproxyError("Connection failed")):
from haproxy_mcp.tools.certificates import load_cert_to_haproxy from haproxy_mcp.tools.certificates import load_cert_to_haproxy
success, msg = load_cert_to_haproxy("example.com") success, msg = load_cert_to_haproxy("example.com")

View File

@@ -547,7 +547,7 @@ class TestStartupRestoreFailures:
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0): with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")): with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=IOError("Certificate error")):
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
from haproxy_mcp.tools.configuration import startup_restore from haproxy_mcp.tools.configuration import startup_restore
@@ -598,7 +598,7 @@ class TestHaproxyReloadFailures:
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")): with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")): with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Restore failed")):
with patch("time.sleep", return_value=None): with patch("time.sleep", return_value=None):
from haproxy_mcp.tools.configuration import register_config_tools from haproxy_mcp.tools.configuration import register_config_tools
mcp = MagicMock() mcp = MagicMock()

View File

@@ -756,8 +756,8 @@ class TestHaproxyAddServersRollback:
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2: if slot == 2:
# Simulate unexpected error (not HaproxyError) # Simulate unexpected error (IOError is caught by the exception handler)
raise RuntimeError("Unexpected system error") raise IOError("Unexpected system error")
configured_slots.append(slot) configured_slots.append(slot)
return f"{server_prefix}_{slot}" return f"{server_prefix}_{slot}"
@@ -807,7 +807,7 @@ class TestHaproxyAddServersRollback:
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2: if slot == 2:
raise RuntimeError("Unexpected error") raise OSError("Unexpected error")
return f"{server_prefix}_{slot}" return f"{server_prefix}_{slot}"
def mock_remove_server_from_config(domain, slot): def mock_remove_server_from_config(domain, slot):