- Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
357 lines
14 KiB
Python
357 lines
14 KiB
Python
"""Domain management tools for HAProxy MCP Server."""
|
|
|
|
import fcntl
|
|
import os
|
|
import subprocess
|
|
from typing import Annotated, Optional
|
|
|
|
from pydantic import Field
|
|
|
|
from ..config import (
|
|
MAP_FILE,
|
|
MAP_FILE_CONTAINER,
|
|
WILDCARDS_MAP_FILE_CONTAINER,
|
|
POOL_COUNT,
|
|
MAX_SLOTS,
|
|
StateField,
|
|
STATE_MIN_COLUMNS,
|
|
SUBPROCESS_TIMEOUT,
|
|
CERTS_DIR,
|
|
logger,
|
|
)
|
|
from ..exceptions import HaproxyError
|
|
from ..validation import validate_domain, validate_ip
|
|
from ..haproxy_client import haproxy_cmd
|
|
from ..file_ops import (
|
|
get_map_contents,
|
|
save_map_file,
|
|
get_domain_backend,
|
|
is_legacy_backend,
|
|
add_server_to_config,
|
|
remove_server_from_config,
|
|
remove_domain_from_config,
|
|
)
|
|
|
|
|
|
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
|
|
"""Find an available pool backend from the pool list.
|
|
|
|
Iterates through pool_1 to pool_N and returns the first pool
|
|
that is not currently in use.
|
|
|
|
Args:
|
|
entries: List of (domain, backend) tuples from the map file.
|
|
used_pools: Set of pool names already in use.
|
|
|
|
Returns:
|
|
Available pool name (e.g., "pool_5") or None if all pools are in use.
|
|
"""
|
|
for i in range(1, POOL_COUNT + 1):
|
|
pool_name = f"pool_{i}"
|
|
if pool_name not in used_pools:
|
|
return pool_name
|
|
return None
|
|
|
|
|
|
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
|
|
"""Check if a domain is a subdomain of an existing registered domain.
|
|
|
|
For example, vault.anvil.it.com is a subdomain if anvil.it.com exists.
|
|
Subdomains should not have wildcard entries added to avoid conflicts.
|
|
|
|
Args:
|
|
domain: Domain name to check (e.g., "api.example.com").
|
|
registered_domains: Set of already registered domain names.
|
|
|
|
Returns:
|
|
Tuple of (is_subdomain, parent_domain or None).
|
|
"""
|
|
parts = domain.split(".")
|
|
for i in range(1, len(parts)):
|
|
candidate = ".".join(parts[i:])
|
|
if candidate in registered_domains:
|
|
return True, candidate
|
|
return False, None
|
|
|
|
|
|
def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
|
|
"""Update HAProxy maps via Runtime API.
|
|
|
|
Uses 2-stage matching: exact domains go to domains.map,
|
|
wildcards go to wildcards.map.
|
|
|
|
Args:
|
|
domain: Domain name to add.
|
|
pool: Pool backend name (e.g., "pool_5").
|
|
is_subdomain: If True, skip adding wildcard entry.
|
|
|
|
Raises:
|
|
HaproxyError: If HAProxy Runtime API command fails.
|
|
"""
|
|
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}")
|
|
if not is_subdomain:
|
|
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
|
|
|
|
|
|
def _rollback_domain_addition(
|
|
domain: str,
|
|
entries: list[tuple[str, str]]
|
|
) -> None:
|
|
"""Rollback a failed domain addition by removing entries from map file.
|
|
|
|
Called when HAProxy Runtime API update fails after the map file
|
|
has already been saved.
|
|
|
|
Args:
|
|
domain: Domain name that was added.
|
|
entries: Current list of map entries to rollback from.
|
|
"""
|
|
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
|
try:
|
|
save_map_file(rollback_entries)
|
|
except IOError:
|
|
logger.error("Failed to rollback map file after HAProxy error")
|
|
|
|
|
|
def check_certificate_coverage(domain: str) -> tuple[bool, str]:
|
|
"""Check if a domain is covered by an existing certificate.
|
|
|
|
Args:
|
|
domain: Domain name to check (e.g., api.example.com)
|
|
|
|
Returns:
|
|
Tuple of (is_covered, certificate_name or message)
|
|
"""
|
|
if not os.path.isdir(CERTS_DIR):
|
|
return False, "Certificate directory not found"
|
|
|
|
# Check for exact match first
|
|
exact_pem = os.path.join(CERTS_DIR, f"{domain}.pem")
|
|
if os.path.exists(exact_pem):
|
|
return True, domain
|
|
|
|
# Check for wildcard coverage (e.g., api.example.com covered by *.example.com)
|
|
parts = domain.split(".")
|
|
if len(parts) >= 2:
|
|
# Try parent domain (example.com for api.example.com)
|
|
parent_domain = ".".join(parts[1:])
|
|
parent_pem = os.path.join(CERTS_DIR, f"{parent_domain}.pem")
|
|
|
|
if os.path.exists(parent_pem):
|
|
# Verify the certificate has wildcard SAN
|
|
try:
|
|
result = subprocess.run(
|
|
["openssl", "x509", "-in", parent_pem, "-noout", "-ext", "subjectAltName"],
|
|
capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT
|
|
)
|
|
if result.returncode == 0:
|
|
# Check if wildcard covers this domain
|
|
wildcard = f"*.{parent_domain}"
|
|
if wildcard in result.stdout:
|
|
return True, f"{parent_domain} (wildcard)"
|
|
except (subprocess.TimeoutExpired, OSError):
|
|
pass
|
|
|
|
return False, "No matching certificate"
|
|
|
|
|
|
def register_domain_tools(mcp):
|
|
"""Register domain management tools with MCP server."""
|
|
|
|
@mcp.tool()
|
|
def haproxy_list_domains(
|
|
include_wildcards: Annotated[bool, Field(default=False, description="Include wildcard entries (.example.com). Default: False")]
|
|
) -> str:
|
|
"""List all configured domains with their backend servers."""
|
|
try:
|
|
domains = []
|
|
state = haproxy_cmd("show servers state")
|
|
|
|
# Build server map from HAProxy state
|
|
server_map: dict[str, list] = {}
|
|
for line in state.split("\n"):
|
|
parts = line.split()
|
|
if len(parts) >= STATE_MIN_COLUMNS and parts[StateField.SRV_ADDR] != "0.0.0.0":
|
|
backend = parts[StateField.BE_NAME]
|
|
if backend not in server_map:
|
|
server_map[backend] = []
|
|
server_map[backend].append(
|
|
f"{parts[StateField.SRV_NAME]}={parts[StateField.SRV_ADDR]}:{parts[StateField.SRV_PORT]}"
|
|
)
|
|
|
|
# Read from domains.map
|
|
seen_domains: set[str] = set()
|
|
for domain, backend in get_map_contents():
|
|
# Skip wildcard entries unless explicitly requested
|
|
if domain.startswith(".") and not include_wildcards:
|
|
continue
|
|
if domain in seen_domains:
|
|
continue
|
|
seen_domains.add(domain)
|
|
servers = server_map.get(backend, ["(none)"])
|
|
if domain.startswith("."):
|
|
backend_type = "wildcard"
|
|
elif backend.startswith("pool_"):
|
|
backend_type = "pool"
|
|
else:
|
|
backend_type = "static"
|
|
domains.append(f"• {domain} -> {backend} ({backend_type}): {', '.join(servers)}")
|
|
|
|
return "\n".join(domains) if domains else "No domains configured"
|
|
except HaproxyError as e:
|
|
return f"Error: {e}"
|
|
|
|
@mcp.tool()
|
|
def haproxy_add_domain(
|
|
domain: Annotated[str, Field(description="Domain name to add (e.g., api.example.com, example.com)")],
|
|
ip: Annotated[str, Field(default="", description="Optional: Initial server IP. If provided, adds server to slot 1")],
|
|
http_port: Annotated[int, Field(default=80, description="HTTP port for backend server (default: 80)")]
|
|
) -> str:
|
|
"""Add a new domain to HAProxy (no reload required).
|
|
|
|
Creates domain→pool mapping. Use haproxy_add_server to add more servers later.
|
|
|
|
Example: haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
|
|
"""
|
|
# Validate inputs
|
|
if domain.startswith("."):
|
|
return "Error: Domain cannot start with '.' (wildcard entries are added automatically)"
|
|
if not validate_domain(domain):
|
|
return "Error: Invalid domain format"
|
|
if not validate_ip(ip, allow_empty=True):
|
|
return "Error: Invalid IP address format"
|
|
if not (1 <= http_port <= 65535):
|
|
return "Error: Port must be between 1 and 65535"
|
|
|
|
# Use file locking for the entire pool allocation operation
|
|
lock_path = f"{MAP_FILE}.lock"
|
|
with open(lock_path, 'w') as lock_file:
|
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
|
try:
|
|
# Read map contents once for both existence check and pool lookup
|
|
entries = get_map_contents()
|
|
|
|
# Check if domain already exists (using cached entries)
|
|
for domain_entry, backend in entries:
|
|
if domain_entry == domain:
|
|
return f"Error: Domain {domain} already exists (mapped to {backend})"
|
|
|
|
# Build used pools and registered domains sets
|
|
used_pools: set[str] = set()
|
|
registered_domains: set[str] = set()
|
|
for entry_domain, backend in entries:
|
|
if backend.startswith("pool_"):
|
|
used_pools.add(backend)
|
|
if not entry_domain.startswith("."):
|
|
registered_domains.add(entry_domain)
|
|
|
|
# Find available pool
|
|
pool = _find_available_pool(entries, used_pools)
|
|
if not pool:
|
|
return f"Error: All {POOL_COUNT} pool backends are in use"
|
|
|
|
# Check if this is a subdomain of an existing domain
|
|
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
|
|
|
|
try:
|
|
# Save to disk first (atomic write for persistence)
|
|
entries.append((domain, pool))
|
|
if not is_subdomain:
|
|
entries.append((f".{domain}", pool))
|
|
try:
|
|
save_map_file(entries)
|
|
except IOError as e:
|
|
return f"Error: Failed to save map file: {e}"
|
|
|
|
# Update HAProxy maps via Runtime API
|
|
try:
|
|
_update_haproxy_maps(domain, pool, is_subdomain)
|
|
except HaproxyError as e:
|
|
_rollback_domain_addition(domain, entries)
|
|
return f"Error: Failed to update HAProxy map: {e}"
|
|
|
|
# If IP provided, add server to slot 1
|
|
if ip:
|
|
add_server_to_config(domain, 1, ip, http_port)
|
|
try:
|
|
server = f"{pool}_1"
|
|
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
|
haproxy_cmd(f"set server {pool}/{server} state ready")
|
|
except HaproxyError as e:
|
|
remove_server_from_config(domain, 1)
|
|
return f"Domain {domain} added to {pool} but server config failed: {e}"
|
|
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
|
|
else:
|
|
result = f"Domain {domain} added to {pool} (no servers configured)"
|
|
|
|
if is_subdomain:
|
|
result += f" (subdomain of {parent_domain}, no wildcard)"
|
|
|
|
# Check certificate coverage
|
|
cert_covered, cert_info = check_certificate_coverage(domain)
|
|
if cert_covered:
|
|
result += f"\nSSL: Using certificate {cert_info}"
|
|
else:
|
|
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
|
|
|
|
return result
|
|
|
|
except HaproxyError as e:
|
|
return f"Error: {e}"
|
|
finally:
|
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
|
|
|
@mcp.tool()
|
|
def haproxy_remove_domain(
|
|
domain: Annotated[str, Field(description="Domain name to remove (e.g., api.example.com)")]
|
|
) -> str:
|
|
"""Remove a domain from HAProxy (no reload required)."""
|
|
if not validate_domain(domain):
|
|
return "Error: Invalid domain format"
|
|
|
|
# Look up the domain in the map
|
|
backend = get_domain_backend(domain)
|
|
if not backend:
|
|
return f"Error: Domain {domain} not found"
|
|
|
|
# Check if this is a legacy backend (not a pool)
|
|
if is_legacy_backend(backend):
|
|
return f"Error: Cannot remove legacy domain {domain} (uses static backend {backend})"
|
|
|
|
try:
|
|
# Save to disk first (atomic write for persistence)
|
|
entries = get_map_contents()
|
|
new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
|
save_map_file(new_entries)
|
|
|
|
# Remove from persistent server config
|
|
remove_domain_from_config(domain)
|
|
|
|
# Clear map entries via Runtime API (immediate effect)
|
|
# 2-stage matching: exact from domains.map, wildcard from wildcards.map
|
|
haproxy_cmd(f"del map {MAP_FILE_CONTAINER} {domain}")
|
|
try:
|
|
haproxy_cmd(f"del map {WILDCARDS_MAP_FILE_CONTAINER} .{domain}")
|
|
except HaproxyError as e:
|
|
logger.warning("Failed to remove wildcard entry for %s: %s", domain, e)
|
|
|
|
# Disable all servers in the pool (reset to 0.0.0.0:0)
|
|
for slot in range(1, MAX_SLOTS + 1):
|
|
server = f"{backend}_{slot}"
|
|
try:
|
|
haproxy_cmd(f"set server {backend}/{server} state maint")
|
|
haproxy_cmd(f"set server {backend}/{server} addr 0.0.0.0 port 0")
|
|
except HaproxyError as e:
|
|
logger.warning(
|
|
"Failed to clear server %s/%s for domain %s: %s",
|
|
backend, server, domain, e
|
|
)
|
|
# Continue with remaining cleanup
|
|
|
|
return f"Domain {domain} removed from {backend}"
|
|
|
|
except IOError as e:
|
|
return f"Error: Failed to update map file: {e}"
|
|
except HaproxyError as e:
|
|
return f"Error: {e}"
|