refactor: Improve code quality, error handling, and test coverage
- Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,13 @@ from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..config import logger, SUBPROCESS_TIMEOUT
|
||||
from ..config import (
|
||||
logger,
|
||||
SUBPROCESS_TIMEOUT,
|
||||
CERTS_DIR,
|
||||
CERTS_DIR_CONTAINER,
|
||||
ACME_HOME,
|
||||
)
|
||||
from ..validation import validate_domain
|
||||
from ..haproxy_client import haproxy_cmd
|
||||
from ..file_ops import (
|
||||
@@ -16,11 +22,8 @@ from ..file_ops import (
|
||||
remove_cert_from_config,
|
||||
)
|
||||
|
||||
# Certificate paths
|
||||
ACME_SH = os.path.expanduser("~/.acme.sh/acme.sh")
|
||||
ACME_HOME = os.path.expanduser("~/.acme.sh")
|
||||
CERTS_DIR = "/opt/haproxy/certs"
|
||||
CERTS_DIR_CONTAINER = "/etc/haproxy/certs"
|
||||
# acme.sh script path (derived from ACME_HOME)
|
||||
ACME_SH = os.path.join(ACME_HOME, "acme.sh")
|
||||
|
||||
# Longer timeout for certificate operations (ACME can be slow)
|
||||
CERT_TIMEOUT = 120
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import fcntl
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..config import (
|
||||
StateField,
|
||||
STATE_MIN_COLUMNS,
|
||||
SUBPROCESS_TIMEOUT,
|
||||
CERTS_DIR,
|
||||
logger,
|
||||
)
|
||||
from ..exceptions import HaproxyError
|
||||
@@ -31,8 +32,85 @@ from ..file_ops import (
|
||||
remove_domain_from_config,
|
||||
)
|
||||
|
||||
# Certificate paths
|
||||
CERTS_DIR = "/opt/haproxy/certs"
|
||||
|
||||
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
|
||||
"""Find an available pool backend from the pool list.
|
||||
|
||||
Iterates through pool_1 to pool_N and returns the first pool
|
||||
that is not currently in use.
|
||||
|
||||
Args:
|
||||
entries: List of (domain, backend) tuples from the map file.
|
||||
used_pools: Set of pool names already in use.
|
||||
|
||||
Returns:
|
||||
Available pool name (e.g., "pool_5") or None if all pools are in use.
|
||||
"""
|
||||
for i in range(1, POOL_COUNT + 1):
|
||||
pool_name = f"pool_{i}"
|
||||
if pool_name not in used_pools:
|
||||
return pool_name
|
||||
return None
|
||||
|
||||
|
||||
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
|
||||
"""Check if a domain is a subdomain of an existing registered domain.
|
||||
|
||||
For example, vault.anvil.it.com is a subdomain if anvil.it.com exists.
|
||||
Subdomains should not have wildcard entries added to avoid conflicts.
|
||||
|
||||
Args:
|
||||
domain: Domain name to check (e.g., "api.example.com").
|
||||
registered_domains: Set of already registered domain names.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_subdomain, parent_domain or None).
|
||||
"""
|
||||
parts = domain.split(".")
|
||||
for i in range(1, len(parts)):
|
||||
candidate = ".".join(parts[i:])
|
||||
if candidate in registered_domains:
|
||||
return True, candidate
|
||||
return False, None
|
||||
|
||||
|
||||
def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
|
||||
"""Update HAProxy maps via Runtime API.
|
||||
|
||||
Uses 2-stage matching: exact domains go to domains.map,
|
||||
wildcards go to wildcards.map.
|
||||
|
||||
Args:
|
||||
domain: Domain name to add.
|
||||
pool: Pool backend name (e.g., "pool_5").
|
||||
is_subdomain: If True, skip adding wildcard entry.
|
||||
|
||||
Raises:
|
||||
HaproxyError: If HAProxy Runtime API command fails.
|
||||
"""
|
||||
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}")
|
||||
if not is_subdomain:
|
||||
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
|
||||
|
||||
|
||||
def _rollback_domain_addition(
|
||||
domain: str,
|
||||
entries: list[tuple[str, str]]
|
||||
) -> None:
|
||||
"""Rollback a failed domain addition by removing entries from map file.
|
||||
|
||||
Called when HAProxy Runtime API update fails after the map file
|
||||
has already been saved.
|
||||
|
||||
Args:
|
||||
domain: Domain name that was added.
|
||||
entries: Current list of map entries to rollback from.
|
||||
"""
|
||||
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
||||
try:
|
||||
save_map_file(rollback_entries)
|
||||
except IOError:
|
||||
logger.error("Failed to rollback map file after HAProxy error")
|
||||
|
||||
|
||||
def check_certificate_coverage(domain: str) -> tuple[bool, str]:
|
||||
@@ -158,41 +236,26 @@ def register_domain_tools(mcp):
|
||||
if domain_entry == domain:
|
||||
return f"Error: Domain {domain} already exists (mapped to {backend})"
|
||||
|
||||
# Find available pool (using cached entries)
|
||||
# Build used pools and registered domains sets
|
||||
used_pools: set[str] = set()
|
||||
registered_domains: set[str] = set()
|
||||
for entry_domain, backend in entries:
|
||||
if backend.startswith("pool_"):
|
||||
used_pools.add(backend)
|
||||
# Collect non-wildcard domains for subdomain check
|
||||
if not entry_domain.startswith("."):
|
||||
registered_domains.add(entry_domain)
|
||||
|
||||
pool = None
|
||||
for i in range(1, POOL_COUNT + 1):
|
||||
pool_name = f"pool_{i}"
|
||||
if pool_name not in used_pools:
|
||||
pool = pool_name
|
||||
break
|
||||
# Find available pool
|
||||
pool = _find_available_pool(entries, used_pools)
|
||||
if not pool:
|
||||
return f"Error: All {POOL_COUNT} pool backends are in use"
|
||||
|
||||
# Check if this is a subdomain of an existing domain
|
||||
# e.g., vault.anvil.it.com is subdomain if anvil.it.com exists
|
||||
is_subdomain = False
|
||||
parent_domain = None
|
||||
parts = domain.split(".")
|
||||
for i in range(1, len(parts)):
|
||||
candidate = ".".join(parts[i:])
|
||||
if candidate in registered_domains:
|
||||
is_subdomain = True
|
||||
parent_domain = candidate
|
||||
break
|
||||
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
|
||||
|
||||
try:
|
||||
# Save to disk first (atomic write for persistence)
|
||||
entries.append((domain, pool))
|
||||
# Only add wildcard for root domains, not subdomains
|
||||
if not is_subdomain:
|
||||
entries.append((f".{domain}", pool))
|
||||
try:
|
||||
@@ -200,42 +263,29 @@ def register_domain_tools(mcp):
|
||||
except IOError as e:
|
||||
return f"Error: Failed to save map file: {e}"
|
||||
|
||||
# Then update HAProxy maps via Runtime API
|
||||
# 2-stage matching: exact domains go to domains.map, wildcards go to wildcards.map
|
||||
# Update HAProxy maps via Runtime API
|
||||
try:
|
||||
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}")
|
||||
if not is_subdomain:
|
||||
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
|
||||
_update_haproxy_maps(domain, pool, is_subdomain)
|
||||
except HaproxyError as e:
|
||||
# Rollback: remove the domain we just added from entries and re-save
|
||||
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
||||
try:
|
||||
save_map_file(rollback_entries)
|
||||
except IOError:
|
||||
logger.error("Failed to rollback map file after HAProxy error")
|
||||
_rollback_domain_addition(domain, entries)
|
||||
return f"Error: Failed to update HAProxy map: {e}"
|
||||
|
||||
# If IP provided, add server to slot 1
|
||||
if ip:
|
||||
# Save server config to disk first
|
||||
add_server_to_config(domain, 1, ip, http_port)
|
||||
|
||||
try:
|
||||
server = f"{pool}_1"
|
||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
||||
haproxy_cmd(f"set server {pool}/{server} state ready")
|
||||
except HaproxyError as e:
|
||||
# Rollback server config on failure
|
||||
remove_server_from_config(domain, 1)
|
||||
return f"Domain {domain} added to {pool} but server config failed: {e}"
|
||||
|
||||
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
|
||||
if is_subdomain:
|
||||
result += f" (subdomain of {parent_domain}, no wildcard)"
|
||||
else:
|
||||
result = f"Domain {domain} added to {pool} (no servers configured)"
|
||||
if is_subdomain:
|
||||
result += f" (subdomain of {parent_domain}, no wildcard)"
|
||||
|
||||
if is_subdomain:
|
||||
result += f" (subdomain of {parent_domain}, no wildcard)"
|
||||
|
||||
# Check certificate coverage
|
||||
cert_covered, cert_info = check_certificate_coverage(domain)
|
||||
|
||||
@@ -13,6 +13,7 @@ from ..config import (
|
||||
StateField,
|
||||
StatField,
|
||||
STATE_MIN_COLUMNS,
|
||||
logger,
|
||||
)
|
||||
from ..exceptions import HaproxyError
|
||||
from ..validation import validate_domain, validate_ip, validate_backend_name
|
||||
@@ -252,6 +253,7 @@ def register_server_tools(mcp):
|
||||
added = []
|
||||
errors = []
|
||||
failed_slots = []
|
||||
successfully_added_slots = []
|
||||
|
||||
try:
|
||||
for server_config in validated_servers:
|
||||
@@ -260,19 +262,43 @@ def register_server_tools(mcp):
|
||||
http_port = server_config["http_port"]
|
||||
try:
|
||||
configure_server_slot(backend, server_prefix, slot, ip, http_port)
|
||||
successfully_added_slots.append(slot)
|
||||
added.append(f"slot {slot}: {ip}:{http_port}")
|
||||
except HaproxyError as e:
|
||||
failed_slots.append(slot)
|
||||
errors.append(f"slot {slot}: {e}")
|
||||
except Exception as e:
|
||||
# Rollback all saved configs on unexpected error
|
||||
# Rollback only successfully added configs on unexpected error
|
||||
for slot in successfully_added_slots:
|
||||
try:
|
||||
remove_server_from_config(domain, slot)
|
||||
except Exception as rollback_error:
|
||||
logger.error(
|
||||
"Failed to rollback server config for %s slot %d: %s",
|
||||
domain, slot, rollback_error
|
||||
)
|
||||
# Also rollback configs that weren't yet processed
|
||||
for server_config in validated_servers:
|
||||
remove_server_from_config(domain, server_config["slot"])
|
||||
slot = server_config["slot"]
|
||||
if slot not in successfully_added_slots:
|
||||
try:
|
||||
remove_server_from_config(domain, slot)
|
||||
except Exception as rollback_error:
|
||||
logger.error(
|
||||
"Failed to rollback server config for %s slot %d: %s",
|
||||
domain, slot, rollback_error
|
||||
)
|
||||
return f"Error: {e}"
|
||||
|
||||
# Rollback failed slots from config
|
||||
for slot in failed_slots:
|
||||
remove_server_from_config(domain, slot)
|
||||
try:
|
||||
remove_server_from_config(domain, slot)
|
||||
except Exception as rollback_error:
|
||||
logger.error(
|
||||
"Failed to rollback server config for %s slot %d: %s",
|
||||
domain, slot, rollback_error
|
||||
)
|
||||
|
||||
# Build result message
|
||||
result_parts = []
|
||||
|
||||
Reference in New Issue
Block a user