From 913ba0fdcaea0ac95f0a4a54667a61a0295f27a2 Mon Sep 17 00:00:00 2001 From: kaffa Date: Sun, 1 Feb 2026 14:15:00 +0000 Subject: [PATCH] fix: Final round of improvements HIGH PRIORITY: 1. Pool allocation race condition - Add file locking around entire pool allocation in haproxy_add_domain - Prevents concurrent calls from getting same pool 2. haproxy_remove_server - disk-first pattern - Remove from config FIRST, then update HAProxy - Rollback config on HAProxy failure 3. Wildcard domain prefix validation - Reject domains starting with '.' - Prevents double-prefix like '..domain.com' MEDIUM PRIORITY: 4. Variable shadowing fix - Rename state_output to servers_state in haproxy_set_domain_state 5. JSON size limit - Add MAX_SERVERS_JSON_SIZE = 10000 limit for haproxy_add_servers 6. Remove get_server_suffixes - Delete unused abstraction layer - Inline logic in restore_servers_from_config and haproxy_add_domain Co-Authored-By: Claude Opus 4.5 --- mcp/server.py | 191 ++++++++++++++++++++++++++------------------------ 1 file changed, 101 insertions(+), 90 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index c50bc52..25ab6c6 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -82,6 +82,7 @@ STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop MAX_BULK_SERVERS = 10 # Max servers per bulk add call +MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON in haproxy_add_servers class HaproxyError(Exception): @@ -577,16 +578,6 @@ def remove_domain_from_config(domain: str) -> None: fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) -def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]: - """Get server suffixes and ports based on port configuration. - - Args: - http_port: HTTP port for backend - - Returns: - List of (suffix, port) tuples - always HTTP only - """ - return [("", http_port)] def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str: @@ -667,19 +658,18 @@ def restore_servers_from_config() -> int: continue try: - http_port = int(server_info.get("http_port", 80)) + port = int(server_info.get("http_port", 80)) except (ValueError, TypeError): logger.warning("Invalid port for %s slot %d, skipping", domain, slot) continue - for suffix, port in get_server_suffixes(http_port): - server = f"{server_prefix}{suffix}_{slot}" - try: - haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}") - haproxy_cmd_checked(f"set server {backend}/{server} state ready") - restored += 1 - except HaproxyError as e: - logger.warning("Failed to restore %s/%s: %s", backend, server, e) + server = f"{server_prefix}_{slot}" + try: + haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}") + haproxy_cmd_checked(f"set server {backend}/{server} state ready") + restored += 1 + except HaproxyError as e: + logger.warning("Failed to restore %s/%s: %s", backend, server, e) return restored @@ -769,6 +759,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str: haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080) """ # Validate inputs + if domain.startswith("."): + return "Error: Domain cannot start with '.' (wildcard entries are added automatically)" if not validate_domain(domain): return "Error: Invalid domain format" if not validate_ip(ip, allow_empty=True): @@ -776,74 +768,80 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str: if not (1 <= http_port <= 65535): return "Error: Port must be between 1 and 65535" - # Read map contents once for both existence check and pool lookup - entries = get_map_contents() - - # Check if domain already exists (using cached entries) - for domain_entry, backend in entries: - if domain_entry == domain: - return f"Error: Domain {domain} already exists (mapped to {backend})" - - # Find available pool (using cached entries) - used_pools: Set[str] = set() - for _, backend in entries: - if backend.startswith("pool_"): - used_pools.add(backend) - - pool = None - for i in range(1, POOL_COUNT + 1): - pool_name = f"pool_{i}" - if pool_name not in used_pools: - pool = pool_name - break - if not pool: - return f"Error: All {POOL_COUNT} pool backends are in use" - - try: - # Save to disk first (atomic write for persistence) - # If HAProxy update fails after this, state will be correct on restart - # Note: We already have 'entries' from the map contents read above - entries.append((domain, pool)) - entries.append((f".{domain}", pool)) + # Use file locking for the entire pool allocation operation + lock_path = f"{MAP_FILE}.lock" + with open(lock_path, 'w') as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) try: - save_map_file(entries) - except IOError as e: - return f"Error: Failed to save map file: {e}" + # Read map contents once for both existence check and pool lookup + entries = get_map_contents() - # Then update HAProxy map via Runtime API - try: - haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}") - haproxy_cmd(f"add map {MAP_FILE_CONTAINER} .{domain} {pool}") - except HaproxyError as e: - # Rollback: remove the domain we just added from entries and re-save - rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"] - try: - save_map_file(rollback_entries) - except IOError: - logger.error("Failed to rollback map file after HAProxy error") - return f"Error: Failed to update HAProxy map: {e}" + # Check if domain already exists (using cached entries) + for domain_entry, backend in entries: + if domain_entry == domain: + return f"Error: Domain {domain} already exists (mapped to {backend})" - # If IP provided, add server to slot 1 - if ip: - # Save server config to disk first - add_server_to_config(domain, 1, ip, http_port) + # Find available pool (using cached entries) + used_pools: Set[str] = set() + for _, backend in entries: + if backend.startswith("pool_"): + used_pools.add(backend) + + pool = None + for i in range(1, POOL_COUNT + 1): + pool_name = f"pool_{i}" + if pool_name not in used_pools: + pool = pool_name + break + if not pool: + return f"Error: All {POOL_COUNT} pool backends are in use" try: - for suffix, port in get_server_suffixes(http_port): - server = f"{pool}{suffix}_1" - haproxy_cmd(f"set server {pool}/{server} addr {ip} port {port}") - haproxy_cmd(f"set server {pool}/{server} state ready") + # Save to disk first (atomic write for persistence) + # If HAProxy update fails after this, state will be correct on restart + # Note: We already have 'entries' from the map contents read above + entries.append((domain, pool)) + entries.append((f".{domain}", pool)) + try: + save_map_file(entries) + except IOError as e: + return f"Error: Failed to save map file: {e}" + + # Then update HAProxy map via Runtime API + try: + haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}") + haproxy_cmd(f"add map {MAP_FILE_CONTAINER} .{domain} {pool}") + except HaproxyError as e: + # Rollback: remove the domain we just added from entries and re-save + rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"] + try: + save_map_file(rollback_entries) + except IOError: + logger.error("Failed to rollback map file after HAProxy error") + return f"Error: Failed to update HAProxy map: {e}" + + # If IP provided, add server to slot 1 + if ip: + # Save server config to disk first + add_server_to_config(domain, 1, ip, http_port) + + try: + server = f"{pool}_1" + haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}") + haproxy_cmd(f"set server {pool}/{server} state ready") + except HaproxyError as e: + # Rollback server config on failure + remove_server_from_config(domain, 1) + return f"Domain {domain} added to {pool} but server config failed: {e}" + + return f"Domain {domain} added to {pool} with server {ip}:{http_port}" + + return f"Domain {domain} added to {pool} (no servers configured)" + except HaproxyError as e: - # Rollback server config on failure - remove_server_from_config(domain, 1) - return f"Domain {domain} added to {pool} but server config failed: {e}" - - return f"Domain {domain} added to {pool} with server {ip}:{http_port}" - - return f"Domain {domain} added to {pool} (no servers configured)" - - except HaproxyError as e: - return f"Error: {e}" + return f"Error: {e}" + finally: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) @mcp.tool() @@ -1053,6 +1051,10 @@ def haproxy_add_servers(domain: str, servers: str) -> str: if not validate_domain(domain): return "Error: Invalid domain format" + # Check JSON size before parsing + if len(servers) > MAX_SERVERS_JSON_SIZE: + return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)" + # Parse JSON array try: server_list = json.loads(servers) @@ -1199,16 +1201,25 @@ def haproxy_remove_server(domain: str, slot: int) -> str: try: backend, server_prefix = get_backend_and_prefix(domain) - # HTTP only - single server per slot - server = f"{server_prefix}_{slot}" - haproxy_cmd_checked(f"set server {backend}/{server} state maint") - haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0") + # Get current server info for potential rollback + config = load_servers_config() + old_config = config.get(domain, {}).get(str(slot), {}) - # Remove from persistent config + # Remove from persistent config FIRST (disk-first pattern) remove_server_from_config(domain, slot) - return f"Removed server at slot {slot} from {domain} ({backend})" - except (HaproxyError, ValueError, IOError) as e: + try: + # HTTP only - single server per slot + server = f"{server_prefix}_{slot}" + haproxy_cmd_checked(f"set server {backend}/{server} state maint") + haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0") + return f"Removed server at slot {slot} from {domain} ({backend})" + except HaproxyError as e: + # Rollback: re-add config if HAProxy command failed + if old_config: + add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80)) + return f"Error: {e}" + except (ValueError, IOError) as e: return f"Error: {e}" @@ -1245,14 +1256,14 @@ def haproxy_set_domain_state(domain: str, state: str) -> str: # Get active servers for this domain try: - state_output = haproxy_cmd("show servers state") + servers_state = haproxy_cmd("show servers state") except HaproxyError as e: return f"Error: {e}" changed = [] errors = [] - for line in state_output.split("\n"): + for line in servers_state.split("\n"): parts = line.split() if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: server_name = parts[StateField.SRV_NAME]