From 61dd4a69fc9e7779c15bc4f0701c75788744ea42 Mon Sep 17 00:00:00 2001 From: kaffa Date: Sun, 1 Feb 2026 12:48:49 +0000 Subject: [PATCH] Improve code quality based on code review Major improvements: - Atomic file writes using temp file + rename pattern - Structured logging with logging module (replaces print) - StateField class for HAProxy state field indices - Helper function get_backend_and_prefix() to reduce duplication - Consistent exception chaining with 'from e' - Proper fd/temp_path tracking to prevent resource leaks - Added IOError handling in server management functions Technical changes: - save_map_file, save_servers_config, haproxy_save_state now use atomic writes with tempfile.mkstemp() + os.rename() - Standardized on 'set server state ready' (was 'enable server') - All magic numbers for state parsing replaced with StateField class Co-Authored-By: Claude Opus 4.5 --- mcp/server.py | 239 +++++++++++++++++++++++++++++++------------------- 1 file changed, 148 insertions(+), 91 deletions(-) diff --git a/mcp/server.py b/mcp/server.py index 6fd379d..6763382 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -10,12 +10,22 @@ import socket import subprocess import re import json -import sys +import logging +import os +import tempfile import time import fcntl -from typing import Dict, Generator, List, Optional, Set, Tuple +from typing import Any, Dict, Generator, List, Optional, Set, Tuple from mcp.server.fastmcp import FastMCP +# Configure structured logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + mcp = FastMCP("haproxy", host="0.0.0.0", port=8000) # Constants @@ -63,6 +73,19 @@ class StatField: CHECK_STATUS = 36 # Check status +# Field indices for HAProxy server state (show servers state command) +class StateField: + """HAProxy server state field indices.""" + BE_ID = 0 # Backend ID + BE_NAME = 1 # Backend name + SRV_ID = 2 # Server ID + SRV_NAME = 3 # Server name + SRV_ADDR = 4 # Server address + SRV_OP_STATE = 5 # Operational state + SRV_ADMIN_STATE = 6 # Admin state + SRV_PORT = 18 # Server port + + def haproxy_cmd(command: str) -> str: """Send command to HAProxy Runtime API. @@ -99,7 +122,7 @@ def haproxy_cmd(command: str) -> str: except HaproxyError: raise except Exception as e: - raise HaproxyError(str(e)) + raise HaproxyError(str(e)) from e def reload_haproxy() -> Tuple[bool, str]: @@ -323,7 +346,9 @@ def get_legacy_backend_name(domain: str) -> str: def save_map_file(entries: List[Tuple[str, str]]) -> None: - """Save entries to domains.map file with file locking. + """Save entries to domains.map file atomically. + + Uses temp file + rename for atomic write to prevent race conditions. Args: entries: List of (domain, backend) tuples to write @@ -331,25 +356,36 @@ def save_map_file(entries: List[Tuple[str, str]]) -> None: Raises: IOError: If the file cannot be written """ - with open(MAP_FILE, "w", encoding="utf-8") as f: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_EX) - except OSError: - pass # Continue without lock if not supported - try: + dir_path = os.path.dirname(MAP_FILE) + fd = None + temp_path = None + try: + fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.domains.map.') + with os.fdopen(fd, 'w', encoding='utf-8') as f: + fd = None # fd is now owned by the file object f.write("# Domain to Backend mapping\n") f.write("# Format: domain backend_name\n") f.write("# Wildcard: .domain.com matches *.domain.com\n\n") for domain, backend in entries: f.write(f"{domain} {backend}\n") - finally: + os.rename(temp_path, MAP_FILE) + temp_path = None # Rename succeeded, don't unlink + except OSError as e: + raise IOError(f"Failed to save map file: {e}") from e + finally: + if fd is not None: try: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) + os.close(fd) + except OSError: + pass + if temp_path is not None: + try: + os.unlink(temp_path) except OSError: pass -def load_servers_config() -> Dict: +def load_servers_config() -> Dict[str, Any]: """Load servers configuration from JSON file with file locking. Returns: @@ -360,7 +396,7 @@ def load_servers_config() -> Dict: try: fcntl.flock(f.fileno(), fcntl.LOCK_SH) except OSError: - pass # Continue without lock if not supported + logger.debug("File locking not supported for %s", SERVERS_FILE) try: return json.load(f) finally: @@ -371,26 +407,39 @@ def load_servers_config() -> Dict: except FileNotFoundError: return {} except json.JSONDecodeError as e: - print(f"Warning: Corrupt config file {SERVERS_FILE}: {e}", file=sys.stderr) + logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e) return {} -def save_servers_config(config: Dict) -> None: - """Save servers configuration to JSON file with file locking. +def save_servers_config(config: Dict[str, Any]) -> None: + """Save servers configuration to JSON file atomically. + + Uses temp file + rename for atomic write to prevent race conditions. Args: config: Dictionary with server configurations """ - with open(SERVERS_FILE, "w", encoding="utf-8") as f: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_EX) - except OSError: - pass # Continue without lock if not supported - try: + dir_path = os.path.dirname(SERVERS_FILE) + fd = None + temp_path = None + try: + fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.json.') + with os.fdopen(fd, 'w', encoding='utf-8') as f: + fd = None # fd is now owned by the file object json.dump(config, f, indent=2) - finally: + os.rename(temp_path, SERVERS_FILE) + temp_path = None # Rename succeeded, don't unlink + except OSError as e: + raise IOError(f"Failed to save servers config: {e}") from e + finally: + if fd is not None: try: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) + os.close(fd) + except OSError: + pass + if temp_path is not None: + try: + os.unlink(temp_path) except OSError: pass @@ -453,6 +502,30 @@ def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]: return [("", http_port)] +def get_backend_and_prefix(domain: str) -> Tuple[str, str]: + """Look up backend and determine server name prefix for a domain. + + Args: + domain: The domain name to look up + + Returns: + Tuple of (backend_name, server_prefix) + + Raises: + ValueError: If domain cannot be mapped to a valid backend + """ + backend = get_domain_backend(domain) + if not backend: + backend = get_legacy_backend_name(domain) + + if backend.startswith("pool_"): + server_prefix = backend + else: + server_prefix = domain_to_backend(domain) + + return backend, server_prefix + + def restore_servers_from_config() -> int: """Restore all servers from configuration file. @@ -468,19 +541,16 @@ def restore_servers_from_config() -> int: continue try: - if backend.startswith("pool_"): - server_prefix = backend - else: - server_prefix = domain_to_backend(domain) + _, server_prefix = get_backend_and_prefix(domain) except ValueError as e: - print(f"Warning: Invalid domain '{domain}': {e}", file=sys.stderr) + logger.warning("Invalid domain '%s': %s", domain, e) continue for slot_str, server_info in slots.items(): try: slot = int(slot_str) except ValueError: - print(f"Warning: Invalid slot '{slot_str}' for {domain}, skipping", file=sys.stderr) + logger.warning("Invalid slot '%s' for %s, skipping", slot_str, domain) continue ip = server_info.get("ip", "") @@ -490,7 +560,7 @@ def restore_servers_from_config() -> int: try: http_port = int(server_info.get("http_port", 80)) except (ValueError, TypeError): - print(f"Warning: Invalid port for {domain} slot {slot}, skipping", file=sys.stderr) + logger.warning("Invalid port for %s slot %d, skipping", domain, slot) continue try: @@ -500,7 +570,7 @@ def restore_servers_from_config() -> int: haproxy_cmd(f"set server {backend}/{server} state ready") restored += 1 except HaproxyError as e: - print(f"Warning: Failed to restore {domain} slot {slot}: {e}", file=sys.stderr) + logger.warning("Failed to restore %s slot %d: %s", domain, slot, e) return restored @@ -520,11 +590,13 @@ def haproxy_list_domains() -> str: server_map: Dict[str, list] = {} for line in state.split("\n"): parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[4] != "0.0.0.0": - backend = parts[1] + if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0": + backend = parts[StateField.BE_NAME] if backend not in server_map: server_map[backend] = [] - server_map[backend].append(f"{parts[3]}={parts[4]}:{parts[18]}") + server_map[backend].append( + f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}" + ) # Read from domains.map (skip wildcard entries starting with .) seen_domains: Set[str] = set() @@ -670,20 +742,18 @@ def haproxy_list_servers(domain: str) -> str: return "Error: Invalid domain format" try: - # Look up backend from map - backend = get_domain_backend(domain) - if not backend: - # Fall back to legacy naming convention - backend = get_legacy_backend_name(domain) - + backend, _ = get_backend_and_prefix(domain) servers = [] state = haproxy_cmd("show servers state") for line in state.split("\n"): parts = line.split() - if len(parts) >= STATE_MIN_COLUMNS and parts[1] == backend: - status = "active" if parts[4] != "0.0.0.0" else "disabled" - servers.append(f"• {parts[3]}: {parts[4]}:{parts[18]} ({status})") + if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend: + addr = parts[StateField.SRV_ADDR] + status = "active" if addr != "0.0.0.0" else "disabled" + servers.append( + f"• {parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})" + ) if not servers: return f"Backend {backend} not found" @@ -718,19 +788,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) -> return "Error: Port must be between 1 and 65535" try: - # Look up backend from map - backend = get_domain_backend(domain) - if not backend: - # Fall back to legacy naming convention - backend = get_legacy_backend_name(domain) - - # Determine server name prefix based on backend type - if backend.startswith("pool_"): - # Pool backends use pool_N_slot naming - server_prefix = backend - else: - # Legacy backends use domain-based naming - server_prefix = domain_to_backend(domain) + backend, server_prefix = get_backend_and_prefix(domain) results = [] for suffix, port in get_server_suffixes(http_port): @@ -743,7 +801,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) -> add_server_to_config(domain, slot, ip, http_port) return f"Added to {domain} ({backend}) slot {slot}:\n" + "\n".join(results) - except (HaproxyError, ValueError) as e: + except (HaproxyError, ValueError, IOError) as e: return f"Error: {e}" @@ -764,19 +822,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str: return f"Error: Slot must be between 1 and {MAX_SLOTS}" try: - # Look up backend from map - backend = get_domain_backend(domain) - if not backend: - # Fall back to legacy naming convention - backend = get_legacy_backend_name(domain) - - # Determine server name prefix based on backend type - if backend.startswith("pool_"): - # Pool backends use pool_N_slot naming - server_prefix = backend - else: - # Legacy backends use domain-based naming - server_prefix = domain_to_backend(domain) + backend, server_prefix = get_backend_and_prefix(domain) # HTTP only - single server per slot server = f"{server_prefix}_{slot}" @@ -787,7 +833,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str: remove_server_from_config(domain, slot) return f"Removed server at slot {slot} from {domain} ({backend})" - except (HaproxyError, ValueError) as e: + except (HaproxyError, ValueError, IOError) as e: return f"Error: {e}" @@ -999,30 +1045,41 @@ def haproxy_check_config() -> str: @mcp.tool() def haproxy_save_state() -> str: - """Save current server state to disk. + """Save current server state to disk atomically. Returns: Success message or error description """ try: state = haproxy_cmd("show servers state") - with open(STATE_FILE, "w", encoding="utf-8") as f: - try: - fcntl.flock(f.fileno(), fcntl.LOCK_EX) - except OSError: - pass # Continue without lock if not supported - try: + dir_path = os.path.dirname(STATE_FILE) + fd = None + temp_path = None + try: + fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.state.') + with os.fdopen(fd, 'w', encoding='utf-8') as f: + fd = None # fd is now owned by the file object f.write(state) - finally: + os.rename(temp_path, STATE_FILE) + temp_path = None # Rename succeeded, don't unlink + except OSError as e: + raise IOError(f"Failed to save state: {e}") from e + finally: + if fd is not None: try: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) + os.close(fd) + except OSError: + pass + if temp_path is not None: + try: + os.unlink(temp_path) except OSError: pass return "Server state saved" except HaproxyError as e: return f"Error: {e}" except IOError as e: - return f"Error: Failed to save state: {e}" + return f"Error: {e}" @mcp.tool() @@ -1051,10 +1108,10 @@ def haproxy_restore_state() -> str: for line in state.split("\n"): parts = line.split() if len(parts) >= STATE_MIN_COLUMNS and not line.startswith("#"): - backend = parts[1] - server = parts[3] - addr = parts[4] - port = parts[18] + backend = parts[StateField.BE_NAME] + server = parts[StateField.SRV_NAME] + addr = parts[StateField.SRV_ADDR] + port = parts[StateField.SRV_PORT] # Skip disabled servers if addr == "0.0.0.0": @@ -1071,7 +1128,7 @@ def haproxy_restore_state() -> str: continue haproxy_cmd(f"set server {backend}/{server} addr {addr} port {port}") - haproxy_cmd(f"enable server {backend}/{server}") + haproxy_cmd(f"set server {backend}/{server} state ready") restored += 1 result = f"Server state restored ({restored} servers)" @@ -1094,15 +1151,15 @@ def startup_restore() -> None: except HaproxyError: time.sleep(1) else: - print("Warning: HAProxy not ready, skipping restore", file=sys.stderr) + logger.warning("HAProxy not ready, skipping restore") return try: count = restore_servers_from_config() if count > 0: - print(f"Restored {count} servers from config", file=sys.stderr) + logger.info("Restored %d servers from config", count) except (HaproxyError, OSError, ValueError, json.JSONDecodeError) as e: - print(f"Warning: Failed to restore servers: {e}", file=sys.stderr) + logger.warning("Failed to restore servers: %s", e) if __name__ == "__main__":