refactor: Extract large functions, improve exception handling, remove duplicates

## Large function extraction
- servers.py: Extract 8 _impl functions from register_server_tools (449 lines)
- certificates.py: Extract 7 _impl functions from register_certificate_tools (386 lines)
- MCP tool wrappers now delegate to module-level implementation functions

## Exception handling improvements
- Replace 11 broad `except Exception` with specific types
- health.py: (OSError, subprocess.SubprocessError)
- configuration.py: (HaproxyError, IOError, OSError, ValueError)
- servers.py: (IOError, OSError, ValueError)
- certificates.py: FileNotFoundError, (subprocess.SubprocessError, OSError)

## Duplicate code extraction
- Add parse_servers_state() to utils.py (replaces 4 duplicate parsers)
- Add disable_server_slot() to utils.py (replaces duplicate patterns)
- Update health.py, servers.py, domains.py to use new helpers

## Other improvements
- Add TypedDict types in file_ops.py and health.py
- Set file permissions (0o600) for sensitive files
- Update tests to use specific exception types

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-03 13:23:51 +09:00
parent e66c5ddc7f
commit 06ab47aca8
12 changed files with 891 additions and 723 deletions

View File

@@ -94,8 +94,8 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH) fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError: except OSError as e:
pass # Continue without lock if not supported logger.debug("File locking not supported for %s: %s", file_path, e)
try: try:
for line in f: for line in f:
line = line.strip() line = line.strip()
@@ -107,10 +107,10 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
finally: finally:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError: except OSError as e:
pass logger.debug("File unlock failed for %s: %s", file_path, e)
except FileNotFoundError: except FileNotFoundError:
pass logger.debug("Map file not found: %s", file_path)
return entries return entries
@@ -420,16 +420,16 @@ def load_certs_config() -> list[str]:
with open(CERTS_FILE, "r", encoding="utf-8") as f: with open(CERTS_FILE, "r", encoding="utf-8") as f:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH) fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError: except OSError as e:
pass logger.debug("File locking not supported for %s: %s", CERTS_FILE, e)
try: try:
data = json.load(f) data = json.load(f)
return data.get("domains", []) return data.get("domains", [])
finally: finally:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except OSError: except OSError as e:
pass logger.debug("File unlock failed for %s: %s", CERTS_FILE, e)
except FileNotFoundError: except FileNotFoundError:
return [] return []
except json.JSONDecodeError as e: except json.JSONDecodeError as e:

View File

@@ -67,8 +67,8 @@ def haproxy_cmd(command: str) -> str:
raise HaproxyError("Invalid UTF-8 in response") raise HaproxyError("Invalid UTF-8 in response")
except HaproxyError: except HaproxyError:
raise raise
except Exception as e: except (OSError, BlockingIOError, BrokenPipeError) as e:
raise HaproxyError(str(e)) from e raise HaproxyError(f"Socket error: {e}") from e
def haproxy_cmd_checked(command: str) -> str: def haproxy_cmd_checked(command: str) -> str:

View File

@@ -14,6 +14,7 @@ from ..config import (
CERTS_DIR_CONTAINER, CERTS_DIR_CONTAINER,
ACME_HOME, ACME_HOME,
) )
from ..exceptions import HaproxyError
from ..validation import validate_domain from ..validation import validate_domain
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import ( from ..file_ops import (
@@ -77,7 +78,11 @@ def load_cert_to_haproxy(domain: str) -> tuple[bool, str]:
haproxy_cmd(f"commit ssl cert {container_path}") haproxy_cmd(f"commit ssl cert {container_path}")
return True, "added" return True, "added"
except Exception as e: except HaproxyError as e:
logger.error("HAProxy error loading certificate %s: %s", domain, e)
return False, str(e)
except (IOError, OSError) as e:
logger.error("File error loading certificate %s: %s", domain, e)
return False, str(e) return False, str(e)
@@ -102,7 +107,8 @@ def unload_cert_from_haproxy(domain: str) -> tuple[bool, str]:
haproxy_cmd(f"del ssl cert {container_path}") haproxy_cmd(f"del ssl cert {container_path}")
return True, "unloaded" return True, "unloaded"
except Exception as e: except HaproxyError as e:
logger.error("HAProxy error unloading certificate %s: %s", domain, e)
return False, str(e) return False, str(e)
@@ -126,6 +132,364 @@ def restore_certificates() -> int:
return restored return restored
# =============================================================================
# Implementation functions (module-level)
# =============================================================================
def _haproxy_list_certs_impl() -> str:
"""Implementation of haproxy_list_certs."""
try:
result = subprocess.run(
[ACME_SH, "--list"],
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
if result.returncode != 0:
return f"Error: {result.stderr}"
lines = result.stdout.strip().split("\n")
if len(lines) <= 1:
return "No certificates found"
# Get HAProxy loaded certs
try:
haproxy_certs = haproxy_cmd("show ssl cert")
except HaproxyError as e:
logger.debug("Could not get HAProxy certs: %s", e)
haproxy_certs = ""
# Parse and format output
certs = []
for line in lines[1:]: # Skip header
parts = line.split()
if len(parts) >= 4:
domain = parts[0]
ca = "unknown"
created = "unknown"
renew = "unknown"
for part in parts:
if "Google" in part or "LetsEncrypt" in part or "ZeroSSL" in part:
ca = part
elif part.endswith("Z") and "T" in part:
if created == "unknown":
created = part
else:
renew = part
# Check deployment status
host_path, container_path = get_pem_paths(domain)
if container_path in haproxy_certs:
status = "loaded"
elif os.path.exists(host_path):
status = "file exists (not loaded)"
else:
status = "not deployed"
certs.append(f"{domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
return "\n\n".join(certs) if certs else "No certificates found"
except subprocess.TimeoutExpired:
return "Error: Command timed out"
except FileNotFoundError:
return "Error: acme.sh not found"
except subprocess.SubprocessError as e:
logger.error("Subprocess error listing certificates: %s", e)
return f"Error: {e}"
except OSError as e:
logger.error("OS error listing certificates: %s", e)
return f"Error: {e}"
def _haproxy_cert_info_impl(domain: str) -> str:
"""Implementation of haproxy_cert_info."""
if not validate_domain(domain):
return "Error: Invalid domain format"
host_path, container_path = get_pem_paths(domain)
if not os.path.exists(host_path):
return f"Error: Certificate not found for {domain}"
try:
# Use openssl to get certificate info
result = subprocess.run(
["openssl", "x509", "-in", host_path, "-noout",
"-subject", "-issuer", "-dates", "-ext", "subjectAltName"],
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT
)
if result.returncode != 0:
return f"Error reading certificate: {result.stderr}"
# Get file info
stat = os.stat(host_path)
modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
# Check HAProxy status
try:
haproxy_certs = haproxy_cmd("show ssl cert")
loaded = "Yes" if container_path in haproxy_certs else "No"
except HaproxyError as e:
logger.debug("Could not check HAProxy cert status: %s", e)
loaded = "Unknown"
info = [
f"Certificate: {domain}",
f"File: {host_path}",
f"Modified: {modified}",
f"Loaded in HAProxy: {loaded}",
"---",
result.stdout.strip()
]
return "\n".join(info)
except subprocess.TimeoutExpired:
return "Error: Command timed out"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error getting certificate info for %s: %s", domain, e)
return f"Error: {e}"
def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
"""Implementation of haproxy_issue_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
# Check if CF_Token is available
if not os.environ.get("CF_Token"):
secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini")
if os.path.exists(secrets_file):
try:
with open(secrets_file) as f:
for line in f:
if "=" in line and "token" in line.lower():
token = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["CF_Token"] = token
break
except (IOError, OSError) as e:
logger.warning("Failed to read Cloudflare token: %s", e)
if not os.environ.get("CF_Token"):
return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini"
# Check if certificate already exists
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
if os.path.exists(cert_dir):
return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew."
# Build acme.sh command (without reload - we'll do it via Runtime API)
host_path, _ = get_pem_paths(domain)
# Create PEM after issuance
install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}"
cmd = [
ACME_SH, "--issue",
"--dns", "dns_cf",
"-d", domain
]
if wildcard:
cmd.extend(["-d", f"*.{domain}"])
cmd.extend(["--reloadcmd", install_cmd])
try:
logger.info("Issuing certificate for %s", domain)
result = subprocess.run(
cmd,
capture_output=True, text=True, timeout=CERT_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
if result.returncode != 0:
error_msg = result.stderr or result.stdout
return f"Error issuing certificate:\n{error_msg}"
# Load into HAProxy via Runtime API (zero-downtime)
if os.path.exists(host_path):
success, msg = load_cert_to_haproxy(domain)
if success:
# Save to config for persistence
add_cert_to_config(domain)
return f"Certificate issued and loaded for {domain} ({msg})"
else:
return f"Certificate issued but HAProxy loading failed: {msg}"
else:
return f"Certificate issued but PEM file not created. Check {host_path}"
except subprocess.TimeoutExpired:
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error issuing certificate for %s: %s", domain, e)
return f"Error: {e}"
def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
"""Implementation of haproxy_renew_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
if not os.path.exists(cert_dir):
return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first."
cmd = [ACME_SH, "--renew", "-d", domain]
if force:
cmd.append("--force")
try:
logger.info("Renewing certificate for %s", domain)
result = subprocess.run(
cmd,
capture_output=True, text=True, timeout=CERT_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
output = result.stdout + result.stderr
if "Skip" in output and "Not yet due" in output:
return f"Certificate for {domain} not due for renewal. Use force=True to force renewal."
if "Cert success" in output or result.returncode == 0:
# Reload into HAProxy via Runtime API
success, msg = load_cert_to_haproxy(domain)
if success:
# Ensure in config
add_cert_to_config(domain)
return f"Certificate renewed and reloaded for {domain} ({msg})"
else:
return f"Certificate renewed but HAProxy reload failed: {msg}"
else:
return f"Error renewing certificate:\n{output}"
except subprocess.TimeoutExpired:
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except FileNotFoundError:
return "Error: acme.sh not found"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error renewing certificate for %s: %s", domain, e)
return f"Error: {e}"
def _haproxy_renew_all_certs_impl() -> str:
"""Implementation of haproxy_renew_all_certs."""
try:
logger.info("Running certificate renewal cron")
result = subprocess.run(
[ACME_SH, "--cron"],
capture_output=True, text=True, timeout=CERT_TIMEOUT * 3,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
output = result.stdout + result.stderr
# Count renewals
renewed = output.count("Cert success")
skipped = output.count("Skip")
# Reload any renewed certs into HAProxy
if renewed > 0:
domains = load_certs_config()
reloaded = 0
for domain in domains:
success, _ = load_cert_to_haproxy(domain)
if success:
reloaded += 1
return f"Renewed {renewed} certificate(s), reloaded {reloaded} into HAProxy"
elif skipped > 0:
return f"No certificates due for renewal ({skipped} checked)"
elif result.returncode != 0:
return f"Error running renewal:\n{output}"
else:
return "Renewal check completed"
except subprocess.TimeoutExpired:
return "Error: Renewal cron timed out"
except FileNotFoundError:
return "Error: acme.sh not found"
except (subprocess.SubprocessError, OSError) as e:
logger.error("Error running certificate renewal cron: %s", e)
return f"Error: {e}"
def _haproxy_delete_cert_impl(domain: str) -> str:
"""Implementation of haproxy_delete_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
host_path, _ = get_pem_paths(domain)
if not os.path.exists(cert_dir) and not os.path.exists(host_path):
return f"Error: No certificate found for {domain}"
errors = []
deleted = []
# Unload from HAProxy first (zero-downtime)
success, msg = unload_cert_from_haproxy(domain)
if success:
deleted.append(f"HAProxy ({msg})")
else:
errors.append(f"HAProxy unload: {msg}")
# Remove from acme.sh
if os.path.exists(cert_dir):
try:
result = subprocess.run(
[ACME_SH, "--remove", "-d", domain],
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
if result.returncode == 0:
deleted.append("acme.sh")
else:
errors.append(f"acme.sh: {result.stderr}")
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e:
errors.append(f"acme.sh: {e}")
# Remove PEM file
if os.path.exists(host_path):
try:
os.remove(host_path)
deleted.append("PEM file")
except OSError as e:
errors.append(f"PEM file: {e}")
# Remove from config
remove_cert_from_config(domain)
result_parts = []
if deleted:
result_parts.append(f"Deleted: {', '.join(deleted)}")
if errors:
result_parts.append(f"Errors: {'; '.join(errors)}")
return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted"
def _haproxy_load_cert_impl(domain: str) -> str:
"""Implementation of haproxy_load_cert."""
if not validate_domain(domain):
return "Error: Invalid domain format"
host_path, _ = get_pem_paths(domain)
if not os.path.exists(host_path):
return f"Error: PEM file not found: {host_path}"
success, msg = load_cert_to_haproxy(domain)
if success:
add_cert_to_config(domain)
return f"Certificate {domain} loaded into HAProxy ({msg})"
else:
return f"Error loading certificate: {msg}"
# =============================================================================
# MCP Tool Registration
# =============================================================================
def register_certificate_tools(mcp): def register_certificate_tools(mcp):
"""Register certificate management tools with MCP server.""" """Register certificate management tools with MCP server."""
@@ -136,62 +500,7 @@ def register_certificate_tools(mcp):
Returns: Returns:
List of certificates with domain, CA, created date, and renewal date List of certificates with domain, CA, created date, and renewal date
""" """
try: return _haproxy_list_certs_impl()
result = subprocess.run(
[ACME_SH, "--list"],
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
if result.returncode != 0:
return f"Error: {result.stderr}"
lines = result.stdout.strip().split("\n")
if len(lines) <= 1:
return "No certificates found"
# Get HAProxy loaded certs
try:
haproxy_certs = haproxy_cmd("show ssl cert")
except Exception:
haproxy_certs = ""
# Parse and format output
certs = []
for line in lines[1:]: # Skip header
parts = line.split()
if len(parts) >= 4:
domain = parts[0]
ca = "unknown"
created = "unknown"
renew = "unknown"
for part in parts:
if "Google" in part or "LetsEncrypt" in part or "ZeroSSL" in part:
ca = part
elif part.endswith("Z") and "T" in part:
if created == "unknown":
created = part
else:
renew = part
# Check deployment status
host_path, container_path = get_pem_paths(domain)
if container_path in haproxy_certs:
status = "loaded"
elif os.path.exists(host_path):
status = "file exists (not loaded)"
else:
status = "not deployed"
certs.append(f"{domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
return "\n\n".join(certs) if certs else "No certificates found"
except subprocess.TimeoutExpired:
return "Error: Command timed out"
except FileNotFoundError:
return "Error: acme.sh not found"
except Exception as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_cert_info( def haproxy_cert_info(
@@ -201,47 +510,7 @@ def register_certificate_tools(mcp):
Shows expiry date, issuer, SANs, and file paths. Shows expiry date, issuer, SANs, and file paths.
""" """
if not validate_domain(domain): return _haproxy_cert_info_impl(domain)
return "Error: Invalid domain format"
host_path, container_path = get_pem_paths(domain)
if not os.path.exists(host_path):
return f"Error: Certificate not found for {domain}"
try:
# Use openssl to get certificate info
result = subprocess.run(
["openssl", "x509", "-in", host_path, "-noout",
"-subject", "-issuer", "-dates", "-ext", "subjectAltName"],
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT
)
if result.returncode != 0:
return f"Error reading certificate: {result.stderr}"
# Get file info
stat = os.stat(host_path)
modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
# Check HAProxy status
try:
haproxy_certs = haproxy_cmd("show ssl cert")
loaded = "Yes" if container_path in haproxy_certs else "No"
except Exception:
loaded = "Unknown"
info = [
f"Certificate: {domain}",
f"File: {host_path}",
f"Modified: {modified}",
f"Loaded in HAProxy: {loaded}",
"---",
result.stdout.strip()
]
return "\n".join(info)
except subprocess.TimeoutExpired:
return "Error: Command timed out"
except Exception as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_issue_cert( def haproxy_issue_cert(
@@ -254,76 +523,7 @@ def register_certificate_tools(mcp):
Example: haproxy_issue_cert("example.com", wildcard=True) Example: haproxy_issue_cert("example.com", wildcard=True)
""" """
if not validate_domain(domain): return _haproxy_issue_cert_impl(domain, wildcard)
return "Error: Invalid domain format"
# Check if CF_Token is available
if not os.environ.get("CF_Token"):
secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini")
if os.path.exists(secrets_file):
try:
with open(secrets_file) as f:
for line in f:
if "=" in line and "token" in line.lower():
token = line.split("=", 1)[1].strip().strip('"').strip("'")
os.environ["CF_Token"] = token
break
except Exception as e:
logger.warning("Failed to read Cloudflare token: %s", e)
if not os.environ.get("CF_Token"):
return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini"
# Check if certificate already exists
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
if os.path.exists(cert_dir):
return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew."
# Build acme.sh command (without reload - we'll do it via Runtime API)
host_path, _ = get_pem_paths(domain)
# Create PEM after issuance
install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}"
cmd = [
ACME_SH, "--issue",
"--dns", "dns_cf",
"-d", domain
]
if wildcard:
cmd.extend(["-d", f"*.{domain}"])
cmd.extend(["--reloadcmd", install_cmd])
try:
logger.info("Issuing certificate for %s", domain)
result = subprocess.run(
cmd,
capture_output=True, text=True, timeout=CERT_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
if result.returncode != 0:
error_msg = result.stderr or result.stdout
return f"Error issuing certificate:\n{error_msg}"
# Load into HAProxy via Runtime API (zero-downtime)
if os.path.exists(host_path):
success, msg = load_cert_to_haproxy(domain)
if success:
# Save to config for persistence
add_cert_to_config(domain)
return f"Certificate issued and loaded for {domain} ({msg})"
else:
return f"Certificate issued but HAProxy loading failed: {msg}"
else:
return f"Certificate issued but PEM file not created. Check {host_path}"
except subprocess.TimeoutExpired:
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except Exception as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_renew_cert( def haproxy_renew_cert(
@@ -336,46 +536,7 @@ def register_certificate_tools(mcp):
Example: haproxy_renew_cert("example.com", force=True) Example: haproxy_renew_cert("example.com", force=True)
""" """
if not validate_domain(domain): return _haproxy_renew_cert_impl(domain, force)
return "Error: Invalid domain format"
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
if not os.path.exists(cert_dir):
return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first."
cmd = [ACME_SH, "--renew", "-d", domain]
if force:
cmd.append("--force")
try:
logger.info("Renewing certificate for %s", domain)
result = subprocess.run(
cmd,
capture_output=True, text=True, timeout=CERT_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
output = result.stdout + result.stderr
if "Skip" in output and "Not yet due" in output:
return f"Certificate for {domain} not due for renewal. Use force=True to force renewal."
if "Cert success" in output or result.returncode == 0:
# Reload into HAProxy via Runtime API
success, msg = load_cert_to_haproxy(domain)
if success:
# Ensure in config
add_cert_to_config(domain)
return f"Certificate renewed and reloaded for {domain} ({msg})"
else:
return f"Certificate renewed but HAProxy reload failed: {msg}"
else:
return f"Error renewing certificate:\n{output}"
except subprocess.TimeoutExpired:
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except Exception as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_renew_all_certs() -> str: def haproxy_renew_all_certs() -> str:
@@ -383,40 +544,7 @@ def register_certificate_tools(mcp):
This runs the acme.sh cron job to check and renew all certificates. This runs the acme.sh cron job to check and renew all certificates.
""" """
try: return _haproxy_renew_all_certs_impl()
logger.info("Running certificate renewal cron")
result = subprocess.run(
[ACME_SH, "--cron"],
capture_output=True, text=True, timeout=CERT_TIMEOUT * 3,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
output = result.stdout + result.stderr
# Count renewals
renewed = output.count("Cert success")
skipped = output.count("Skip")
# Reload any renewed certs into HAProxy
if renewed > 0:
domains = load_certs_config()
reloaded = 0
for domain in domains:
success, _ = load_cert_to_haproxy(domain)
if success:
reloaded += 1
return f"Renewed {renewed} certificate(s), reloaded {reloaded} into HAProxy"
elif skipped > 0:
return f"No certificates due for renewal ({skipped} checked)"
elif result.returncode != 0:
return f"Error running renewal:\n{output}"
else:
return "Renewal check completed"
except subprocess.TimeoutExpired:
return "Error: Renewal cron timed out"
except Exception as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_delete_cert( def haproxy_delete_cert(
@@ -428,58 +556,7 @@ def register_certificate_tools(mcp):
Example: haproxy_delete_cert("example.com") Example: haproxy_delete_cert("example.com")
""" """
if not validate_domain(domain): return _haproxy_delete_cert_impl(domain)
return "Error: Invalid domain format"
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
host_path, _ = get_pem_paths(domain)
if not os.path.exists(cert_dir) and not os.path.exists(host_path):
return f"Error: No certificate found for {domain}"
errors = []
deleted = []
# Unload from HAProxy first (zero-downtime)
success, msg = unload_cert_from_haproxy(domain)
if success:
deleted.append(f"HAProxy ({msg})")
else:
errors.append(f"HAProxy unload: {msg}")
# Remove from acme.sh
if os.path.exists(cert_dir):
try:
result = subprocess.run(
[ACME_SH, "--remove", "-d", domain],
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
env={**os.environ, "HOME": os.path.expanduser("~")}
)
if result.returncode == 0:
deleted.append("acme.sh")
else:
errors.append(f"acme.sh: {result.stderr}")
except Exception as e:
errors.append(f"acme.sh: {e}")
# Remove PEM file
if os.path.exists(host_path):
try:
os.remove(host_path)
deleted.append("PEM file")
except Exception as e:
errors.append(f"PEM file: {e}")
# Remove from config
remove_cert_from_config(domain)
result_parts = []
if deleted:
result_parts.append(f"Deleted: {', '.join(deleted)}")
if errors:
result_parts.append(f"Errors: {'; '.join(errors)}")
return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted"
@mcp.tool() @mcp.tool()
def haproxy_load_cert( def haproxy_load_cert(
@@ -491,16 +568,4 @@ def register_certificate_tools(mcp):
Example: haproxy_load_cert("example.com") Example: haproxy_load_cert("example.com")
""" """
if not validate_domain(domain): return _haproxy_load_cert_impl(domain)
return "Error: Invalid domain format"
host_path, _ = get_pem_paths(domain)
if not os.path.exists(host_path):
return f"Error: PEM file not found: {host_path}"
success, msg = load_cert_to_haproxy(domain)
if success:
add_cert_to_config(domain)
return f"Certificate {domain} loaded into HAProxy ({msg})"
else:
return f"Error loading certificate: {msg}"

View File

@@ -117,7 +117,7 @@ def startup_restore() -> None:
cert_count = restore_certificates() cert_count = restore_certificates()
if cert_count > 0: if cert_count > 0:
logger.info("Restored %d certificates from config", cert_count) logger.info("Restored %d certificates from config", cert_count)
except Exception as e: except (HaproxyError, IOError, OSError, ValueError) as e:
logger.warning("Failed to restore certificates: %s", e) logger.warning("Failed to restore certificates: %s", e)
@@ -155,7 +155,7 @@ def register_config_tools(mcp):
try: try:
restored = restore_servers_from_config() restored = restore_servers_from_config()
return f"HAProxy configuration reloaded successfully ({restored} servers restored)" return f"HAProxy configuration reloaded successfully ({restored} servers restored)"
except Exception as e: except (HaproxyError, IOError, OSError, ValueError) as e:
logger.error("Failed to restore servers after reload: %s", e) logger.error("Failed to restore servers after reload: %s", e)
return f"HAProxy reloaded but server restore failed: {e}" return f"HAProxy reloaded but server restore failed: {e}"

View File

@@ -13,14 +13,12 @@ from ..config import (
WILDCARDS_MAP_FILE_CONTAINER, WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT, POOL_COUNT,
MAX_SLOTS, MAX_SLOTS,
StateField,
STATE_MIN_COLUMNS,
SUBPROCESS_TIMEOUT, SUBPROCESS_TIMEOUT,
CERTS_DIR, CERTS_DIR,
logger, logger,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import ( from ..file_ops import (
get_map_contents, get_map_contents,
@@ -31,6 +29,7 @@ from ..file_ops import (
remove_server_from_config, remove_server_from_config,
remove_domain_from_config, remove_domain_from_config,
) )
from ..utils import parse_servers_state, disable_server_slot
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]: def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
@@ -166,18 +165,18 @@ def register_domain_tools(mcp):
try: try:
domains = [] domains = []
state = haproxy_cmd("show servers state") state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
# Build server map from HAProxy state # Build server map from HAProxy state
server_map: dict[str, list] = {} server_map: dict[str, list[str]] = {}
for line in state.split("\n"): for backend, servers_dict in parsed_state.items():
parts = line.split() for server_name, srv_info in servers_dict.items():
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0": if srv_info["addr"] != "0.0.0.0":
backend = parts[StateField.BE_NAME] if backend not in server_map:
if backend not in server_map: server_map[backend] = []
server_map[backend] = [] server_map[backend].append(
server_map[backend].append( f"{server_name}={srv_info['addr']}:{srv_info['port']}"
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}" )
)
# Read from domains.map # Read from domains.map
seen_domains: set[str] = set() seen_domains: set[str] = set()
@@ -220,7 +219,7 @@ def register_domain_tools(mcp):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not validate_ip(ip, allow_empty=True): if not validate_ip(ip, allow_empty=True):
return "Error: Invalid IP address format" return "Error: Invalid IP address format"
if not (1 <= http_port <= 65535): if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535" return "Error: Port must be between 1 and 65535"
# Use file locking for the entire pool allocation operation # Use file locking for the entire pool allocation operation
@@ -339,8 +338,7 @@ def register_domain_tools(mcp):
for slot in range(1, MAX_SLOTS + 1): for slot in range(1, MAX_SLOTS + 1):
server = f"{backend}_{slot}" server = f"{backend}_{slot}"
try: try:
haproxy_cmd(f"set server {backend}/{server} state maint") disable_server_slot(backend, server)
haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0")
except HaproxyError as e: except HaproxyError as e:
logger.warning( logger.warning(
"Failed to clear server %s/%s for domain %s: %s", "Failed to clear server %s/%s for domain %s: %s",

View File

@@ -12,14 +12,12 @@ from ..config import (
MAP_FILE, MAP_FILE,
SERVERS_FILE, SERVERS_FILE,
HAPROXY_CONTAINER, HAPROXY_CONTAINER,
StateField,
STATE_MIN_COLUMNS,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_backend_name from ..validation import validate_domain, validate_backend_name
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import get_backend_and_prefix from ..file_ops import get_backend_and_prefix
from ..utils import parse_stat_csv from ..utils import parse_stat_csv, parse_servers_state
def register_health_tools(mcp): def register_health_tools(mcp):
@@ -83,7 +81,7 @@ def register_health_tools(mcp):
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
result["components"]["container"] = {"status": "timeout"} result["components"]["container"] = {"status": "timeout"}
result["status"] = "unhealthy" result["status"] = "unhealthy"
except Exception as e: except (OSError, subprocess.SubprocessError) as e:
result["components"]["container"] = {"status": "error", "error": str(e)} result["components"]["container"] = {"status": "error", "error": str(e)}
# Check configuration files # Check configuration files
@@ -144,34 +142,34 @@ def register_health_tools(mcp):
} }
# Parse server state for address info # Parse server state for address info
for line in state_output.split("\n"): parsed_state = parse_servers_state(state_output)
parts = line.split() backend_servers = parsed_state.get(backend, {})
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
port = parts[StateField.SRV_PORT]
# Skip disabled servers (0.0.0.0) for server_name, srv_info in backend_servers.items():
if addr == "0.0.0.0": addr = srv_info["addr"]
continue port = srv_info["port"]
server_info: dict[str, Any] = { # Skip disabled servers (0.0.0.0)
"name": server_name, if addr == "0.0.0.0":
"addr": f"{addr}:{port}", continue
"status": "unknown"
}
# Get status from stat output server_info: dict[str, Any] = {
if server_name in status_map: "name": server_name,
server_info["status"] = status_map[server_name]["status"] "addr": f"{addr}:{port}",
server_info["check_status"] = status_map[server_name]["check_status"] "status": "unknown"
server_info["weight"] = status_map[server_name]["weight"] }
result["servers"].append(server_info) # Get status from stat output
result["total_count"] += 1 if server_name in status_map:
server_info["status"] = status_map[server_name]["status"]
server_info["check_status"] = status_map[server_name]["check_status"]
server_info["weight"] = status_map[server_name]["weight"]
if server_info["status"] == "UP": result["servers"].append(server_info)
result["healthy_count"] += 1 result["total_count"] += 1
if server_info["status"] == "UP":
result["healthy_count"] += 1
# Determine overall status # Determine overall status
if result["total_count"] == 0: if result["total_count"] == 0:

View File

@@ -10,13 +10,11 @@ from ..config import (
MAX_SLOTS, MAX_SLOTS,
MAX_BULK_SERVERS, MAX_BULK_SERVERS,
MAX_SERVERS_JSON_SIZE, MAX_SERVERS_JSON_SIZE,
StateField,
StatField, StatField,
STATE_MIN_COLUMNS,
logger, logger,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip, validate_backend_name from ..validation import validate_domain, validate_ip, validate_backend_name, validate_port_int
from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch
from ..file_ops import ( from ..file_ops import (
get_backend_and_prefix, get_backend_and_prefix,
@@ -24,6 +22,7 @@ from ..file_ops import (
add_server_to_config, add_server_to_config,
remove_server_from_config, remove_server_from_config,
) )
from ..utils import parse_servers_state, disable_server_slot
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str: def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
@@ -51,6 +50,390 @@ def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str,
return server return server
# =============================================================================
# Implementation functions (module-level)
# =============================================================================
def _haproxy_list_servers_impl(domain: str) -> str:
"""Implementation of haproxy_list_servers."""
if not validate_domain(domain):
return "Error: Invalid domain format"
try:
backend, _ = get_backend_and_prefix(domain)
state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
backend_servers = parsed_state.get(backend, {})
if not backend_servers:
return f"Backend {backend} not found"
servers = []
for server_name, srv_info in backend_servers.items():
addr = srv_info["addr"]
port = srv_info["port"]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(f"{server_name}: {addr}:{port} ({status})")
return f"Servers for {domain} ({backend}):\n" + "\n".join(servers)
except (HaproxyError, ValueError) as e:
return f"Error: {e}"
def _haproxy_add_server_impl(domain: str, slot: int, ip: str, http_port: int) -> str:
"""Implementation of haproxy_add_server."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if not ip:
return "Error: IP address is required"
if not validate_ip(ip):
return "Error: Invalid IP address format"
if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535"
try:
backend, server_prefix = get_backend_and_prefix(domain)
# Auto-select slot if slot <= 0
if slot <= 0:
state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
backend_servers = parsed_state.get(backend, {})
used_slots: set[int] = set()
for server_name, srv_info in backend_servers.items():
if srv_info["addr"] != "0.0.0.0":
# Extract slot number from server name (e.g., pool_1_3 -> 3)
try:
used_slots.add(int(server_name.rsplit("_", 1)[1]))
except (ValueError, IndexError):
pass
for s in range(1, MAX_SLOTS + 1):
if s not in used_slots:
slot = s
break
else:
return f"Error: No available slots (all {MAX_SLOTS} slots in use)"
elif not (1 <= slot <= MAX_SLOTS):
return f"Error: Slot must be between 1 and {MAX_SLOTS}, or 0/-1 for auto-select"
# Save to persistent config FIRST (disk-first pattern)
add_server_to_config(domain, slot, ip, http_port)
try:
server = configure_server_slot(backend, server_prefix, slot, ip, http_port)
return f"Added to {domain} ({backend}) slot {slot}:\n{server}{ip}:{http_port}"
except HaproxyError as e:
# Rollback config on HAProxy failure
remove_server_from_config(domain, slot)
return f"Error: {e}"
except (ValueError, IOError) as e:
return f"Error: {e}"
def _haproxy_add_servers_impl(domain: str, servers: str) -> str:
"""Implementation of haproxy_add_servers."""
if not validate_domain(domain):
return "Error: Invalid domain format"
# Check JSON size before parsing
if len(servers) > MAX_SERVERS_JSON_SIZE:
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
# Parse JSON array
try:
server_list = json.loads(servers)
except json.JSONDecodeError as e:
return f"Error: Invalid JSON - {e}"
if not isinstance(server_list, list):
return "Error: servers must be a JSON array"
if not server_list:
return "Error: servers array is empty"
if len(server_list) > MAX_BULK_SERVERS:
return f"Error: Cannot add more than {MAX_BULK_SERVERS} servers at once"
# Validate all servers first before adding any
validated_servers = []
validation_errors = []
for i, srv in enumerate(server_list):
if not isinstance(srv, dict):
validation_errors.append(f"Server {i+1}: must be an object")
continue
# Extract and validate slot
slot = srv.get("slot")
if slot is None:
validation_errors.append(f"Server {i+1}: missing 'slot' field")
continue
try:
slot = int(slot)
except (ValueError, TypeError):
validation_errors.append(f"Server {i+1}: slot must be an integer")
continue
if not (1 <= slot <= MAX_SLOTS):
validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}")
continue
# Extract and validate IP
ip = srv.get("ip")
if not ip:
validation_errors.append(f"Server {i+1}: missing 'ip' field")
continue
if not validate_ip(ip):
validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'")
continue
# Extract and validate port
http_port = srv.get("http_port", 80)
try:
http_port = int(http_port)
except (ValueError, TypeError):
validation_errors.append(f"Server {i+1}: http_port must be an integer")
continue
if not validate_port_int(http_port):
validation_errors.append(f"Server {i+1}: port must be between 1 and 65535")
continue
validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port})
# Return validation errors if any
if validation_errors:
return "Validation errors:\n" + "\n".join(f"{e}" for e in validation_errors)
# Check for duplicate slots
slots = [s["slot"] for s in validated_servers]
if len(slots) != len(set(slots)):
return "Error: Duplicate slot numbers in servers array"
# Get backend info
try:
backend, server_prefix = get_backend_and_prefix(domain)
except ValueError as e:
return f"Error: {e}"
# Save ALL servers to config FIRST (disk-first pattern)
for server_config in validated_servers:
slot = server_config["slot"]
ip = server_config["ip"]
http_port = server_config["http_port"]
add_server_to_config(domain, slot, ip, http_port)
# Then update HAProxy
added = []
errors = []
failed_slots = []
successfully_added_slots = []
try:
for server_config in validated_servers:
slot = server_config["slot"]
ip = server_config["ip"]
http_port = server_config["http_port"]
try:
configure_server_slot(backend, server_prefix, slot, ip, http_port)
successfully_added_slots.append(slot)
added.append(f"slot {slot}: {ip}:{http_port}")
except HaproxyError as e:
failed_slots.append(slot)
errors.append(f"slot {slot}: {e}")
except (IOError, OSError, ValueError) as e:
# Rollback only successfully added configs on unexpected error
logger.error("Unexpected error during bulk server add for %s: %s", domain, e)
for slot in successfully_added_slots:
try:
remove_server_from_config(domain, slot)
except (IOError, OSError, ValueError) as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
)
# Also rollback configs that weren't yet processed
for server_config in validated_servers:
slot = server_config["slot"]
if slot not in successfully_added_slots:
try:
remove_server_from_config(domain, slot)
except (IOError, OSError, ValueError) as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
)
return f"Error: {e}"
# Rollback failed slots from config
for slot in failed_slots:
try:
remove_server_from_config(domain, slot)
except (IOError, OSError, ValueError) as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
)
# Build result message
result_parts = []
if added:
result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):")
result_parts.extend(f"{s}" for s in added)
if errors:
result_parts.append(f"Failed to add {len(errors)} servers:")
result_parts.extend(f"{e}" for e in errors)
return "\n".join(result_parts) if result_parts else "No servers added"
def _haproxy_remove_server_impl(domain: str, slot: int) -> str:
"""Implementation of haproxy_remove_server."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if not (1 <= slot <= MAX_SLOTS):
return f"Error: Slot must be between 1 and {MAX_SLOTS}"
try:
backend, server_prefix = get_backend_and_prefix(domain)
# Get current server info for potential rollback
config = load_servers_config()
old_config = config.get(domain, {}).get(str(slot), {})
# Remove from persistent config FIRST (disk-first pattern)
remove_server_from_config(domain, slot)
try:
# HTTP only - single server per slot
server = f"{server_prefix}_{slot}"
disable_server_slot(backend, server)
return f"Removed server at slot {slot} from {domain} ({backend})"
except HaproxyError as e:
# Rollback: re-add config if HAProxy command failed
if old_config:
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
return f"Error: {e}"
except (ValueError, IOError) as e:
return f"Error: {e}"
def _haproxy_set_domain_state_impl(domain: str, state: str) -> str:
"""Implementation of haproxy_set_domain_state."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if state not in ["ready", "drain", "maint"]:
return "Error: State must be 'ready', 'drain', or 'maint'"
try:
backend, _ = get_backend_and_prefix(domain)
except ValueError as e:
return f"Error: {e}"
# Get active servers for this domain
try:
servers_state = haproxy_cmd("show servers state")
except HaproxyError as e:
return f"Error: {e}"
parsed_state = parse_servers_state(servers_state)
backend_servers = parsed_state.get(backend, {})
changed = []
errors = []
for server_name, srv_info in backend_servers.items():
# Only change state for configured servers (not 0.0.0.0)
if srv_info["addr"] != "0.0.0.0":
try:
haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}")
changed.append(server_name)
except HaproxyError as e:
errors.append(f"{server_name}: {e}")
if not changed and not errors:
return f"No active servers found for {domain}"
result = f"Set {len(changed)} servers to '{state}' for {domain}"
if changed:
result += ":\n" + "\n".join(f"{s}" for s in changed)
if errors:
result += f"\n\nErrors ({len(errors)}):\n" + "\n".join(f"{e}" for e in errors)
return result
def _haproxy_wait_drain_impl(domain: str, timeout: int) -> str:
"""Implementation of haproxy_wait_drain."""
if not validate_domain(domain):
return "Error: Invalid domain format"
if not (1 <= timeout <= 300):
return "Error: Timeout must be between 1 and 300 seconds"
try:
backend, _ = get_backend_and_prefix(domain)
except ValueError as e:
return f"Error: {e}"
start_time = time.time()
while time.time() - start_time < timeout:
try:
stats = haproxy_cmd("show stat")
total_connections = 0
for line in stats.split("\n"):
parts = line.split(",")
if len(parts) > StatField.SCUR and parts[0] == backend and parts[1] not in ["FRONTEND", "BACKEND", ""]:
try:
scur = int(parts[StatField.SCUR]) if parts[StatField.SCUR] else 0
total_connections += scur
except ValueError:
pass
if total_connections == 0:
elapsed = int(time.time() - start_time)
return f"All connections drained for {domain} ({elapsed}s)"
time.sleep(1)
except HaproxyError as e:
return f"Error checking connections: {e}"
return f"Timeout: Connections still active after {timeout}s"
def _haproxy_set_server_state_impl(backend: str, server: str, state: str) -> str:
"""Implementation of haproxy_set_server_state."""
if not validate_backend_name(backend):
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server):
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
if state not in ["ready", "drain", "maint"]:
return "Error: state must be 'ready', 'drain', or 'maint'"
try:
haproxy_cmd_checked(f"set server {backend}/{server} state {state}")
return f"Server {backend}/{server} set to {state}"
except HaproxyError as e:
return f"Error: {e}"
def _haproxy_set_server_weight_impl(backend: str, server: str, weight: int) -> str:
"""Implementation of haproxy_set_server_weight."""
if not validate_backend_name(backend):
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server):
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
if not (0 <= weight <= 256):
return "Error: weight must be between 0 and 256"
try:
haproxy_cmd_checked(f"set server {backend}/{server} weight {weight}")
return f"Server {backend}/{server} weight set to {weight}"
except HaproxyError as e:
return f"Error: {e}"
# =============================================================================
# MCP Tool Registration
# =============================================================================
def register_server_tools(mcp): def register_server_tools(mcp):
"""Register server management tools with MCP server.""" """Register server management tools with MCP server."""
@@ -65,29 +448,7 @@ def register_server_tools(mcp):
# Output: pool_1_1: 10.0.0.1:8080 (UP) # Output: pool_1_1: 10.0.0.1:8080 (UP)
# pool_1_2: 10.0.0.2:8080 (UP) # pool_1_2: 10.0.0.2:8080 (UP)
""" """
if not validate_domain(domain): return _haproxy_list_servers_impl(domain)
return "Error: Invalid domain format"
try:
backend, _ = get_backend_and_prefix(domain)
servers = []
state = haproxy_cmd("show servers state")
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
addr = parts[StateField.SRV_ADDR]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(
f"{parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
)
if not servers:
return f"Backend {backend} not found"
return f"Servers for {domain} ({backend}):\n" + "\n".join(servers)
except (HaproxyError, ValueError) as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_add_server( def haproxy_add_server(
@@ -103,53 +464,7 @@ def register_server_tools(mcp):
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080) Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
""" """
if not validate_domain(domain): return _haproxy_add_server_impl(domain, slot, ip, http_port)
return "Error: Invalid domain format"
if not ip:
return "Error: IP address is required"
if not validate_ip(ip):
return "Error: Invalid IP address format"
if not (1 <= http_port <= 65535):
return "Error: Port must be between 1 and 65535"
try:
backend, server_prefix = get_backend_and_prefix(domain)
# Auto-select slot if slot <= 0
if slot <= 0:
state = haproxy_cmd("show servers state")
used_slots: set[int] = set()
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
if parts[StateField.SRV_ADDR] != "0.0.0.0":
# Extract slot number from server name (e.g., pool_1_3 -> 3)
server_name = parts[StateField.SRV_NAME]
try:
used_slots.add(int(server_name.rsplit("_", 1)[1]))
except (ValueError, IndexError):
pass
for s in range(1, MAX_SLOTS + 1):
if s not in used_slots:
slot = s
break
else:
return f"Error: No available slots (all {MAX_SLOTS} slots in use)"
elif not (1 <= slot <= MAX_SLOTS):
return f"Error: Slot must be between 1 and {MAX_SLOTS}, or 0/-1 for auto-select"
# Save to persistent config FIRST (disk-first pattern)
add_server_to_config(domain, slot, ip, http_port)
try:
server = configure_server_slot(backend, server_prefix, slot, ip, http_port)
return f"Added to {domain} ({backend}) slot {slot}:\n{server}{ip}:{http_port}"
except HaproxyError as e:
# Rollback config on HAProxy failure
remove_server_from_config(domain, slot)
return f"Error: {e}"
except (ValueError, IOError) as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_add_servers( def haproxy_add_servers(
@@ -160,156 +475,7 @@ def register_server_tools(mcp):
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]') Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]')
""" """
if not validate_domain(domain): return _haproxy_add_servers_impl(domain, servers)
return "Error: Invalid domain format"
# Check JSON size before parsing
if len(servers) > MAX_SERVERS_JSON_SIZE:
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
# Parse JSON array
try:
server_list = json.loads(servers)
except json.JSONDecodeError as e:
return f"Error: Invalid JSON - {e}"
if not isinstance(server_list, list):
return "Error: servers must be a JSON array"
if not server_list:
return "Error: servers array is empty"
if len(server_list) > MAX_BULK_SERVERS:
return f"Error: Cannot add more than {MAX_BULK_SERVERS} servers at once"
# Validate all servers first before adding any
validated_servers = []
validation_errors = []
for i, srv in enumerate(server_list):
if not isinstance(srv, dict):
validation_errors.append(f"Server {i+1}: must be an object")
continue
# Extract and validate slot
slot = srv.get("slot")
if slot is None:
validation_errors.append(f"Server {i+1}: missing 'slot' field")
continue
try:
slot = int(slot)
except (ValueError, TypeError):
validation_errors.append(f"Server {i+1}: slot must be an integer")
continue
if not (1 <= slot <= MAX_SLOTS):
validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}")
continue
# Extract and validate IP
ip = srv.get("ip")
if not ip:
validation_errors.append(f"Server {i+1}: missing 'ip' field")
continue
if not validate_ip(ip):
validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'")
continue
# Extract and validate port
http_port = srv.get("http_port", 80)
try:
http_port = int(http_port)
except (ValueError, TypeError):
validation_errors.append(f"Server {i+1}: http_port must be an integer")
continue
if not (1 <= http_port <= 65535):
validation_errors.append(f"Server {i+1}: port must be between 1 and 65535")
continue
validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port})
# Return validation errors if any
if validation_errors:
return "Validation errors:\n" + "\n".join(f"{e}" for e in validation_errors)
# Check for duplicate slots
slots = [s["slot"] for s in validated_servers]
if len(slots) != len(set(slots)):
return "Error: Duplicate slot numbers in servers array"
# Get backend info
try:
backend, server_prefix = get_backend_and_prefix(domain)
except ValueError as e:
return f"Error: {e}"
# Save ALL servers to config FIRST (disk-first pattern)
for server_config in validated_servers:
slot = server_config["slot"]
ip = server_config["ip"]
http_port = server_config["http_port"]
add_server_to_config(domain, slot, ip, http_port)
# Then update HAProxy
added = []
errors = []
failed_slots = []
successfully_added_slots = []
try:
for server_config in validated_servers:
slot = server_config["slot"]
ip = server_config["ip"]
http_port = server_config["http_port"]
try:
configure_server_slot(backend, server_prefix, slot, ip, http_port)
successfully_added_slots.append(slot)
added.append(f"slot {slot}: {ip}:{http_port}")
except HaproxyError as e:
failed_slots.append(slot)
errors.append(f"slot {slot}: {e}")
except Exception as e:
# Rollback only successfully added configs on unexpected error
for slot in successfully_added_slots:
try:
remove_server_from_config(domain, slot)
except Exception as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
)
# Also rollback configs that weren't yet processed
for server_config in validated_servers:
slot = server_config["slot"]
if slot not in successfully_added_slots:
try:
remove_server_from_config(domain, slot)
except Exception as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
)
return f"Error: {e}"
# Rollback failed slots from config
for slot in failed_slots:
try:
remove_server_from_config(domain, slot)
except Exception as rollback_error:
logger.error(
"Failed to rollback server config for %s slot %d: %s",
domain, slot, rollback_error
)
# Build result message
result_parts = []
if added:
result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):")
result_parts.extend(f"{s}" for s in added)
if errors:
result_parts.append(f"Failed to add {len(errors)} servers:")
result_parts.extend(f"{e}" for e in errors)
return "\n".join(result_parts) if result_parts else "No servers added"
@mcp.tool() @mcp.tool()
def haproxy_remove_server( def haproxy_remove_server(
@@ -320,37 +486,7 @@ def register_server_tools(mcp):
Example: haproxy_remove_server("api.example.com", slot=2) Example: haproxy_remove_server("api.example.com", slot=2)
""" """
if not validate_domain(domain): return _haproxy_remove_server_impl(domain, slot)
return "Error: Invalid domain format"
if not (1 <= slot <= MAX_SLOTS):
return f"Error: Slot must be between 1 and {MAX_SLOTS}"
try:
backend, server_prefix = get_backend_and_prefix(domain)
# Get current server info for potential rollback
config = load_servers_config()
old_config = config.get(domain, {}).get(str(slot), {})
# Remove from persistent config FIRST (disk-first pattern)
remove_server_from_config(domain, slot)
try:
# HTTP only - single server per slot
server = f"{server_prefix}_{slot}"
# Batch both commands in single TCP connection
haproxy_cmd_batch([
f"set server {backend}/{server} state maint",
f"set server {backend}/{server} addr 0.0.0.0 port 0"
])
return f"Removed server at slot {slot} from {domain} ({backend})"
except HaproxyError as e:
# Rollback: re-add config if HAProxy command failed
if old_config:
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
return f"Error: {e}"
except (ValueError, IOError) as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_set_domain_state( def haproxy_set_domain_state(
@@ -368,48 +504,7 @@ def register_server_tools(mcp):
# Re-enable all servers after deployment # Re-enable all servers after deployment
haproxy_set_domain_state("api.example.com", "ready") haproxy_set_domain_state("api.example.com", "ready")
""" """
if not validate_domain(domain): return _haproxy_set_domain_state_impl(domain, state)
return "Error: Invalid domain format"
if state not in ["ready", "drain", "maint"]:
return "Error: State must be 'ready', 'drain', or 'maint'"
try:
backend, _ = get_backend_and_prefix(domain)
except ValueError as e:
return f"Error: {e}"
# Get active servers for this domain
try:
servers_state = haproxy_cmd("show servers state")
except HaproxyError as e:
return f"Error: {e}"
changed = []
errors = []
for line in servers_state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
# Only change state for configured servers (not 0.0.0.0)
if addr != "0.0.0.0":
try:
haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}")
changed.append(server_name)
except HaproxyError as e:
errors.append(f"{server_name}: {e}")
if not changed and not errors:
return f"No active servers found for {domain}"
result = f"Set {len(changed)} servers to '{state}' for {domain}"
if changed:
result += ":\n" + "\n".join(f"{s}" for s in changed)
if errors:
result += f"\n\nErrors ({len(errors)}):\n" + "\n".join(f"{e}" for e in errors)
return result
@mcp.tool() @mcp.tool()
def haproxy_wait_drain( def haproxy_wait_drain(
@@ -422,39 +517,7 @@ def register_server_tools(mcp):
Example: haproxy_wait_drain("api.example.com", timeout=60) Example: haproxy_wait_drain("api.example.com", timeout=60)
""" """
if not validate_domain(domain): return _haproxy_wait_drain_impl(domain, timeout)
return "Error: Invalid domain format"
if not (1 <= timeout <= 300):
return "Error: Timeout must be between 1 and 300 seconds"
try:
backend, _ = get_backend_and_prefix(domain)
except ValueError as e:
return f"Error: {e}"
start_time = time.time()
while time.time() - start_time < timeout:
try:
stats = haproxy_cmd("show stat")
total_connections = 0
for line in stats.split("\n"):
parts = line.split(",")
if len(parts) > StatField.SCUR and parts[0] == backend and parts[1] not in ["FRONTEND", "BACKEND", ""]:
try:
scur = int(parts[StatField.SCUR]) if parts[StatField.SCUR] else 0
total_connections += scur
except ValueError:
pass
if total_connections == 0:
elapsed = int(time.time() - start_time)
return f"All connections drained for {domain} ({elapsed}s)"
time.sleep(1)
except HaproxyError as e:
return f"Error checking connections: {e}"
return f"Timeout: Connections still active after {timeout}s"
@mcp.tool() @mcp.tool()
def haproxy_set_server_state( def haproxy_set_server_state(
@@ -466,17 +529,7 @@ def register_server_tools(mcp):
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint") Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint")
""" """
if not validate_backend_name(backend): return _haproxy_set_server_state_impl(backend, server, state)
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server):
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
if state not in ["ready", "drain", "maint"]:
return "Error: state must be 'ready', 'drain', or 'maint'"
try:
haproxy_cmd_checked(f"set server {backend}/{server} state {state}")
return f"Server {backend}/{server} set to {state}"
except HaproxyError as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
def haproxy_set_server_weight( def haproxy_set_server_weight(
@@ -488,14 +541,4 @@ def register_server_tools(mcp):
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2) Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2)
""" """
if not validate_backend_name(backend): return _haproxy_set_server_weight_impl(backend, server, weight)
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
if not validate_backend_name(server):
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
if not (0 <= weight <= 256):
return "Error: weight must be between 0 and 256"
try:
haproxy_cmd_checked(f"set server {backend}/{server} weight {weight}")
return f"Server {backend}/{server} weight set to {weight}"
except HaproxyError as e:
return f"Error: {e}"

View File

@@ -1,7 +1,58 @@
"""Utility functions for HAProxy MCP Server.""" """Utility functions for HAProxy MCP Server."""
from typing import Dict, Generator from typing import Dict, Generator
from .config import StatField from .config import StatField, StateField, STATE_MIN_COLUMNS
from .haproxy_client import haproxy_cmd_batch
def parse_servers_state(state_output: str) -> dict[str, dict[str, dict[str, str]]]:
"""Parse 'show servers state' output.
Args:
state_output: Raw output from HAProxy 'show servers state' command
Returns:
Nested dict: {backend: {server: {addr: str, port: str, state: str}}}
Example:
state = haproxy_cmd("show servers state")
parsed = parse_servers_state(state)
# parsed["pool_1"]["pool_1_1"] == {"addr": "10.0.0.1", "port": "8080", "state": "2"}
"""
result: dict[str, dict[str, dict[str, str]]] = {}
for line in state_output.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS:
backend = parts[StateField.BE_NAME]
server = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
port = parts[StateField.SRV_PORT]
state = parts[StateField.SRV_OP_STATE] if len(parts) > StateField.SRV_OP_STATE else ""
if backend not in result:
result[backend] = {}
result[backend][server] = {
"addr": addr,
"port": port,
"state": state,
}
return result
def disable_server_slot(backend: str, server: str) -> None:
"""Disable a server slot (set to maint and clear address).
Args:
backend: Backend name (e.g., 'pool_1')
server: Server name (e.g., 'pool_1_1')
Raises:
HaproxyError: If HAProxy command fails
"""
haproxy_cmd_batch([
f"set server {backend}/{server} state maint",
f"set server {backend}/{server} addr 0.0.0.0 port 0"
])
def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]: def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]:

View File

@@ -52,6 +52,18 @@ def validate_port(port: str) -> bool:
return 1 <= port_num <= 65535 return 1 <= port_num <= 65535
def validate_port_int(port: int) -> bool:
"""Validate port number as integer is in valid range.
Args:
port: Port number as integer
Returns:
True if port is valid (1-65535), False otherwise
"""
return isinstance(port, int) and 1 <= port <= 65535
def validate_backend_name(name: str) -> bool: def validate_backend_name(name: str) -> bool:
"""Validate backend or server name to prevent command injection. """Validate backend or server name to prevent command injection.

View File

@@ -1078,9 +1078,10 @@ class TestLoadCertToHaproxyError:
pem_file = certs_dir / "example.com.pem" pem_file = certs_dir / "example.com.pem"
pem_file.write_text("cert content") pem_file.write_text("cert content")
# Mock haproxy_cmd to raise exception # Mock haproxy_cmd to raise HaproxyError
from haproxy_mcp.exceptions import HaproxyError
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=Exception("Connection failed")): with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=HaproxyError("Connection failed")):
from haproxy_mcp.tools.certificates import load_cert_to_haproxy from haproxy_mcp.tools.certificates import load_cert_to_haproxy
success, msg = load_cert_to_haproxy("example.com") success, msg = load_cert_to_haproxy("example.com")

View File

@@ -547,7 +547,7 @@ class TestStartupRestoreFailures:
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0): with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")): with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=IOError("Certificate error")):
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
from haproxy_mcp.tools.configuration import startup_restore from haproxy_mcp.tools.configuration import startup_restore
@@ -598,7 +598,7 @@ class TestHaproxyReloadFailures:
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")): with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")): with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Restore failed")):
with patch("time.sleep", return_value=None): with patch("time.sleep", return_value=None):
from haproxy_mcp.tools.configuration import register_config_tools from haproxy_mcp.tools.configuration import register_config_tools
mcp = MagicMock() mcp = MagicMock()

View File

@@ -756,8 +756,8 @@ class TestHaproxyAddServersRollback:
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2: if slot == 2:
# Simulate unexpected error (not HaproxyError) # Simulate unexpected error (IOError is caught by the exception handler)
raise RuntimeError("Unexpected system error") raise IOError("Unexpected system error")
configured_slots.append(slot) configured_slots.append(slot)
return f"{server_prefix}_{slot}" return f"{server_prefix}_{slot}"
@@ -807,7 +807,7 @@ class TestHaproxyAddServersRollback:
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2: if slot == 2:
raise RuntimeError("Unexpected error") raise OSError("Unexpected error")
return f"{server_prefix}_{slot}" return f"{server_prefix}_{slot}"
def mock_remove_server_from_config(domain, slot): def mock_remove_server_from_config(domain, slot):