Improve code quality based on code review
Major improvements: - Atomic file writes using temp file + rename pattern - Structured logging with logging module (replaces print) - StateField class for HAProxy state field indices - Helper function get_backend_and_prefix() to reduce duplication - Consistent exception chaining with 'from e' - Proper fd/temp_path tracking to prevent resource leaks - Added IOError handling in server management functions Technical changes: - save_map_file, save_servers_config, haproxy_save_state now use atomic writes with tempfile.mkstemp() + os.rename() - Standardized on 'set server state ready' (was 'enable server') - All magic numbers for state parsing replaced with StateField class Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
227
mcp/server.py
227
mcp/server.py
@@ -10,12 +10,22 @@ import socket
|
||||
import subprocess
|
||||
import re
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import fcntl
|
||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Configure structured logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mcp = FastMCP("haproxy", host="0.0.0.0", port=8000)
|
||||
|
||||
# Constants
|
||||
@@ -63,6 +73,19 @@ class StatField:
|
||||
CHECK_STATUS = 36 # Check status
|
||||
|
||||
|
||||
# Field indices for HAProxy server state (show servers state command)
|
||||
class StateField:
|
||||
"""HAProxy server state field indices."""
|
||||
BE_ID = 0 # Backend ID
|
||||
BE_NAME = 1 # Backend name
|
||||
SRV_ID = 2 # Server ID
|
||||
SRV_NAME = 3 # Server name
|
||||
SRV_ADDR = 4 # Server address
|
||||
SRV_OP_STATE = 5 # Operational state
|
||||
SRV_ADMIN_STATE = 6 # Admin state
|
||||
SRV_PORT = 18 # Server port
|
||||
|
||||
|
||||
def haproxy_cmd(command: str) -> str:
|
||||
"""Send command to HAProxy Runtime API.
|
||||
|
||||
@@ -99,7 +122,7 @@ def haproxy_cmd(command: str) -> str:
|
||||
except HaproxyError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HaproxyError(str(e))
|
||||
raise HaproxyError(str(e)) from e
|
||||
|
||||
|
||||
def reload_haproxy() -> Tuple[bool, str]:
|
||||
@@ -323,7 +346,9 @@ def get_legacy_backend_name(domain: str) -> str:
|
||||
|
||||
|
||||
def save_map_file(entries: List[Tuple[str, str]]) -> None:
|
||||
"""Save entries to domains.map file with file locking.
|
||||
"""Save entries to domains.map file atomically.
|
||||
|
||||
Uses temp file + rename for atomic write to prevent race conditions.
|
||||
|
||||
Args:
|
||||
entries: List of (domain, backend) tuples to write
|
||||
@@ -331,25 +356,36 @@ def save_map_file(entries: List[Tuple[str, str]]) -> None:
|
||||
Raises:
|
||||
IOError: If the file cannot be written
|
||||
"""
|
||||
with open(MAP_FILE, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
except OSError:
|
||||
pass # Continue without lock if not supported
|
||||
dir_path = os.path.dirname(MAP_FILE)
|
||||
fd = None
|
||||
temp_path = None
|
||||
try:
|
||||
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.domains.map.')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
fd = None # fd is now owned by the file object
|
||||
f.write("# Domain to Backend mapping\n")
|
||||
f.write("# Format: domain backend_name\n")
|
||||
f.write("# Wildcard: .domain.com matches *.domain.com\n\n")
|
||||
for domain, backend in entries:
|
||||
f.write(f"{domain} {backend}\n")
|
||||
os.rename(temp_path, MAP_FILE)
|
||||
temp_path = None # Rename succeeded, don't unlink
|
||||
except OSError as e:
|
||||
raise IOError(f"Failed to save map file: {e}") from e
|
||||
finally:
|
||||
if fd is not None:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
if temp_path is not None:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def load_servers_config() -> Dict:
|
||||
def load_servers_config() -> Dict[str, Any]:
|
||||
"""Load servers configuration from JSON file with file locking.
|
||||
|
||||
Returns:
|
||||
@@ -360,7 +396,7 @@ def load_servers_config() -> Dict:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
|
||||
except OSError:
|
||||
pass # Continue without lock if not supported
|
||||
logger.debug("File locking not supported for %s", SERVERS_FILE)
|
||||
try:
|
||||
return json.load(f)
|
||||
finally:
|
||||
@@ -371,26 +407,39 @@ def load_servers_config() -> Dict:
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Warning: Corrupt config file {SERVERS_FILE}: {e}", file=sys.stderr)
|
||||
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
|
||||
return {}
|
||||
|
||||
|
||||
def save_servers_config(config: Dict) -> None:
|
||||
"""Save servers configuration to JSON file with file locking.
|
||||
def save_servers_config(config: Dict[str, Any]) -> None:
|
||||
"""Save servers configuration to JSON file atomically.
|
||||
|
||||
Uses temp file + rename for atomic write to prevent race conditions.
|
||||
|
||||
Args:
|
||||
config: Dictionary with server configurations
|
||||
"""
|
||||
with open(SERVERS_FILE, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
except OSError:
|
||||
pass # Continue without lock if not supported
|
||||
dir_path = os.path.dirname(SERVERS_FILE)
|
||||
fd = None
|
||||
temp_path = None
|
||||
try:
|
||||
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.json.')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
fd = None # fd is now owned by the file object
|
||||
json.dump(config, f, indent=2)
|
||||
os.rename(temp_path, SERVERS_FILE)
|
||||
temp_path = None # Rename succeeded, don't unlink
|
||||
except OSError as e:
|
||||
raise IOError(f"Failed to save servers config: {e}") from e
|
||||
finally:
|
||||
if fd is not None:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
if temp_path is not None:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -453,6 +502,30 @@ def get_server_suffixes(http_port: int) -> List[Tuple[str, int]]:
|
||||
return [("", http_port)]
|
||||
|
||||
|
||||
def get_backend_and_prefix(domain: str) -> Tuple[str, str]:
|
||||
"""Look up backend and determine server name prefix for a domain.
|
||||
|
||||
Args:
|
||||
domain: The domain name to look up
|
||||
|
||||
Returns:
|
||||
Tuple of (backend_name, server_prefix)
|
||||
|
||||
Raises:
|
||||
ValueError: If domain cannot be mapped to a valid backend
|
||||
"""
|
||||
backend = get_domain_backend(domain)
|
||||
if not backend:
|
||||
backend = get_legacy_backend_name(domain)
|
||||
|
||||
if backend.startswith("pool_"):
|
||||
server_prefix = backend
|
||||
else:
|
||||
server_prefix = domain_to_backend(domain)
|
||||
|
||||
return backend, server_prefix
|
||||
|
||||
|
||||
def restore_servers_from_config() -> int:
|
||||
"""Restore all servers from configuration file.
|
||||
|
||||
@@ -468,19 +541,16 @@ def restore_servers_from_config() -> int:
|
||||
continue
|
||||
|
||||
try:
|
||||
if backend.startswith("pool_"):
|
||||
server_prefix = backend
|
||||
else:
|
||||
server_prefix = domain_to_backend(domain)
|
||||
_, server_prefix = get_backend_and_prefix(domain)
|
||||
except ValueError as e:
|
||||
print(f"Warning: Invalid domain '{domain}': {e}", file=sys.stderr)
|
||||
logger.warning("Invalid domain '%s': %s", domain, e)
|
||||
continue
|
||||
|
||||
for slot_str, server_info in slots.items():
|
||||
try:
|
||||
slot = int(slot_str)
|
||||
except ValueError:
|
||||
print(f"Warning: Invalid slot '{slot_str}' for {domain}, skipping", file=sys.stderr)
|
||||
logger.warning("Invalid slot '%s' for %s, skipping", slot_str, domain)
|
||||
continue
|
||||
|
||||
ip = server_info.get("ip", "")
|
||||
@@ -490,7 +560,7 @@ def restore_servers_from_config() -> int:
|
||||
try:
|
||||
http_port = int(server_info.get("http_port", 80))
|
||||
except (ValueError, TypeError):
|
||||
print(f"Warning: Invalid port for {domain} slot {slot}, skipping", file=sys.stderr)
|
||||
logger.warning("Invalid port for %s slot %d, skipping", domain, slot)
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -500,7 +570,7 @@ def restore_servers_from_config() -> int:
|
||||
haproxy_cmd(f"set server {backend}/{server} state ready")
|
||||
restored += 1
|
||||
except HaproxyError as e:
|
||||
print(f"Warning: Failed to restore {domain} slot {slot}: {e}", file=sys.stderr)
|
||||
logger.warning("Failed to restore %s slot %d: %s", domain, slot, e)
|
||||
|
||||
return restored
|
||||
|
||||
@@ -520,11 +590,13 @@ def haproxy_list_domains() -> str:
|
||||
server_map: Dict[str, list] = {}
|
||||
for line in state.split("\n"):
|
||||
parts = line.split()
|
||||
if len(parts) >= STATE_MIN_COLUMNS and parts[4] != "0.0.0.0":
|
||||
backend = parts[1]
|
||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0":
|
||||
backend = parts[StateField.BE_NAME]
|
||||
if backend not in server_map:
|
||||
server_map[backend] = []
|
||||
server_map[backend].append(f"{parts[3]}={parts[4]}:{parts[18]}")
|
||||
server_map[backend].append(
|
||||
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}"
|
||||
)
|
||||
|
||||
# Read from domains.map (skip wildcard entries starting with .)
|
||||
seen_domains: Set[str] = set()
|
||||
@@ -670,20 +742,18 @@ def haproxy_list_servers(domain: str) -> str:
|
||||
return "Error: Invalid domain format"
|
||||
|
||||
try:
|
||||
# Look up backend from map
|
||||
backend = get_domain_backend(domain)
|
||||
if not backend:
|
||||
# Fall back to legacy naming convention
|
||||
backend = get_legacy_backend_name(domain)
|
||||
|
||||
backend, _ = get_backend_and_prefix(domain)
|
||||
servers = []
|
||||
state = haproxy_cmd("show servers state")
|
||||
|
||||
for line in state.split("\n"):
|
||||
parts = line.split()
|
||||
if len(parts) >= STATE_MIN_COLUMNS and parts[1] == backend:
|
||||
status = "active" if parts[4] != "0.0.0.0" else "disabled"
|
||||
servers.append(f"• {parts[3]}: {parts[4]}:{parts[18]} ({status})")
|
||||
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.BE_NAME] == backend:
|
||||
addr = parts[StateField.SRV_ADDR]
|
||||
status = "active" if addr != "0.0.0.0" else "disabled"
|
||||
servers.append(
|
||||
f"• {parts[StateField.SRV_NAME]}: {addr}:{parts[StateField.SRV_PORT]} ({status})"
|
||||
)
|
||||
|
||||
if not servers:
|
||||
return f"Backend {backend} not found"
|
||||
@@ -718,19 +788,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) ->
|
||||
return "Error: Port must be between 1 and 65535"
|
||||
|
||||
try:
|
||||
# Look up backend from map
|
||||
backend = get_domain_backend(domain)
|
||||
if not backend:
|
||||
# Fall back to legacy naming convention
|
||||
backend = get_legacy_backend_name(domain)
|
||||
|
||||
# Determine server name prefix based on backend type
|
||||
if backend.startswith("pool_"):
|
||||
# Pool backends use pool_N_slot naming
|
||||
server_prefix = backend
|
||||
else:
|
||||
# Legacy backends use domain-based naming
|
||||
server_prefix = domain_to_backend(domain)
|
||||
backend, server_prefix = get_backend_and_prefix(domain)
|
||||
|
||||
results = []
|
||||
for suffix, port in get_server_suffixes(http_port):
|
||||
@@ -743,7 +801,7 @@ def haproxy_add_server(domain: str, slot: int, ip: str, http_port: int = 80) ->
|
||||
add_server_to_config(domain, slot, ip, http_port)
|
||||
|
||||
return f"Added to {domain} ({backend}) slot {slot}:\n" + "\n".join(results)
|
||||
except (HaproxyError, ValueError) as e:
|
||||
except (HaproxyError, ValueError, IOError) as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
@@ -764,19 +822,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
|
||||
return f"Error: Slot must be between 1 and {MAX_SLOTS}"
|
||||
|
||||
try:
|
||||
# Look up backend from map
|
||||
backend = get_domain_backend(domain)
|
||||
if not backend:
|
||||
# Fall back to legacy naming convention
|
||||
backend = get_legacy_backend_name(domain)
|
||||
|
||||
# Determine server name prefix based on backend type
|
||||
if backend.startswith("pool_"):
|
||||
# Pool backends use pool_N_slot naming
|
||||
server_prefix = backend
|
||||
else:
|
||||
# Legacy backends use domain-based naming
|
||||
server_prefix = domain_to_backend(domain)
|
||||
backend, server_prefix = get_backend_and_prefix(domain)
|
||||
|
||||
# HTTP only - single server per slot
|
||||
server = f"{server_prefix}_{slot}"
|
||||
@@ -787,7 +833,7 @@ def haproxy_remove_server(domain: str, slot: int) -> str:
|
||||
remove_server_from_config(domain, slot)
|
||||
|
||||
return f"Removed server at slot {slot} from {domain} ({backend})"
|
||||
except (HaproxyError, ValueError) as e:
|
||||
except (HaproxyError, ValueError, IOError) as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
@@ -999,30 +1045,41 @@ def haproxy_check_config() -> str:
|
||||
|
||||
@mcp.tool()
|
||||
def haproxy_save_state() -> str:
|
||||
"""Save current server state to disk.
|
||||
"""Save current server state to disk atomically.
|
||||
|
||||
Returns:
|
||||
Success message or error description
|
||||
"""
|
||||
try:
|
||||
state = haproxy_cmd("show servers state")
|
||||
with open(STATE_FILE, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
except OSError:
|
||||
pass # Continue without lock if not supported
|
||||
dir_path = os.path.dirname(STATE_FILE)
|
||||
fd = None
|
||||
temp_path = None
|
||||
try:
|
||||
fd, temp_path = tempfile.mkstemp(dir=dir_path, prefix='.servers.state.')
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
fd = None # fd is now owned by the file object
|
||||
f.write(state)
|
||||
os.rename(temp_path, STATE_FILE)
|
||||
temp_path = None # Rename succeeded, don't unlink
|
||||
except OSError as e:
|
||||
raise IOError(f"Failed to save state: {e}") from e
|
||||
finally:
|
||||
if fd is not None:
|
||||
try:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
if temp_path is not None:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return "Server state saved"
|
||||
except HaproxyError as e:
|
||||
return f"Error: {e}"
|
||||
except IOError as e:
|
||||
return f"Error: Failed to save state: {e}"
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1051,10 +1108,10 @@ def haproxy_restore_state() -> str:
|
||||
for line in state.split("\n"):
|
||||
parts = line.split()
|
||||
if len(parts) >= STATE_MIN_COLUMNS and not line.startswith("#"):
|
||||
backend = parts[1]
|
||||
server = parts[3]
|
||||
addr = parts[4]
|
||||
port = parts[18]
|
||||
backend = parts[StateField.BE_NAME]
|
||||
server = parts[StateField.SRV_NAME]
|
||||
addr = parts[StateField.SRV_ADDR]
|
||||
port = parts[StateField.SRV_PORT]
|
||||
|
||||
# Skip disabled servers
|
||||
if addr == "0.0.0.0":
|
||||
@@ -1071,7 +1128,7 @@ def haproxy_restore_state() -> str:
|
||||
continue
|
||||
|
||||
haproxy_cmd(f"set server {backend}/{server} addr {addr} port {port}")
|
||||
haproxy_cmd(f"enable server {backend}/{server}")
|
||||
haproxy_cmd(f"set server {backend}/{server} state ready")
|
||||
restored += 1
|
||||
|
||||
result = f"Server state restored ({restored} servers)"
|
||||
@@ -1094,15 +1151,15 @@ def startup_restore() -> None:
|
||||
except HaproxyError:
|
||||
time.sleep(1)
|
||||
else:
|
||||
print("Warning: HAProxy not ready, skipping restore", file=sys.stderr)
|
||||
logger.warning("HAProxy not ready, skipping restore")
|
||||
return
|
||||
|
||||
try:
|
||||
count = restore_servers_from_config()
|
||||
if count > 0:
|
||||
print(f"Restored {count} servers from config", file=sys.stderr)
|
||||
logger.info("Restored %d servers from config", count)
|
||||
except (HaproxyError, OSError, ValueError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Failed to restore servers: {e}", file=sys.stderr)
|
||||
logger.warning("Failed to restore servers: %s", e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user