diff --git a/mcp/server.py b/mcp/server.py index fe129af..e1f96d4 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -24,6 +24,7 @@ import re import json import logging import os +import select import tempfile import time import fcntl @@ -60,6 +61,9 @@ SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers POOL_COUNT: int = int(os.getenv("HAPROXY_POOL_COUNT", "100")) MAX_SLOTS: int = int(os.getenv("HAPROXY_MAX_SLOTS", "10")) +# Container configuration +HAPROXY_CONTAINER: str = os.getenv("HAPROXY_CONTAINER", "haproxy") + # Validation patterns - compiled once for performance DOMAIN_PATTERN = re.compile( r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?' @@ -77,6 +81,7 @@ SUBPROCESS_TIMEOUT = 30 # seconds STARTUP_RETRY_COUNT = 10 # HAProxy ready check retries STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection +SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop class HaproxyError(Exception): @@ -119,7 +124,7 @@ def haproxy_cmd(command: str) -> str: The response from HAProxy Raises: - HaproxyError: If connection fails or response exceeds size limit + HaproxyError: If connection fails, times out, or response exceeds size limit """ try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -127,14 +132,30 @@ def haproxy_cmd(command: str) -> str: s.connect(HAPROXY_SOCKET) s.sendall(f"{command}\n".encode()) s.shutdown(socket.SHUT_WR) + + # Set socket to non-blocking for select-based recv loop + s.setblocking(False) response = b"" + start_time = time.time() + while True: - data = s.recv(8192) - if not data: - break - response += data - if len(response) > MAX_RESPONSE_SIZE: - raise HaproxyError(f"Response exceeded {MAX_RESPONSE_SIZE} bytes limit") + # Check for overall timeout + elapsed = time.time() - start_time + if elapsed >= SOCKET_RECV_TIMEOUT: + raise HaproxyError(f"Response timeout after {SOCKET_RECV_TIMEOUT} seconds") + + # Wait for data with timeout (remaining time) + remaining = SOCKET_RECV_TIMEOUT - elapsed + ready, _, _ = select.select([s], [], [], min(remaining, 1.0)) + + if ready: + data = s.recv(8192) + if not data: + break + response += data + if len(response) > MAX_RESPONSE_SIZE: + raise HaproxyError(f"Response exceeded {MAX_RESPONSE_SIZE} bytes limit") + return response.decode().strip() except socket.timeout: raise HaproxyError("Connection timeout") @@ -156,14 +177,14 @@ def reload_haproxy() -> Tuple[bool, str]: """ try: validate = subprocess.run( - ["podman", "exec", "haproxy", "haproxy", "-c", "-f", "/usr/local/etc/haproxy/haproxy.cfg"], + ["podman", "exec", HAPROXY_CONTAINER, "haproxy", "-c", "-f", "/usr/local/etc/haproxy/haproxy.cfg"], capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT ) if validate.returncode != 0: return False, f"Config validation failed:\n{validate.stderr}" result = subprocess.run( - ["podman", "kill", "--signal", "USR2", "haproxy"], + ["podman", "kill", "--signal", "USR2", HAPROXY_CONTAINER], capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT ) if result.returncode != 0: @@ -595,11 +616,16 @@ def restore_servers_from_config() -> int: @mcp.tool() -def haproxy_list_domains() -> str: +def haproxy_list_domains(include_wildcards: bool = False) -> str: """List all configured domains with their backend servers. Shows all domains mapped in HAProxy with their pool backend and configured servers. + Args: + include_wildcards: If True, also show wildcard domain mappings (entries starting + with '.', e.g., '.example.com' which matches '*.example.com'). + Default is False to show only explicit domain mappings. + Returns: List of domains in format: domain -> pool_N (pool): server=ip:port @@ -607,6 +633,10 @@ def haproxy_list_domains() -> str: # Output: # • api.example.com -> pool_1 (pool): pool_1_1=10.0.0.1:8080, pool_1_2=10.0.0.2:8080 # • web.example.com -> pool_2 (pool): pool_2_1=10.0.0.3:80 + + # With include_wildcards=True: + # • api.example.com -> pool_1 (pool): pool_1_1=10.0.0.1:8080 + # • .api.example.com -> pool_1 (wildcard): pool_1_1=10.0.0.1:8080 """ try: domains = [] @@ -624,16 +654,22 @@ def haproxy_list_domains() -> str: f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}" ) - # Read from domains.map (skip wildcard entries starting with .) + # Read from domains.map seen_domains: Set[str] = set() for domain, backend in get_map_contents(): - if domain.startswith("."): + # Skip wildcard entries unless explicitly requested + if domain.startswith(".") and not include_wildcards: continue if domain in seen_domains: continue seen_domains.add(domain) servers = server_map.get(backend, ["(none)"]) - backend_type = "pool" if backend.startswith("pool_") else "static" + if domain.startswith("."): + backend_type = "wildcard" + elif backend.startswith("pool_"): + backend_type = "pool" + else: + backend_type = "static" domains.append(f"• {domain} -> {backend} ({backend_type}): {', '.join(servers)}") return "\n".join(domains) if domains else "No domains configured" @@ -755,8 +791,12 @@ def haproxy_remove_domain(domain: str) -> str: try: haproxy_cmd(f"set server {backend}/{server} state maint") haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0") - except HaproxyError: - pass # Ignore errors for individual servers + except HaproxyError as e: + logger.warning( + "Failed to clear server %s/%s for domain %s: %s", + backend, server, domain, e + ) + # Continue with remaining cleanup return f"Domain {domain} removed from {backend}" @@ -863,6 +903,138 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) -> return f"Error: {e}" +@mcp.tool() +def haproxy_add_servers(domain: str, servers: str) -> str: + """Add multiple servers to a domain's backend at once. + + More efficient than calling haproxy_add_server multiple times. + All servers are validated before any are added. + + Args: + domain: The domain name to add servers to + servers: JSON array of server configs. Each object can have: + - slot (required): Server slot number (1-10) + - ip (required): IP address of the server + - http_port (optional): HTTP port (default: 80) + Example: '[{"slot":1,"ip":"10.0.0.1","http_port":80},{"slot":2,"ip":"10.0.0.2"}]' + + Returns: + Summary of added servers or errors for each failed server + + Example: + haproxy_add_servers("api.example.com", '[{"slot":1,"ip":"10.0.0.1"},{"slot":2,"ip":"10.0.0.2"}]') + # Output: Added 2 servers to api.example.com (pool_1): + # • slot 1: 10.0.0.1:80 + # • slot 2: 10.0.0.2:80 + """ + if not validate_domain(domain): + return "Error: Invalid domain format" + + # Parse JSON array + try: + server_list = json.loads(servers) + except json.JSONDecodeError as e: + return f"Error: Invalid JSON - {e}" + + if not isinstance(server_list, list): + return "Error: servers must be a JSON array" + + if not server_list: + return "Error: servers array is empty" + + # Validate all servers first before adding any + validated_servers = [] + validation_errors = [] + + for i, srv in enumerate(server_list): + if not isinstance(srv, dict): + validation_errors.append(f"Server {i+1}: must be an object") + continue + + # Extract and validate slot + slot = srv.get("slot") + if slot is None: + validation_errors.append(f"Server {i+1}: missing 'slot' field") + continue + try: + slot = int(slot) + except (ValueError, TypeError): + validation_errors.append(f"Server {i+1}: slot must be an integer") + continue + if not (1 <= slot <= MAX_SLOTS): + validation_errors.append(f"Server {i+1}: slot must be between 1 and {MAX_SLOTS}") + continue + + # Extract and validate IP + ip = srv.get("ip") + if not ip: + validation_errors.append(f"Server {i+1}: missing 'ip' field") + continue + if not validate_ip(ip): + validation_errors.append(f"Server {i+1}: invalid IP address '{ip}'") + continue + + # Extract and validate port + http_port = srv.get("http_port", 80) + try: + http_port = int(http_port) + except (ValueError, TypeError): + validation_errors.append(f"Server {i+1}: http_port must be an integer") + continue + if not (1 <= http_port <= 65535): + validation_errors.append(f"Server {i+1}: port must be between 1 and 65535") + continue + + validated_servers.append({"slot": slot, "ip": ip, "http_port": http_port}) + + # Return validation errors if any + if validation_errors: + return "Validation errors:\n" + "\n".join(f" • {e}" for e in validation_errors) + + # Check for duplicate slots + slots = [s["slot"] for s in validated_servers] + if len(slots) != len(set(slots)): + return "Error: Duplicate slot numbers in servers array" + + # Get backend info + try: + backend, server_prefix = get_backend_and_prefix(domain) + except ValueError as e: + return f"Error: {e}" + + # Add all servers + added = [] + errors = [] + + for srv in validated_servers: + slot = srv["slot"] + ip = srv["ip"] + http_port = srv["http_port"] + + try: + for suffix, port in get_server_suffixes(http_port): + server = f"{server_prefix}{suffix}_{slot}" + haproxy_cmd(f"set server {backend}/{server} addr {ip} port {port}") + haproxy_cmd(f"set server {backend}/{server} state ready") + + # Save to persistent config + add_server_to_config(domain, slot, ip, http_port) + added.append(f"slot {slot}: {ip}:{http_port}") + except (HaproxyError, IOError) as e: + errors.append(f"slot {slot}: {e}") + + # Build result message + result_parts = [] + if added: + result_parts.append(f"Added {len(added)} servers to {domain} ({backend}):") + result_parts.extend(f" • {s}" for s in added) + if errors: + result_parts.append(f"Failed to add {len(errors)} servers:") + result_parts.extend(f" • {e}" for e in errors) + + return "\n".join(result_parts) if result_parts else "No servers added" + + @mcp.tool() def haproxy_remove_server(domain: str, slot: int) -> str: """Remove a server from a domain's backend at specified slot. @@ -1293,7 +1465,7 @@ def haproxy_check_config() -> str: """ try: result = subprocess.run( - ["podman", "exec", "haproxy", "haproxy", "-c", "-f", "/usr/local/etc/haproxy/haproxy.cfg"], + ["podman", "exec", HAPROXY_CONTAINER, "haproxy", "-c", "-f", "/usr/local/etc/haproxy/haproxy.cfg"], capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT ) if result.returncode == 0: