fix: Final round of improvements
HIGH PRIORITY: 1. Pool allocation race condition - Add file locking around entire pool allocation in haproxy_add_domain - Prevents concurrent calls from getting same pool 2. haproxy_remove_server - disk-first pattern - Remove from config FIRST, then update HAProxy - Rollback config on HAProxy failure 3. Wildcard domain prefix validation - Reject domains starting with '.' - Prevents double-prefix like '..domain.com' MEDIUM PRIORITY: 4. Variable shadowing fix - Rename state_output to servers_state in haproxy_set_domain_state 5. JSON size limit - Add MAX_SERVERS_JSON_SIZE = 10000 limit for haproxy_add_servers 6. Remove get_server_suffixes - Delete unused abstraction layer - Inline logic in restore_servers_from_config and haproxy_add_domain Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -82,6 +82,7 @@ STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output
|
||||
SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection
|
||||
SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop
|
||||
MAX_BULK_SERVERS = 10 # Max servers per bulk add call
|
||||
MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON in haproxy_add_servers
|
||||
|
||||
|
||||
class HaproxyError(Exception):
|
||||
@@ -577,16 +578,6 @@ def remove_domain_from_config(domain: str) -> None:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
|
||||
"""Get server suffixes and ports based on port configuration.
|
||||
|
||||
Args:
|
||||
http_port: HTTP port for backend
|
||||
|
||||
Returns:
|
||||
List of (suffix, port) tuples - always HTTP only
|
||||
"""
|
||||
return [("", http_port)]
|
||||
|
||||
|
||||
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
|
||||
@@ -667,13 +658,12 @@ def restore_servers_from_config() -> int:
|
||||
continue
|
||||
|
||||
try:
|
||||
http_port = int(server_info.get("http_port", 80))
|
||||
port = int(server_info.get("http_port", 80))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
|
||||
continue
|
||||
|
||||
for suffix, port in get_server_suffixes(http_port):
|
||||
server = f"{server_prefix}{suffix}_{slot}"
|
||||
server = f"{server_prefix}_{slot}"
|
||||
try:
|
||||
haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}")
|
||||
haproxy_cmd_checked(f"set server {backend}/{server} state ready")
|
||||
@@ -769,6 +759,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
||||
haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
|
||||
"""
|
||||
# Validate inputs
|
||||
if domain.startswith("."):
|
||||
return "Error: Domain cannot start with '.' (wildcard entries are added automatically)"
|
||||
if not validate_domain(domain):
|
||||
return "Error: Invalid domain format"
|
||||
if not validate_ip(ip, allow_empty=True):
|
||||
@@ -776,6 +768,11 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
||||
if not (1 <= http_port <= 65535):
|
||||
return "Error: Port must be between 1 and 65535"
|
||||
|
||||
# Use file locking for the entire pool allocation operation
|
||||
lock_path = f"{MAP_FILE}.lock"
|
||||
with open(lock_path, 'w') as lock_file:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
try:
|
||||
# Read map contents once for both existence check and pool lookup
|
||||
entries = get_map_contents()
|
||||
|
||||
@@ -829,9 +826,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
||||
add_server_to_config(domain, 1, ip, http_port)
|
||||
|
||||
try:
|
||||
for suffix, port in get_server_suffixes(http_port):
|
||||
server = f"{pool}{suffix}_1"
|
||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {port}")
|
||||
server = f"{pool}_1"
|
||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
||||
haproxy_cmd(f"set server {pool}/{server} state ready")
|
||||
except HaproxyError as e:
|
||||
# Rollback server config on failure
|
||||
@@ -844,6 +840,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
||||
|
||||
except HaproxyError as e:
|
||||
return f"Error: {e}"
|
||||
finally:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1053,6 +1051,10 @@ def haproxy_add_servers(domain: str, servers: str) -> str:
|
||||
if not validate_domain(domain):
|
||||
return "Error: Invalid domain format"
|
||||
|
||||
# Check JSON size before parsing
|
||||
if len(servers) > MAX_SERVERS_JSON_SIZE:
|
||||
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
|
||||
|
||||
# Parse JSON array
|
||||
try:
|
||||
server_list = json.loads(servers)
|
||||
@@ -1199,16 +1201,25 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
|
||||
try:
|
||||
backend, server_prefix = get_backend_and_prefix(domain)
|
||||
|
||||
# Get current server info for potential rollback
|
||||
config = load_servers_config()
|
||||
old_config = config.get(domain, {}).get(str(slot), {})
|
||||
|
||||
# Remove from persistent config FIRST (disk-first pattern)
|
||||
remove_server_from_config(domain, slot)
|
||||
|
||||
try:
|
||||
# HTTP only - single server per slot
|
||||
server = f"{server_prefix}_{slot}"
|
||||
haproxy_cmd_checked(f"set server {backend}/{server} state maint")
|
||||
haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0")
|
||||
|
||||
# Remove from persistent config
|
||||
remove_server_from_config(domain, slot)
|
||||
|
||||
return f"Removed server at slot {slot} from {domain} ({backend})"
|
||||
except (HaproxyError, ValueError, IOError) as e:
|
||||
except HaproxyError as e:
|
||||
# Rollback: re-add config if HAProxy command failed
|
||||
if old_config:
|
||||
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
|
||||
return f"Error: {e}"
|
||||
except (ValueError, IOError) as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
@@ -1245,14 +1256,14 @@ def haproxy_set_domain_state(domain: str, state: str) -> str:
|
||||
|
||||
# Get active servers for this domain
|
||||
try:
|
||||
state_output = haproxy_cmd("show servers state")
|
||||
servers_state = haproxy_cmd("show servers state")
|
||||
except HaproxyError as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
changed = []
|
||||
errors = []
|
||||
|
||||
for line in state_output.split("\n"):
|
||||
for line in servers_state.split("\n"):
|
||||
parts = line.split()
|
||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
||||
server_name = parts[StateField.SRV_NAME]
|
||||
|
||||
Reference in New Issue
Block a user