Improve code quality based on code review

Major improvements:
- Atomic file writes using temp file + rename pattern
- Structured logging with logging module (replaces print)
- StateField class for HAProxy state field indices
- Helper function get_backend_and_prefix() to reduce duplication
- Consistent exception chaining with 'from e'
- Proper fd/temp_path tracking to prevent resource leaks
- Added IOError handling in server management functions

Technical changes:
- save_map_file, save_servers_config, haproxy_save_state now use
  atomic writes with tempfile.mkstemp() + os.rename()
- Standardized on 'set server state ready' (was 'enable server')
- All magic numbers for state parsing replaced with StateField class

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-01 12:48:49 +00:00
parent 196374e70c
commit 61dd4a69fc

View File

@@ -10,12 +10,22 @@ import socket
import subprocess import subprocess
import re import re
import json import json
import sys import logging
import os
import tempfile
import time import time
import fcntl import fcntl
from typing import Dict, Generator, List, Optional, Set, Tuple from typing import Any, Dict, Generator, List, Optional, Set, Tuple
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
# Configure structured logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
mcp = FastMCP("haproxy", host="0.0.0.0", port=8000) mcp = FastMCP("haproxy", host="0.0.0.0", port=8000)
# Constants # Constants
@@ -63,6 +73,19 @@ class StatField:
CHECK_STATUS = 36 # Check status CHECK_STATUS = 36 # Check status
# Field indices for HAProxy server state (show servers state command)
class StateField:
"""HAProxy server state field indices."""
BE_ID = 0 # Backend ID
BE_NAME = 1 # Backend name
SRV_ID = 2 # Server ID
SRV_NAME = 3 # Server name
SRV_ADDR = 4 # Server address
SRV_OP_STATE = 5 # Operational state
SRV_ADMIN_STATE = 6 # Admin state
SRV_PORT = 18 # Server port
def haproxy_cmd(command: str) -> str: def haproxy_cmd(command: str) -> str:
"""Send command to HAProxy Runtime API. """Send command to HAProxy Runtime API.
@@ -99,7 +122,7 @@ def haproxy_cmd(command: str) -> str:
except HaproxyError: except HaproxyError:
raise raise
except Exception as e: except Exception as e:
raise HaproxyError(str(e)) raise HaproxyError(str(e)) from e
def reload_haproxy() -> Tuple[bool, str]: def reload_haproxy() -> Tuple[bool, str]:
@@ -323,7 +346,9 @@ def get_legacy_backend_name(domain: str) -> str:
def save_map_file(entries: List[Tuple[str, str]]) -> None: def save_map_file(entries: List[Tuple[str, str]]) -> None:
"""Save entries to domains.map file with file locking. """Save entries to domains.map file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args: Args:
entries: List of (domain, backend) tuples to write entries: List of (domain, backend) tuples to write
@@ -331,25 +356,36 @@ def save_map_file(entries: List[Tuple[str, str]]) -> None:
Raises: Raises:
IOError: If the file cannot be written IOError: If the file cannot be written
""" """
with open(MAP_FILE, "w", encoding="utf-8") as f: dir_path = os.path.dirname(MAP_FILE)
try: fd = None
fcntl.flock(f.fileno(), fcntl.LOCK_EX) temp_path = None
except OSError:
pass # Continue without lock if not supported
try: try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.domains.map.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
f.write("# Domain to Backend mapping\n") f.write("# Domain to Backend mapping\n")
f.write("# Format: domain backend_name\n") f.write("# Format: domain backend_name\n")
f.write("# Wildcard: .domain.com matches *.domain.com\n\n") f.write("# Wildcard: .domain.com matches *.domain.com\n\n")
for domain, backend in entries: for domain, backend in entries:
f.write(f"{domain} {backend}\n") f.write(f"{domain} {backend}\n")
os.rename(temp_path, MAP_FILE)
temp_path = None # Rename succeeded, don't unlink
except OSError as e:
raise IOError(f"Failed to save map file: {e}") from e
finally: finally:
if fd is not None:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError: except OSError:
pass pass
def load_servers_config() -> Dict: def load_servers_config() -> Dict[str, Any]:
"""Load servers configuration from JSON file with file locking. """Load servers configuration from JSON file with file locking.
Returns: Returns:
@@ -360,7 +396,7 @@ def load_servers_config() -> Dict:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH) fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError: except OSError:
pass # Continue without lock if not supported logger.debug("File locking not supported for %s", SERVERS_FILE)
try: try:
return json.load(f) return json.load(f)
finally: finally:
@@ -371,26 +407,39 @@ def load_servers_config() -> Dict:
except FileNotFoundError: except FileNotFoundError:
return {} return {}
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"Warning: Corrupt config file {SERVERS_FILE}: {e}", file=sys.stderr) logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
return {} return {}
def save_servers_config(config: Dict) -> None: def save_servers_config(config: Dict[str, Any]) -> None:
"""Save servers configuration to JSON file with file locking. """Save servers configuration to JSON file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args: Args:
config: Dictionary with server configurations config: Dictionary with server configurations
""" """
with open(SERVERS_FILE, "w", encoding="utf-8") as f: dir_path = os.path.dirname(SERVERS_FILE)
try: fd = None
fcntl.flock(f.fileno(), fcntl.LOCK_EX) temp_path = None
except OSError:
pass # Continue without lock if not supported
try: try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.json.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
json.dump(config, f, indent=2) json.dump(config, f, indent=2)
os.rename(temp_path, SERVERS_FILE)
temp_path = None # Rename succeeded, don't unlink
except OSError as e:
raise IOError(f"Failed to save servers config: {e}") from e
finally: finally:
if fd is not None:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError: except OSError:
pass pass
@@ -453,6 +502,30 @@ def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
return [("", http_port)] return [("", http_port)]
def get_backend_and_prefix(domain: str) -> Tuple[str, str]:
"""Look up backend and determine server name prefix for a domain.
Args:
domain: The domain name to look up
Returns:
Tuple of (backend_name, server_prefix)
Raises:
ValueError: If domain cannot be mapped to a valid backend
"""
backend = get_domain_backend(domain)
if not backend:
backend = get_legacy_backend_name(domain)
if backend.startswith("pool_"):
server_prefix = backend
else:
server_prefix = domain_to_backend(domain)
return backend, server_prefix
def restore_servers_from_config() -> int: def restore_servers_from_config() -> int:
"""Restore all servers from configuration file. """Restore all servers from configuration file.
@@ -468,19 +541,16 @@ def restore_servers_from_config() -> int:
continue continue
try: try:
if backend.startswith("pool_"): _, server_prefix = get_backend_and_prefix(domain)
server_prefix = backend
else:
server_prefix = domain_to_backend(domain)
except ValueError as e: except ValueError as e:
print(f"Warning: Invalid domain '{domain}': {e}", file=sys.stderr) logger.warning("Invalid domain '%s': %s", domain, e)
continue continue
for slot_str, server_info in slots.items(): for slot_str, server_info in slots.items():
try: try:
slot = int(slot_str) slot = int(slot_str)
except ValueError: except ValueError:
print(f"Warning: Invalid slot '{slot_str}' for {domain}, skipping", file=sys.stderr) logger.warning("Invalid slot '%s' for %s, skipping", slot_str, domain)
continue continue
ip = server_info.get("ip", "") ip = server_info.get("ip", "")
@@ -490,7 +560,7 @@ def restore_servers_from_config() -> int:
try: try:
http_port = int(server_info.get("http_port", 80)) http_port = int(server_info.get("http_port", 80))
except (ValueError, TypeError): except (ValueError, TypeError):
print(f"Warning: Invalid port for {domain} slot {slot}, skipping", file=sys.stderr) logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
continue continue
try: try:
@@ -500,7 +570,7 @@ def restore_servers_from_config() -> int:
haproxy_cmd(f"set server {backend}/{server} state ready") haproxy_cmd(f"set server {backend}/{server} state ready")
restored += 1 restored += 1
except HaproxyError as e: except HaproxyError as e:
print(f"Warning: Failed to restore {domain} slot {slot}: {e}", file=sys.stderr) logger.warning("Failed to restore %s slot %d: %s", domain, slot, e)
return restored return restored
@@ -520,11 +590,13 @@ def haproxy_list_domains() -> str:
server_map: Dict[str, list] = {} server_map: Dict[str, list] = {}
for line in state.split("\n"): for line in state.split("\n"):
parts = line.split() parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[4] != "0.0.0.0": if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0":
backend = parts[1] backend = parts[StateField.BE_NAME]
if backend not in server_map: if backend not in server_map:
server_map[backend] = [] server_map[backend] = []
server_map[backend].append(f"{parts[3]}={parts[4]}:{parts[18]}") server_map[backend].append(
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}"
)
# Read from domains.map (skip wildcard entries starting with .) # Read from domains.map (skip wildcard entries starting with .)
seen_domains: Set[str] = set() seen_domains: Set[str] = set()
@@ -670,20 +742,18 @@ def haproxy_list_servers(domain: str) -> str:
return "Error: Invalid domain format" return "Error: Invalid domain format"
try: try:
# Look up backend from map backend, _ = get_backend_and_prefix(domain)
backend = get_domain_backend(domain)
if not backend:
# Fall back to legacy naming convention
backend = get_legacy_backend_name(domain)
servers = [] servers = []
state = haproxy_cmd("show servers state") state = haproxy_cmd("show servers state")
for line in state.split("\n"): for line in state.split("\n"):
parts = line.split() parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[1] == backend: if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
status = "active" if parts[4] != "0.0.0.0" else "disabled" addr = parts[StateField.SRV_ADDR]
servers.append(f"{parts[3]}: {parts[4]}:{parts[18]} ({status})") status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(
f"{parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
)
if not servers: if not servers:
return f"Backend {backend} not found" return f"Backend {backend} not found"
@@ -718,19 +788,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) ->
return "Error: Port must be between 1 and 65535" return "Error: Port must be between 1 and 65535"
try: try:
# Look up backend from map backend, server_prefix = get_backend_and_prefix(domain)
backend = get_domain_backend(domain)
if not backend:
# Fall back to legacy naming convention
backend = get_legacy_backend_name(domain)
# Determine server name prefix based on backend type
if backend.startswith("pool_"):
# Pool backends use pool_N_slot naming
server_prefix = backend
else:
# Legacy backends use domain-based naming
server_prefix = domain_to_backend(domain)
results = [] results = []
for suffix, port in get_server_suffixes(http_port): for suffix, port in get_server_suffixes(http_port):
@@ -743,7 +801,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) ->
add_server_to_config(domain, slot, ip, http_port) add_server_to_config(domain, slot, ip, http_port)
return f"Added to {domain} ({backend}) slot {slot}:\n" + "\n".join(results) return f"Added to {domain} ({backend}) slot {slot}:\n" + "\n".join(results)
except (HaproxyError, ValueError) as e: except (HaproxyError, ValueError, IOError) as e:
return f"Error: {e}" return f"Error: {e}"
@@ -764,19 +822,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
return f"Error: Slot must be between 1 and {MAX_SLOTS}" return f"Error: Slot must be between 1 and {MAX_SLOTS}"
try: try:
# Look up backend from map backend, server_prefix = get_backend_and_prefix(domain)
backend = get_domain_backend(domain)
if not backend:
# Fall back to legacy naming convention
backend = get_legacy_backend_name(domain)
# Determine server name prefix based on backend type
if backend.startswith("pool_"):
# Pool backends use pool_N_slot naming
server_prefix = backend
else:
# Legacy backends use domain-based naming
server_prefix = domain_to_backend(domain)
# HTTP only - single server per slot # HTTP only - single server per slot
server = f"{server_prefix}_{slot}" server = f"{server_prefix}_{slot}"
@@ -787,7 +833,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
remove_server_from_config(domain, slot) remove_server_from_config(domain, slot)
return f"Removed server at slot {slot} from {domain} ({backend})" return f"Removed server at slot {slot} from {domain} ({backend})"
except (HaproxyError, ValueError) as e: except (HaproxyError, ValueError, IOError) as e:
return f"Error: {e}" return f"Error: {e}"
@@ -999,30 +1045,41 @@ def haproxy_check_config() -> str:
@mcp.tool() @mcp.tool()
def haproxy_save_state() -> str: def haproxy_save_state() -> str:
"""Save current server state to disk. """Save current server state to disk atomically.
Returns: Returns:
Success message or error description Success message or error description
""" """
try: try:
state = haproxy_cmd("show servers state") state = haproxy_cmd("show servers state")
with open(STATE_FILE, "w", encoding="utf-8") as f: dir_path = os.path.dirname(STATE_FILE)
try: fd = None
fcntl.flock(f.fileno(), fcntl.LOCK_EX) temp_path = None
except OSError:
pass # Continue without lock if not supported
try: try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.state.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
f.write(state) f.write(state)
os.rename(temp_path, STATE_FILE)
temp_path = None # Rename succeeded, don't unlink
except OSError as e:
raise IOError(f"Failed to save state: {e}") from e
finally: finally:
if fd is not None:
try: try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN) os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError: except OSError:
pass pass
return "Server state saved" return "Server state saved"
except HaproxyError as e: except HaproxyError as e:
return f"Error: {e}" return f"Error: {e}"
except IOError as e: except IOError as e:
return f"Error: Failed to save state: {e}" return f"Error: {e}"
@mcp.tool() @mcp.tool()
@@ -1051,10 +1108,10 @@ def haproxy_restore_state() -> str:
for line in state.split("\n"): for line in state.split("\n"):
parts = line.split() parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and not line.startswith("#"): if len(parts) >= STATE_MIN_COLUMNS and not line.startswith("#"):
backend = parts[1] backend = parts[StateField.BE_NAME]
server = parts[3] server = parts[StateField.SRV_NAME]
addr = parts[4] addr = parts[StateField.SRV_ADDR]
port = parts[18] port = parts[StateField.SRV_PORT]
# Skip disabled servers # Skip disabled servers
if addr == "0.0.0.0": if addr == "0.0.0.0":
@@ -1071,7 +1128,7 @@ def haproxy_restore_state() -> str:
continue continue
haproxy_cmd(f"set server {backend}/{server} addr {addr} port {port}") haproxy_cmd(f"set server {backend}/{server} addr {addr} port {port}")
haproxy_cmd(f"enable server {backend}/{server}") haproxy_cmd(f"set server {backend}/{server} state ready")
restored += 1 restored += 1
result = f"Server state restored ({restored} servers)" result = f"Server state restored ({restored} servers)"
@@ -1094,15 +1151,15 @@ def startup_restore() -> None:
except HaproxyError: except HaproxyError:
time.sleep(1) time.sleep(1)
else: else:
print("Warning: HAProxy not ready, skipping restore", file=sys.stderr) logger.warning("HAProxy not ready, skipping restore")
return return
try: try:
count = restore_servers_from_config() count = restore_servers_from_config()
if count > 0: if count > 0:
print(f"Restored {count} servers from config", file=sys.stderr) logger.info("Restored %d servers from config", count)
except (HaproxyError, OSError, ValueError, json.JSONDecodeError) as e: except (HaproxyError, OSError, ValueError, json.JSONDecodeError) as e:
print(f"Warning: Failed to restore servers: {e}", file=sys.stderr) logger.warning("Failed to restore servers: %s", e)
if __name__ == "__main__": if __name__ == "__main__":