Improve code quality based on code review

Major improvements:
- Atomic file writes using temp file + rename pattern
- Structured logging with logging module (replaces print)
- StateField class for HAProxy state field indices
- Helper function get_backend_and_prefix() to reduce duplication
- Consistent exception chaining with 'from e'
- Proper fd/temp_path tracking to prevent resource leaks
- Added IOError handling in server management functions

Technical changes:
- save_map_file, save_servers_config, haproxy_save_state now use
  atomic writes with tempfile.mkstemp() + os.rename()
- Standardized on 'set server state ready' (was 'enable server')
- All magic numbers for state parsing replaced with StateField class

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-01 12:48:49 +00:00
parent 196374e70c
commit 61dd4a69fc

View File

@@ -10,12 +10,22 @@ import socket
import subprocess
import re
import json
import sys
import logging
import os
import tempfile
import time
import fcntl
from typing import Dict, Generator, List, Optional, Set, Tuple
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
from mcp.server.fastmcp import FastMCP
# Configure structured logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
mcp = FastMCP("haproxy", host="0.0.0.0", port=8000)
# Constants
@@ -63,6 +73,19 @@ class StatField:
CHECK_STATUS = 36 # Check status
# Field indices for HAProxy server state (show servers state command)
class StateField:
"""HAProxy server state field indices."""
BE_ID = 0 # Backend ID
BE_NAME = 1 # Backend name
SRV_ID = 2 # Server ID
SRV_NAME = 3 # Server name
SRV_ADDR = 4 # Server address
SRV_OP_STATE = 5 # Operational state
SRV_ADMIN_STATE = 6 # Admin state
SRV_PORT = 18 # Server port
def haproxy_cmd(command: str) -> str:
"""Send command to HAProxy Runtime API.
@@ -99,7 +122,7 @@ def haproxy_cmd(command: str) -> str:
except HaproxyError:
raise
except Exception as e:
raise HaproxyError(str(e))
raise HaproxyError(str(e)) from e
def reload_haproxy() -> Tuple[bool, str]:
@@ -323,7 +346,9 @@ def get_legacy_backend_name(domain: str) -> str:
def save_map_file(entries: List[Tuple[str, str]]) -> None:
"""Save entries to domains.map file with file locking.
"""Save entries to domains.map file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args:
entries: List of (domain, backend) tuples to write
@@ -331,25 +356,36 @@ def save_map_file(entries: List[Tuple[str, str]]) -> None:
Raises:
IOError: If the file cannot be written
"""
with open(MAP_FILE, "w", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
except OSError:
pass # Continue without lock if not supported
try:
dir_path = os.path.dirname(MAP_FILE)
fd = None
temp_path = None
try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.domains.map.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
f.write("# Domain to Backend mapping\n")
f.write("# Format: domain backend_name\n")
f.write("# Wildcard: .domain.com matches *.domain.com\n\n")
for domain, backend in entries:
f.write(f"{domain} {backend}\n")
finally:
os.rename(temp_path, MAP_FILE)
temp_path = None # Rename succeeded, don't unlink
except OSError as e:
raise IOError(f"Failed to save map file: {e}") from e
finally:
if fd is not None:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError:
pass
def load_servers_config() -> Dict:
def load_servers_config() -> Dict[str, Any]:
"""Load servers configuration from JSON file with file locking.
Returns:
@@ -360,7 +396,7 @@ def load_servers_config() -> Dict:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
except OSError:
pass # Continue without lock if not supported
logger.debug("File locking not supported for %s", SERVERS_FILE)
try:
return json.load(f)
finally:
@@ -371,26 +407,39 @@ def load_servers_config() -> Dict:
except FileNotFoundError:
return {}
except json.JSONDecodeError as e:
print(f"Warning: Corrupt config file {SERVERS_FILE}: {e}", file=sys.stderr)
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
return {}
def save_servers_config(config: Dict) -> None:
"""Save servers configuration to JSON file with file locking.
def save_servers_config(config: Dict[str, Any]) -> None:
"""Save servers configuration to JSON file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args:
config: Dictionary with server configurations
"""
with open(SERVERS_FILE, "w", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
except OSError:
pass # Continue without lock if not supported
try:
dir_path = os.path.dirname(SERVERS_FILE)
fd = None
temp_path = None
try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.json.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
json.dump(config, f, indent=2)
finally:
os.rename(temp_path, SERVERS_FILE)
temp_path = None # Rename succeeded, don't unlink
except OSError as e:
raise IOError(f"Failed to save servers config: {e}") from e
finally:
if fd is not None:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError:
pass
@@ -453,6 +502,30 @@ def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
return [("", http_port)]
def get_backend_and_prefix(domain: str) -> Tuple[str, str]:
"""Look up backend and determine server name prefix for a domain.
Args:
domain: The domain name to look up
Returns:
Tuple of (backend_name, server_prefix)
Raises:
ValueError: If domain cannot be mapped to a valid backend
"""
backend = get_domain_backend(domain)
if not backend:
backend = get_legacy_backend_name(domain)
if backend.startswith("pool_"):
server_prefix = backend
else:
server_prefix = domain_to_backend(domain)
return backend, server_prefix
def restore_servers_from_config() -> int:
"""Restore all servers from configuration file.
@@ -468,19 +541,16 @@ def restore_servers_from_config() -> int:
continue
try:
if backend.startswith("pool_"):
server_prefix = backend
else:
server_prefix = domain_to_backend(domain)
_, server_prefix = get_backend_and_prefix(domain)
except ValueError as e:
print(f"Warning: Invalid domain '{domain}': {e}", file=sys.stderr)
logger.warning("Invalid domain '%s': %s", domain, e)
continue
for slot_str, server_info in slots.items():
try:
slot = int(slot_str)
except ValueError:
print(f"Warning: Invalid slot '{slot_str}' for {domain}, skipping", file=sys.stderr)
logger.warning("Invalid slot '%s' for %s, skipping", slot_str, domain)
continue
ip = server_info.get("ip", "")
@@ -490,7 +560,7 @@ def restore_servers_from_config() -> int:
try:
http_port = int(server_info.get("http_port", 80))
except (ValueError, TypeError):
print(f"Warning: Invalid port for {domain} slot {slot}, skipping", file=sys.stderr)
logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
continue
try:
@@ -500,7 +570,7 @@ def restore_servers_from_config() -> int:
haproxy_cmd(f"set server {backend}/{server} state ready")
restored += 1
except HaproxyError as e:
print(f"Warning: Failed to restore {domain} slot {slot}: {e}", file=sys.stderr)
logger.warning("Failed to restore %s slot %d: %s", domain, slot, e)
return restored
@@ -520,11 +590,13 @@ def haproxy_list_domains() -> str:
server_map: Dict[str, list] = {}
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[4] != "0.0.0.0":
backend = parts[1]
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0":
backend = parts[StateField.BE_NAME]
if backend not in server_map:
server_map[backend] = []
server_map[backend].append(f"{parts[3]}={parts[4]}:{parts[18]}")
server_map[backend].append(
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}"
)
# Read from domains.map (skip wildcard entries starting with .)
seen_domains: Set[str] = set()
@@ -670,20 +742,18 @@ def haproxy_list_servers(domain: str) -> str:
return "Error: Invalid domain format"
try:
# Look up backend from map
backend = get_domain_backend(domain)
if not backend:
# Fall back to legacy naming convention
backend = get_legacy_backend_name(domain)
backend, _ = get_backend_and_prefix(domain)
servers = []
state = haproxy_cmd("show servers state")
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and parts[1] == backend:
status = "active" if parts[4] != "0.0.0.0" else "disabled"
servers.append(f"{parts[3]}: {parts[4]}:{parts[18]} ({status})")
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
addr = parts[StateField.SRV_ADDR]
status = "active" if addr != "0.0.0.0" else "disabled"
servers.append(
f"{parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
)
if not servers:
return f"Backend {backend} not found"
@@ -718,19 +788,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) ->
return "Error: Port must be between 1 and 65535"
try:
# Look up backend from map
backend = get_domain_backend(domain)
if not backend:
# Fall back to legacy naming convention
backend = get_legacy_backend_name(domain)
# Determine server name prefix based on backend type
if backend.startswith("pool_"):
# Pool backends use pool_N_slot naming
server_prefix = backend
else:
# Legacy backends use domain-based naming
server_prefix = domain_to_backend(domain)
backend, server_prefix = get_backend_and_prefix(domain)
results = []
for suffix, port in get_server_suffixes(http_port):
@@ -743,7 +801,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) ->
add_server_to_config(domain, slot, ip, http_port)
return f"Added to {domain} ({backend}) slot {slot}:\n" + "\n".join(results)
except (HaproxyError, ValueError) as e:
except (HaproxyError, ValueError, IOError) as e:
return f"Error: {e}"
@@ -764,19 +822,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
return f"Error: Slot must be between 1 and {MAX_SLOTS}"
try:
# Look up backend from map
backend = get_domain_backend(domain)
if not backend:
# Fall back to legacy naming convention
backend = get_legacy_backend_name(domain)
# Determine server name prefix based on backend type
if backend.startswith("pool_"):
# Pool backends use pool_N_slot naming
server_prefix = backend
else:
# Legacy backends use domain-based naming
server_prefix = domain_to_backend(domain)
backend, server_prefix = get_backend_and_prefix(domain)
# HTTP only - single server per slot
server = f"{server_prefix}_{slot}"
@@ -787,7 +833,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
remove_server_from_config(domain, slot)
return f"Removed server at slot {slot} from {domain} ({backend})"
except (HaproxyError, ValueError) as e:
except (HaproxyError, ValueError, IOError) as e:
return f"Error: {e}"
@@ -999,30 +1045,41 @@ def haproxy_check_config() -> str:
@mcp.tool()
def haproxy_save_state() -> str:
"""Save current server state to disk.
"""Save current server state to disk atomically.
Returns:
Success message or error description
"""
try:
state = haproxy_cmd("show servers state")
with open(STATE_FILE, "w", encoding="utf-8") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
except OSError:
pass # Continue without lock if not supported
try:
dir_path = os.path.dirname(STATE_FILE)
fd = None
temp_path = None
try:
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.state.')
with os.fdopen(fd, 'w', encoding='utf-8') as f:
fd = None # fd is now owned by the file object
f.write(state)
finally:
os.rename(temp_path, STATE_FILE)
temp_path = None # Rename succeeded, don't unlink
except OSError as e:
raise IOError(f"Failed to save state: {e}") from e
finally:
if fd is not None:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
os.close(fd)
except OSError:
pass
if temp_path is not None:
try:
os.unlink(temp_path)
except OSError:
pass
return "Server state saved"
except HaproxyError as e:
return f"Error: {e}"
except IOError as e:
return f"Error: Failed to save state: {e}"
return f"Error: {e}"
@mcp.tool()
@@ -1051,10 +1108,10 @@ def haproxy_restore_state() -> str:
for line in state.split("\n"):
parts = line.split()
if len(parts) >= STATE_MIN_COLUMNS and not line.startswith("#"):
backend = parts[1]
server = parts[3]
addr = parts[4]
port = parts[18]
backend = parts[StateField.BE_NAME]
server = parts[StateField.SRV_NAME]
addr = parts[StateField.SRV_ADDR]
port = parts[StateField.SRV_PORT]
# Skip disabled servers
if addr == "0.0.0.0":
@@ -1071,7 +1128,7 @@ def haproxy_restore_state() -> str:
continue
haproxy_cmd(f"set server {backend}/{server} addr {addr} port {port}")
haproxy_cmd(f"enable server {backend}/{server}")
haproxy_cmd(f"set server {backend}/{server} state ready")
restored += 1
result = f"Server state restored ({restored} servers)"
@@ -1094,15 +1151,15 @@ def startup_restore() -> None:
except HaproxyError:
time.sleep(1)
else:
print("Warning: HAProxy not ready, skipping restore", file=sys.stderr)
logger.warning("HAProxy not ready, skipping restore")
return
try:
count = restore_servers_from_config()
if count > 0:
print(f"Restored {count} servers from config", file=sys.stderr)
logger.info("Restored %d servers from config", count)
except (HaproxyError, OSError, ValueError, json.JSONDecodeError) as e:
print(f"Warning: Failed to restore servers: {e}", file=sys.stderr)
logger.warning("Failed to restore servers: %s", e)
if __name__ == "__main__":