refactor: Extract large functions, improve exception handling, remove duplicates

## Large function extraction
- servers.py: Extract 8 _impl functions from register_server_tools (449 lines)
- certificates.py: Extract 7 _impl functions from register_certificate_tools (386 lines)
- MCP tool wrappers now delegate to module-level implementation functions

## Exception handling improvements
- Replace 11 broad `except Exception` with specific types
- health.py: (OSError, subprocess.SubprocessError)
- configuration.py: (HaproxyError, IOError, OSError, ValueError)
- servers.py: (IOError, OSError, ValueError)
- certificates.py: FileNotFoundError, (subprocess.SubprocessError, OSError)

## Duplicate code extraction
- Add parse_servers_state() to utils.py (replaces 4 duplicate parsers)
- Add disable_server_slot() to utils.py (replaces duplicate patterns)
- Update health.py, servers.py, domains.py to use new helpers

## Other improvements
- Add TypedDict types in file_ops.py and health.py
- Set file permissions (0o600) for sensitive files
- Update tests to use specific exception types

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-03 13:23:51 +09:00
parent e66c5ddc7f
commit 06ab47aca8
12 changed files with 891 additions and 723 deletions

View File

@@ -94,8 +94,8 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
with open(file_path, "r", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError:
pass # Continue without lock if not supported
except OSError as e:
logger.debug("File locking not supported for %s: %s", file_path, e)
try:
for line in f:
line = line.strip()
@@ -107,10 +107,10 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
finally:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError:
pass
except OSError as e:
logger.debug("File unlock failed for %s: %s", file_path, e)
except FileNotFoundError:
pass
logger.debug("Map file not found: %s", file_path)
return entries
@@ -420,16 +420,16 @@ def load_certs_config() -> list[str]:
with open(CERTS_FILE, "r", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError:
pass
except OSError as e:
logger.debug("File locking not supported for %s: %s", CERTS_FILE, e)
try:
data = json.load(f)
return data.get("domains", [])
finally:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError:
pass
except OSError as e:
logger.debug("File unlock failed for %s: %s", CERTS_FILE, e)
except FileNotFoundError:
return []
except json.JSONDecodeError as e:

View File

@@ -67,8 +67,8 @@ def haproxy_cmd(command: str) -> str:
raise HaproxyError("Invalid UTF-8 in response")
except HaproxyError:
raise
except Exception as e:
raise HaproxyError(str(e)) from e
except (OSError, BlockingIOError, BrokenPipeError) as e:
raise HaproxyError(f"Socket error: {e}") from e
def haproxy_cmd_checked(command: str) -> str:

View File

@@ -14,6 +14,7 @@ from ..config import (
CERTS_DIR_CONTAINER,
ACME_HOME,
)
from ..exceptions import HaproxyError
from ..validation import validate_domain
from ..haproxy_client import haproxy_cmd
from ..file_ops import (
@@ -77,7 +78,11 @@ def load_cert_to_haproxy(domain: str) -> tuple[bool, str]:
haproxy_cmd(f"commit ssl cert {container_path}")
return True, "added"
except Exception as e:
except HaproxyError as e:
logger.error("HAProxy error loading certificate %s: %s", domain, e)
return False, str(e)
except (IOError, OSError) as e:
logger.error("File error loading certificate %s: %s", domain, e)
return False, str(e)
@@ -102,7 +107,8 @@ def unload_cert_from_haproxy(domain: str) -> tuple[bool, str]:
haproxy_cmd(f"del ssl cert {container_path}")
return True, "unloaded"
except Exception as e:
except HaproxyError as e:
logger.error("HAProxy error unloading certificate %s: %s", domain, e)
return False, str(e)
@@ -126,16 +132,13 @@ def restore_certificates() -> int:
return restored
def register_certificate_tools(mcp):
"""Register certificate management tools with MCP server."""
# =============================================================================
# Implementation functions (module-level)
# =============================================================================
@mcp.tool()
def haproxy_list_certs() -> str:
"""List all SSL/TLS certificates with expiry information.
Returns:
List of certificates with domain, CA, created date, and renewal date
"""
def _haproxy_list_certs_impl() -> str:
"""Implementation of haproxy_list_certs."""
try:
result = subprocess.run(
[ACME_SH, "--list"],
@@ -152,7 +155,8 @@ def register_certificate_tools(mcp):
# Get HAProxy loaded certs
try:
haproxy_certs = haproxy_cmd("show ssl cert")
except Exception:
except HaproxyError as e:
logger.debug("Could not get HAProxy certs: %s", e)
haproxy_certs = ""
# Parse and format output
@@ -190,17 +194,16 @@ def register_certificate_tools(mcp):
return "Error: Command timed out"
except FileNotFoundError:
return "Error: acme.sh not found"
except Exception as e:
except subprocess.SubprocessError as e:
logger.error("Subprocess error listing certificates: %s", e)
return f"Error: {e}"
except OSError as e:
logger.error("OS error listing certificates: %s", e)
return f"Error: {e}"
@mcp.tool()
def haproxy_cert_info(
domain: Annotated[str, Field(description="Domain name to check (e.g., example.com)")]
) -> str:
"""Get detailed certificate information for a domain.
Shows expiry date, issuer, SANs, and file paths.
"""
def _haproxy_cert_info_impl(domain: str) -> str:
"""Implementation of haproxy_cert_info."""
if not validate_domain(domain):
return "Error: Invalid domain format"
@@ -226,7 +229,8 @@ def register_certificate_tools(mcp):
try:
haproxy_certs = haproxy_cmd("show ssl cert")
loaded = "Yes" if container_path in haproxy_certs else "No"
except Exception:
except HaproxyError as e:
logger.debug("Could not check HAProxy cert status: %s", e)
loaded = "Unknown"
info = [
@@ -240,20 +244,13 @@ def register_certificate_tools(mcp):
return "\n".join(info)
except subprocess.TimeoutExpired:
return "Error: Command timed out"
except Exception as e:
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error getting certificate info for %s: %s", domain, e)
return f"Error: {e}"
@mcp.tool()
def haproxy_issue_cert(
domain: Annotated[str, Field(description="Primary domain (e.g., example.com)")],
wildcard: Annotated[bool, Field(default=True, description="Include wildcard (*.example.com). Default: true")]
) -> str:
"""Issue a new SSL/TLS certificate using acme.sh with Cloudflare DNS.
Automatically deploys to HAProxy via Runtime API (zero-downtime).
Example: haproxy_issue_cert("example.com", wildcard=True)
"""
def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
"""Implementation of haproxy_issue_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
@@ -268,7 +265,7 @@ def register_certificate_tools(mcp):
token = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["CF_Token"] = token
break
except Exception as e:
except (IOError, OSError) as e:
logger.warning("Failed to read Cloudflare token: %s", e)
if not os.environ.get("CF_Token"):
@@ -322,20 +319,13 @@ def register_certificate_tools(mcp):
except subprocess.TimeoutExpired:
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except Exception as e:
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error issuing certificate for %s: %s", domain, e)
return f"Error: {e}"
@mcp.tool()
def haproxy_renew_cert(
domain: Annotated[str, Field(description="Domain name to renew (e.g., example.com)")],
force: Annotated[bool, Field(default=False, description="Force renewal even if not due. Default: false")]
) -> str:
"""Renew an existing certificate.
Uses Runtime API for zero-downtime reload.
Example: haproxy_renew_cert("example.com", force=True)
"""
def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
"""Implementation of haproxy_renew_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
@@ -374,15 +364,15 @@ def register_certificate_tools(mcp):
except subprocess.TimeoutExpired:
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except Exception as e:
except FileNotFoundError:
return "Error: acme.sh not found"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error renewing certificate for %s: %s", domain, e)
return f"Error: {e}"
@mcp.tool()
def haproxy_renew_all_certs() -> str:
"""Renew all certificates that are due for renewal.
This runs the acme.sh cron job to check and renew all certificates.
"""
def _haproxy_renew_all_certs_impl() -> str:
"""Implementation of haproxy_renew_all_certs."""
try:
logger.info("Running certificate renewal cron")
result = subprocess.run(
@@ -415,19 +405,15 @@ def register_certificate_tools(mcp):
except subprocess.TimeoutExpired:
return "Error: Renewal cron timed out"
except Exception as e:
except FileNotFoundError:
return "Error: acme.sh not found"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error running certificate renewal cron: %s", e)
return f"Error: {e}"
@mcp.tool()
def haproxy_delete_cert(
domain: Annotated[str, Field(description="Domain name to delete certificate for")]
) -> str:
"""Delete a certificate from acme.sh and HAProxy.
WARNING: This permanently removes the certificate. The domain will lose HTTPS.
Example: haproxy_delete_cert("example.com")
"""
def _haproxy_delete_cert_impl(domain: str) -> str:
"""Implementation of haproxy_delete_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
@@ -459,7 +445,7 @@ def register_certificate_tools(mcp):
deleted.append("acme.sh")
else:
errors.append(f"acme.sh: {result.stderr}")
except Exception as e:
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e:
errors.append(f"acme.sh: {e}")
# Remove PEM file
@@ -467,7 +453,7 @@ def register_certificate_tools(mcp):
try:
os.remove(host_path)
deleted.append("PEM file")
except Exception as e:
except OSError as e:
errors.append(f"PEM file: {e}")
# Remove from config
@@ -481,16 +467,9 @@ def register_certificate_tools(mcp):
return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted"
@mcp.tool()
def haproxy_load_cert(
domain: Annotated[str, Field(description="Domain name to load certificate for")]
) -> str:
"""Load/reload a certificate into HAProxy (zero-downtime).
Use after manually updating a certificate file.
Example: haproxy_load_cert("example.com")
"""
def _haproxy_load_cert_impl(domain: str) -> str:
"""Implementation of haproxy_load_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
@@ -504,3 +483,89 @@ def register_certificate_tools(mcp):
return f"Certificate {domain} loaded into HAProxy ({msg})"
else:
return f"Error loading certificate: {msg}"
# =============================================================================
# MCP Tool Registration
# =============================================================================
def register_certificate_tools(mcp):
"""Register certificate management tools with MCP server."""
@mcp.tool()
def haproxy_list_certs() -> str:
"""List all SSL/TLS certificates with expiry information.
Returns:
List of certificates with domain, CA, created date, and renewal date
"""
return _haproxy_list_certs_impl()
@mcp.tool()
def haproxy_cert_info(
domain: Annotated[str, Field(description="Domain name to check (e.g., example.com)")]
) -> str:
"""Get detailed certificate information for a domain.
Shows expiry date, issuer, SANs, and file paths.
"""
return _haproxy_cert_info_impl(domain)
@mcp.tool()
def haproxy_issue_cert(
domain: Annotated[str, Field(description="Primary domain (e.g., example.com)")],
wildcard: Annotated[bool, Field(default=True, description="Include wildcard (*.example.com). Default: true")]
) -> str:
"""Issue a new SSL/TLS certificate using acme.sh with Cloudflare DNS.
Automatically deploys to HAProxy via Runtime API (zero-downtime).
Example: haproxy_issue_cert("example.com", wildcard=True)
"""
return _haproxy_issue_cert_impl(domain, wildcard)
@mcp.tool()
def haproxy_renew_cert(
domain: Annotated[str, Field(description="Domain name to renew (e.g., example.com)")],
force: Annotated[bool, Field(default=False, description="Force renewal even if not due. Default: false")]
) -> str:
"""Renew an existing certificate.
Uses Runtime API for zero-downtime reload.
Example: haproxy_renew_cert("example.com", force=True)
"""
return _haproxy_renew_cert_impl(domain, force)
@mcp.tool()
def haproxy_renew_all_certs() -> str:
"""Renew all certificates that are due for renewal.
This runs the acme.sh cron job to check and renew all certificates.
"""
return _haproxy_renew_all_certs_impl()
@mcp.tool()
def haproxy_delete_cert(
domain: Annotated[str, Field(description="Domain name to delete certificate for")]
) -> str:
"""Delete a certificate from acme.sh and HAProxy.
WARNING: This permanently removes the certificate. The domain will lose HTTPS.
Example: haproxy_delete_cert("example.com")
"""
return _haproxy_delete_cert_impl(domain)
@mcp.tool()
def haproxy_load_cert(
domain: Annotated[str, Field(description="Domain name to load certificate for")]
) -> str:
"""Load/reload a certificate into HAProxy (zero-downtime).
Use after manually updating a certificate file.
Example: haproxy_load_cert("example.com")
"""
return _haproxy_load_cert_impl(domain)

View File

@@ -117,7 +117,7 @@ def startup_restore() -> None:
cert_count = restore_certificates()
if cert_count > 0:
logger.info("Restored %d certificates from config", cert_count)
except Exception as e:
except (HaproxyError, IOError, OSError, ValueError) as e:
logger.warning("Failed to restore certificates: %s", e)
@@ -155,7 +155,7 @@ def register_config_tools(mcp):
try:
restored = restore_servers_from_config()
return f"HAProxy configuration reloaded successfully ({restored} servers restored)"
except Exception as e:
except (HaproxyError, IOError, OSError, ValueError) as e:
logger.error("Failed to restore servers after reload: %s", e)
return f"HAProxy reloaded but server restore failed: {e}"

View File

@@ -13,14 +13,12 @@ from ..config import (
WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT,
MAX_SLOTS,
StateField,
STATE_MIN_COLUMNS,
SUBPROCESS_TIMEOUT,
CERTS_DIR,
logger,
)
from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip
from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd
from ..file_ops import (
get_map_contents,
@@ -31,6 +29,7 @@ from ..file_ops import (
remove_server_from_config,
remove_domain_from_config,
)
from ..utils import parse_servers_state, disable_server_slot
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
@@ -166,17 +165,17 @@ def register_domain_tools(mcp):
try:
domains = []
state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
# Build server map from HAProxy state
server_map: dict[str, list] = {}
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0":
backend = parts[StateField.BE_NAME]
server_map: dict[str, list[str]] = {}
for backend, servers_dict in parsed_state.items():
for server_name, srv_info in servers_dict.items():
if srv_info["addr"] != "0.0.0.0":
if backend not in server_map:
server_map[backend] = []
server_map[backend].append(
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}"
f"{server_name}={srv_info['addr']}:{srv_info['port']}"
)
# Read from domains.map
@@ -220,7 +219,7 @@ def register_domain_tools(mcp):
return "Error: Invalid domain format"
if not validate_ip(ip, allow_empty=True):
return "Error: Invalid IP address format"
if not (1 <= http_port <= 65535):
if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535"
# Use file locking for the entire pool allocation operation
@@ -339,8 +338,7 @@ def register_domain_tools(mcp):
for slot in range(1, MAX_SLOTS + 1):
server = f"{backend}_{slot}"
try:
haproxy_cmd(f"set server {backend}/{server} state maint")
haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0")
disable_server_slot(backend, server)
except HaproxyError as e:
logger.warning(
"Failed to clear server %s/%s for domain %s: %s",

View File

@@ -12,14 +12,12 @@ from ..config import (
MAP_FILE,
SERVERS_FILE,
HAPROXY_CONTAINER,
StateField,
STATE_MIN_COLUMNS,
)
from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_backend_name
from ..haproxy_client import haproxy_cmd
from ..file_ops import get_backend_and_prefix
from ..utils import parse_stat_csv
from ..utils import parse_stat_csv, parse_servers_state
def register_health_tools(mcp):
@@ -83,7 +81,7 @@ def register_health_tools(mcp):
except subprocess.TimeoutExpired:
result["components"]["container"] = {"status": "timeout"}
result["status"] = "unhealthy"
except Exception as e:
except (OSError, subprocess.SubprocessError) as e:
result["components"]["container"] = {"status": "error", "error": str(e)}
# Check configuration files
@@ -144,12 +142,12 @@ def register_health_tools(mcp):
}
# Parse server state for address info
for line in state_output.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
port = parts[StateField.SRV_PORT]
parsed_state = parse_servers_state(state_output)
backend_servers = parsed_state.get(backend, {})
for server_name, srv_info in backend_servers.items():
addr = srv_info["addr"]
port = srv_info["port"]
# Skip disabled servers (0.0.0.0)
if addr == "0.0.0.0":

View File

@@ -10,13 +10,11 @@ from ..config import (
MAX_SLOTS,
MAX_BULK_SERVERS,
MAX_SERVERS_JSON_SIZE,
StateField,
StatField,
STATE_MIN_COLUMNS,
logger,
)
from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip, validate_backend_name
from ..validation import validate_domain, validate_ip, validate_backend_name, validate_port_int
from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch
from ..file_ops import (
get_backend_and_prefix,
@@ -24,6 +22,7 @@ from ..file_ops import (
add_server_to_config,
remove_server_from_config,
)
from ..utils import parse_servers_state, disable_server_slot
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
@@ -51,65 +50,46 @@ def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str,
return server
def register_server_tools(mcp):
"""Register server management tools with MCP server."""
# =============================================================================
# Implementation functions (module-level)
# =============================================================================
@mcp.tool()
def haproxy_list_servers(
domain: Annotated[str, Field(description="Domain name to list servers for (e.g., api.example.com)")]
) -> str:
"""List all servers for a domain with slot numbers, addresses, and status (UP/DOWN/MAINT).
Example:
haproxy_list_servers("api.example.com")
# Output: pool_1_1: 10.0.0.1:8080 (UP)
# pool_1_2: 10.0.0.2:8080 (UP)
"""
def _haproxy_list_servers_impl(domain: str) -> str:
"""Implementation of haproxy_list_servers."""
if not validate_domain(domain):
return "Error: Invalid domain format"
try:
backend, _ = get_backend_and_prefix(domain)
servers = []
state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
backend_servers = parsed_state.get(backend, {})
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
addr = parts[StateField.SRV_ADDR]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(
f"{parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
)
if not servers:
if not backend_servers:
return f"Backend {backend} not found"
servers = []
for server_name, srv_info in backend_servers.items():
addr = srv_info["addr"]
port = srv_info["port"]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(f"{server_name}: {addr}:{port} ({status})")
return f"Servers for {domain} ({backend}):\n" + "\n".join(servers)
except (HaproxyError, ValueError) as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_add_server(
domain: Annotated[str, Field(description="Domain name to add server to (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number 1-10, or 0 for auto-select next available slot")],
ip: Annotated[str, Field(description="Server IP address (IPv4 like 10.0.0.1 or IPv6 like 2001:db8::1)")],
http_port: Annotated[int, Field(default=80, description="HTTP port for backend connection (default: 80)")]
) -> str:
"""Add a server to a domain's backend pool for load balancing.
Each domain can have up to 10 servers (slots 1-10). HAProxy distributes traffic
across all configured servers using round-robin.
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
"""
def _haproxy_add_server_impl(domain: str, slot: int, ip: str, http_port: int) -> str:
"""Implementation of haproxy_add_server."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if not ip:
return "Error: IP address is required"
if not validate_ip(ip):
return "Error: Invalid IP address format"
if not (1 <= http_port <= 65535):
if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535"
try:
@@ -118,13 +98,12 @@ def register_server_tools(mcp):
# Auto-select slot if slot <= 0
if slot <= 0:
state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
backend_servers = parsed_state.get(backend, {})
used_slots: set[int] = set()
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
if parts[StateField.SRV_ADDR] != "0.0.0.0":
for server_name, srv_info in backend_servers.items():
if srv_info["addr"] != "0.0.0.0":
# Extract slot number from server name (e.g., pool_1_3 -> 3)
server_name = parts[StateField.SRV_NAME]
try:
used_slots.add(int(server_name.rsplit("_", 1)[1]))
except (ValueError, IndexError):
@@ -151,15 +130,9 @@ def register_server_tools(mcp):
except (ValueError, IOError) as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_add_servers(
domain: Annotated[str, Field(description="Domain name to add servers to (e.g., api.example.com)")],
servers: Annotated[str, Field(description='JSON array of servers. Each object: {"slot": 1-10, "ip": "10.0.0.1", "http_port": 80}. Example: \'[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]\'')]
) -> str:
"""Add multiple servers to a domain's backend at once (bulk operation).
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]')
"""
def _haproxy_add_servers_impl(domain: str, servers: str) -> str:
"""Implementation of haproxy_add_servers."""
if not validate_domain(domain):
return "Error: Invalid domain format"
@@ -221,7 +194,7 @@ def register_server_tools(mcp):
except (ValueError, TypeError):
validation_errors.append(f"Server {i+1}: http_port must be an integer")
continue
if not (1 <= http_port <= 65535):
if not validate_port_int(http_port):
validation_errors.append(f"Server {i+1}: port must be between 1 and 65535")
continue
@@ -267,12 +240,13 @@ def register_server_tools(mcp):
except HaproxyError as e:
failed_slots.append(slot)
errors.append(f"slot {slot}: {e}")
except Exception as e:
except (IOError, OSError, ValueError) as e:
# Rollback only successfully added configs on unexpected error
logger.error("Unexpected error during bulk server add for %s: %s", domain, e)
for slot in successfully_added_slots:
try:
remove_server_from_config(domain, slot)
except Exception as rollback_error:
except (IOError, OSError, ValueError) as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
@@ -283,7 +257,7 @@ def register_server_tools(mcp):
if slot not in successfully_added_slots:
try:
remove_server_from_config(domain, slot)
except Exception as rollback_error:
except (IOError, OSError, ValueError) as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
@@ -294,7 +268,7 @@ def register_server_tools(mcp):
for slot in failed_slots:
try:
remove_server_from_config(domain, slot)
except Exception as rollback_error:
except (IOError, OSError, ValueError) as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
@@ -311,15 +285,9 @@ def register_server_tools(mcp):
return "\n".join(result_parts) if result_parts else "No servers added"
@mcp.tool()
def haproxy_remove_server(
domain: Annotated[str, Field(description="Domain name to remove server from (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number to remove (1-10)")]
) -> str:
"""Remove a server from a domain's backend at specified slot.
Example: haproxy_remove_server("api.example.com", slot=2)
"""
def _haproxy_remove_server_impl(domain: str, slot: int) -> str:
"""Implementation of haproxy_remove_server."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if not (1 <= slot <= MAX_SLOTS):
@@ -338,11 +306,7 @@ def register_server_tools(mcp):
try:
# HTTP only - single server per slot
server = f"{server_prefix}_{slot}"
# Batch both commands in single TCP connection
haproxy_cmd_batch([
f"set server {backend}/{server} state maint",
f"set server {backend}/{server} addr 0.0.0.0 port 0"
])
disable_server_slot(backend, server)
return f"Removed server at slot {slot} from {domain} ({backend})"
except HaproxyError as e:
# Rollback: re-add config if HAProxy command failed
@@ -352,22 +316,9 @@ def register_server_tools(mcp):
except (ValueError, IOError) as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_set_domain_state(
domain: Annotated[str, Field(description="Domain name (e.g., api.example.com)")],
state: Annotated[str, Field(description="Target state: 'ready' (normal), 'drain' (stop new connections), or 'maint' (maintenance)")]
) -> str:
"""Set state for all servers of a domain at once.
Example: haproxy_set_domain_state("api.example.com", state="drain")
Example:
# Put all servers in maintenance for deployment
haproxy_set_domain_state("api.example.com", "maint")
# Re-enable all servers after deployment
haproxy_set_domain_state("api.example.com", "ready")
"""
def _haproxy_set_domain_state_impl(domain: str, state: str) -> str:
"""Implementation of haproxy_set_domain_state."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if state not in ["ready", "drain", "maint"]:
@@ -384,16 +335,15 @@ def register_server_tools(mcp):
except HaproxyError as e:
return f"Error: {e}"
parsed_state = parse_servers_state(servers_state)
backend_servers = parsed_state.get(backend, {})
changed = []
errors = []
for line in servers_state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
for server_name, srv_info in backend_servers.items():
# Only change state for configured servers (not 0.0.0.0)
if addr != "0.0.0.0":
if srv_info["addr"] != "0.0.0.0":
try:
haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}")
changed.append(server_name)
@@ -411,17 +361,9 @@ def register_server_tools(mcp):
return result
@mcp.tool()
def haproxy_wait_drain(
domain: Annotated[str, Field(description="Domain name to wait for (e.g., api.example.com)")],
timeout: Annotated[int, Field(default=30, description="Maximum seconds to wait (default: 30, max: 300)")]
) -> str:
"""Wait for all active connections to drain from a domain's servers.
Use after setting servers to 'drain' state before maintenance.
Example: haproxy_wait_drain("api.example.com", timeout=60)
"""
def _haproxy_wait_drain_impl(domain: str, timeout: int) -> str:
"""Implementation of haproxy_wait_drain."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if not (1 <= timeout <= 300):
@@ -456,16 +398,9 @@ def register_server_tools(mcp):
return f"Timeout: Connections still active after {timeout}s"
@mcp.tool()
def haproxy_set_server_state(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
state: Annotated[str, Field(description="'ready' (enable), 'drain' (graceful shutdown), or 'maint' (maintenance)")]
) -> str:
"""Set server state for maintenance or traffic control.
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint")
"""
def _haproxy_set_server_state_impl(backend: str, server: str, state: str) -> str:
"""Implementation of haproxy_set_server_state."""
if not validate_backend_name(backend):
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server):
@@ -478,16 +413,9 @@ def register_server_tools(mcp):
except HaproxyError as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_set_server_weight(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
weight: Annotated[int, Field(description="Weight 0-256 (higher = more traffic, 0 = disabled)")]
) -> str:
"""Set server weight for load balancing ratio control.
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2)
"""
def _haproxy_set_server_weight_impl(backend: str, server: str, weight: int) -> str:
"""Implementation of haproxy_set_server_weight."""
if not validate_backend_name(backend):
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server):
@@ -499,3 +427,118 @@ def register_server_tools(mcp):
return f"Server {backend}/{server} weight set to {weight}"
except HaproxyError as e:
return f"Error: {e}"
# =============================================================================
# MCP Tool Registration
# =============================================================================
def register_server_tools(mcp):
"""Register server management tools with MCP server."""
@mcp.tool()
def haproxy_list_servers(
domain: Annotated[str, Field(description="Domain name to list servers for (e.g., api.example.com)")]
) -> str:
"""List all servers for a domain with slot numbers, addresses, and status (UP/DOWN/MAINT).
Example:
haproxy_list_servers("api.example.com")
# Output: pool_1_1: 10.0.0.1:8080 (UP)
# pool_1_2: 10.0.0.2:8080 (UP)
"""
return _haproxy_list_servers_impl(domain)
@mcp.tool()
def haproxy_add_server(
domain: Annotated[str, Field(description="Domain name to add server to (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number 1-10, or 0 for auto-select next available slot")],
ip: Annotated[str, Field(description="Server IP address (IPv4 like 10.0.0.1 or IPv6 like 2001:db8::1)")],
http_port: Annotated[int, Field(default=80, description="HTTP port for backend connection (default: 80)")]
) -> str:
"""Add a server to a domain's backend pool for load balancing.
Each domain can have up to 10 servers (slots 1-10). HAProxy distributes traffic
across all configured servers using round-robin.
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
"""
return _haproxy_add_server_impl(domain, slot, ip, http_port)
@mcp.tool()
def haproxy_add_servers(
domain: Annotated[str, Field(description="Domain name to add servers to (e.g., api.example.com)")],
servers: Annotated[str, Field(description='JSON array of servers. Each object: {"slot": 1-10, "ip": "10.0.0.1", "http_port": 80}. Example: \'[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]\'')]
) -> str:
"""Add multiple servers to a domain's backend at once (bulk operation).
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]')
"""
return _haproxy_add_servers_impl(domain, servers)
@mcp.tool()
def haproxy_remove_server(
domain: Annotated[str, Field(description="Domain name to remove server from (e.g., api.example.com)")],
slot: Annotated[int, Field(description="Server slot number to remove (1-10)")]
) -> str:
"""Remove a server from a domain's backend at specified slot.
Example: haproxy_remove_server("api.example.com", slot=2)
"""
return _haproxy_remove_server_impl(domain, slot)
@mcp.tool()
def haproxy_set_domain_state(
domain: Annotated[str, Field(description="Domain name (e.g., api.example.com)")],
state: Annotated[str, Field(description="Target state: 'ready' (normal), 'drain' (stop new connections), or 'maint' (maintenance)")]
) -> str:
"""Set state for all servers of a domain at once.
Example: haproxy_set_domain_state("api.example.com", state="drain")
Example:
# Put all servers in maintenance for deployment
haproxy_set_domain_state("api.example.com", "maint")
# Re-enable all servers after deployment
haproxy_set_domain_state("api.example.com", "ready")
"""
return _haproxy_set_domain_state_impl(domain, state)
@mcp.tool()
def haproxy_wait_drain(
domain: Annotated[str, Field(description="Domain name to wait for (e.g., api.example.com)")],
timeout: Annotated[int, Field(default=30, description="Maximum seconds to wait (default: 30, max: 300)")]
) -> str:
"""Wait for all active connections to drain from a domain's servers.
Use after setting servers to 'drain' state before maintenance.
Example: haproxy_wait_drain("api.example.com", timeout=60)
"""
return _haproxy_wait_drain_impl(domain, timeout)
@mcp.tool()
def haproxy_set_server_state(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
state: Annotated[str, Field(description="'ready' (enable), 'drain' (graceful shutdown), or 'maint' (maintenance)")]
) -> str:
"""Set server state for maintenance or traffic control.
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint")
"""
return _haproxy_set_server_state_impl(backend, server, state)
@mcp.tool()
def haproxy_set_server_weight(
backend: Annotated[str, Field(description="Backend name (e.g., 'pool_1')")],
server: Annotated[str, Field(description="Server name (e.g., 'pool_1_1')")],
weight: Annotated[int, Field(description="Weight 0-256 (higher = more traffic, 0 = disabled)")]
) -> str:
"""Set server weight for load balancing ratio control.
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2)
"""
return _haproxy_set_server_weight_impl(backend, server, weight)

View File

@@ -1,7 +1,58 @@
"""Utility functions for HAProxy MCP Server."""
from typing import Dict, Generator
from .config import StatField
from .config import StatField, StateField, STATE_MIN_COLUMNS
from .haproxy_client import haproxy_cmd_batch
def parse_servers_state(state_output: str) -> dict[str, dict[str, dict[str, str]]]:
"""Parse 'show servers state' output.
Args:
state_output: Raw output from HAProxy 'show servers state' command
Returns:
Nested dict: {backend: {server: {addr: str, port: str, state: str}}}
Example:
state = haproxy_cmd("show servers state")
parsed = parse_servers_state(state)
# parsed["pool_1"]["pool_1_1"] == {"addr": "10.0.0.1", "port": "8080", "state": "2"}
"""
result: dict[str, dict[str, dict[str, str]]] = {}
for line in state_output.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS:
backend = parts[StateField.BE_NAME]
server = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
port = parts[StateField.SRV_PORT]
state = parts[StateField.SRV_OP_STATE] if len(parts) > StateField.SRV_OP_STATE else ""
if backend not in result:
result[backend] = {}
result[backend][server] = {
"addr": addr,
"port": port,
"state": state,
}
return result
def disable_server_slot(backend: str, server: str) -> None:
"""Disable a server slot (set to maint and clear address).
Args:
backend: Backend name (e.g., 'pool_1')
server: Server name (e.g., 'pool_1_1')
Raises:
HaproxyError: If HAProxy command fails
"""
haproxy_cmd_batch([
f"set server {backend}/{server} state maint",
f"set server {backend}/{server} addr 0.0.0.0 port 0"
])
def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]:

View File

@@ -52,6 +52,18 @@ def validate_port(port: str) -> bool:
return 1 <= port_num <= 65535
def validate_port_int(port: int) -> bool:
"""Validate port number as integer is in valid range.
Args:
port: Port number as integer
Returns:
True if port is valid (1-65535), False otherwise
"""
return isinstance(port, int) and 1 <= port <= 65535
def validate_backend_name(name: str) -> bool:
"""Validate backend or server name to prevent command injection.

View File

@@ -1078,9 +1078,10 @@ class TestLoadCertToHaproxyError:
pem_file = certs_dir / "example.com.pem"
pem_file.write_text("cert content")
# Mock haproxy_cmd to raise exception
# Mock haproxy_cmd to raise HaproxyError
from haproxy_mcp.exceptions import HaproxyError
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=Exception("Connection failed")):
with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=HaproxyError("Connection failed")):
from haproxy_mcp.tools.certificates import load_cert_to_haproxy
success, msg = load_cert_to_haproxy("example.com")

View File

@@ -547,7 +547,7 @@ class TestStartupRestoreFailures:
with patch("socket.socket", return_value=mock_sock):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")):
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=IOError("Certificate error")):
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
from haproxy_mcp.tools.configuration import startup_restore
@@ -598,7 +598,7 @@ class TestHaproxyReloadFailures:
with patch("socket.socket", return_value=mock_sock):
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Restore failed")):
with patch("time.sleep", return_value=None):
from haproxy_mcp.tools.configuration import register_config_tools
mcp = MagicMock()

View File

@@ -756,8 +756,8 @@ class TestHaproxyAddServersRollback:
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2:
# Simulate unexpected error (not HaproxyError)
raise RuntimeError("Unexpected system error")
# Simulate unexpected error (IOError is caught by the exception handler)
raise IOError("Unexpected system error")
configured_slots.append(slot)
return f"{server_prefix}_{slot}"
@@ -807,7 +807,7 @@ class TestHaproxyAddServersRollback:
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2:
raise RuntimeError("Unexpected error")
raise OSError("Unexpected error")
return f"{server_prefix}_{slot}"
def mock_remove_server_from_config(domain, slot):