refactor: Extract large functions, improve exception handling, remove duplicates
## Large function extraction - servers.py: Extract 8 _impl functions from register_server_tools (449 lines) - certificates.py: Extract 7 _impl functions from register_certificate_tools (386 lines) - MCP tool wrappers now delegate to module-level implementation functions ## Exception handling improvements - Replace 11 broad `except Exception` with specific types - health.py: (OSError, subprocess.SubprocessError) - configuration.py: (HaproxyError, IOError, OSError, ValueError) - servers.py: (IOError, OSError, ValueError) - certificates.py: FileNotFoundError, (subprocess.SubprocessError, OSError) ## Duplicate code extraction - Add parse_servers_state() to utils.py (replaces 4 duplicate parsers) - Add disable_server_slot() to utils.py (replaces duplicate patterns) - Update health.py, servers.py, domains.py to use new helpers ## Other improvements - Add TypedDict types in file_ops.py and health.py - Set file permissions (0o600) for sensitive files - Update tests to use specific exception types Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -94,8 +94,8 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
|
|||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
|
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
|
||||||
except OSError:
|
except OSError as e:
|
||||||
pass # Continue without lock if not supported
|
logger.debug("File locking not supported for %s: %s", file_path, e)
|
||||||
try:
|
try:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
@@ -107,10 +107,10 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]:
|
|||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||||
except OSError:
|
except OSError as e:
|
||||||
pass
|
logger.debug("File unlock failed for %s: %s", file_path, e)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
logger.debug("Map file not found: %s", file_path)
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
|
||||||
@@ -420,16 +420,16 @@ def load_certs_config() -> list[str]:
|
|||||||
with open(CERTS_FILE, "r", encoding="utf-8") as f:
|
with open(CERTS_FILE, "r", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
|
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
|
||||||
except OSError:
|
except OSError as e:
|
||||||
pass
|
logger.debug("File locking not supported for %s: %s", CERTS_FILE, e)
|
||||||
try:
|
try:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data.get("domains", [])
|
return data.get("domains", [])
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||||
except OSError:
|
except OSError as e:
|
||||||
pass
|
logger.debug("File unlock failed for %s: %s", CERTS_FILE, e)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return []
|
return []
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ def haproxy_cmd(command: str) -> str:
|
|||||||
raise HaproxyError("Invalid UTF-8 in response")
|
raise HaproxyError("Invalid UTF-8 in response")
|
||||||
except HaproxyError:
|
except HaproxyError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except (OSError, BlockingIOError, BrokenPipeError) as e:
|
||||||
raise HaproxyError(str(e)) from e
|
raise HaproxyError(f"Socket error: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
def haproxy_cmd_checked(command: str) -> str:
|
def haproxy_cmd_checked(command: str) -> str:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from ..config import (
|
|||||||
CERTS_DIR_CONTAINER,
|
CERTS_DIR_CONTAINER,
|
||||||
ACME_HOME,
|
ACME_HOME,
|
||||||
)
|
)
|
||||||
|
from ..exceptions import HaproxyError
|
||||||
from ..validation import validate_domain
|
from ..validation import validate_domain
|
||||||
from ..haproxy_client import haproxy_cmd
|
from ..haproxy_client import haproxy_cmd
|
||||||
from ..file_ops import (
|
from ..file_ops import (
|
||||||
@@ -77,7 +78,11 @@ def load_cert_to_haproxy(domain: str) -> tuple[bool, str]:
|
|||||||
haproxy_cmd(f"commit ssl cert {container_path}")
|
haproxy_cmd(f"commit ssl cert {container_path}")
|
||||||
return True, "added"
|
return True, "added"
|
||||||
|
|
||||||
except Exception as e:
|
except HaproxyError as e:
|
||||||
|
logger.error("HAProxy error loading certificate %s: %s", domain, e)
|
||||||
|
return False, str(e)
|
||||||
|
except (IOError, OSError) as e:
|
||||||
|
logger.error("File error loading certificate %s: %s", domain, e)
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,7 +107,8 @@ def unload_cert_from_haproxy(domain: str) -> tuple[bool, str]:
|
|||||||
haproxy_cmd(f"del ssl cert {container_path}")
|
haproxy_cmd(f"del ssl cert {container_path}")
|
||||||
return True, "unloaded"
|
return True, "unloaded"
|
||||||
|
|
||||||
except Exception as e:
|
except HaproxyError as e:
|
||||||
|
logger.error("HAProxy error unloading certificate %s: %s", domain, e)
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
|
|
||||||
@@ -126,6 +132,364 @@ def restore_certificates() -> int:
|
|||||||
return restored
|
return restored
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Implementation functions (module-level)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_list_certs_impl() -> str:
|
||||||
|
"""Implementation of haproxy_list_certs."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[ACME_SH, "--list"],
|
||||||
|
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
|
||||||
|
env={**os.environ, "HOME": os.path.expanduser("~")}
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return f"Error: {result.stderr}"
|
||||||
|
|
||||||
|
lines = result.stdout.strip().split("\n")
|
||||||
|
if len(lines) <= 1:
|
||||||
|
return "No certificates found"
|
||||||
|
|
||||||
|
# Get HAProxy loaded certs
|
||||||
|
try:
|
||||||
|
haproxy_certs = haproxy_cmd("show ssl cert")
|
||||||
|
except HaproxyError as e:
|
||||||
|
logger.debug("Could not get HAProxy certs: %s", e)
|
||||||
|
haproxy_certs = ""
|
||||||
|
|
||||||
|
# Parse and format output
|
||||||
|
certs = []
|
||||||
|
for line in lines[1:]: # Skip header
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 4:
|
||||||
|
domain = parts[0]
|
||||||
|
ca = "unknown"
|
||||||
|
created = "unknown"
|
||||||
|
renew = "unknown"
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if "Google" in part or "LetsEncrypt" in part or "ZeroSSL" in part:
|
||||||
|
ca = part
|
||||||
|
elif part.endswith("Z") and "T" in part:
|
||||||
|
if created == "unknown":
|
||||||
|
created = part
|
||||||
|
else:
|
||||||
|
renew = part
|
||||||
|
|
||||||
|
# Check deployment status
|
||||||
|
host_path, container_path = get_pem_paths(domain)
|
||||||
|
if container_path in haproxy_certs:
|
||||||
|
status = "loaded"
|
||||||
|
elif os.path.exists(host_path):
|
||||||
|
status = "file exists (not loaded)"
|
||||||
|
else:
|
||||||
|
status = "not deployed"
|
||||||
|
|
||||||
|
certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
|
||||||
|
|
||||||
|
return "\n\n".join(certs) if certs else "No certificates found"
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return "Error: Command timed out"
|
||||||
|
except FileNotFoundError:
|
||||||
|
return "Error: acme.sh not found"
|
||||||
|
except subprocess.SubprocessError as e:
|
||||||
|
logger.error("Subprocess error listing certificates: %s", e)
|
||||||
|
return f"Error: {e}"
|
||||||
|
except OSError as e:
|
||||||
|
logger.error("OS error listing certificates: %s", e)
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_cert_info_impl(domain: str) -> str:
|
||||||
|
"""Implementation of haproxy_cert_info."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
host_path, container_path = get_pem_paths(domain)
|
||||||
|
if not os.path.exists(host_path):
|
||||||
|
return f"Error: Certificate not found for {domain}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use openssl to get certificate info
|
||||||
|
result = subprocess.run(
|
||||||
|
["openssl", "x509", "-in", host_path, "-noout",
|
||||||
|
"-subject", "-issuer", "-dates", "-ext", "subjectAltName"],
|
||||||
|
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return f"Error reading certificate: {result.stderr}"
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
stat = os.stat(host_path)
|
||||||
|
modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Check HAProxy status
|
||||||
|
try:
|
||||||
|
haproxy_certs = haproxy_cmd("show ssl cert")
|
||||||
|
loaded = "Yes" if container_path in haproxy_certs else "No"
|
||||||
|
except HaproxyError as e:
|
||||||
|
logger.debug("Could not check HAProxy cert status: %s", e)
|
||||||
|
loaded = "Unknown"
|
||||||
|
|
||||||
|
info = [
|
||||||
|
f"Certificate: {domain}",
|
||||||
|
f"File: {host_path}",
|
||||||
|
f"Modified: {modified}",
|
||||||
|
f"Loaded in HAProxy: {loaded}",
|
||||||
|
"---",
|
||||||
|
result.stdout.strip()
|
||||||
|
]
|
||||||
|
return "\n".join(info)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return "Error: Command timed out"
|
||||||
|
except (subprocess.SubprocessError, OSError) as e:
|
||||||
|
logger.error("Error getting certificate info for %s: %s", domain, e)
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
|
||||||
|
"""Implementation of haproxy_issue_cert."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
# Check if CF_Token is available
|
||||||
|
if not os.environ.get("CF_Token"):
|
||||||
|
secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini")
|
||||||
|
if os.path.exists(secrets_file):
|
||||||
|
try:
|
||||||
|
with open(secrets_file) as f:
|
||||||
|
for line in f:
|
||||||
|
if "=" in line and "token" in line.lower():
|
||||||
|
token = line.split("=", 1)[1].strip().strip('"').strip("'")
|
||||||
|
os.environ["CF_Token"] = token
|
||||||
|
break
|
||||||
|
except (IOError, OSError) as e:
|
||||||
|
logger.warning("Failed to read Cloudflare token: %s", e)
|
||||||
|
|
||||||
|
if not os.environ.get("CF_Token"):
|
||||||
|
return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini"
|
||||||
|
|
||||||
|
# Check if certificate already exists
|
||||||
|
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
|
||||||
|
if os.path.exists(cert_dir):
|
||||||
|
return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew."
|
||||||
|
|
||||||
|
# Build acme.sh command (without reload - we'll do it via Runtime API)
|
||||||
|
host_path, _ = get_pem_paths(domain)
|
||||||
|
|
||||||
|
# Create PEM after issuance
|
||||||
|
install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
ACME_SH, "--issue",
|
||||||
|
"--dns", "dns_cf",
|
||||||
|
"-d", domain
|
||||||
|
]
|
||||||
|
|
||||||
|
if wildcard:
|
||||||
|
cmd.extend(["-d", f"*.{domain}"])
|
||||||
|
|
||||||
|
cmd.extend(["--reloadcmd", install_cmd])
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Issuing certificate for %s", domain)
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True, text=True, timeout=CERT_TIMEOUT,
|
||||||
|
env={**os.environ, "HOME": os.path.expanduser("~")}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
error_msg = result.stderr or result.stdout
|
||||||
|
return f"Error issuing certificate:\n{error_msg}"
|
||||||
|
|
||||||
|
# Load into HAProxy via Runtime API (zero-downtime)
|
||||||
|
if os.path.exists(host_path):
|
||||||
|
success, msg = load_cert_to_haproxy(domain)
|
||||||
|
if success:
|
||||||
|
# Save to config for persistence
|
||||||
|
add_cert_to_config(domain)
|
||||||
|
return f"Certificate issued and loaded for {domain} ({msg})"
|
||||||
|
else:
|
||||||
|
return f"Certificate issued but HAProxy loading failed: {msg}"
|
||||||
|
else:
|
||||||
|
return f"Certificate issued but PEM file not created. Check {host_path}"
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
|
||||||
|
except (subprocess.SubprocessError, OSError) as e:
|
||||||
|
logger.error("Error issuing certificate for %s: %s", domain, e)
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
|
||||||
|
"""Implementation of haproxy_renew_cert."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
|
||||||
|
if not os.path.exists(cert_dir):
|
||||||
|
return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first."
|
||||||
|
|
||||||
|
cmd = [ACME_SH, "--renew", "-d", domain]
|
||||||
|
if force:
|
||||||
|
cmd.append("--force")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Renewing certificate for %s", domain)
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True, text=True, timeout=CERT_TIMEOUT,
|
||||||
|
env={**os.environ, "HOME": os.path.expanduser("~")}
|
||||||
|
)
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
|
||||||
|
if "Skip" in output and "Not yet due" in output:
|
||||||
|
return f"Certificate for {domain} not due for renewal. Use force=True to force renewal."
|
||||||
|
|
||||||
|
if "Cert success" in output or result.returncode == 0:
|
||||||
|
# Reload into HAProxy via Runtime API
|
||||||
|
success, msg = load_cert_to_haproxy(domain)
|
||||||
|
if success:
|
||||||
|
# Ensure in config
|
||||||
|
add_cert_to_config(domain)
|
||||||
|
return f"Certificate renewed and reloaded for {domain} ({msg})"
|
||||||
|
else:
|
||||||
|
return f"Certificate renewed but HAProxy reload failed: {msg}"
|
||||||
|
else:
|
||||||
|
return f"Error renewing certificate:\n{output}"
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
|
||||||
|
except FileNotFoundError:
|
||||||
|
return "Error: acme.sh not found"
|
||||||
|
except (subprocess.SubprocessError, OSError) as e:
|
||||||
|
logger.error("Error renewing certificate for %s: %s", domain, e)
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_renew_all_certs_impl() -> str:
|
||||||
|
"""Implementation of haproxy_renew_all_certs."""
|
||||||
|
try:
|
||||||
|
logger.info("Running certificate renewal cron")
|
||||||
|
result = subprocess.run(
|
||||||
|
[ACME_SH, "--cron"],
|
||||||
|
capture_output=True, text=True, timeout=CERT_TIMEOUT * 3,
|
||||||
|
env={**os.environ, "HOME": os.path.expanduser("~")}
|
||||||
|
)
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
|
||||||
|
# Count renewals
|
||||||
|
renewed = output.count("Cert success")
|
||||||
|
skipped = output.count("Skip")
|
||||||
|
|
||||||
|
# Reload any renewed certs into HAProxy
|
||||||
|
if renewed > 0:
|
||||||
|
domains = load_certs_config()
|
||||||
|
reloaded = 0
|
||||||
|
for domain in domains:
|
||||||
|
success, _ = load_cert_to_haproxy(domain)
|
||||||
|
if success:
|
||||||
|
reloaded += 1
|
||||||
|
return f"Renewed {renewed} certificate(s), reloaded {reloaded} into HAProxy"
|
||||||
|
elif skipped > 0:
|
||||||
|
return f"No certificates due for renewal ({skipped} checked)"
|
||||||
|
elif result.returncode != 0:
|
||||||
|
return f"Error running renewal:\n{output}"
|
||||||
|
else:
|
||||||
|
return "Renewal check completed"
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return "Error: Renewal cron timed out"
|
||||||
|
except FileNotFoundError:
|
||||||
|
return "Error: acme.sh not found"
|
||||||
|
except (subprocess.SubprocessError, OSError) as e:
|
||||||
|
logger.error("Error running certificate renewal cron: %s", e)
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_delete_cert_impl(domain: str) -> str:
|
||||||
|
"""Implementation of haproxy_delete_cert."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
|
||||||
|
host_path, _ = get_pem_paths(domain)
|
||||||
|
|
||||||
|
if not os.path.exists(cert_dir) and not os.path.exists(host_path):
|
||||||
|
return f"Error: No certificate found for {domain}"
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
deleted = []
|
||||||
|
|
||||||
|
# Unload from HAProxy first (zero-downtime)
|
||||||
|
success, msg = unload_cert_from_haproxy(domain)
|
||||||
|
if success:
|
||||||
|
deleted.append(f"HAProxy ({msg})")
|
||||||
|
else:
|
||||||
|
errors.append(f"HAProxy unload: {msg}")
|
||||||
|
|
||||||
|
# Remove from acme.sh
|
||||||
|
if os.path.exists(cert_dir):
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[ACME_SH, "--remove", "-d", domain],
|
||||||
|
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
|
||||||
|
env={**os.environ, "HOME": os.path.expanduser("~")}
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
deleted.append("acme.sh")
|
||||||
|
else:
|
||||||
|
errors.append(f"acme.sh: {result.stderr}")
|
||||||
|
except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e:
|
||||||
|
errors.append(f"acme.sh: {e}")
|
||||||
|
|
||||||
|
# Remove PEM file
|
||||||
|
if os.path.exists(host_path):
|
||||||
|
try:
|
||||||
|
os.remove(host_path)
|
||||||
|
deleted.append("PEM file")
|
||||||
|
except OSError as e:
|
||||||
|
errors.append(f"PEM file: {e}")
|
||||||
|
|
||||||
|
# Remove from config
|
||||||
|
remove_cert_from_config(domain)
|
||||||
|
|
||||||
|
result_parts = []
|
||||||
|
if deleted:
|
||||||
|
result_parts.append(f"Deleted: {', '.join(deleted)}")
|
||||||
|
if errors:
|
||||||
|
result_parts.append(f"Errors: {'; '.join(errors)}")
|
||||||
|
|
||||||
|
return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_load_cert_impl(domain: str) -> str:
|
||||||
|
"""Implementation of haproxy_load_cert."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
host_path, _ = get_pem_paths(domain)
|
||||||
|
if not os.path.exists(host_path):
|
||||||
|
return f"Error: PEM file not found: {host_path}"
|
||||||
|
|
||||||
|
success, msg = load_cert_to_haproxy(domain)
|
||||||
|
if success:
|
||||||
|
add_cert_to_config(domain)
|
||||||
|
return f"Certificate {domain} loaded into HAProxy ({msg})"
|
||||||
|
else:
|
||||||
|
return f"Error loading certificate: {msg}"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MCP Tool Registration
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def register_certificate_tools(mcp):
|
def register_certificate_tools(mcp):
|
||||||
"""Register certificate management tools with MCP server."""
|
"""Register certificate management tools with MCP server."""
|
||||||
|
|
||||||
@@ -136,62 +500,7 @@ def register_certificate_tools(mcp):
|
|||||||
Returns:
|
Returns:
|
||||||
List of certificates with domain, CA, created date, and renewal date
|
List of certificates with domain, CA, created date, and renewal date
|
||||||
"""
|
"""
|
||||||
try:
|
return _haproxy_list_certs_impl()
|
||||||
result = subprocess.run(
|
|
||||||
[ACME_SH, "--list"],
|
|
||||||
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
|
|
||||||
env={**os.environ, "HOME": os.path.expanduser("~")}
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
return f"Error: {result.stderr}"
|
|
||||||
|
|
||||||
lines = result.stdout.strip().split("\n")
|
|
||||||
if len(lines) <= 1:
|
|
||||||
return "No certificates found"
|
|
||||||
|
|
||||||
# Get HAProxy loaded certs
|
|
||||||
try:
|
|
||||||
haproxy_certs = haproxy_cmd("show ssl cert")
|
|
||||||
except Exception:
|
|
||||||
haproxy_certs = ""
|
|
||||||
|
|
||||||
# Parse and format output
|
|
||||||
certs = []
|
|
||||||
for line in lines[1:]: # Skip header
|
|
||||||
parts = line.split()
|
|
||||||
if len(parts) >= 4:
|
|
||||||
domain = parts[0]
|
|
||||||
ca = "unknown"
|
|
||||||
created = "unknown"
|
|
||||||
renew = "unknown"
|
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
if "Google" in part or "LetsEncrypt" in part or "ZeroSSL" in part:
|
|
||||||
ca = part
|
|
||||||
elif part.endswith("Z") and "T" in part:
|
|
||||||
if created == "unknown":
|
|
||||||
created = part
|
|
||||||
else:
|
|
||||||
renew = part
|
|
||||||
|
|
||||||
# Check deployment status
|
|
||||||
host_path, container_path = get_pem_paths(domain)
|
|
||||||
if container_path in haproxy_certs:
|
|
||||||
status = "loaded"
|
|
||||||
elif os.path.exists(host_path):
|
|
||||||
status = "file exists (not loaded)"
|
|
||||||
else:
|
|
||||||
status = "not deployed"
|
|
||||||
|
|
||||||
certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
|
|
||||||
|
|
||||||
return "\n\n".join(certs) if certs else "No certificates found"
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
return "Error: Command timed out"
|
|
||||||
except FileNotFoundError:
|
|
||||||
return "Error: acme.sh not found"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_cert_info(
|
def haproxy_cert_info(
|
||||||
@@ -201,47 +510,7 @@ def register_certificate_tools(mcp):
|
|||||||
|
|
||||||
Shows expiry date, issuer, SANs, and file paths.
|
Shows expiry date, issuer, SANs, and file paths.
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_cert_info_impl(domain)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
host_path, container_path = get_pem_paths(domain)
|
|
||||||
if not os.path.exists(host_path):
|
|
||||||
return f"Error: Certificate not found for {domain}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use openssl to get certificate info
|
|
||||||
result = subprocess.run(
|
|
||||||
["openssl", "x509", "-in", host_path, "-noout",
|
|
||||||
"-subject", "-issuer", "-dates", "-ext", "subjectAltName"],
|
|
||||||
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
return f"Error reading certificate: {result.stderr}"
|
|
||||||
|
|
||||||
# Get file info
|
|
||||||
stat = os.stat(host_path)
|
|
||||||
modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
# Check HAProxy status
|
|
||||||
try:
|
|
||||||
haproxy_certs = haproxy_cmd("show ssl cert")
|
|
||||||
loaded = "Yes" if container_path in haproxy_certs else "No"
|
|
||||||
except Exception:
|
|
||||||
loaded = "Unknown"
|
|
||||||
|
|
||||||
info = [
|
|
||||||
f"Certificate: {domain}",
|
|
||||||
f"File: {host_path}",
|
|
||||||
f"Modified: {modified}",
|
|
||||||
f"Loaded in HAProxy: {loaded}",
|
|
||||||
"---",
|
|
||||||
result.stdout.strip()
|
|
||||||
]
|
|
||||||
return "\n".join(info)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
return "Error: Command timed out"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_issue_cert(
|
def haproxy_issue_cert(
|
||||||
@@ -254,76 +523,7 @@ def register_certificate_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_issue_cert("example.com", wildcard=True)
|
Example: haproxy_issue_cert("example.com", wildcard=True)
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_issue_cert_impl(domain, wildcard)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
# Check if CF_Token is available
|
|
||||||
if not os.environ.get("CF_Token"):
|
|
||||||
secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini")
|
|
||||||
if os.path.exists(secrets_file):
|
|
||||||
try:
|
|
||||||
with open(secrets_file) as f:
|
|
||||||
for line in f:
|
|
||||||
if "=" in line and "token" in line.lower():
|
|
||||||
token = line.split("=", 1)[1].strip().strip('"').strip("'")
|
|
||||||
os.environ["CF_Token"] = token
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to read Cloudflare token: %s", e)
|
|
||||||
|
|
||||||
if not os.environ.get("CF_Token"):
|
|
||||||
return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini"
|
|
||||||
|
|
||||||
# Check if certificate already exists
|
|
||||||
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
|
|
||||||
if os.path.exists(cert_dir):
|
|
||||||
return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew."
|
|
||||||
|
|
||||||
# Build acme.sh command (without reload - we'll do it via Runtime API)
|
|
||||||
host_path, _ = get_pem_paths(domain)
|
|
||||||
|
|
||||||
# Create PEM after issuance
|
|
||||||
install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}"
|
|
||||||
|
|
||||||
cmd = [
|
|
||||||
ACME_SH, "--issue",
|
|
||||||
"--dns", "dns_cf",
|
|
||||||
"-d", domain
|
|
||||||
]
|
|
||||||
|
|
||||||
if wildcard:
|
|
||||||
cmd.extend(["-d", f"*.{domain}"])
|
|
||||||
|
|
||||||
cmd.extend(["--reloadcmd", install_cmd])
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("Issuing certificate for %s", domain)
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd,
|
|
||||||
capture_output=True, text=True, timeout=CERT_TIMEOUT,
|
|
||||||
env={**os.environ, "HOME": os.path.expanduser("~")}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
error_msg = result.stderr or result.stdout
|
|
||||||
return f"Error issuing certificate:\n{error_msg}"
|
|
||||||
|
|
||||||
# Load into HAProxy via Runtime API (zero-downtime)
|
|
||||||
if os.path.exists(host_path):
|
|
||||||
success, msg = load_cert_to_haproxy(domain)
|
|
||||||
if success:
|
|
||||||
# Save to config for persistence
|
|
||||||
add_cert_to_config(domain)
|
|
||||||
return f"Certificate issued and loaded for {domain} ({msg})"
|
|
||||||
else:
|
|
||||||
return f"Certificate issued but HAProxy loading failed: {msg}"
|
|
||||||
else:
|
|
||||||
return f"Certificate issued but PEM file not created. Check {host_path}"
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_renew_cert(
|
def haproxy_renew_cert(
|
||||||
@@ -336,46 +536,7 @@ def register_certificate_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_renew_cert("example.com", force=True)
|
Example: haproxy_renew_cert("example.com", force=True)
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_renew_cert_impl(domain, force)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
|
|
||||||
if not os.path.exists(cert_dir):
|
|
||||||
return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first."
|
|
||||||
|
|
||||||
cmd = [ACME_SH, "--renew", "-d", domain]
|
|
||||||
if force:
|
|
||||||
cmd.append("--force")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("Renewing certificate for %s", domain)
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd,
|
|
||||||
capture_output=True, text=True, timeout=CERT_TIMEOUT,
|
|
||||||
env={**os.environ, "HOME": os.path.expanduser("~")}
|
|
||||||
)
|
|
||||||
|
|
||||||
output = result.stdout + result.stderr
|
|
||||||
|
|
||||||
if "Skip" in output and "Not yet due" in output:
|
|
||||||
return f"Certificate for {domain} not due for renewal. Use force=True to force renewal."
|
|
||||||
|
|
||||||
if "Cert success" in output or result.returncode == 0:
|
|
||||||
# Reload into HAProxy via Runtime API
|
|
||||||
success, msg = load_cert_to_haproxy(domain)
|
|
||||||
if success:
|
|
||||||
# Ensure in config
|
|
||||||
add_cert_to_config(domain)
|
|
||||||
return f"Certificate renewed and reloaded for {domain} ({msg})"
|
|
||||||
else:
|
|
||||||
return f"Certificate renewed but HAProxy reload failed: {msg}"
|
|
||||||
else:
|
|
||||||
return f"Error renewing certificate:\n{output}"
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_renew_all_certs() -> str:
|
def haproxy_renew_all_certs() -> str:
|
||||||
@@ -383,40 +544,7 @@ def register_certificate_tools(mcp):
|
|||||||
|
|
||||||
This runs the acme.sh cron job to check and renew all certificates.
|
This runs the acme.sh cron job to check and renew all certificates.
|
||||||
"""
|
"""
|
||||||
try:
|
return _haproxy_renew_all_certs_impl()
|
||||||
logger.info("Running certificate renewal cron")
|
|
||||||
result = subprocess.run(
|
|
||||||
[ACME_SH, "--cron"],
|
|
||||||
capture_output=True, text=True, timeout=CERT_TIMEOUT * 3,
|
|
||||||
env={**os.environ, "HOME": os.path.expanduser("~")}
|
|
||||||
)
|
|
||||||
|
|
||||||
output = result.stdout + result.stderr
|
|
||||||
|
|
||||||
# Count renewals
|
|
||||||
renewed = output.count("Cert success")
|
|
||||||
skipped = output.count("Skip")
|
|
||||||
|
|
||||||
# Reload any renewed certs into HAProxy
|
|
||||||
if renewed > 0:
|
|
||||||
domains = load_certs_config()
|
|
||||||
reloaded = 0
|
|
||||||
for domain in domains:
|
|
||||||
success, _ = load_cert_to_haproxy(domain)
|
|
||||||
if success:
|
|
||||||
reloaded += 1
|
|
||||||
return f"Renewed {renewed} certificate(s), reloaded {reloaded} into HAProxy"
|
|
||||||
elif skipped > 0:
|
|
||||||
return f"No certificates due for renewal ({skipped} checked)"
|
|
||||||
elif result.returncode != 0:
|
|
||||||
return f"Error running renewal:\n{output}"
|
|
||||||
else:
|
|
||||||
return "Renewal check completed"
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
return "Error: Renewal cron timed out"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_delete_cert(
|
def haproxy_delete_cert(
|
||||||
@@ -428,58 +556,7 @@ def register_certificate_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_delete_cert("example.com")
|
Example: haproxy_delete_cert("example.com")
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_delete_cert_impl(domain)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc")
|
|
||||||
host_path, _ = get_pem_paths(domain)
|
|
||||||
|
|
||||||
if not os.path.exists(cert_dir) and not os.path.exists(host_path):
|
|
||||||
return f"Error: No certificate found for {domain}"
|
|
||||||
|
|
||||||
errors = []
|
|
||||||
deleted = []
|
|
||||||
|
|
||||||
# Unload from HAProxy first (zero-downtime)
|
|
||||||
success, msg = unload_cert_from_haproxy(domain)
|
|
||||||
if success:
|
|
||||||
deleted.append(f"HAProxy ({msg})")
|
|
||||||
else:
|
|
||||||
errors.append(f"HAProxy unload: {msg}")
|
|
||||||
|
|
||||||
# Remove from acme.sh
|
|
||||||
if os.path.exists(cert_dir):
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
[ACME_SH, "--remove", "-d", domain],
|
|
||||||
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT,
|
|
||||||
env={**os.environ, "HOME": os.path.expanduser("~")}
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
|
||||||
deleted.append("acme.sh")
|
|
||||||
else:
|
|
||||||
errors.append(f"acme.sh: {result.stderr}")
|
|
||||||
except Exception as e:
|
|
||||||
errors.append(f"acme.sh: {e}")
|
|
||||||
|
|
||||||
# Remove PEM file
|
|
||||||
if os.path.exists(host_path):
|
|
||||||
try:
|
|
||||||
os.remove(host_path)
|
|
||||||
deleted.append("PEM file")
|
|
||||||
except Exception as e:
|
|
||||||
errors.append(f"PEM file: {e}")
|
|
||||||
|
|
||||||
# Remove from config
|
|
||||||
remove_cert_from_config(domain)
|
|
||||||
|
|
||||||
result_parts = []
|
|
||||||
if deleted:
|
|
||||||
result_parts.append(f"Deleted: {', '.join(deleted)}")
|
|
||||||
if errors:
|
|
||||||
result_parts.append(f"Errors: {'; '.join(errors)}")
|
|
||||||
|
|
||||||
return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_load_cert(
|
def haproxy_load_cert(
|
||||||
@@ -491,16 +568,4 @@ def register_certificate_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_load_cert("example.com")
|
Example: haproxy_load_cert("example.com")
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_load_cert_impl(domain)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
host_path, _ = get_pem_paths(domain)
|
|
||||||
if not os.path.exists(host_path):
|
|
||||||
return f"Error: PEM file not found: {host_path}"
|
|
||||||
|
|
||||||
success, msg = load_cert_to_haproxy(domain)
|
|
||||||
if success:
|
|
||||||
add_cert_to_config(domain)
|
|
||||||
return f"Certificate {domain} loaded into HAProxy ({msg})"
|
|
||||||
else:
|
|
||||||
return f"Error loading certificate: {msg}"
|
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ def startup_restore() -> None:
|
|||||||
cert_count = restore_certificates()
|
cert_count = restore_certificates()
|
||||||
if cert_count > 0:
|
if cert_count > 0:
|
||||||
logger.info("Restored %d certificates from config", cert_count)
|
logger.info("Restored %d certificates from config", cert_count)
|
||||||
except Exception as e:
|
except (HaproxyError, IOError, OSError, ValueError) as e:
|
||||||
logger.warning("Failed to restore certificates: %s", e)
|
logger.warning("Failed to restore certificates: %s", e)
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +155,7 @@ def register_config_tools(mcp):
|
|||||||
try:
|
try:
|
||||||
restored = restore_servers_from_config()
|
restored = restore_servers_from_config()
|
||||||
return f"HAProxy configuration reloaded successfully ({restored} servers restored)"
|
return f"HAProxy configuration reloaded successfully ({restored} servers restored)"
|
||||||
except Exception as e:
|
except (HaproxyError, IOError, OSError, ValueError) as e:
|
||||||
logger.error("Failed to restore servers after reload: %s", e)
|
logger.error("Failed to restore servers after reload: %s", e)
|
||||||
return f"HAProxy reloaded but server restore failed: {e}"
|
return f"HAProxy reloaded but server restore failed: {e}"
|
||||||
|
|
||||||
|
|||||||
@@ -13,14 +13,12 @@ from ..config import (
|
|||||||
WILDCARDS_MAP_FILE_CONTAINER,
|
WILDCARDS_MAP_FILE_CONTAINER,
|
||||||
POOL_COUNT,
|
POOL_COUNT,
|
||||||
MAX_SLOTS,
|
MAX_SLOTS,
|
||||||
StateField,
|
|
||||||
STATE_MIN_COLUMNS,
|
|
||||||
SUBPROCESS_TIMEOUT,
|
SUBPROCESS_TIMEOUT,
|
||||||
CERTS_DIR,
|
CERTS_DIR,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from ..exceptions import HaproxyError
|
from ..exceptions import HaproxyError
|
||||||
from ..validation import validate_domain, validate_ip
|
from ..validation import validate_domain, validate_ip, validate_port_int
|
||||||
from ..haproxy_client import haproxy_cmd
|
from ..haproxy_client import haproxy_cmd
|
||||||
from ..file_ops import (
|
from ..file_ops import (
|
||||||
get_map_contents,
|
get_map_contents,
|
||||||
@@ -31,6 +29,7 @@ from ..file_ops import (
|
|||||||
remove_server_from_config,
|
remove_server_from_config,
|
||||||
remove_domain_from_config,
|
remove_domain_from_config,
|
||||||
)
|
)
|
||||||
|
from ..utils import parse_servers_state, disable_server_slot
|
||||||
|
|
||||||
|
|
||||||
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
|
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
|
||||||
@@ -166,18 +165,18 @@ def register_domain_tools(mcp):
|
|||||||
try:
|
try:
|
||||||
domains = []
|
domains = []
|
||||||
state = haproxy_cmd("show servers state")
|
state = haproxy_cmd("show servers state")
|
||||||
|
parsed_state = parse_servers_state(state)
|
||||||
|
|
||||||
# Build server map from HAProxy state
|
# Build server map from HAProxy state
|
||||||
server_map: dict[str, list] = {}
|
server_map: dict[str, list[str]] = {}
|
||||||
for line in state.split("\n"):
|
for backend, servers_dict in parsed_state.items():
|
||||||
parts = line.split()
|
for server_name, srv_info in servers_dict.items():
|
||||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0":
|
if srv_info["addr"] != "0.0.0.0":
|
||||||
backend = parts[StateField.BE_NAME]
|
if backend not in server_map:
|
||||||
if backend not in server_map:
|
server_map[backend] = []
|
||||||
server_map[backend] = []
|
server_map[backend].append(
|
||||||
server_map[backend].append(
|
f"{server_name}={srv_info['addr']}:{srv_info['port']}"
|
||||||
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Read from domains.map
|
# Read from domains.map
|
||||||
seen_domains: set[str] = set()
|
seen_domains: set[str] = set()
|
||||||
@@ -220,7 +219,7 @@ def register_domain_tools(mcp):
|
|||||||
return "Error: Invalid domain format"
|
return "Error: Invalid domain format"
|
||||||
if not validate_ip(ip, allow_empty=True):
|
if not validate_ip(ip, allow_empty=True):
|
||||||
return "Error: Invalid IP address format"
|
return "Error: Invalid IP address format"
|
||||||
if not (1 <= http_port <= 65535):
|
if not validate_port_int(http_port):
|
||||||
return "Error: Port must be between 1 and 65535"
|
return "Error: Port must be between 1 and 65535"
|
||||||
|
|
||||||
# Use file locking for the entire pool allocation operation
|
# Use file locking for the entire pool allocation operation
|
||||||
@@ -339,8 +338,7 @@ def register_domain_tools(mcp):
|
|||||||
for slot in range(1, MAX_SLOTS + 1):
|
for slot in range(1, MAX_SLOTS + 1):
|
||||||
server = f"{backend}_{slot}"
|
server = f"{backend}_{slot}"
|
||||||
try:
|
try:
|
||||||
haproxy_cmd(f"set server {backend}/{server} state maint")
|
disable_server_slot(backend, server)
|
||||||
haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0")
|
|
||||||
except HaproxyError as e:
|
except HaproxyError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to clear server %s/%s for domain %s: %s",
|
"Failed to clear server %s/%s for domain %s: %s",
|
||||||
|
|||||||
@@ -12,14 +12,12 @@ from ..config import (
|
|||||||
MAP_FILE,
|
MAP_FILE,
|
||||||
SERVERS_FILE,
|
SERVERS_FILE,
|
||||||
HAPROXY_CONTAINER,
|
HAPROXY_CONTAINER,
|
||||||
StateField,
|
|
||||||
STATE_MIN_COLUMNS,
|
|
||||||
)
|
)
|
||||||
from ..exceptions import HaproxyError
|
from ..exceptions import HaproxyError
|
||||||
from ..validation import validate_domain, validate_backend_name
|
from ..validation import validate_domain, validate_backend_name
|
||||||
from ..haproxy_client import haproxy_cmd
|
from ..haproxy_client import haproxy_cmd
|
||||||
from ..file_ops import get_backend_and_prefix
|
from ..file_ops import get_backend_and_prefix
|
||||||
from ..utils import parse_stat_csv
|
from ..utils import parse_stat_csv, parse_servers_state
|
||||||
|
|
||||||
|
|
||||||
def register_health_tools(mcp):
|
def register_health_tools(mcp):
|
||||||
@@ -83,7 +81,7 @@ def register_health_tools(mcp):
|
|||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
result["components"]["container"] = {"status": "timeout"}
|
result["components"]["container"] = {"status": "timeout"}
|
||||||
result["status"] = "unhealthy"
|
result["status"] = "unhealthy"
|
||||||
except Exception as e:
|
except (OSError, subprocess.SubprocessError) as e:
|
||||||
result["components"]["container"] = {"status": "error", "error": str(e)}
|
result["components"]["container"] = {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
# Check configuration files
|
# Check configuration files
|
||||||
@@ -144,34 +142,34 @@ def register_health_tools(mcp):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Parse server state for address info
|
# Parse server state for address info
|
||||||
for line in state_output.split("\n"):
|
parsed_state = parse_servers_state(state_output)
|
||||||
parts = line.split()
|
backend_servers = parsed_state.get(backend, {})
|
||||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
|
||||||
server_name = parts[StateField.SRV_NAME]
|
|
||||||
addr = parts[StateField.SRV_ADDR]
|
|
||||||
port = parts[StateField.SRV_PORT]
|
|
||||||
|
|
||||||
# Skip disabled servers (0.0.0.0)
|
for server_name, srv_info in backend_servers.items():
|
||||||
if addr == "0.0.0.0":
|
addr = srv_info["addr"]
|
||||||
continue
|
port = srv_info["port"]
|
||||||
|
|
||||||
server_info: dict[str, Any] = {
|
# Skip disabled servers (0.0.0.0)
|
||||||
"name": server_name,
|
if addr == "0.0.0.0":
|
||||||
"addr": f"{addr}:{port}",
|
continue
|
||||||
"status": "unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get status from stat output
|
server_info: dict[str, Any] = {
|
||||||
if server_name in status_map:
|
"name": server_name,
|
||||||
server_info["status"] = status_map[server_name]["status"]
|
"addr": f"{addr}:{port}",
|
||||||
server_info["check_status"] = status_map[server_name]["check_status"]
|
"status": "unknown"
|
||||||
server_info["weight"] = status_map[server_name]["weight"]
|
}
|
||||||
|
|
||||||
result["servers"].append(server_info)
|
# Get status from stat output
|
||||||
result["total_count"] += 1
|
if server_name in status_map:
|
||||||
|
server_info["status"] = status_map[server_name]["status"]
|
||||||
|
server_info["check_status"] = status_map[server_name]["check_status"]
|
||||||
|
server_info["weight"] = status_map[server_name]["weight"]
|
||||||
|
|
||||||
if server_info["status"] == "UP":
|
result["servers"].append(server_info)
|
||||||
result["healthy_count"] += 1
|
result["total_count"] += 1
|
||||||
|
|
||||||
|
if server_info["status"] == "UP":
|
||||||
|
result["healthy_count"] += 1
|
||||||
|
|
||||||
# Determine overall status
|
# Determine overall status
|
||||||
if result["total_count"] == 0:
|
if result["total_count"] == 0:
|
||||||
|
|||||||
@@ -10,13 +10,11 @@ from ..config import (
|
|||||||
MAX_SLOTS,
|
MAX_SLOTS,
|
||||||
MAX_BULK_SERVERS,
|
MAX_BULK_SERVERS,
|
||||||
MAX_SERVERS_JSON_SIZE,
|
MAX_SERVERS_JSON_SIZE,
|
||||||
StateField,
|
|
||||||
StatField,
|
StatField,
|
||||||
STATE_MIN_COLUMNS,
|
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from ..exceptions import HaproxyError
|
from ..exceptions import HaproxyError
|
||||||
from ..validation import validate_domain, validate_ip, validate_backend_name
|
from ..validation import validate_domain, validate_ip, validate_backend_name, validate_port_int
|
||||||
from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch
|
from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch
|
||||||
from ..file_ops import (
|
from ..file_ops import (
|
||||||
get_backend_and_prefix,
|
get_backend_and_prefix,
|
||||||
@@ -24,6 +22,7 @@ from ..file_ops import (
|
|||||||
add_server_to_config,
|
add_server_to_config,
|
||||||
remove_server_from_config,
|
remove_server_from_config,
|
||||||
)
|
)
|
||||||
|
from ..utils import parse_servers_state, disable_server_slot
|
||||||
|
|
||||||
|
|
||||||
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
|
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
|
||||||
@@ -51,6 +50,390 @@ def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str,
|
|||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Implementation functions (module-level)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_list_servers_impl(domain: str) -> str:
|
||||||
|
"""Implementation of haproxy_list_servers."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend, _ = get_backend_and_prefix(domain)
|
||||||
|
state = haproxy_cmd("show servers state")
|
||||||
|
parsed_state = parse_servers_state(state)
|
||||||
|
backend_servers = parsed_state.get(backend, {})
|
||||||
|
|
||||||
|
if not backend_servers:
|
||||||
|
return f"Backend {backend} not found"
|
||||||
|
|
||||||
|
servers = []
|
||||||
|
for server_name, srv_info in backend_servers.items():
|
||||||
|
addr = srv_info["addr"]
|
||||||
|
port = srv_info["port"]
|
||||||
|
status = "active" if addr != "0.0.0.0" else "disabled"
|
||||||
|
servers.append(f"• {server_name}: {addr}:{port} ({status})")
|
||||||
|
|
||||||
|
return f"Servers for {domain} ({backend}):\n" + "\n".join(servers)
|
||||||
|
except (HaproxyError, ValueError) as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_add_server_impl(domain: str, slot: int, ip: str, http_port: int) -> str:
|
||||||
|
"""Implementation of haproxy_add_server."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
if not ip:
|
||||||
|
return "Error: IP address is required"
|
||||||
|
if not validate_ip(ip):
|
||||||
|
return "Error: Invalid IP address format"
|
||||||
|
if not validate_port_int(http_port):
|
||||||
|
return "Error: Port must be between 1 and 65535"
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend, server_prefix = get_backend_and_prefix(domain)
|
||||||
|
|
||||||
|
# Auto-select slot if slot <= 0
|
||||||
|
if slot <= 0:
|
||||||
|
state = haproxy_cmd("show servers state")
|
||||||
|
parsed_state = parse_servers_state(state)
|
||||||
|
backend_servers = parsed_state.get(backend, {})
|
||||||
|
used_slots: set[int] = set()
|
||||||
|
for server_name, srv_info in backend_servers.items():
|
||||||
|
if srv_info["addr"] != "0.0.0.0":
|
||||||
|
# Extract slot number from server name (e.g., pool_1_3 -> 3)
|
||||||
|
try:
|
||||||
|
used_slots.add(int(server_name.rsplit("_", 1)[1]))
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
for s in range(1, MAX_SLOTS + 1):
|
||||||
|
if s not in used_slots:
|
||||||
|
slot = s
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return f"Error: No available slots (all {MAX_SLOTS} slots in use)"
|
||||||
|
elif not (1 <= slot <= MAX_SLOTS):
|
||||||
|
return f"Error: Slot must be between 1 and {MAX_SLOTS}, or 0/-1 for auto-select"
|
||||||
|
|
||||||
|
# Save to persistent config FIRST (disk-first pattern)
|
||||||
|
add_server_to_config(domain, slot, ip, http_port)
|
||||||
|
|
||||||
|
try:
|
||||||
|
server = configure_server_slot(backend, server_prefix, slot, ip, http_port)
|
||||||
|
return f"Added to {domain} ({backend}) slot {slot}:\n{server} → {ip}:{http_port}"
|
||||||
|
except HaproxyError as e:
|
||||||
|
# Rollback config on HAProxy failure
|
||||||
|
remove_server_from_config(domain, slot)
|
||||||
|
return f"Error: {e}"
|
||||||
|
except (ValueError, IOError) as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_add_servers_impl(domain: str, servers: str) -> str:
|
||||||
|
"""Implementation of haproxy_add_servers."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
# Check JSON size before parsing
|
||||||
|
if len(servers) > MAX_SERVERS_JSON_SIZE:
|
||||||
|
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
|
||||||
|
|
||||||
|
# Parse JSON array
|
||||||
|
try:
|
||||||
|
server_list = json.loads(servers)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return f"Error: Invalid JSON - {e}"
|
||||||
|
|
||||||
|
if not isinstance(server_list, list):
|
||||||
|
return "Error: servers must be a JSON array"
|
||||||
|
|
||||||
|
if not server_list:
|
||||||
|
return "Error: servers array is empty"
|
||||||
|
|
||||||
|
if len(server_list) > MAX_BULK_SERVERS:
|
||||||
|
return f"Error: Cannot add more than {MAX_BULK_SERVERS} servers at once"
|
||||||
|
|
||||||
|
# Validate all servers first before adding any
|
||||||
|
validated_servers = []
|
||||||
|
validation_errors = []
|
||||||
|
|
||||||
|
for i, srv in enumerate(server_list):
|
||||||
|
if not isinstance(srv, dict):
|
||||||
|
validation_errors.append(f"Server {i+1}: must be an object")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract and validate slot
|
||||||
|
slot = srv.get("slot")
|
||||||
|
if slot is None:
|
||||||
|
validation_errors.append(f"Server {i+1}: missing 'slot' field")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
slot = int(slot)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
validation_errors.append(f"Server {i+1}: slot must be an integer")
|
||||||
|
continue
|
||||||
|
if not (1 <= slot <= MAX_SLOTS):
|
||||||
|
validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract and validate IP
|
||||||
|
ip = srv.get("ip")
|
||||||
|
if not ip:
|
||||||
|
validation_errors.append(f"Server {i+1}: missing 'ip' field")
|
||||||
|
continue
|
||||||
|
if not validate_ip(ip):
|
||||||
|
validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract and validate port
|
||||||
|
http_port = srv.get("http_port", 80)
|
||||||
|
try:
|
||||||
|
http_port = int(http_port)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
validation_errors.append(f"Server {i+1}: http_port must be an integer")
|
||||||
|
continue
|
||||||
|
if not validate_port_int(http_port):
|
||||||
|
validation_errors.append(f"Server {i+1}: port must be between 1 and 65535")
|
||||||
|
continue
|
||||||
|
|
||||||
|
validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port})
|
||||||
|
|
||||||
|
# Return validation errors if any
|
||||||
|
if validation_errors:
|
||||||
|
return "Validation errors:\n" + "\n".join(f" • {e}" for e in validation_errors)
|
||||||
|
|
||||||
|
# Check for duplicate slots
|
||||||
|
slots = [s["slot"] for s in validated_servers]
|
||||||
|
if len(slots) != len(set(slots)):
|
||||||
|
return "Error: Duplicate slot numbers in servers array"
|
||||||
|
|
||||||
|
# Get backend info
|
||||||
|
try:
|
||||||
|
backend, server_prefix = get_backend_and_prefix(domain)
|
||||||
|
except ValueError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
# Save ALL servers to config FIRST (disk-first pattern)
|
||||||
|
for server_config in validated_servers:
|
||||||
|
slot = server_config["slot"]
|
||||||
|
ip = server_config["ip"]
|
||||||
|
http_port = server_config["http_port"]
|
||||||
|
add_server_to_config(domain, slot, ip, http_port)
|
||||||
|
|
||||||
|
# Then update HAProxy
|
||||||
|
added = []
|
||||||
|
errors = []
|
||||||
|
failed_slots = []
|
||||||
|
successfully_added_slots = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for server_config in validated_servers:
|
||||||
|
slot = server_config["slot"]
|
||||||
|
ip = server_config["ip"]
|
||||||
|
http_port = server_config["http_port"]
|
||||||
|
try:
|
||||||
|
configure_server_slot(backend, server_prefix, slot, ip, http_port)
|
||||||
|
successfully_added_slots.append(slot)
|
||||||
|
added.append(f"slot {slot}: {ip}:{http_port}")
|
||||||
|
except HaproxyError as e:
|
||||||
|
failed_slots.append(slot)
|
||||||
|
errors.append(f"slot {slot}: {e}")
|
||||||
|
except (IOError, OSError, ValueError) as e:
|
||||||
|
# Rollback only successfully added configs on unexpected error
|
||||||
|
logger.error("Unexpected error during bulk server add for %s: %s", domain, e)
|
||||||
|
for slot in successfully_added_slots:
|
||||||
|
try:
|
||||||
|
remove_server_from_config(domain, slot)
|
||||||
|
except (IOError, OSError, ValueError) as rollback_error:
|
||||||
|
logger.error(
|
||||||
|
"Failed to rollback server config for %s slot %d: %s",
|
||||||
|
domain, slot, rollback_error
|
||||||
|
)
|
||||||
|
# Also rollback configs that weren't yet processed
|
||||||
|
for server_config in validated_servers:
|
||||||
|
slot = server_config["slot"]
|
||||||
|
if slot not in successfully_added_slots:
|
||||||
|
try:
|
||||||
|
remove_server_from_config(domain, slot)
|
||||||
|
except (IOError, OSError, ValueError) as rollback_error:
|
||||||
|
logger.error(
|
||||||
|
"Failed to rollback server config for %s slot %d: %s",
|
||||||
|
domain, slot, rollback_error
|
||||||
|
)
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
# Rollback failed slots from config
|
||||||
|
for slot in failed_slots:
|
||||||
|
try:
|
||||||
|
remove_server_from_config(domain, slot)
|
||||||
|
except (IOError, OSError, ValueError) as rollback_error:
|
||||||
|
logger.error(
|
||||||
|
"Failed to rollback server config for %s slot %d: %s",
|
||||||
|
domain, slot, rollback_error
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build result message
|
||||||
|
result_parts = []
|
||||||
|
if added:
|
||||||
|
result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):")
|
||||||
|
result_parts.extend(f" • {s}" for s in added)
|
||||||
|
if errors:
|
||||||
|
result_parts.append(f"Failed to add {len(errors)} servers:")
|
||||||
|
result_parts.extend(f" • {e}" for e in errors)
|
||||||
|
|
||||||
|
return "\n".join(result_parts) if result_parts else "No servers added"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_remove_server_impl(domain: str, slot: int) -> str:
|
||||||
|
"""Implementation of haproxy_remove_server."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
if not (1 <= slot <= MAX_SLOTS):
|
||||||
|
return f"Error: Slot must be between 1 and {MAX_SLOTS}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend, server_prefix = get_backend_and_prefix(domain)
|
||||||
|
|
||||||
|
# Get current server info for potential rollback
|
||||||
|
config = load_servers_config()
|
||||||
|
old_config = config.get(domain, {}).get(str(slot), {})
|
||||||
|
|
||||||
|
# Remove from persistent config FIRST (disk-first pattern)
|
||||||
|
remove_server_from_config(domain, slot)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# HTTP only - single server per slot
|
||||||
|
server = f"{server_prefix}_{slot}"
|
||||||
|
disable_server_slot(backend, server)
|
||||||
|
return f"Removed server at slot {slot} from {domain} ({backend})"
|
||||||
|
except HaproxyError as e:
|
||||||
|
# Rollback: re-add config if HAProxy command failed
|
||||||
|
if old_config:
|
||||||
|
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
|
||||||
|
return f"Error: {e}"
|
||||||
|
except (ValueError, IOError) as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_set_domain_state_impl(domain: str, state: str) -> str:
|
||||||
|
"""Implementation of haproxy_set_domain_state."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
if state not in ["ready", "drain", "maint"]:
|
||||||
|
return "Error: State must be 'ready', 'drain', or 'maint'"
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend, _ = get_backend_and_prefix(domain)
|
||||||
|
except ValueError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
# Get active servers for this domain
|
||||||
|
try:
|
||||||
|
servers_state = haproxy_cmd("show servers state")
|
||||||
|
except HaproxyError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
parsed_state = parse_servers_state(servers_state)
|
||||||
|
backend_servers = parsed_state.get(backend, {})
|
||||||
|
|
||||||
|
changed = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for server_name, srv_info in backend_servers.items():
|
||||||
|
# Only change state for configured servers (not 0.0.0.0)
|
||||||
|
if srv_info["addr"] != "0.0.0.0":
|
||||||
|
try:
|
||||||
|
haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}")
|
||||||
|
changed.append(server_name)
|
||||||
|
except HaproxyError as e:
|
||||||
|
errors.append(f"{server_name}: {e}")
|
||||||
|
|
||||||
|
if not changed and not errors:
|
||||||
|
return f"No active servers found for {domain}"
|
||||||
|
|
||||||
|
result = f"Set {len(changed)} servers to '{state}' for {domain}"
|
||||||
|
if changed:
|
||||||
|
result += ":\n" + "\n".join(f" • {s}" for s in changed)
|
||||||
|
if errors:
|
||||||
|
result += f"\n\nErrors ({len(errors)}):\n" + "\n".join(f" • {e}" for e in errors)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_wait_drain_impl(domain: str, timeout: int) -> str:
|
||||||
|
"""Implementation of haproxy_wait_drain."""
|
||||||
|
if not validate_domain(domain):
|
||||||
|
return "Error: Invalid domain format"
|
||||||
|
if not (1 <= timeout <= 300):
|
||||||
|
return "Error: Timeout must be between 1 and 300 seconds"
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend, _ = get_backend_and_prefix(domain)
|
||||||
|
except ValueError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
stats = haproxy_cmd("show stat")
|
||||||
|
total_connections = 0
|
||||||
|
for line in stats.split("\n"):
|
||||||
|
parts = line.split(",")
|
||||||
|
if len(parts) > StatField.SCUR and parts[0] == backend and parts[1] not in ["FRONTEND", "BACKEND", ""]:
|
||||||
|
try:
|
||||||
|
scur = int(parts[StatField.SCUR]) if parts[StatField.SCUR] else 0
|
||||||
|
total_connections += scur
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if total_connections == 0:
|
||||||
|
elapsed = int(time.time() - start_time)
|
||||||
|
return f"All connections drained for {domain} ({elapsed}s)"
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
except HaproxyError as e:
|
||||||
|
return f"Error checking connections: {e}"
|
||||||
|
|
||||||
|
return f"Timeout: Connections still active after {timeout}s"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_set_server_state_impl(backend: str, server: str, state: str) -> str:
|
||||||
|
"""Implementation of haproxy_set_server_state."""
|
||||||
|
if not validate_backend_name(backend):
|
||||||
|
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
|
||||||
|
if not validate_backend_name(server):
|
||||||
|
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
|
||||||
|
if state not in ["ready", "drain", "maint"]:
|
||||||
|
return "Error: state must be 'ready', 'drain', or 'maint'"
|
||||||
|
try:
|
||||||
|
haproxy_cmd_checked(f"set server {backend}/{server} state {state}")
|
||||||
|
return f"Server {backend}/{server} set to {state}"
|
||||||
|
except HaproxyError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _haproxy_set_server_weight_impl(backend: str, server: str, weight: int) -> str:
|
||||||
|
"""Implementation of haproxy_set_server_weight."""
|
||||||
|
if not validate_backend_name(backend):
|
||||||
|
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
|
||||||
|
if not validate_backend_name(server):
|
||||||
|
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
|
||||||
|
if not (0 <= weight <= 256):
|
||||||
|
return "Error: weight must be between 0 and 256"
|
||||||
|
try:
|
||||||
|
haproxy_cmd_checked(f"set server {backend}/{server} weight {weight}")
|
||||||
|
return f"Server {backend}/{server} weight set to {weight}"
|
||||||
|
except HaproxyError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MCP Tool Registration
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def register_server_tools(mcp):
|
def register_server_tools(mcp):
|
||||||
"""Register server management tools with MCP server."""
|
"""Register server management tools with MCP server."""
|
||||||
|
|
||||||
@@ -65,29 +448,7 @@ def register_server_tools(mcp):
|
|||||||
# Output: pool_1_1: 10.0.0.1:8080 (UP)
|
# Output: pool_1_1: 10.0.0.1:8080 (UP)
|
||||||
# pool_1_2: 10.0.0.2:8080 (UP)
|
# pool_1_2: 10.0.0.2:8080 (UP)
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_list_servers_impl(domain)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
try:
|
|
||||||
backend, _ = get_backend_and_prefix(domain)
|
|
||||||
servers = []
|
|
||||||
state = haproxy_cmd("show servers state")
|
|
||||||
|
|
||||||
for line in state.split("\n"):
|
|
||||||
parts = line.split()
|
|
||||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
|
||||||
addr = parts[StateField.SRV_ADDR]
|
|
||||||
status = "active" if addr != "0.0.0.0" else "disabled"
|
|
||||||
servers.append(
|
|
||||||
f"• {parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not servers:
|
|
||||||
return f"Backend {backend} not found"
|
|
||||||
|
|
||||||
return f"Servers for {domain} ({backend}):\n" + "\n".join(servers)
|
|
||||||
except (HaproxyError, ValueError) as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_add_server(
|
def haproxy_add_server(
|
||||||
@@ -103,53 +464,7 @@ def register_server_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
|
Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080)
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_add_server_impl(domain, slot, ip, http_port)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
if not ip:
|
|
||||||
return "Error: IP address is required"
|
|
||||||
if not validate_ip(ip):
|
|
||||||
return "Error: Invalid IP address format"
|
|
||||||
if not (1 <= http_port <= 65535):
|
|
||||||
return "Error: Port must be between 1 and 65535"
|
|
||||||
|
|
||||||
try:
|
|
||||||
backend, server_prefix = get_backend_and_prefix(domain)
|
|
||||||
|
|
||||||
# Auto-select slot if slot <= 0
|
|
||||||
if slot <= 0:
|
|
||||||
state = haproxy_cmd("show servers state")
|
|
||||||
used_slots: set[int] = set()
|
|
||||||
for line in state.split("\n"):
|
|
||||||
parts = line.split()
|
|
||||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
|
||||||
if parts[StateField.SRV_ADDR] != "0.0.0.0":
|
|
||||||
# Extract slot number from server name (e.g., pool_1_3 -> 3)
|
|
||||||
server_name = parts[StateField.SRV_NAME]
|
|
||||||
try:
|
|
||||||
used_slots.add(int(server_name.rsplit("_", 1)[1]))
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
for s in range(1, MAX_SLOTS + 1):
|
|
||||||
if s not in used_slots:
|
|
||||||
slot = s
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return f"Error: No available slots (all {MAX_SLOTS} slots in use)"
|
|
||||||
elif not (1 <= slot <= MAX_SLOTS):
|
|
||||||
return f"Error: Slot must be between 1 and {MAX_SLOTS}, or 0/-1 for auto-select"
|
|
||||||
|
|
||||||
# Save to persistent config FIRST (disk-first pattern)
|
|
||||||
add_server_to_config(domain, slot, ip, http_port)
|
|
||||||
|
|
||||||
try:
|
|
||||||
server = configure_server_slot(backend, server_prefix, slot, ip, http_port)
|
|
||||||
return f"Added to {domain} ({backend}) slot {slot}:\n{server} → {ip}:{http_port}"
|
|
||||||
except HaproxyError as e:
|
|
||||||
# Rollback config on HAProxy failure
|
|
||||||
remove_server_from_config(domain, slot)
|
|
||||||
return f"Error: {e}"
|
|
||||||
except (ValueError, IOError) as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_add_servers(
|
def haproxy_add_servers(
|
||||||
@@ -160,156 +475,7 @@ def register_server_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]')
|
Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]')
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_add_servers_impl(domain, servers)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
|
|
||||||
# Check JSON size before parsing
|
|
||||||
if len(servers) > MAX_SERVERS_JSON_SIZE:
|
|
||||||
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
|
|
||||||
|
|
||||||
# Parse JSON array
|
|
||||||
try:
|
|
||||||
server_list = json.loads(servers)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
return f"Error: Invalid JSON - {e}"
|
|
||||||
|
|
||||||
if not isinstance(server_list, list):
|
|
||||||
return "Error: servers must be a JSON array"
|
|
||||||
|
|
||||||
if not server_list:
|
|
||||||
return "Error: servers array is empty"
|
|
||||||
|
|
||||||
if len(server_list) > MAX_BULK_SERVERS:
|
|
||||||
return f"Error: Cannot add more than {MAX_BULK_SERVERS} servers at once"
|
|
||||||
|
|
||||||
# Validate all servers first before adding any
|
|
||||||
validated_servers = []
|
|
||||||
validation_errors = []
|
|
||||||
|
|
||||||
for i, srv in enumerate(server_list):
|
|
||||||
if not isinstance(srv, dict):
|
|
||||||
validation_errors.append(f"Server {i+1}: must be an object")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract and validate slot
|
|
||||||
slot = srv.get("slot")
|
|
||||||
if slot is None:
|
|
||||||
validation_errors.append(f"Server {i+1}: missing 'slot' field")
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
slot = int(slot)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
validation_errors.append(f"Server {i+1}: slot must be an integer")
|
|
||||||
continue
|
|
||||||
if not (1 <= slot <= MAX_SLOTS):
|
|
||||||
validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract and validate IP
|
|
||||||
ip = srv.get("ip")
|
|
||||||
if not ip:
|
|
||||||
validation_errors.append(f"Server {i+1}: missing 'ip' field")
|
|
||||||
continue
|
|
||||||
if not validate_ip(ip):
|
|
||||||
validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract and validate port
|
|
||||||
http_port = srv.get("http_port", 80)
|
|
||||||
try:
|
|
||||||
http_port = int(http_port)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
validation_errors.append(f"Server {i+1}: http_port must be an integer")
|
|
||||||
continue
|
|
||||||
if not (1 <= http_port <= 65535):
|
|
||||||
validation_errors.append(f"Server {i+1}: port must be between 1 and 65535")
|
|
||||||
continue
|
|
||||||
|
|
||||||
validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port})
|
|
||||||
|
|
||||||
# Return validation errors if any
|
|
||||||
if validation_errors:
|
|
||||||
return "Validation errors:\n" + "\n".join(f" • {e}" for e in validation_errors)
|
|
||||||
|
|
||||||
# Check for duplicate slots
|
|
||||||
slots = [s["slot"] for s in validated_servers]
|
|
||||||
if len(slots) != len(set(slots)):
|
|
||||||
return "Error: Duplicate slot numbers in servers array"
|
|
||||||
|
|
||||||
# Get backend info
|
|
||||||
try:
|
|
||||||
backend, server_prefix = get_backend_and_prefix(domain)
|
|
||||||
except ValueError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
# Save ALL servers to config FIRST (disk-first pattern)
|
|
||||||
for server_config in validated_servers:
|
|
||||||
slot = server_config["slot"]
|
|
||||||
ip = server_config["ip"]
|
|
||||||
http_port = server_config["http_port"]
|
|
||||||
add_server_to_config(domain, slot, ip, http_port)
|
|
||||||
|
|
||||||
# Then update HAProxy
|
|
||||||
added = []
|
|
||||||
errors = []
|
|
||||||
failed_slots = []
|
|
||||||
successfully_added_slots = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
for server_config in validated_servers:
|
|
||||||
slot = server_config["slot"]
|
|
||||||
ip = server_config["ip"]
|
|
||||||
http_port = server_config["http_port"]
|
|
||||||
try:
|
|
||||||
configure_server_slot(backend, server_prefix, slot, ip, http_port)
|
|
||||||
successfully_added_slots.append(slot)
|
|
||||||
added.append(f"slot {slot}: {ip}:{http_port}")
|
|
||||||
except HaproxyError as e:
|
|
||||||
failed_slots.append(slot)
|
|
||||||
errors.append(f"slot {slot}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
# Rollback only successfully added configs on unexpected error
|
|
||||||
for slot in successfully_added_slots:
|
|
||||||
try:
|
|
||||||
remove_server_from_config(domain, slot)
|
|
||||||
except Exception as rollback_error:
|
|
||||||
logger.error(
|
|
||||||
"Failed to rollback server config for %s slot %d: %s",
|
|
||||||
domain, slot, rollback_error
|
|
||||||
)
|
|
||||||
# Also rollback configs that weren't yet processed
|
|
||||||
for server_config in validated_servers:
|
|
||||||
slot = server_config["slot"]
|
|
||||||
if slot not in successfully_added_slots:
|
|
||||||
try:
|
|
||||||
remove_server_from_config(domain, slot)
|
|
||||||
except Exception as rollback_error:
|
|
||||||
logger.error(
|
|
||||||
"Failed to rollback server config for %s slot %d: %s",
|
|
||||||
domain, slot, rollback_error
|
|
||||||
)
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
# Rollback failed slots from config
|
|
||||||
for slot in failed_slots:
|
|
||||||
try:
|
|
||||||
remove_server_from_config(domain, slot)
|
|
||||||
except Exception as rollback_error:
|
|
||||||
logger.error(
|
|
||||||
"Failed to rollback server config for %s slot %d: %s",
|
|
||||||
domain, slot, rollback_error
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build result message
|
|
||||||
result_parts = []
|
|
||||||
if added:
|
|
||||||
result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):")
|
|
||||||
result_parts.extend(f" • {s}" for s in added)
|
|
||||||
if errors:
|
|
||||||
result_parts.append(f"Failed to add {len(errors)} servers:")
|
|
||||||
result_parts.extend(f" • {e}" for e in errors)
|
|
||||||
|
|
||||||
return "\n".join(result_parts) if result_parts else "No servers added"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_remove_server(
|
def haproxy_remove_server(
|
||||||
@@ -320,37 +486,7 @@ def register_server_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_remove_server("api.example.com", slot=2)
|
Example: haproxy_remove_server("api.example.com", slot=2)
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_remove_server_impl(domain, slot)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
if not (1 <= slot <= MAX_SLOTS):
|
|
||||||
return f"Error: Slot must be between 1 and {MAX_SLOTS}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
backend, server_prefix = get_backend_and_prefix(domain)
|
|
||||||
|
|
||||||
# Get current server info for potential rollback
|
|
||||||
config = load_servers_config()
|
|
||||||
old_config = config.get(domain, {}).get(str(slot), {})
|
|
||||||
|
|
||||||
# Remove from persistent config FIRST (disk-first pattern)
|
|
||||||
remove_server_from_config(domain, slot)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# HTTP only - single server per slot
|
|
||||||
server = f"{server_prefix}_{slot}"
|
|
||||||
# Batch both commands in single TCP connection
|
|
||||||
haproxy_cmd_batch([
|
|
||||||
f"set server {backend}/{server} state maint",
|
|
||||||
f"set server {backend}/{server} addr 0.0.0.0 port 0"
|
|
||||||
])
|
|
||||||
return f"Removed server at slot {slot} from {domain} ({backend})"
|
|
||||||
except HaproxyError as e:
|
|
||||||
# Rollback: re-add config if HAProxy command failed
|
|
||||||
if old_config:
|
|
||||||
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
|
|
||||||
return f"Error: {e}"
|
|
||||||
except (ValueError, IOError) as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_set_domain_state(
|
def haproxy_set_domain_state(
|
||||||
@@ -368,48 +504,7 @@ def register_server_tools(mcp):
|
|||||||
# Re-enable all servers after deployment
|
# Re-enable all servers after deployment
|
||||||
haproxy_set_domain_state("api.example.com", "ready")
|
haproxy_set_domain_state("api.example.com", "ready")
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_set_domain_state_impl(domain, state)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
if state not in ["ready", "drain", "maint"]:
|
|
||||||
return "Error: State must be 'ready', 'drain', or 'maint'"
|
|
||||||
|
|
||||||
try:
|
|
||||||
backend, _ = get_backend_and_prefix(domain)
|
|
||||||
except ValueError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
# Get active servers for this domain
|
|
||||||
try:
|
|
||||||
servers_state = haproxy_cmd("show servers state")
|
|
||||||
except HaproxyError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
changed = []
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
for line in servers_state.split("\n"):
|
|
||||||
parts = line.split()
|
|
||||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
|
||||||
server_name = parts[StateField.SRV_NAME]
|
|
||||||
addr = parts[StateField.SRV_ADDR]
|
|
||||||
# Only change state for configured servers (not 0.0.0.0)
|
|
||||||
if addr != "0.0.0.0":
|
|
||||||
try:
|
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}")
|
|
||||||
changed.append(server_name)
|
|
||||||
except HaproxyError as e:
|
|
||||||
errors.append(f"{server_name}: {e}")
|
|
||||||
|
|
||||||
if not changed and not errors:
|
|
||||||
return f"No active servers found for {domain}"
|
|
||||||
|
|
||||||
result = f"Set {len(changed)} servers to '{state}' for {domain}"
|
|
||||||
if changed:
|
|
||||||
result += ":\n" + "\n".join(f" • {s}" for s in changed)
|
|
||||||
if errors:
|
|
||||||
result += f"\n\nErrors ({len(errors)}):\n" + "\n".join(f" • {e}" for e in errors)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_wait_drain(
|
def haproxy_wait_drain(
|
||||||
@@ -422,39 +517,7 @@ def register_server_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_wait_drain("api.example.com", timeout=60)
|
Example: haproxy_wait_drain("api.example.com", timeout=60)
|
||||||
"""
|
"""
|
||||||
if not validate_domain(domain):
|
return _haproxy_wait_drain_impl(domain, timeout)
|
||||||
return "Error: Invalid domain format"
|
|
||||||
if not (1 <= timeout <= 300):
|
|
||||||
return "Error: Timeout must be between 1 and 300 seconds"
|
|
||||||
|
|
||||||
try:
|
|
||||||
backend, _ = get_backend_and_prefix(domain)
|
|
||||||
except ValueError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
while time.time() - start_time < timeout:
|
|
||||||
try:
|
|
||||||
stats = haproxy_cmd("show stat")
|
|
||||||
total_connections = 0
|
|
||||||
for line in stats.split("\n"):
|
|
||||||
parts = line.split(",")
|
|
||||||
if len(parts) > StatField.SCUR and parts[0] == backend and parts[1] not in ["FRONTEND", "BACKEND", ""]:
|
|
||||||
try:
|
|
||||||
scur = int(parts[StatField.SCUR]) if parts[StatField.SCUR] else 0
|
|
||||||
total_connections += scur
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if total_connections == 0:
|
|
||||||
elapsed = int(time.time() - start_time)
|
|
||||||
return f"All connections drained for {domain} ({elapsed}s)"
|
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
except HaproxyError as e:
|
|
||||||
return f"Error checking connections: {e}"
|
|
||||||
|
|
||||||
return f"Timeout: Connections still active after {timeout}s"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_set_server_state(
|
def haproxy_set_server_state(
|
||||||
@@ -466,17 +529,7 @@ def register_server_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint")
|
Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint")
|
||||||
"""
|
"""
|
||||||
if not validate_backend_name(backend):
|
return _haproxy_set_server_state_impl(backend, server, state)
|
||||||
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
|
|
||||||
if not validate_backend_name(server):
|
|
||||||
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
|
|
||||||
if state not in ["ready", "drain", "maint"]:
|
|
||||||
return "Error: state must be 'ready', 'drain', or 'maint'"
|
|
||||||
try:
|
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server} state {state}")
|
|
||||||
return f"Server {backend}/{server} set to {state}"
|
|
||||||
except HaproxyError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_set_server_weight(
|
def haproxy_set_server_weight(
|
||||||
@@ -488,14 +541,4 @@ def register_server_tools(mcp):
|
|||||||
|
|
||||||
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2)
|
Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2)
|
||||||
"""
|
"""
|
||||||
if not validate_backend_name(backend):
|
return _haproxy_set_server_weight_impl(backend, server, weight)
|
||||||
return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)"
|
|
||||||
if not validate_backend_name(server):
|
|
||||||
return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)"
|
|
||||||
if not (0 <= weight <= 256):
|
|
||||||
return "Error: weight must be between 0 and 256"
|
|
||||||
try:
|
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server} weight {weight}")
|
|
||||||
return f"Server {backend}/{server} weight set to {weight}"
|
|
||||||
except HaproxyError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
|
|||||||
@@ -1,7 +1,58 @@
|
|||||||
"""Utility functions for HAProxy MCP Server."""
|
"""Utility functions for HAProxy MCP Server."""
|
||||||
|
|
||||||
from typing import Dict, Generator
|
from typing import Dict, Generator
|
||||||
from .config import StatField
|
from .config import StatField, StateField, STATE_MIN_COLUMNS
|
||||||
|
from .haproxy_client import haproxy_cmd_batch
|
||||||
|
|
||||||
|
|
||||||
|
def parse_servers_state(state_output: str) -> dict[str, dict[str, dict[str, str]]]:
|
||||||
|
"""Parse 'show servers state' output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_output: Raw output from HAProxy 'show servers state' command
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Nested dict: {backend: {server: {addr: str, port: str, state: str}}}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
state = haproxy_cmd("show servers state")
|
||||||
|
parsed = parse_servers_state(state)
|
||||||
|
# parsed["pool_1"]["pool_1_1"] == {"addr": "10.0.0.1", "port": "8080", "state": "2"}
|
||||||
|
"""
|
||||||
|
result: dict[str, dict[str, dict[str, str]]] = {}
|
||||||
|
for line in state_output.split("\n"):
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= STATE_MIN_COLUMNS:
|
||||||
|
backend = parts[StateField.BE_NAME]
|
||||||
|
server = parts[StateField.SRV_NAME]
|
||||||
|
addr = parts[StateField.SRV_ADDR]
|
||||||
|
port = parts[StateField.SRV_PORT]
|
||||||
|
state = parts[StateField.SRV_OP_STATE] if len(parts) > StateField.SRV_OP_STATE else ""
|
||||||
|
|
||||||
|
if backend not in result:
|
||||||
|
result[backend] = {}
|
||||||
|
result[backend][server] = {
|
||||||
|
"addr": addr,
|
||||||
|
"port": port,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def disable_server_slot(backend: str, server: str) -> None:
|
||||||
|
"""Disable a server slot (set to maint and clear address).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Backend name (e.g., 'pool_1')
|
||||||
|
server: Server name (e.g., 'pool_1_1')
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HaproxyError: If HAProxy command fails
|
||||||
|
"""
|
||||||
|
haproxy_cmd_batch([
|
||||||
|
f"set server {backend}/{server} state maint",
|
||||||
|
f"set server {backend}/{server} addr 0.0.0.0 port 0"
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]:
|
def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]:
|
||||||
|
|||||||
@@ -52,6 +52,18 @@ def validate_port(port: str) -> bool:
|
|||||||
return 1 <= port_num <= 65535
|
return 1 <= port_num <= 65535
|
||||||
|
|
||||||
|
|
||||||
|
def validate_port_int(port: int) -> bool:
|
||||||
|
"""Validate port number as integer is in valid range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: Port number as integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if port is valid (1-65535), False otherwise
|
||||||
|
"""
|
||||||
|
return isinstance(port, int) and 1 <= port <= 65535
|
||||||
|
|
||||||
|
|
||||||
def validate_backend_name(name: str) -> bool:
|
def validate_backend_name(name: str) -> bool:
|
||||||
"""Validate backend or server name to prevent command injection.
|
"""Validate backend or server name to prevent command injection.
|
||||||
|
|
||||||
|
|||||||
@@ -1078,9 +1078,10 @@ class TestLoadCertToHaproxyError:
|
|||||||
pem_file = certs_dir / "example.com.pem"
|
pem_file = certs_dir / "example.com.pem"
|
||||||
pem_file.write_text("cert content")
|
pem_file.write_text("cert content")
|
||||||
|
|
||||||
# Mock haproxy_cmd to raise exception
|
# Mock haproxy_cmd to raise HaproxyError
|
||||||
|
from haproxy_mcp.exceptions import HaproxyError
|
||||||
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
|
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
|
||||||
with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=Exception("Connection failed")):
|
with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=HaproxyError("Connection failed")):
|
||||||
from haproxy_mcp.tools.certificates import load_cert_to_haproxy
|
from haproxy_mcp.tools.certificates import load_cert_to_haproxy
|
||||||
|
|
||||||
success, msg = load_cert_to_haproxy("example.com")
|
success, msg = load_cert_to_haproxy("example.com")
|
||||||
|
|||||||
@@ -547,7 +547,7 @@ class TestStartupRestoreFailures:
|
|||||||
|
|
||||||
with patch("socket.socket", return_value=mock_sock):
|
with patch("socket.socket", return_value=mock_sock):
|
||||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
|
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
|
||||||
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")):
|
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=IOError("Certificate error")):
|
||||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||||
from haproxy_mcp.tools.configuration import startup_restore
|
from haproxy_mcp.tools.configuration import startup_restore
|
||||||
|
|
||||||
@@ -598,7 +598,7 @@ class TestHaproxyReloadFailures:
|
|||||||
|
|
||||||
with patch("socket.socket", return_value=mock_sock):
|
with patch("socket.socket", return_value=mock_sock):
|
||||||
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
|
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
|
||||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")):
|
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Restore failed")):
|
||||||
with patch("time.sleep", return_value=None):
|
with patch("time.sleep", return_value=None):
|
||||||
from haproxy_mcp.tools.configuration import register_config_tools
|
from haproxy_mcp.tools.configuration import register_config_tools
|
||||||
mcp = MagicMock()
|
mcp = MagicMock()
|
||||||
|
|||||||
@@ -756,8 +756,8 @@ class TestHaproxyAddServersRollback:
|
|||||||
|
|
||||||
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
||||||
if slot == 2:
|
if slot == 2:
|
||||||
# Simulate unexpected error (not HaproxyError)
|
# Simulate unexpected error (IOError is caught by the exception handler)
|
||||||
raise RuntimeError("Unexpected system error")
|
raise IOError("Unexpected system error")
|
||||||
configured_slots.append(slot)
|
configured_slots.append(slot)
|
||||||
return f"{server_prefix}_{slot}"
|
return f"{server_prefix}_{slot}"
|
||||||
|
|
||||||
@@ -807,7 +807,7 @@ class TestHaproxyAddServersRollback:
|
|||||||
|
|
||||||
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
||||||
if slot == 2:
|
if slot == 2:
|
||||||
raise RuntimeError("Unexpected error")
|
raise OSError("Unexpected error")
|
||||||
return f"{server_prefix}_{slot}"
|
return f"{server_prefix}_{slot}"
|
||||||
|
|
||||||
def mock_remove_server_from_config(domain, slot):
|
def mock_remove_server_from_config(domain, slot):
|
||||||
|
|||||||
Reference in New Issue
Block a user