diff --git a/haproxy_mcp/file_ops.py b/haproxy_mcp/file_ops.py index 2da2e7d..4879309 100644 --- a/haproxy_mcp/file_ops.py +++ b/haproxy_mcp/file_ops.py @@ -94,8 +94,8 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]: with open(file_path, "r", encoding="utf-8") as f: try: fcntl.flock(f.fileno(), fcntl.LOCK_SH) - except OSError: - pass # Continue without lock if not supported + except OSError as e: + logger.debug("File locking not supported for %s: %s", file_path, e) try: for line in f: line = line.strip() @@ -107,10 +107,10 @@ def _read_map_file(file_path: str) -> list[tuple[str, str]]: finally: try: fcntl.flock(f.fileno(), fcntl.LOCK_UN) - except OSError: - pass + except OSError as e: + logger.debug("File unlock failed for %s: %s", file_path, e) except FileNotFoundError: - pass + logger.debug("Map file not found: %s", file_path) return entries @@ -420,16 +420,16 @@ def load_certs_config() -> list[str]: with open(CERTS_FILE, "r", encoding="utf-8") as f: try: fcntl.flock(f.fileno(), fcntl.LOCK_SH) - except OSError: - pass + except OSError as e: + logger.debug("File locking not supported for %s: %s", CERTS_FILE, e) try: data = json.load(f) return data.get("domains", []) finally: try: fcntl.flock(f.fileno(), fcntl.LOCK_UN) - except OSError: - pass + except OSError as e: + logger.debug("File unlock failed for %s: %s", CERTS_FILE, e) except FileNotFoundError: return [] except json.JSONDecodeError as e: diff --git a/haproxy_mcp/haproxy_client.py b/haproxy_mcp/haproxy_client.py index 981caba..36056aa 100644 --- a/haproxy_mcp/haproxy_client.py +++ b/haproxy_mcp/haproxy_client.py @@ -67,8 +67,8 @@ def haproxy_cmd(command: str) -> str: raise HaproxyError("Invalid UTF-8 in response") except HaproxyError: raise - except Exception as e: - raise HaproxyError(str(e)) from e + except (OSError, BlockingIOError, BrokenPipeError) as e: + raise HaproxyError(f"Socket error: {e}") from e def haproxy_cmd_checked(command: str) -> str: diff --git a/haproxy_mcp/tools/certificates.py b/haproxy_mcp/tools/certificates.py index a910a8f..b827a5c 100644 --- a/haproxy_mcp/tools/certificates.py +++ b/haproxy_mcp/tools/certificates.py @@ -14,6 +14,7 @@ from ..config import ( CERTS_DIR_CONTAINER, ACME_HOME, ) +from ..exceptions import HaproxyError from ..validation import validate_domain from ..haproxy_client import haproxy_cmd from ..file_ops import ( @@ -77,7 +78,11 @@ def load_cert_to_haproxy(domain: str) -> tuple[bool, str]: haproxy_cmd(f"commit ssl cert {container_path}") return True, "added" - except Exception as e: + except HaproxyError as e: + logger.error("HAProxy error loading certificate %s: %s", domain, e) + return False, str(e) + except (IOError, OSError) as e: + logger.error("File error loading certificate %s: %s", domain, e) return False, str(e) @@ -102,7 +107,8 @@ def unload_cert_from_haproxy(domain: str) -> tuple[bool, str]: haproxy_cmd(f"del ssl cert {container_path}") return True, "unloaded" - except Exception as e: + except HaproxyError as e: + logger.error("HAProxy error unloading certificate %s: %s", domain, e) return False, str(e) @@ -126,6 +132,364 @@ def restore_certificates() -> int: return restored +# ============================================================================= +# Implementation functions (module-level) +# ============================================================================= + + +def _haproxy_list_certs_impl() -> str: + """Implementation of haproxy_list_certs.""" + try: + result = subprocess.run( + [ACME_SH, "--list"], + capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT, + env={**os.environ, "HOME": os.path.expanduser("~")} + ) + if result.returncode != 0: + return f"Error: {result.stderr}" + + lines = result.stdout.strip().split("\n") + if len(lines) <= 1: + return "No certificates found" + + # Get HAProxy loaded certs + try: + haproxy_certs = haproxy_cmd("show ssl cert") + except HaproxyError as e: + logger.debug("Could not get HAProxy certs: %s", e) + haproxy_certs = "" + + # Parse and format output + certs = [] + for line in lines[1:]: # Skip header + parts = line.split() + if len(parts) >= 4: + domain = parts[0] + ca = "unknown" + created = "unknown" + renew = "unknown" + + for part in parts: + if "Google" in part or "LetsEncrypt" in part or "ZeroSSL" in part: + ca = part + elif part.endswith("Z") and "T" in part: + if created == "unknown": + created = part + else: + renew = part + + # Check deployment status + host_path, container_path = get_pem_paths(domain) + if container_path in haproxy_certs: + status = "loaded" + elif os.path.exists(host_path): + status = "file exists (not loaded)" + else: + status = "not deployed" + + certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}") + + return "\n\n".join(certs) if certs else "No certificates found" + except subprocess.TimeoutExpired: + return "Error: Command timed out" + except FileNotFoundError: + return "Error: acme.sh not found" + except subprocess.SubprocessError as e: + logger.error("Subprocess error listing certificates: %s", e) + return f"Error: {e}" + except OSError as e: + logger.error("OS error listing certificates: %s", e) + return f"Error: {e}" + + +def _haproxy_cert_info_impl(domain: str) -> str: + """Implementation of haproxy_cert_info.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + host_path, container_path = get_pem_paths(domain) + if not os.path.exists(host_path): + return f"Error: Certificate not found for {domain}" + + try: + # Use openssl to get certificate info + result = subprocess.run( + ["openssl", "x509", "-in", host_path, "-noout", + "-subject", "-issuer", "-dates", "-ext", "subjectAltName"], + capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT + ) + if result.returncode != 0: + return f"Error reading certificate: {result.stderr}" + + # Get file info + stat = os.stat(host_path) + modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S") + + # Check HAProxy status + try: + haproxy_certs = haproxy_cmd("show ssl cert") + loaded = "Yes" if container_path in haproxy_certs else "No" + except HaproxyError as e: + logger.debug("Could not check HAProxy cert status: %s", e) + loaded = "Unknown" + + info = [ + f"Certificate: {domain}", + f"File: {host_path}", + f"Modified: {modified}", + f"Loaded in HAProxy: {loaded}", + "---", + result.stdout.strip() + ] + return "\n".join(info) + except subprocess.TimeoutExpired: + return "Error: Command timed out" + except (subprocess.SubprocessError, OSError) as e: + logger.error("Error getting certificate info for %s: %s", domain, e) + return f"Error: {e}" + + +def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str: + """Implementation of haproxy_issue_cert.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + # Check if CF_Token is available + if not os.environ.get("CF_Token"): + secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini") + if os.path.exists(secrets_file): + try: + with open(secrets_file) as f: + for line in f: + if "=" in line and "token" in line.lower(): + token = line.split("=", 1)[1].strip().strip('"').strip("'") + os.environ["CF_Token"] = token + break + except (IOError, OSError) as e: + logger.warning("Failed to read Cloudflare token: %s", e) + + if not os.environ.get("CF_Token"): + return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini" + + # Check if certificate already exists + cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") + if os.path.exists(cert_dir): + return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew." + + # Build acme.sh command (without reload - we'll do it via Runtime API) + host_path, _ = get_pem_paths(domain) + + # Create PEM after issuance + install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}" + + cmd = [ + ACME_SH, "--issue", + "--dns", "dns_cf", + "-d", domain + ] + + if wildcard: + cmd.extend(["-d", f"*.{domain}"]) + + cmd.extend(["--reloadcmd", install_cmd]) + + try: + logger.info("Issuing certificate for %s", domain) + result = subprocess.run( + cmd, + capture_output=True, text=True, timeout=CERT_TIMEOUT, + env={**os.environ, "HOME": os.path.expanduser("~")} + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + return f"Error issuing certificate:\n{error_msg}" + + # Load into HAProxy via Runtime API (zero-downtime) + if os.path.exists(host_path): + success, msg = load_cert_to_haproxy(domain) + if success: + # Save to config for persistence + add_cert_to_config(domain) + return f"Certificate issued and loaded for {domain} ({msg})" + else: + return f"Certificate issued but HAProxy loading failed: {msg}" + else: + return f"Certificate issued but PEM file not created. Check {host_path}" + + except subprocess.TimeoutExpired: + return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s" + except (subprocess.SubprocessError, OSError) as e: + logger.error("Error issuing certificate for %s: %s", domain, e) + return f"Error: {e}" + + +def _haproxy_renew_cert_impl(domain: str, force: bool) -> str: + """Implementation of haproxy_renew_cert.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") + if not os.path.exists(cert_dir): + return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first." + + cmd = [ACME_SH, "--renew", "-d", domain] + if force: + cmd.append("--force") + + try: + logger.info("Renewing certificate for %s", domain) + result = subprocess.run( + cmd, + capture_output=True, text=True, timeout=CERT_TIMEOUT, + env={**os.environ, "HOME": os.path.expanduser("~")} + ) + + output = result.stdout + result.stderr + + if "Skip" in output and "Not yet due" in output: + return f"Certificate for {domain} not due for renewal. Use force=True to force renewal." + + if "Cert success" in output or result.returncode == 0: + # Reload into HAProxy via Runtime API + success, msg = load_cert_to_haproxy(domain) + if success: + # Ensure in config + add_cert_to_config(domain) + return f"Certificate renewed and reloaded for {domain} ({msg})" + else: + return f"Certificate renewed but HAProxy reload failed: {msg}" + else: + return f"Error renewing certificate:\n{output}" + + except subprocess.TimeoutExpired: + return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s" + except FileNotFoundError: + return "Error: acme.sh not found" + except (subprocess.SubprocessError, OSError) as e: + logger.error("Error renewing certificate for %s: %s", domain, e) + return f"Error: {e}" + + +def _haproxy_renew_all_certs_impl() -> str: + """Implementation of haproxy_renew_all_certs.""" + try: + logger.info("Running certificate renewal cron") + result = subprocess.run( + [ACME_SH, "--cron"], + capture_output=True, text=True, timeout=CERT_TIMEOUT * 3, + env={**os.environ, "HOME": os.path.expanduser("~")} + ) + + output = result.stdout + result.stderr + + # Count renewals + renewed = output.count("Cert success") + skipped = output.count("Skip") + + # Reload any renewed certs into HAProxy + if renewed > 0: + domains = load_certs_config() + reloaded = 0 + for domain in domains: + success, _ = load_cert_to_haproxy(domain) + if success: + reloaded += 1 + return f"Renewed {renewed} certificate(s), reloaded {reloaded} into HAProxy" + elif skipped > 0: + return f"No certificates due for renewal ({skipped} checked)" + elif result.returncode != 0: + return f"Error running renewal:\n{output}" + else: + return "Renewal check completed" + + except subprocess.TimeoutExpired: + return "Error: Renewal cron timed out" + except FileNotFoundError: + return "Error: acme.sh not found" + except (subprocess.SubprocessError, OSError) as e: + logger.error("Error running certificate renewal cron: %s", e) + return f"Error: {e}" + + +def _haproxy_delete_cert_impl(domain: str) -> str: + """Implementation of haproxy_delete_cert.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") + host_path, _ = get_pem_paths(domain) + + if not os.path.exists(cert_dir) and not os.path.exists(host_path): + return f"Error: No certificate found for {domain}" + + errors = [] + deleted = [] + + # Unload from HAProxy first (zero-downtime) + success, msg = unload_cert_from_haproxy(domain) + if success: + deleted.append(f"HAProxy ({msg})") + else: + errors.append(f"HAProxy unload: {msg}") + + # Remove from acme.sh + if os.path.exists(cert_dir): + try: + result = subprocess.run( + [ACME_SH, "--remove", "-d", domain], + capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT, + env={**os.environ, "HOME": os.path.expanduser("~")} + ) + if result.returncode == 0: + deleted.append("acme.sh") + else: + errors.append(f"acme.sh: {result.stderr}") + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e: + errors.append(f"acme.sh: {e}") + + # Remove PEM file + if os.path.exists(host_path): + try: + os.remove(host_path) + deleted.append("PEM file") + except OSError as e: + errors.append(f"PEM file: {e}") + + # Remove from config + remove_cert_from_config(domain) + + result_parts = [] + if deleted: + result_parts.append(f"Deleted: {', '.join(deleted)}") + if errors: + result_parts.append(f"Errors: {'; '.join(errors)}") + + return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted" + + +def _haproxy_load_cert_impl(domain: str) -> str: + """Implementation of haproxy_load_cert.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + host_path, _ = get_pem_paths(domain) + if not os.path.exists(host_path): + return f"Error: PEM file not found: {host_path}" + + success, msg = load_cert_to_haproxy(domain) + if success: + add_cert_to_config(domain) + return f"Certificate {domain} loaded into HAProxy ({msg})" + else: + return f"Error loading certificate: {msg}" + + +# ============================================================================= +# MCP Tool Registration +# ============================================================================= + + def register_certificate_tools(mcp): """Register certificate management tools with MCP server.""" @@ -136,62 +500,7 @@ def register_certificate_tools(mcp): Returns: List of certificates with domain, CA, created date, and renewal date """ - try: - result = subprocess.run( - [ACME_SH, "--list"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) - if result.returncode != 0: - return f"Error: {result.stderr}" - - lines = result.stdout.strip().split("\n") - if len(lines) <= 1: - return "No certificates found" - - # Get HAProxy loaded certs - try: - haproxy_certs = haproxy_cmd("show ssl cert") - except Exception: - haproxy_certs = "" - - # Parse and format output - certs = [] - for line in lines[1:]: # Skip header - parts = line.split() - if len(parts) >= 4: - domain = parts[0] - ca = "unknown" - created = "unknown" - renew = "unknown" - - for part in parts: - if "Google" in part or "LetsEncrypt" in part or "ZeroSSL" in part: - ca = part - elif part.endswith("Z") and "T" in part: - if created == "unknown": - created = part - else: - renew = part - - # Check deployment status - host_path, container_path = get_pem_paths(domain) - if container_path in haproxy_certs: - status = "loaded" - elif os.path.exists(host_path): - status = "file exists (not loaded)" - else: - status = "not deployed" - - certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}") - - return "\n\n".join(certs) if certs else "No certificates found" - except subprocess.TimeoutExpired: - return "Error: Command timed out" - except FileNotFoundError: - return "Error: acme.sh not found" - except Exception as e: - return f"Error: {e}" + return _haproxy_list_certs_impl() @mcp.tool() def haproxy_cert_info( @@ -201,47 +510,7 @@ def register_certificate_tools(mcp): Shows expiry date, issuer, SANs, and file paths. """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - host_path, container_path = get_pem_paths(domain) - if not os.path.exists(host_path): - return f"Error: Certificate not found for {domain}" - - try: - # Use openssl to get certificate info - result = subprocess.run( - ["openssl", "x509", "-in", host_path, "-noout", - "-subject", "-issuer", "-dates", "-ext", "subjectAltName"], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT - ) - if result.returncode != 0: - return f"Error reading certificate: {result.stderr}" - - # Get file info - stat = os.stat(host_path) - modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S") - - # Check HAProxy status - try: - haproxy_certs = haproxy_cmd("show ssl cert") - loaded = "Yes" if container_path in haproxy_certs else "No" - except Exception: - loaded = "Unknown" - - info = [ - f"Certificate: {domain}", - f"File: {host_path}", - f"Modified: {modified}", - f"Loaded in HAProxy: {loaded}", - "---", - result.stdout.strip() - ] - return "\n".join(info) - except subprocess.TimeoutExpired: - return "Error: Command timed out" - except Exception as e: - return f"Error: {e}" + return _haproxy_cert_info_impl(domain) @mcp.tool() def haproxy_issue_cert( @@ -254,76 +523,7 @@ def register_certificate_tools(mcp): Example: haproxy_issue_cert("example.com", wildcard=True) """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - # Check if CF_Token is available - if not os.environ.get("CF_Token"): - secrets_file = os.path.expanduser("~/.secrets/cloudflare.ini") - if os.path.exists(secrets_file): - try: - with open(secrets_file) as f: - for line in f: - if "=" in line and "token" in line.lower(): - token = line.split("=", 1)[1].strip().strip('"').strip("'") - os.environ["CF_Token"] = token - break - except Exception as e: - logger.warning("Failed to read Cloudflare token: %s", e) - - if not os.environ.get("CF_Token"): - return "Error: CF_Token not set. Export CF_Token or add to ~/.secrets/cloudflare.ini" - - # Check if certificate already exists - cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") - if os.path.exists(cert_dir): - return f"Error: Certificate for {domain} already exists. Use haproxy_renew_cert to renew." - - # Build acme.sh command (without reload - we'll do it via Runtime API) - host_path, _ = get_pem_paths(domain) - - # Create PEM after issuance - install_cmd = f"cat {ACME_HOME}/{domain}_ecc/fullchain.cer {ACME_HOME}/{domain}_ecc/{domain}.key > {host_path}" - - cmd = [ - ACME_SH, "--issue", - "--dns", "dns_cf", - "-d", domain - ] - - if wildcard: - cmd.extend(["-d", f"*.{domain}"]) - - cmd.extend(["--reloadcmd", install_cmd]) - - try: - logger.info("Issuing certificate for %s", domain) - result = subprocess.run( - cmd, - capture_output=True, text=True, timeout=CERT_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) - - if result.returncode != 0: - error_msg = result.stderr or result.stdout - return f"Error issuing certificate:\n{error_msg}" - - # Load into HAProxy via Runtime API (zero-downtime) - if os.path.exists(host_path): - success, msg = load_cert_to_haproxy(domain) - if success: - # Save to config for persistence - add_cert_to_config(domain) - return f"Certificate issued and loaded for {domain} ({msg})" - else: - return f"Certificate issued but HAProxy loading failed: {msg}" - else: - return f"Certificate issued but PEM file not created. Check {host_path}" - - except subprocess.TimeoutExpired: - return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s" - except Exception as e: - return f"Error: {e}" + return _haproxy_issue_cert_impl(domain, wildcard) @mcp.tool() def haproxy_renew_cert( @@ -336,46 +536,7 @@ def register_certificate_tools(mcp): Example: haproxy_renew_cert("example.com", force=True) """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") - if not os.path.exists(cert_dir): - return f"Error: No certificate found for {domain}. Use haproxy_issue_cert first." - - cmd = [ACME_SH, "--renew", "-d", domain] - if force: - cmd.append("--force") - - try: - logger.info("Renewing certificate for %s", domain) - result = subprocess.run( - cmd, - capture_output=True, text=True, timeout=CERT_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) - - output = result.stdout + result.stderr - - if "Skip" in output and "Not yet due" in output: - return f"Certificate for {domain} not due for renewal. Use force=True to force renewal." - - if "Cert success" in output or result.returncode == 0: - # Reload into HAProxy via Runtime API - success, msg = load_cert_to_haproxy(domain) - if success: - # Ensure in config - add_cert_to_config(domain) - return f"Certificate renewed and reloaded for {domain} ({msg})" - else: - return f"Certificate renewed but HAProxy reload failed: {msg}" - else: - return f"Error renewing certificate:\n{output}" - - except subprocess.TimeoutExpired: - return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s" - except Exception as e: - return f"Error: {e}" + return _haproxy_renew_cert_impl(domain, force) @mcp.tool() def haproxy_renew_all_certs() -> str: @@ -383,40 +544,7 @@ def register_certificate_tools(mcp): This runs the acme.sh cron job to check and renew all certificates. """ - try: - logger.info("Running certificate renewal cron") - result = subprocess.run( - [ACME_SH, "--cron"], - capture_output=True, text=True, timeout=CERT_TIMEOUT * 3, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) - - output = result.stdout + result.stderr - - # Count renewals - renewed = output.count("Cert success") - skipped = output.count("Skip") - - # Reload any renewed certs into HAProxy - if renewed > 0: - domains = load_certs_config() - reloaded = 0 - for domain in domains: - success, _ = load_cert_to_haproxy(domain) - if success: - reloaded += 1 - return f"Renewed {renewed} certificate(s), reloaded {reloaded} into HAProxy" - elif skipped > 0: - return f"No certificates due for renewal ({skipped} checked)" - elif result.returncode != 0: - return f"Error running renewal:\n{output}" - else: - return "Renewal check completed" - - except subprocess.TimeoutExpired: - return "Error: Renewal cron timed out" - except Exception as e: - return f"Error: {e}" + return _haproxy_renew_all_certs_impl() @mcp.tool() def haproxy_delete_cert( @@ -428,58 +556,7 @@ def register_certificate_tools(mcp): Example: haproxy_delete_cert("example.com") """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - cert_dir = os.path.join(ACME_HOME, f"{domain}_ecc") - host_path, _ = get_pem_paths(domain) - - if not os.path.exists(cert_dir) and not os.path.exists(host_path): - return f"Error: No certificate found for {domain}" - - errors = [] - deleted = [] - - # Unload from HAProxy first (zero-downtime) - success, msg = unload_cert_from_haproxy(domain) - if success: - deleted.append(f"HAProxy ({msg})") - else: - errors.append(f"HAProxy unload: {msg}") - - # Remove from acme.sh - if os.path.exists(cert_dir): - try: - result = subprocess.run( - [ACME_SH, "--remove", "-d", domain], - capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT, - env={**os.environ, "HOME": os.path.expanduser("~")} - ) - if result.returncode == 0: - deleted.append("acme.sh") - else: - errors.append(f"acme.sh: {result.stderr}") - except Exception as e: - errors.append(f"acme.sh: {e}") - - # Remove PEM file - if os.path.exists(host_path): - try: - os.remove(host_path) - deleted.append("PEM file") - except Exception as e: - errors.append(f"PEM file: {e}") - - # Remove from config - remove_cert_from_config(domain) - - result_parts = [] - if deleted: - result_parts.append(f"Deleted: {', '.join(deleted)}") - if errors: - result_parts.append(f"Errors: {'; '.join(errors)}") - - return "\n".join(result_parts) if result_parts else f"Certificate {domain} deleted" + return _haproxy_delete_cert_impl(domain) @mcp.tool() def haproxy_load_cert( @@ -491,16 +568,4 @@ def register_certificate_tools(mcp): Example: haproxy_load_cert("example.com") """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - host_path, _ = get_pem_paths(domain) - if not os.path.exists(host_path): - return f"Error: PEM file not found: {host_path}" - - success, msg = load_cert_to_haproxy(domain) - if success: - add_cert_to_config(domain) - return f"Certificate {domain} loaded into HAProxy ({msg})" - else: - return f"Error loading certificate: {msg}" + return _haproxy_load_cert_impl(domain) diff --git a/haproxy_mcp/tools/configuration.py b/haproxy_mcp/tools/configuration.py index 8ee2eca..78330bb 100644 --- a/haproxy_mcp/tools/configuration.py +++ b/haproxy_mcp/tools/configuration.py @@ -117,7 +117,7 @@ def startup_restore() -> None: cert_count = restore_certificates() if cert_count > 0: logger.info("Restored %d certificates from config", cert_count) - except Exception as e: + except (HaproxyError, IOError, OSError, ValueError) as e: logger.warning("Failed to restore certificates: %s", e) @@ -155,7 +155,7 @@ def register_config_tools(mcp): try: restored = restore_servers_from_config() return f"HAProxy configuration reloaded successfully ({restored} servers restored)" - except Exception as e: + except (HaproxyError, IOError, OSError, ValueError) as e: logger.error("Failed to restore servers after reload: %s", e) return f"HAProxy reloaded but server restore failed: {e}" diff --git a/haproxy_mcp/tools/domains.py b/haproxy_mcp/tools/domains.py index 9a6a7ce..276a46a 100644 --- a/haproxy_mcp/tools/domains.py +++ b/haproxy_mcp/tools/domains.py @@ -13,14 +13,12 @@ from ..config import ( WILDCARDS_MAP_FILE_CONTAINER, POOL_COUNT, MAX_SLOTS, - StateField, - STATE_MIN_COLUMNS, SUBPROCESS_TIMEOUT, CERTS_DIR, logger, ) from ..exceptions import HaproxyError -from ..validation import validate_domain, validate_ip +from ..validation import validate_domain, validate_ip, validate_port_int from ..haproxy_client import haproxy_cmd from ..file_ops import ( get_map_contents, @@ -31,6 +29,7 @@ from ..file_ops import ( remove_server_from_config, remove_domain_from_config, ) +from ..utils import parse_servers_state, disable_server_slot def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]: @@ -166,18 +165,18 @@ def register_domain_tools(mcp): try: domains = [] state = haproxy_cmd("show servers state") + parsed_state = parse_servers_state(state) # Build server map from HAProxy state - server_map: dict[str, list] = {} - for line in state.split("\n"): - parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0": - backend = parts[StateField.BE_NAME] - if backend not in server_map: - server_map[backend] = [] - server_map[backend].append( - f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}" - ) + server_map: dict[str, list[str]] = {} + for backend, servers_dict in parsed_state.items(): + for server_name, srv_info in servers_dict.items(): + if srv_info["addr"] != "0.0.0.0": + if backend not in server_map: + server_map[backend] = [] + server_map[backend].append( + f"{server_name}={srv_info['addr']}:{srv_info['port']}" + ) # Read from domains.map seen_domains: set[str] = set() @@ -220,7 +219,7 @@ def register_domain_tools(mcp): return "Error: Invalid domain format" if not validate_ip(ip, allow_empty=True): return "Error: Invalid IP address format" - if not (1 <= http_port <= 65535): + if not validate_port_int(http_port): return "Error: Port must be between 1 and 65535" # Use file locking for the entire pool allocation operation @@ -339,8 +338,7 @@ def register_domain_tools(mcp): for slot in range(1, MAX_SLOTS + 1): server = f"{backend}_{slot}" try: - haproxy_cmd(f"set server {backend}/{server} state maint") - haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0") + disable_server_slot(backend, server) except HaproxyError as e: logger.warning( "Failed to clear server %s/%s for domain %s: %s", diff --git a/haproxy_mcp/tools/health.py b/haproxy_mcp/tools/health.py index d34b7d7..3f1cbe1 100644 --- a/haproxy_mcp/tools/health.py +++ b/haproxy_mcp/tools/health.py @@ -12,14 +12,12 @@ from ..config import ( MAP_FILE, SERVERS_FILE, HAPROXY_CONTAINER, - StateField, - STATE_MIN_COLUMNS, ) from ..exceptions import HaproxyError from ..validation import validate_domain, validate_backend_name from ..haproxy_client import haproxy_cmd from ..file_ops import get_backend_and_prefix -from ..utils import parse_stat_csv +from ..utils import parse_stat_csv, parse_servers_state def register_health_tools(mcp): @@ -83,7 +81,7 @@ def register_health_tools(mcp): except subprocess.TimeoutExpired: result["components"]["container"] = {"status": "timeout"} result["status"] = "unhealthy" - except Exception as e: + except (OSError, subprocess.SubprocessError) as e: result["components"]["container"] = {"status": "error", "error": str(e)} # Check configuration files @@ -144,34 +142,34 @@ def register_health_tools(mcp): } # Parse server state for address info - for line in state_output.split("\n"): - parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: - server_name = parts[StateField.SRV_NAME] - addr = parts[StateField.SRV_ADDR] - port = parts[StateField.SRV_PORT] + parsed_state = parse_servers_state(state_output) + backend_servers = parsed_state.get(backend, {}) - # Skip disabled servers (0.0.0.0) - if addr == "0.0.0.0": - continue + for server_name, srv_info in backend_servers.items(): + addr = srv_info["addr"] + port = srv_info["port"] - server_info: dict[str, Any] = { - "name": server_name, - "addr": f"{addr}:{port}", - "status": "unknown" - } + # Skip disabled servers (0.0.0.0) + if addr == "0.0.0.0": + continue - # Get status from stat output - if server_name in status_map: - server_info["status"] = status_map[server_name]["status"] - server_info["check_status"] = status_map[server_name]["check_status"] - server_info["weight"] = status_map[server_name]["weight"] + server_info: dict[str, Any] = { + "name": server_name, + "addr": f"{addr}:{port}", + "status": "unknown" + } - result["servers"].append(server_info) - result["total_count"] += 1 + # Get status from stat output + if server_name in status_map: + server_info["status"] = status_map[server_name]["status"] + server_info["check_status"] = status_map[server_name]["check_status"] + server_info["weight"] = status_map[server_name]["weight"] - if server_info["status"] == "UP": - result["healthy_count"] += 1 + result["servers"].append(server_info) + result["total_count"] += 1 + + if server_info["status"] == "UP": + result["healthy_count"] += 1 # Determine overall status if result["total_count"] == 0: diff --git a/haproxy_mcp/tools/servers.py b/haproxy_mcp/tools/servers.py index 5a73f54..1dabc45 100644 --- a/haproxy_mcp/tools/servers.py +++ b/haproxy_mcp/tools/servers.py @@ -10,13 +10,11 @@ from ..config import ( MAX_SLOTS, MAX_BULK_SERVERS, MAX_SERVERS_JSON_SIZE, - StateField, StatField, - STATE_MIN_COLUMNS, logger, ) from ..exceptions import HaproxyError -from ..validation import validate_domain, validate_ip, validate_backend_name +from ..validation import validate_domain, validate_ip, validate_backend_name, validate_port_int from ..haproxy_client import haproxy_cmd, haproxy_cmd_checked, haproxy_cmd_batch from ..file_ops import ( get_backend_and_prefix, @@ -24,6 +22,7 @@ from ..file_ops import ( add_server_to_config, remove_server_from_config, ) +from ..utils import parse_servers_state, disable_server_slot def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str: @@ -51,6 +50,390 @@ def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, return server +# ============================================================================= +# Implementation functions (module-level) +# ============================================================================= + + +def _haproxy_list_servers_impl(domain: str) -> str: + """Implementation of haproxy_list_servers.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + try: + backend, _ = get_backend_and_prefix(domain) + state = haproxy_cmd("show servers state") + parsed_state = parse_servers_state(state) + backend_servers = parsed_state.get(backend, {}) + + if not backend_servers: + return f"Backend {backend} not found" + + servers = [] + for server_name, srv_info in backend_servers.items(): + addr = srv_info["addr"] + port = srv_info["port"] + status = "active" if addr != "0.0.0.0" else "disabled" + servers.append(f"• {server_name}: {addr}:{port} ({status})") + + return f"Servers for {domain} ({backend}):\n" + "\n".join(servers) + except (HaproxyError, ValueError) as e: + return f"Error: {e}" + + +def _haproxy_add_server_impl(domain: str, slot: int, ip: str, http_port: int) -> str: + """Implementation of haproxy_add_server.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + if not ip: + return "Error: IP address is required" + if not validate_ip(ip): + return "Error: Invalid IP address format" + if not validate_port_int(http_port): + return "Error: Port must be between 1 and 65535" + + try: + backend, server_prefix = get_backend_and_prefix(domain) + + # Auto-select slot if slot <= 0 + if slot <= 0: + state = haproxy_cmd("show servers state") + parsed_state = parse_servers_state(state) + backend_servers = parsed_state.get(backend, {}) + used_slots: set[int] = set() + for server_name, srv_info in backend_servers.items(): + if srv_info["addr"] != "0.0.0.0": + # Extract slot number from server name (e.g., pool_1_3 -> 3) + try: + used_slots.add(int(server_name.rsplit("_", 1)[1])) + except (ValueError, IndexError): + pass + for s in range(1, MAX_SLOTS + 1): + if s not in used_slots: + slot = s + break + else: + return f"Error: No available slots (all {MAX_SLOTS} slots in use)" + elif not (1 <= slot <= MAX_SLOTS): + return f"Error: Slot must be between 1 and {MAX_SLOTS}, or 0/-1 for auto-select" + + # Save to persistent config FIRST (disk-first pattern) + add_server_to_config(domain, slot, ip, http_port) + + try: + server = configure_server_slot(backend, server_prefix, slot, ip, http_port) + return f"Added to {domain} ({backend}) slot {slot}:\n{server} → {ip}:{http_port}" + except HaproxyError as e: + # Rollback config on HAProxy failure + remove_server_from_config(domain, slot) + return f"Error: {e}" + except (ValueError, IOError) as e: + return f"Error: {e}" + + +def _haproxy_add_servers_impl(domain: str, servers: str) -> str: + """Implementation of haproxy_add_servers.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + + # Check JSON size before parsing + if len(servers) > MAX_SERVERS_JSON_SIZE: + return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)" + + # Parse JSON array + try: + server_list = json.loads(servers) + except json.JSONDecodeError as e: + return f"Error: Invalid JSON - {e}" + + if not isinstance(server_list, list): + return "Error: servers must be a JSON array" + + if not server_list: + return "Error: servers array is empty" + + if len(server_list) > MAX_BULK_SERVERS: + return f"Error: Cannot add more than {MAX_BULK_SERVERS} servers at once" + + # Validate all servers first before adding any + validated_servers = [] + validation_errors = [] + + for i, srv in enumerate(server_list): + if not isinstance(srv, dict): + validation_errors.append(f"Server {i+1}: must be an object") + continue + + # Extract and validate slot + slot = srv.get("slot") + if slot is None: + validation_errors.append(f"Server {i+1}: missing 'slot' field") + continue + try: + slot = int(slot) + except (ValueError, TypeError): + validation_errors.append(f"Server {i+1}: slot must be an integer") + continue + if not (1 <= slot <= MAX_SLOTS): + validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}") + continue + + # Extract and validate IP + ip = srv.get("ip") + if not ip: + validation_errors.append(f"Server {i+1}: missing 'ip' field") + continue + if not validate_ip(ip): + validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'") + continue + + # Extract and validate port + http_port = srv.get("http_port", 80) + try: + http_port = int(http_port) + except (ValueError, TypeError): + validation_errors.append(f"Server {i+1}: http_port must be an integer") + continue + if not validate_port_int(http_port): + validation_errors.append(f"Server {i+1}: port must be between 1 and 65535") + continue + + validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port}) + + # Return validation errors if any + if validation_errors: + return "Validation errors:\n" + "\n".join(f" • {e}" for e in validation_errors) + + # Check for duplicate slots + slots = [s["slot"] for s in validated_servers] + if len(slots) != len(set(slots)): + return "Error: Duplicate slot numbers in servers array" + + # Get backend info + try: + backend, server_prefix = get_backend_and_prefix(domain) + except ValueError as e: + return f"Error: {e}" + + # Save ALL servers to config FIRST (disk-first pattern) + for server_config in validated_servers: + slot = server_config["slot"] + ip = server_config["ip"] + http_port = server_config["http_port"] + add_server_to_config(domain, slot, ip, http_port) + + # Then update HAProxy + added = [] + errors = [] + failed_slots = [] + successfully_added_slots = [] + + try: + for server_config in validated_servers: + slot = server_config["slot"] + ip = server_config["ip"] + http_port = server_config["http_port"] + try: + configure_server_slot(backend, server_prefix, slot, ip, http_port) + successfully_added_slots.append(slot) + added.append(f"slot {slot}: {ip}:{http_port}") + except HaproxyError as e: + failed_slots.append(slot) + errors.append(f"slot {slot}: {e}") + except (IOError, OSError, ValueError) as e: + # Rollback only successfully added configs on unexpected error + logger.error("Unexpected error during bulk server add for %s: %s", domain, e) + for slot in successfully_added_slots: + try: + remove_server_from_config(domain, slot) + except (IOError, OSError, ValueError) as rollback_error: + logger.error( + "Failed to rollback server config for %s slot %d: %s", + domain, slot, rollback_error + ) + # Also rollback configs that weren't yet processed + for server_config in validated_servers: + slot = server_config["slot"] + if slot not in successfully_added_slots: + try: + remove_server_from_config(domain, slot) + except (IOError, OSError, ValueError) as rollback_error: + logger.error( + "Failed to rollback server config for %s slot %d: %s", + domain, slot, rollback_error + ) + return f"Error: {e}" + + # Rollback failed slots from config + for slot in failed_slots: + try: + remove_server_from_config(domain, slot) + except (IOError, OSError, ValueError) as rollback_error: + logger.error( + "Failed to rollback server config for %s slot %d: %s", + domain, slot, rollback_error + ) + + # Build result message + result_parts = [] + if added: + result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):") + result_parts.extend(f" • {s}" for s in added) + if errors: + result_parts.append(f"Failed to add {len(errors)} servers:") + result_parts.extend(f" • {e}" for e in errors) + + return "\n".join(result_parts) if result_parts else "No servers added" + + +def _haproxy_remove_server_impl(domain: str, slot: int) -> str: + """Implementation of haproxy_remove_server.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + if not (1 <= slot <= MAX_SLOTS): + return f"Error: Slot must be between 1 and {MAX_SLOTS}" + + try: + backend, server_prefix = get_backend_and_prefix(domain) + + # Get current server info for potential rollback + config = load_servers_config() + old_config = config.get(domain, {}).get(str(slot), {}) + + # Remove from persistent config FIRST (disk-first pattern) + remove_server_from_config(domain, slot) + + try: + # HTTP only - single server per slot + server = f"{server_prefix}_{slot}" + disable_server_slot(backend, server) + return f"Removed server at slot {slot} from {domain} ({backend})" + except HaproxyError as e: + # Rollback: re-add config if HAProxy command failed + if old_config: + add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80)) + return f"Error: {e}" + except (ValueError, IOError) as e: + return f"Error: {e}" + + +def _haproxy_set_domain_state_impl(domain: str, state: str) -> str: + """Implementation of haproxy_set_domain_state.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + if state not in ["ready", "drain", "maint"]: + return "Error: State must be 'ready', 'drain', or 'maint'" + + try: + backend, _ = get_backend_and_prefix(domain) + except ValueError as e: + return f"Error: {e}" + + # Get active servers for this domain + try: + servers_state = haproxy_cmd("show servers state") + except HaproxyError as e: + return f"Error: {e}" + + parsed_state = parse_servers_state(servers_state) + backend_servers = parsed_state.get(backend, {}) + + changed = [] + errors = [] + + for server_name, srv_info in backend_servers.items(): + # Only change state for configured servers (not 0.0.0.0) + if srv_info["addr"] != "0.0.0.0": + try: + haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}") + changed.append(server_name) + except HaproxyError as e: + errors.append(f"{server_name}: {e}") + + if not changed and not errors: + return f"No active servers found for {domain}" + + result = f"Set {len(changed)} servers to '{state}' for {domain}" + if changed: + result += ":\n" + "\n".join(f" • {s}" for s in changed) + if errors: + result += f"\n\nErrors ({len(errors)}):\n" + "\n".join(f" • {e}" for e in errors) + + return result + + +def _haproxy_wait_drain_impl(domain: str, timeout: int) -> str: + """Implementation of haproxy_wait_drain.""" + if not validate_domain(domain): + return "Error: Invalid domain format" + if not (1 <= timeout <= 300): + return "Error: Timeout must be between 1 and 300 seconds" + + try: + backend, _ = get_backend_and_prefix(domain) + except ValueError as e: + return f"Error: {e}" + + start_time = time.time() + while time.time() - start_time < timeout: + try: + stats = haproxy_cmd("show stat") + total_connections = 0 + for line in stats.split("\n"): + parts = line.split(",") + if len(parts) > StatField.SCUR and parts[0] == backend and parts[1] not in ["FRONTEND", "BACKEND", ""]: + try: + scur = int(parts[StatField.SCUR]) if parts[StatField.SCUR] else 0 + total_connections += scur + except ValueError: + pass + + if total_connections == 0: + elapsed = int(time.time() - start_time) + return f"All connections drained for {domain} ({elapsed}s)" + + time.sleep(1) + except HaproxyError as e: + return f"Error checking connections: {e}" + + return f"Timeout: Connections still active after {timeout}s" + + +def _haproxy_set_server_state_impl(backend: str, server: str, state: str) -> str: + """Implementation of haproxy_set_server_state.""" + if not validate_backend_name(backend): + return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)" + if not validate_backend_name(server): + return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)" + if state not in ["ready", "drain", "maint"]: + return "Error: state must be 'ready', 'drain', or 'maint'" + try: + haproxy_cmd_checked(f"set server {backend}/{server} state {state}") + return f"Server {backend}/{server} set to {state}" + except HaproxyError as e: + return f"Error: {e}" + + +def _haproxy_set_server_weight_impl(backend: str, server: str, weight: int) -> str: + """Implementation of haproxy_set_server_weight.""" + if not validate_backend_name(backend): + return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)" + if not validate_backend_name(server): + return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)" + if not (0 <= weight <= 256): + return "Error: weight must be between 0 and 256" + try: + haproxy_cmd_checked(f"set server {backend}/{server} weight {weight}") + return f"Server {backend}/{server} weight set to {weight}" + except HaproxyError as e: + return f"Error: {e}" + + +# ============================================================================= +# MCP Tool Registration +# ============================================================================= + + def register_server_tools(mcp): """Register server management tools with MCP server.""" @@ -65,29 +448,7 @@ def register_server_tools(mcp): # Output: pool_1_1: 10.0.0.1:8080 (UP) # pool_1_2: 10.0.0.2:8080 (UP) """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - try: - backend, _ = get_backend_and_prefix(domain) - servers = [] - state = haproxy_cmd("show servers state") - - for line in state.split("\n"): - parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: - addr = parts[StateField.SRV_ADDR] - status = "active" if addr != "0.0.0.0" else "disabled" - servers.append( - f"• {parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})" - ) - - if not servers: - return f"Backend {backend} not found" - - return f"Servers for {domain} ({backend}):\n" + "\n".join(servers) - except (HaproxyError, ValueError) as e: - return f"Error: {e}" + return _haproxy_list_servers_impl(domain) @mcp.tool() def haproxy_add_server( @@ -103,53 +464,7 @@ def register_server_tools(mcp): Example: haproxy_add_server("api.example.com", slot=1, ip="10.0.0.1", http_port=8080) """ - if not validate_domain(domain): - return "Error: Invalid domain format" - if not ip: - return "Error: IP address is required" - if not validate_ip(ip): - return "Error: Invalid IP address format" - if not (1 <= http_port <= 65535): - return "Error: Port must be between 1 and 65535" - - try: - backend, server_prefix = get_backend_and_prefix(domain) - - # Auto-select slot if slot <= 0 - if slot <= 0: - state = haproxy_cmd("show servers state") - used_slots: set[int] = set() - for line in state.split("\n"): - parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: - if parts[StateField.SRV_ADDR] != "0.0.0.0": - # Extract slot number from server name (e.g., pool_1_3 -> 3) - server_name = parts[StateField.SRV_NAME] - try: - used_slots.add(int(server_name.rsplit("_", 1)[1])) - except (ValueError, IndexError): - pass - for s in range(1, MAX_SLOTS + 1): - if s not in used_slots: - slot = s - break - else: - return f"Error: No available slots (all {MAX_SLOTS} slots in use)" - elif not (1 <= slot <= MAX_SLOTS): - return f"Error: Slot must be between 1 and {MAX_SLOTS}, or 0/-1 for auto-select" - - # Save to persistent config FIRST (disk-first pattern) - add_server_to_config(domain, slot, ip, http_port) - - try: - server = configure_server_slot(backend, server_prefix, slot, ip, http_port) - return f"Added to {domain} ({backend}) slot {slot}:\n{server} → {ip}:{http_port}" - except HaproxyError as e: - # Rollback config on HAProxy failure - remove_server_from_config(domain, slot) - return f"Error: {e}" - except (ValueError, IOError) as e: - return f"Error: {e}" + return _haproxy_add_server_impl(domain, slot, ip, http_port) @mcp.tool() def haproxy_add_servers( @@ -160,156 +475,7 @@ def register_server_tools(mcp): Example: haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]') """ - if not validate_domain(domain): - return "Error: Invalid domain format" - - # Check JSON size before parsing - if len(servers) > MAX_SERVERS_JSON_SIZE: - return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)" - - # Parse JSON array - try: - server_list = json.loads(servers) - except json.JSONDecodeError as e: - return f"Error: Invalid JSON - {e}" - - if not isinstance(server_list, list): - return "Error: servers must be a JSON array" - - if not server_list: - return "Error: servers array is empty" - - if len(server_list) > MAX_BULK_SERVERS: - return f"Error: Cannot add more than {MAX_BULK_SERVERS} servers at once" - - # Validate all servers first before adding any - validated_servers = [] - validation_errors = [] - - for i, srv in enumerate(server_list): - if not isinstance(srv, dict): - validation_errors.append(f"Server {i+1}: must be an object") - continue - - # Extract and validate slot - slot = srv.get("slot") - if slot is None: - validation_errors.append(f"Server {i+1}: missing 'slot' field") - continue - try: - slot = int(slot) - except (ValueError, TypeError): - validation_errors.append(f"Server {i+1}: slot must be an integer") - continue - if not (1 <= slot <= MAX_SLOTS): - validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}") - continue - - # Extract and validate IP - ip = srv.get("ip") - if not ip: - validation_errors.append(f"Server {i+1}: missing 'ip' field") - continue - if not validate_ip(ip): - validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'") - continue - - # Extract and validate port - http_port = srv.get("http_port", 80) - try: - http_port = int(http_port) - except (ValueError, TypeError): - validation_errors.append(f"Server {i+1}: http_port must be an integer") - continue - if not (1 <= http_port <= 65535): - validation_errors.append(f"Server {i+1}: port must be between 1 and 65535") - continue - - validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port}) - - # Return validation errors if any - if validation_errors: - return "Validation errors:\n" + "\n".join(f" • {e}" for e in validation_errors) - - # Check for duplicate slots - slots = [s["slot"] for s in validated_servers] - if len(slots) != len(set(slots)): - return "Error: Duplicate slot numbers in servers array" - - # Get backend info - try: - backend, server_prefix = get_backend_and_prefix(domain) - except ValueError as e: - return f"Error: {e}" - - # Save ALL servers to config FIRST (disk-first pattern) - for server_config in validated_servers: - slot = server_config["slot"] - ip = server_config["ip"] - http_port = server_config["http_port"] - add_server_to_config(domain, slot, ip, http_port) - - # Then update HAProxy - added = [] - errors = [] - failed_slots = [] - successfully_added_slots = [] - - try: - for server_config in validated_servers: - slot = server_config["slot"] - ip = server_config["ip"] - http_port = server_config["http_port"] - try: - configure_server_slot(backend, server_prefix, slot, ip, http_port) - successfully_added_slots.append(slot) - added.append(f"slot {slot}: {ip}:{http_port}") - except HaproxyError as e: - failed_slots.append(slot) - errors.append(f"slot {slot}: {e}") - except Exception as e: - # Rollback only successfully added configs on unexpected error - for slot in successfully_added_slots: - try: - remove_server_from_config(domain, slot) - except Exception as rollback_error: - logger.error( - "Failed to rollback server config for %s slot %d: %s", - domain, slot, rollback_error - ) - # Also rollback configs that weren't yet processed - for server_config in validated_servers: - slot = server_config["slot"] - if slot not in successfully_added_slots: - try: - remove_server_from_config(domain, slot) - except Exception as rollback_error: - logger.error( - "Failed to rollback server config for %s slot %d: %s", - domain, slot, rollback_error - ) - return f"Error: {e}" - - # Rollback failed slots from config - for slot in failed_slots: - try: - remove_server_from_config(domain, slot) - except Exception as rollback_error: - logger.error( - "Failed to rollback server config for %s slot %d: %s", - domain, slot, rollback_error - ) - - # Build result message - result_parts = [] - if added: - result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):") - result_parts.extend(f" • {s}" for s in added) - if errors: - result_parts.append(f"Failed to add {len(errors)} servers:") - result_parts.extend(f" • {e}" for e in errors) - - return "\n".join(result_parts) if result_parts else "No servers added" + return _haproxy_add_servers_impl(domain, servers) @mcp.tool() def haproxy_remove_server( @@ -320,37 +486,7 @@ def register_server_tools(mcp): Example: haproxy_remove_server("api.example.com", slot=2) """ - if not validate_domain(domain): - return "Error: Invalid domain format" - if not (1 <= slot <= MAX_SLOTS): - return f"Error: Slot must be between 1 and {MAX_SLOTS}" - - try: - backend, server_prefix = get_backend_and_prefix(domain) - - # Get current server info for potential rollback - config = load_servers_config() - old_config = config.get(domain, {}).get(str(slot), {}) - - # Remove from persistent config FIRST (disk-first pattern) - remove_server_from_config(domain, slot) - - try: - # HTTP only - single server per slot - server = f"{server_prefix}_{slot}" - # Batch both commands in single TCP connection - haproxy_cmd_batch([ - f"set server {backend}/{server} state maint", - f"set server {backend}/{server} addr 0.0.0.0 port 0" - ]) - return f"Removed server at slot {slot} from {domain} ({backend})" - except HaproxyError as e: - # Rollback: re-add config if HAProxy command failed - if old_config: - add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80)) - return f"Error: {e}" - except (ValueError, IOError) as e: - return f"Error: {e}" + return _haproxy_remove_server_impl(domain, slot) @mcp.tool() def haproxy_set_domain_state( @@ -368,48 +504,7 @@ def register_server_tools(mcp): # Re-enable all servers after deployment haproxy_set_domain_state("api.example.com", "ready") """ - if not validate_domain(domain): - return "Error: Invalid domain format" - if state not in ["ready", "drain", "maint"]: - return "Error: State must be 'ready', 'drain', or 'maint'" - - try: - backend, _ = get_backend_and_prefix(domain) - except ValueError as e: - return f"Error: {e}" - - # Get active servers for this domain - try: - servers_state = haproxy_cmd("show servers state") - except HaproxyError as e: - return f"Error: {e}" - - changed = [] - errors = [] - - for line in servers_state.split("\n"): - parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: - server_name = parts[StateField.SRV_NAME] - addr = parts[StateField.SRV_ADDR] - # Only change state for configured servers (not 0.0.0.0) - if addr != "0.0.0.0": - try: - haproxy_cmd_checked(f"set server {backend}/{server_name} state {state}") - changed.append(server_name) - except HaproxyError as e: - errors.append(f"{server_name}: {e}") - - if not changed and not errors: - return f"No active servers found for {domain}" - - result = f"Set {len(changed)} servers to '{state}' for {domain}" - if changed: - result += ":\n" + "\n".join(f" • {s}" for s in changed) - if errors: - result += f"\n\nErrors ({len(errors)}):\n" + "\n".join(f" • {e}" for e in errors) - - return result + return _haproxy_set_domain_state_impl(domain, state) @mcp.tool() def haproxy_wait_drain( @@ -422,39 +517,7 @@ def register_server_tools(mcp): Example: haproxy_wait_drain("api.example.com", timeout=60) """ - if not validate_domain(domain): - return "Error: Invalid domain format" - if not (1 <= timeout <= 300): - return "Error: Timeout must be between 1 and 300 seconds" - - try: - backend, _ = get_backend_and_prefix(domain) - except ValueError as e: - return f"Error: {e}" - - start_time = time.time() - while time.time() - start_time < timeout: - try: - stats = haproxy_cmd("show stat") - total_connections = 0 - for line in stats.split("\n"): - parts = line.split(",") - if len(parts) > StatField.SCUR and parts[0] == backend and parts[1] not in ["FRONTEND", "BACKEND", ""]: - try: - scur = int(parts[StatField.SCUR]) if parts[StatField.SCUR] else 0 - total_connections += scur - except ValueError: - pass - - if total_connections == 0: - elapsed = int(time.time() - start_time) - return f"All connections drained for {domain} ({elapsed}s)" - - time.sleep(1) - except HaproxyError as e: - return f"Error checking connections: {e}" - - return f"Timeout: Connections still active after {timeout}s" + return _haproxy_wait_drain_impl(domain, timeout) @mcp.tool() def haproxy_set_server_state( @@ -466,17 +529,7 @@ def register_server_tools(mcp): Example: haproxy_set_server_state("pool_1", "pool_1_2", "maint") """ - if not validate_backend_name(backend): - return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)" - if not validate_backend_name(server): - return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)" - if state not in ["ready", "drain", "maint"]: - return "Error: state must be 'ready', 'drain', or 'maint'" - try: - haproxy_cmd_checked(f"set server {backend}/{server} state {state}") - return f"Server {backend}/{server} set to {state}" - except HaproxyError as e: - return f"Error: {e}" + return _haproxy_set_server_state_impl(backend, server, state) @mcp.tool() def haproxy_set_server_weight( @@ -488,14 +541,4 @@ def register_server_tools(mcp): Example: haproxy_set_server_weight("pool_1", "pool_1_1", weight=2) """ - if not validate_backend_name(backend): - return "Error: Invalid backend name (use alphanumeric, underscore, hyphen only)" - if not validate_backend_name(server): - return "Error: Invalid server name (use alphanumeric, underscore, hyphen only)" - if not (0 <= weight <= 256): - return "Error: weight must be between 0 and 256" - try: - haproxy_cmd_checked(f"set server {backend}/{server} weight {weight}") - return f"Server {backend}/{server} weight set to {weight}" - except HaproxyError as e: - return f"Error: {e}" + return _haproxy_set_server_weight_impl(backend, server, weight) diff --git a/haproxy_mcp/utils.py b/haproxy_mcp/utils.py index 96ada02..dfb5491 100644 --- a/haproxy_mcp/utils.py +++ b/haproxy_mcp/utils.py @@ -1,7 +1,58 @@ """Utility functions for HAProxy MCP Server.""" from typing import Dict, Generator -from .config import StatField +from .config import StatField, StateField, STATE_MIN_COLUMNS +from .haproxy_client import haproxy_cmd_batch + + +def parse_servers_state(state_output: str) -> dict[str, dict[str, dict[str, str]]]: + """Parse 'show servers state' output. + + Args: + state_output: Raw output from HAProxy 'show servers state' command + + Returns: + Nested dict: {backend: {server: {addr: str, port: str, state: str}}} + + Example: + state = haproxy_cmd("show servers state") + parsed = parse_servers_state(state) + # parsed["pool_1"]["pool_1_1"] == {"addr": "10.0.0.1", "port": "8080", "state": "2"} + """ + result: dict[str, dict[str, dict[str, str]]] = {} + for line in state_output.split("\n"): + parts = line.split() + if len(parts) >= STATE_MIN_COLUMNS: + backend = parts[StateField.BE_NAME] + server = parts[StateField.SRV_NAME] + addr = parts[StateField.SRV_ADDR] + port = parts[StateField.SRV_PORT] + state = parts[StateField.SRV_OP_STATE] if len(parts) > StateField.SRV_OP_STATE else "" + + if backend not in result: + result[backend] = {} + result[backend][server] = { + "addr": addr, + "port": port, + "state": state, + } + return result + + +def disable_server_slot(backend: str, server: str) -> None: + """Disable a server slot (set to maint and clear address). + + Args: + backend: Backend name (e.g., 'pool_1') + server: Server name (e.g., 'pool_1_1') + + Raises: + HaproxyError: If HAProxy command fails + """ + haproxy_cmd_batch([ + f"set server {backend}/{server} state maint", + f"set server {backend}/{server} addr 0.0.0.0 port 0" + ]) def parse_stat_csv(stat_output: str) -> Generator[Dict[str, str], None, None]: diff --git a/haproxy_mcp/validation.py b/haproxy_mcp/validation.py index 359f4f7..9749a6b 100644 --- a/haproxy_mcp/validation.py +++ b/haproxy_mcp/validation.py @@ -52,6 +52,18 @@ def validate_port(port: str) -> bool: return 1 <= port_num <= 65535 +def validate_port_int(port: int) -> bool: + """Validate port number as integer is in valid range. + + Args: + port: Port number as integer + + Returns: + True if port is valid (1-65535), False otherwise + """ + return isinstance(port, int) and 1 <= port <= 65535 + + def validate_backend_name(name: str) -> bool: """Validate backend or server name to prevent command injection. diff --git a/tests/unit/tools/test_certificates.py b/tests/unit/tools/test_certificates.py index 8c2272e..22d89a4 100644 --- a/tests/unit/tools/test_certificates.py +++ b/tests/unit/tools/test_certificates.py @@ -1078,9 +1078,10 @@ class TestLoadCertToHaproxyError: pem_file = certs_dir / "example.com.pem" pem_file.write_text("cert content") - # Mock haproxy_cmd to raise exception + # Mock haproxy_cmd to raise HaproxyError + from haproxy_mcp.exceptions import HaproxyError with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): - with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=Exception("Connection failed")): + with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=HaproxyError("Connection failed")): from haproxy_mcp.tools.certificates import load_cert_to_haproxy success, msg = load_cert_to_haproxy("example.com") diff --git a/tests/unit/tools/test_configuration.py b/tests/unit/tools/test_configuration.py index f2b4dc5..e9997e7 100644 --- a/tests/unit/tools/test_configuration.py +++ b/tests/unit/tools/test_configuration.py @@ -547,7 +547,7 @@ class TestStartupRestoreFailures: with patch("socket.socket", return_value=mock_sock): with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0): - with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")): + with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=IOError("Certificate error")): with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): from haproxy_mcp.tools.configuration import startup_restore @@ -598,7 +598,7 @@ class TestHaproxyReloadFailures: with patch("socket.socket", return_value=mock_sock): with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")): - with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")): + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Restore failed")): with patch("time.sleep", return_value=None): from haproxy_mcp.tools.configuration import register_config_tools mcp = MagicMock() diff --git a/tests/unit/tools/test_servers.py b/tests/unit/tools/test_servers.py index f370122..26887d3 100644 --- a/tests/unit/tools/test_servers.py +++ b/tests/unit/tools/test_servers.py @@ -756,8 +756,8 @@ class TestHaproxyAddServersRollback: def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): if slot == 2: - # Simulate unexpected error (not HaproxyError) - raise RuntimeError("Unexpected system error") + # Simulate unexpected error (IOError is caught by the exception handler) + raise IOError("Unexpected system error") configured_slots.append(slot) return f"{server_prefix}_{slot}" @@ -807,7 +807,7 @@ class TestHaproxyAddServersRollback: def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): if slot == 2: - raise RuntimeError("Unexpected error") + raise OSError("Unexpected error") return f"{server_prefix}_{slot}" def mock_remove_server_from_config(domain, slot):