fix: Final round of improvements

HIGH PRIORITY:
1. Pool allocation race condition
   - Add file locking around entire pool allocation in haproxy_add_domain
   - Prevents concurrent calls from getting same pool

2. haproxy_remove_server - disk-first pattern
   - Remove from config FIRST, then update HAProxy
   - Rollback config on HAProxy failure

3. Wildcard domain prefix validation
   - Reject domains starting with '.'
   - Prevents double-prefix like '..domain.com'

MEDIUM PRIORITY:
4. Variable shadowing fix
   - Rename state_output to servers_state in haproxy_set_domain_state

5. JSON size limit
   - Add MAX_SERVERS_JSON_SIZE = 10000 limit for haproxy_add_servers

6. Remove get_server_suffixes
   - Delete unused abstraction layer
   - Inline logic in restore_servers_from_config and haproxy_add_domain

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-01 14:15:00 +00:00
parent bdc1f8a279
commit 913ba0fdca

View File

@@ -82,6 +82,7 @@ STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output
SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection
SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop
MAX_BULK_SERVERS = 10 # Max servers per bulk add call MAX_BULK_SERVERS = 10 # Max servers per bulk add call
MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON in haproxy_add_servers
class HaproxyError(Exception): class HaproxyError(Exception):
@@ -577,16 +578,6 @@ def remove_domain_from_config(domain: str) -> None:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
"""Get server suffixes and ports based on port configuration.
Args:
http_port: HTTP port for backend
Returns:
List of (suffix, port) tuples - always HTTP only
"""
return [("", http_port)]
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str: def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
@@ -667,19 +658,18 @@ def restore_servers_from_config() -> int:
continue continue
try: try:
http_port = int(server_info.get("http_port", 80)) port = int(server_info.get("http_port", 80))
except (ValueError, TypeError): except (ValueError, TypeError):
logger.warning("Invalid port for %s slot %d, skipping", domain, slot) logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
continue continue
for suffix, port in get_server_suffixes(http_port): server = f"{server_prefix}_{slot}"
server = f"{server_prefix}{suffix}_{slot}" try:
try: haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}")
haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}") haproxy_cmd_checked(f"set server {backend}/{server} state ready")
haproxy_cmd_checked(f"set server {backend}/{server} state ready") restored += 1
restored += 1 except HaproxyError as e:
except HaproxyError as e: logger.warning("Failed to restore %s/%s: %s", backend, server, e)
logger.warning("Failed to restore %s/%s: %s", backend, server, e)
return restored return restored
@@ -769,6 +759,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080) haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
""" """
# Validate inputs # Validate inputs
if domain.startswith("."):
return "Error: Domain cannot start with '.' (wildcard entries are added automatically)"
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not validate_ip(ip, allow_empty=True): if not validate_ip(ip, allow_empty=True):
@@ -776,74 +768,80 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
if not (1 <= http_port <= 65535): if not (1 <= http_port <= 65535):
return "Error: Port must be between 1 and 65535" return "Error: Port must be between 1 and 65535"
# Read map contents once for both existence check and pool lookup # Use file locking for the entire pool allocation operation
entries = get_map_contents() lock_path = f"{MAP_FILE}.lock"
with open(lock_path, 'w') as lock_file:
# Check if domain already exists (using cached entries) fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
for domain_entry, backend in entries:
if domain_entry == domain:
return f"Error: Domain {domain} already exists (mapped to {backend})"
# Find available pool (using cached entries)
used_pools: Set[str] = set()
for _, backend in entries:
if backend.startswith("pool_"):
used_pools.add(backend)
pool = None
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
pool = pool_name
break
if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use"
try:
# Save to disk first (atomic write for persistence)
# If HAProxy update fails after this, state will be correct on restart
# Note: We already have 'entries' from the map contents read above
entries.append((domain, pool))
entries.append((f".{domain}", pool))
try: try:
save_map_file(entries) # Read map contents once for both existence check and pool lookup
except IOError as e: entries = get_map_contents()
return f"Error: Failed to save map file: {e}"
# Then update HAProxy map via Runtime API # Check if domain already exists (using cached entries)
try: for domain_entry, backend in entries:
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}") if domain_entry == domain:
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} .{domain} {pool}") return f"Error: Domain {domain} already exists (mapped to {backend})"
except HaproxyError as e:
# Rollback: remove the domain we just added from entries and re-save
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
try:
save_map_file(rollback_entries)
except IOError:
logger.error("Failed to rollback map file after HAProxy error")
return f"Error: Failed to update HAProxy map: {e}"
# If IP provided, add server to slot 1 # Find available pool (using cached entries)
if ip: used_pools: Set[str] = set()
# Save server config to disk first for _, backend in entries:
add_server_to_config(domain, 1, ip, http_port) if backend.startswith("pool_"):
used_pools.add(backend)
pool = None
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
pool = pool_name
break
if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use"
try: try:
for suffix, port in get_server_suffixes(http_port): # Save to disk first (atomic write for persistence)
server = f"{pool}{suffix}_1" # If HAProxy update fails after this, state will be correct on restart
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {port}") # Note: We already have 'entries' from the map contents read above
haproxy_cmd(f"set server {pool}/{server} state ready") entries.append((domain, pool))
entries.append((f".{domain}", pool))
try:
save_map_file(entries)
except IOError as e:
return f"Error: Failed to save map file: {e}"
# Then update HAProxy map via Runtime API
try:
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}")
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} .{domain} {pool}")
except HaproxyError as e:
# Rollback: remove the domain we just added from entries and re-save
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
try:
save_map_file(rollback_entries)
except IOError:
logger.error("Failed to rollback map file after HAProxy error")
return f"Error: Failed to update HAProxy map: {e}"
# If IP provided, add server to slot 1
if ip:
# Save server config to disk first
add_server_to_config(domain, 1, ip, http_port)
try:
server = f"{pool}_1"
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
haproxy_cmd(f"set server {pool}/{server} state ready")
except HaproxyError as e:
# Rollback server config on failure
remove_server_from_config(domain, 1)
return f"Domain {domain} added to {pool} but server config failed: {e}"
return f"Domain {domain} added to {pool} with server {ip}:{http_port}"
return f"Domain {domain} added to {pool} (no servers configured)"
except HaproxyError as e: except HaproxyError as e:
# Rollback server config on failure return f"Error: {e}"
remove_server_from_config(domain, 1) finally:
return f"Domain {domain} added to {pool} but server config failed: {e}" fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
return f"Domain {domain} added to {pool} with server {ip}:{http_port}"
return f"Domain {domain} added to {pool} (no servers configured)"
except HaproxyError as e:
return f"Error: {e}"
@mcp.tool() @mcp.tool()
@@ -1053,6 +1051,10 @@ def haproxy_add_servers(domain: str, servers: str) -> str:
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
# Check JSON size before parsing
if len(servers) > MAX_SERVERS_JSON_SIZE:
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
# Parse JSON array # Parse JSON array
try: try:
server_list = json.loads(servers) server_list = json.loads(servers)
@@ -1199,16 +1201,25 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
try: try:
backend, server_prefix = get_backend_and_prefix(domain) backend, server_prefix = get_backend_and_prefix(domain)
# HTTP only - single server per slot # Get current server info for potential rollback
server = f"{server_prefix}_{slot}" config = load_servers_config()
haproxy_cmd_checked(f"set server {backend}/{server} state maint") old_config = config.get(domain, {}).get(str(slot), {})
haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0")
# Remove from persistent config # Remove from persistent config FIRST (disk-first pattern)
remove_server_from_config(domain, slot) remove_server_from_config(domain, slot)
return f"Removed server at slot {slot} from {domain} ({backend})" try:
except (HaproxyError, ValueError, IOError) as e: # HTTP only - single server per slot
server = f"{server_prefix}_{slot}"
haproxy_cmd_checked(f"set server {backend}/{server} state maint")
haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0")
return f"Removed server at slot {slot} from {domain} ({backend})"
except HaproxyError as e:
# Rollback: re-add config if HAProxy command failed
if old_config:
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
return f"Error: {e}"
except (ValueError, IOError) as e:
return f"Error: {e}" return f"Error: {e}"
@@ -1245,14 +1256,14 @@ def haproxy_set_domain_state(domain: str, state: str) -> str:
# Get active servers for this domain # Get active servers for this domain
try: try:
state_output = haproxy_cmd("show servers state") servers_state = haproxy_cmd("show servers state")
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
changed = [] changed = []
errors = [] errors = []
for line in state_output.split("\n"): for line in servers_state.split("\n"):
parts = line.split() parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME] server_name = parts[StateField.SRV_NAME]