fix: Final round of improvements
HIGH PRIORITY: 1. Pool allocation race condition - Add file locking around entire pool allocation in haproxy_add_domain - Prevents concurrent calls from getting same pool 2. haproxy_remove_server - disk-first pattern - Remove from config FIRST, then update HAProxy - Rollback config on HAProxy failure 3. Wildcard domain prefix validation - Reject domains starting with '.' - Prevents double-prefix like '..domain.com' MEDIUM PRIORITY: 4. Variable shadowing fix - Rename state_output to servers_state in haproxy_set_domain_state 5. JSON size limit - Add MAX_SERVERS_JSON_SIZE = 10000 limit for haproxy_add_servers 6. Remove get_server_suffixes - Delete unused abstraction layer - Inline logic in restore_servers_from_config and haproxy_add_domain Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -82,6 +82,7 @@ STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output
|
|||||||
SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection
|
SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection
|
||||||
SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop
|
SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop
|
||||||
MAX_BULK_SERVERS = 10 # Max servers per bulk add call
|
MAX_BULK_SERVERS = 10 # Max servers per bulk add call
|
||||||
|
MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON in haproxy_add_servers
|
||||||
|
|
||||||
|
|
||||||
class HaproxyError(Exception):
|
class HaproxyError(Exception):
|
||||||
@@ -577,16 +578,6 @@ def remove_domain_from_config(domain: str) -> None:
|
|||||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||||
|
|
||||||
|
|
||||||
def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
|
|
||||||
"""Get server suffixes and ports based on port configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
http_port: HTTP port for backend
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of (suffix, port) tuples - always HTTP only
|
|
||||||
"""
|
|
||||||
return [("", http_port)]
|
|
||||||
|
|
||||||
|
|
||||||
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
|
def configure_server_slot(backend: str, server_prefix: str, slot: int, ip: str, http_port: int) -> str:
|
||||||
@@ -667,13 +658,12 @@ def restore_servers_from_config() -> int:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
http_port = int(server_info.get("http_port", 80))
|
port = int(server_info.get("http_port", 80))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
|
logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for suffix, port in get_server_suffixes(http_port):
|
server = f"{server_prefix}_{slot}"
|
||||||
server = f"{server_prefix}{suffix}_{slot}"
|
|
||||||
try:
|
try:
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}")
|
haproxy_cmd_checked(f"set server {backend}/{server} addr {ip} port {port}")
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server} state ready")
|
haproxy_cmd_checked(f"set server {backend}/{server} state ready")
|
||||||
@@ -769,6 +759,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
|||||||
haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
|
haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
|
||||||
"""
|
"""
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
|
if domain.startswith("."):
|
||||||
|
return "Error: Domain cannot start with '.' (wildcard entries are added automatically)"
|
||||||
if not validate_domain(domain):
|
if not validate_domain(domain):
|
||||||
return "Error: Invalid domain format"
|
return "Error: Invalid domain format"
|
||||||
if not validate_ip(ip, allow_empty=True):
|
if not validate_ip(ip, allow_empty=True):
|
||||||
@@ -776,6 +768,11 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
|||||||
if not (1 <= http_port <= 65535):
|
if not (1 <= http_port <= 65535):
|
||||||
return "Error: Port must be between 1 and 65535"
|
return "Error: Port must be between 1 and 65535"
|
||||||
|
|
||||||
|
# Use file locking for the entire pool allocation operation
|
||||||
|
lock_path = f"{MAP_FILE}.lock"
|
||||||
|
with open(lock_path, 'w') as lock_file:
|
||||||
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||||
|
try:
|
||||||
# Read map contents once for both existence check and pool lookup
|
# Read map contents once for both existence check and pool lookup
|
||||||
entries = get_map_contents()
|
entries = get_map_contents()
|
||||||
|
|
||||||
@@ -829,9 +826,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
|||||||
add_server_to_config(domain, 1, ip, http_port)
|
add_server_to_config(domain, 1, ip, http_port)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for suffix, port in get_server_suffixes(http_port):
|
server = f"{pool}_1"
|
||||||
server = f"{pool}{suffix}_1"
|
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
||||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {port}")
|
|
||||||
haproxy_cmd(f"set server {pool}/{server} state ready")
|
haproxy_cmd(f"set server {pool}/{server} state ready")
|
||||||
except HaproxyError as e:
|
except HaproxyError as e:
|
||||||
# Rollback server config on failure
|
# Rollback server config on failure
|
||||||
@@ -844,6 +840,8 @@ def haproxy_add_domain(domain: str, ip: str = "", http_port: int = 80) -> str:
|
|||||||
|
|
||||||
except HaproxyError as e:
|
except HaproxyError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
finally:
|
||||||
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
@@ -1053,6 +1051,10 @@ def haproxy_add_servers(domain: str, servers: str) -> str:
|
|||||||
if not validate_domain(domain):
|
if not validate_domain(domain):
|
||||||
return "Error: Invalid domain format"
|
return "Error: Invalid domain format"
|
||||||
|
|
||||||
|
# Check JSON size before parsing
|
||||||
|
if len(servers) > MAX_SERVERS_JSON_SIZE:
|
||||||
|
return f"Error: servers JSON exceeds maximum size ({MAX_SERVERS_JSON_SIZE} bytes)"
|
||||||
|
|
||||||
# Parse JSON array
|
# Parse JSON array
|
||||||
try:
|
try:
|
||||||
server_list = json.loads(servers)
|
server_list = json.loads(servers)
|
||||||
@@ -1199,16 +1201,25 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
|
|||||||
try:
|
try:
|
||||||
backend, server_prefix = get_backend_and_prefix(domain)
|
backend, server_prefix = get_backend_and_prefix(domain)
|
||||||
|
|
||||||
|
# Get current server info for potential rollback
|
||||||
|
config = load_servers_config()
|
||||||
|
old_config = config.get(domain, {}).get(str(slot), {})
|
||||||
|
|
||||||
|
# Remove from persistent config FIRST (disk-first pattern)
|
||||||
|
remove_server_from_config(domain, slot)
|
||||||
|
|
||||||
|
try:
|
||||||
# HTTP only - single server per slot
|
# HTTP only - single server per slot
|
||||||
server = f"{server_prefix}_{slot}"
|
server = f"{server_prefix}_{slot}"
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server} state maint")
|
haproxy_cmd_checked(f"set server {backend}/{server} state maint")
|
||||||
haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0")
|
haproxy_cmd_checked(f"set server {backend}/{server} addr 0.0.0.0 port 0")
|
||||||
|
|
||||||
# Remove from persistent config
|
|
||||||
remove_server_from_config(domain, slot)
|
|
||||||
|
|
||||||
return f"Removed server at slot {slot} from {domain} ({backend})"
|
return f"Removed server at slot {slot} from {domain} ({backend})"
|
||||||
except (HaproxyError, ValueError, IOError) as e:
|
except HaproxyError as e:
|
||||||
|
# Rollback: re-add config if HAProxy command failed
|
||||||
|
if old_config:
|
||||||
|
add_server_to_config(domain, slot, old_config.get("ip", ""), old_config.get("http_port", 80))
|
||||||
|
return f"Error: {e}"
|
||||||
|
except (ValueError, IOError) as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
@@ -1245,14 +1256,14 @@ def haproxy_set_domain_state(domain: str, state: str) -> str:
|
|||||||
|
|
||||||
# Get active servers for this domain
|
# Get active servers for this domain
|
||||||
try:
|
try:
|
||||||
state_output = haproxy_cmd("show servers state")
|
servers_state = haproxy_cmd("show servers state")
|
||||||
except HaproxyError as e:
|
except HaproxyError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
changed = []
|
changed = []
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
for line in state_output.split("\n"):
|
for line in servers_state.split("\n"):
|
||||||
parts = line.split()
|
parts = line.split()
|
||||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
||||||
server_name = parts[StateField.SRV_NAME]
|
server_name = parts[StateField.SRV_NAME]
|
||||||
|
|||||||
Reference in New Issue
Block a user