fix: Final round of improvements

HIGH PRIORITY:
1. Pool allocation race condition
   - Add file locking around entire pool allocation in haproxy_add_domain
   - Prevents concurrent calls from getting same pool

2. haproxy_remove_server - disk-first pattern
   - Remove from config FIRST, then update HAProxy
   - Rollback config on HAProxy failure

3. Wildcard domain prefix validation
   - Reject domains starting with '.'
   - Prevents double-prefix like '..domain.com'

MEDIUM PRIORITY:
4. Variable shadowing fix
   - Rename state_output to servers_state in haproxy_set_domain_state

5. JSON size limit
   - Add MAX_SERVERS_JSON_SIZE = 10000 limit for haproxy_add_servers

6. Remove get_server_suffixes
   - Delete unused abstraction layer
   - Inline logic in restore_servers_from_config and haproxy_add_domain

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-01 14:15:00 +00:00
parent bdc1f8a279
commit 913ba0fdca

View File

@@ -82,6 +82,7 @@ STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output
SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection
SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop
MAX_BULK_SERVERS = 10 # Max servers per bulk add call MAX_BULK_SERVERS = 10 # Max servers per bulk add call
MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON in haproxy_add_servers
class HaproxyError(Exception): class HaproxyError(Exception):
@@ -577,16 +578,6 @@ def remove_domain_from_config(domain: str) -> None:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
"""Get server suffixes and ports based on port configuration.
Args:
http_port: HTTP port for backend
Returns:
List of (suffix, port) tuples - always HTTP only
"""
return [("", http_port)]
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str: def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
@@ -667,13 +658,12 @@ def restore_servers_from_config() -> int:
continue continue
try: try:
http_port = int(server_info.get("http_port", 80)) port = int(server_info.get("http_port", 80))
except (ValueError, TypeError): except (ValueError, TypeError):
logger.warning("Invalid port for %s slot %d, skipping", domain, slot) logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
continue continue
for suffix, port in get_server_suffixes(http_port): server = f"{server_prefix}_{slot}"
server = f"{server_prefix}{suffix}_{slot}"
try: try:
haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}") haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}")
haproxy_cmd_checked(f"set server {backend}/{server} state ready") haproxy_cmd_checked(f"set server {backend}/{server} state ready")
@@ -769,6 +759,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080) haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
""" """
# Validate inputs # Validate inputs
if domain.startswith("."):
return "Error: Domain cannot start with '.' (wildcard entries are added automatically)"
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
if not validate_ip(ip, allow_empty=True): if not validate_ip(ip, allow_empty=True):
@@ -776,6 +768,11 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
if not (1 <= http_port <= 65535): if not (1 <= http_port <= 65535):
return "Error: Port must be between 1 and 65535" return "Error: Port must be between 1 and 65535"
# Use file locking for the entire pool allocation operation
lock_path = f"{MAP_FILE}.lock"
with open(lock_path, 'w') as lock_file:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
# Read map contents once for both existence check and pool lookup # Read map contents once for both existence check and pool lookup
entries = get_map_contents() entries = get_map_contents()
@@ -829,9 +826,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
add_server_to_config(domain, 1, ip, http_port) add_server_to_config(domain, 1, ip, http_port)
try: try:
for suffix, port in get_server_suffixes(http_port): server = f"{pool}_1"
server = f"{pool}{suffix}_1" haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {port}")
haproxy_cmd(f"set server {pool}/{server} state ready") haproxy_cmd(f"set server {pool}/{server} state ready")
except HaproxyError as e: except HaproxyError as e:
# Rollback server config on failure # Rollback server config on failure
@@ -844,6 +840,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
finally:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
@mcp.tool() @mcp.tool()
@@ -1053,6 +1051,10 @@ def haproxy_add_servers(domain: str, servers: str) -> str:
if not validate_domain(domain): if not validate_domain(domain):
return "Error: Invalid domain format" return "Error: Invalid domain format"
# Check JSON size before parsing
if len(servers) > MAX_SERVERS_JSON_SIZE:
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
# Parse JSON array # Parse JSON array
try: try:
server_list = json.loads(servers) server_list = json.loads(servers)
@@ -1199,16 +1201,25 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
try: try:
backend, server_prefix = get_backend_and_prefix(domain) backend, server_prefix = get_backend_and_prefix(domain)
# Get current server info for potential rollback
config = load_servers_config()
old_config = config.get(domain, {}).get(str(slot), {})
# Remove from persistent config FIRST (disk-first pattern)
remove_server_from_config(domain, slot)
try:
# HTTP only - single server per slot # HTTP only - single server per slot
server = f"{server_prefix}_{slot}" server = f"{server_prefix}_{slot}"
haproxy_cmd_checked(f"set server {backend}/{server} state maint") haproxy_cmd_checked(f"set server {backend}/{server} state maint")
haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0") haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0")
# Remove from persistent config
remove_server_from_config(domain, slot)
return f"Removed server at slot {slot} from {domain} ({backend})" return f"Removed server at slot {slot} from {domain} ({backend})"
except (HaproxyError, ValueError, IOError) as e: except HaproxyError as e:
# Rollback: re-add config if HAProxy command failed
if old_config:
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
return f"Error: {e}"
except (ValueError, IOError) as e:
return f"Error: {e}" return f"Error: {e}"
@@ -1245,14 +1256,14 @@ def haproxy_set_domain_state(domain: str, state: str) -> str:
# Get active servers for this domain # Get active servers for this domain
try: try:
state_output = haproxy_cmd("show servers state") servers_state = haproxy_cmd("show servers state")
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
changed = [] changed = []
errors = [] errors = []
for line in state_output.split("\n"): for line in servers_state.split("\n"):
parts = line.split() parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
server_name = parts[StateField.SRV_NAME] server_name = parts[StateField.SRV_NAME]