From 6bcfee519cb783b4e1aab0f31162ea1885ce3121 Mon Sep 17 00:00:00 2001 From: kaffa Date: Tue, 3 Feb 2026 12:50:00 +0900 Subject: [PATCH] refactor: Improve code quality, error handling, and test coverage - Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 --- haproxy_mcp/config.py | 43 +- haproxy_mcp/exceptions.py | 15 + haproxy_mcp/file_ops.py | 209 ++-- haproxy_mcp/haproxy_client.py | 38 +- haproxy_mcp/pyproject.toml | 17 + haproxy_mcp/tools/certificates.py | 15 +- haproxy_mcp/tools/domains.py | 132 ++- haproxy_mcp/tools/servers.py | 32 +- haproxy_mcp/uv.lock | 216 +++- tests/__init__.py | 1 + tests/conftest.py | 345 ++++++ tests/integration/__init__.py | 1 + tests/unit/__init__.py | 1 + tests/unit/test_config.py | 198 ++++ tests/unit/test_file_ops.py | 498 +++++++++ tests/unit/test_haproxy_client.py | 279 +++++ tests/unit/test_utils.py | 130 +++ tests/unit/test_validation.py | 275 +++++ tests/unit/tools/__init__.py | 1 + tests/unit/tools/test_certificates.py | 1198 +++++++++++++++++++++ tests/unit/tools/test_configuration.py | 749 +++++++++++++ tests/unit/tools/test_domains.py | 476 +++++++++ tests/unit/tools/test_health.py | 433 ++++++++ tests/unit/tools/test_monitoring.py | 325 ++++++ tests/unit/tools/test_servers.py | 1350 ++++++++++++++++++++++++ 25 files changed, 6852 insertions(+), 125 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_file_ops.py create mode 100644 tests/unit/test_haproxy_client.py create mode 100644 tests/unit/test_utils.py create mode 100644 tests/unit/test_validation.py create mode 100644 tests/unit/tools/__init__.py create mode 100644 tests/unit/tools/test_certificates.py create mode 100644 tests/unit/tools/test_configuration.py create mode 100644 tests/unit/tools/test_domains.py create mode 100644 tests/unit/tools/test_health.py create mode 100644 tests/unit/tools/test_monitoring.py create mode 100644 tests/unit/tools/test_servers.py diff --git a/haproxy_mcp/config.py b/haproxy_mcp/config.py index d4f0d6e..ffa3900 100644 --- a/haproxy_mcp/config.py +++ b/haproxy_mcp/config.py @@ -32,6 +32,11 @@ WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAI SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json") CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json") +# Certificate paths +CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs") +CERTS_DIR_CONTAINER: str = os.getenv("HAPROXY_CERTS_DIR_CONTAINER", "/etc/haproxy/certs") +ACME_HOME: str = os.getenv("ACME_HOME", os.path.expanduser("~/.acme.sh")) + # Pool configuration POOL_COUNT: int = int(os.getenv("HAPROXY_POOL_COUNT", "100")) MAX_SLOTS: int = int(os.getenv("HAPROXY_MAX_SLOTS", "10")) @@ -49,15 +54,37 @@ BACKEND_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_-]+$') # Pattern for converting domain to backend name NON_ALNUM_PATTERN = re.compile(r'[^a-zA-Z0-9]') -# Limits +# Limits and Constants MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB max response from HAProxy -SUBPROCESS_TIMEOUT = 30 # seconds -STARTUP_RETRY_COUNT = 10 # HAProxy ready check retries -STATE_MIN_COLUMNS = 19 # Minimum columns in HAProxy server state output -SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection -SOCKET_RECV_TIMEOUT = 30 # seconds for HAProxy socket recv loop -MAX_BULK_SERVERS = 10 # Max servers per bulk add call -MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON in haproxy_add_servers + +SUBPROCESS_TIMEOUT = 30 # seconds for podman exec commands (config validation, reload) + +# STARTUP_RETRY_COUNT: Number of attempts to verify HAProxy is ready on MCP startup. +# During startup, MCP needs to restore server configurations from servers.json. +# HAProxy may take a few seconds to fully initialize the Runtime API socket. +# Each retry waits 1 second, so 10 retries = max 10 seconds startup wait. +# If HAProxy isn't ready after 10 attempts, startup proceeds but logs a warning. +STARTUP_RETRY_COUNT = 10 + +# STATE_MIN_COLUMNS: Expected minimum column count in 'show servers state' output. +# HAProxy 'show servers state' returns tab-separated values with the following fields: +# 0: be_id - Backend ID +# 1: be_name - Backend name +# 2: srv_id - Server ID +# 3: srv_name - Server name +# 4: srv_addr - Server IP address +# 5: srv_op_state - Operational state (0=stopped, 1=starting, 2=running, etc.) +# 6: srv_admin_state - Admin state (0=ready, 1=drain, 2=maint, etc.) +# 7-17: Various internal state fields (weight, check info, etc.) +# 18: srv_port - Server port +# Total: 19+ columns (may increase in future HAProxy versions) +# Lines with fewer columns are invalid/incomplete and should be skipped. +STATE_MIN_COLUMNS = 19 + +SOCKET_TIMEOUT = 5 # seconds for HAProxy socket connection establishment +SOCKET_RECV_TIMEOUT = 30 # seconds for complete response (large stats output) +MAX_BULK_SERVERS = 10 # Max servers per bulk add call (prevents oversized requests) +MAX_SERVERS_JSON_SIZE = 10000 # Max size of servers JSON input (10KB, prevents abuse) # CSV field indices for HAProxy stats (show stat command) diff --git a/haproxy_mcp/exceptions.py b/haproxy_mcp/exceptions.py index 2981687..f5909bb 100644 --- a/haproxy_mcp/exceptions.py +++ b/haproxy_mcp/exceptions.py @@ -9,3 +9,18 @@ class HaproxyError(Exception): class NoAvailablePoolError(HaproxyError): """All pool backends are in use.""" pass + + +class ValidationError(HaproxyError): + """Input validation error.""" + pass + + +class ConfigurationError(HaproxyError): + """Configuration file or state error.""" + pass + + +class CertificateError(HaproxyError): + """SSL/TLS certificate operation error.""" + pass diff --git a/haproxy_mcp/file_ops.py b/haproxy_mcp/file_ops.py index b8da8fb..2da2e7d 100644 --- a/haproxy_mcp/file_ops.py +++ b/haproxy_mcp/file_ops.py @@ -4,7 +4,8 @@ import fcntl import json import os import tempfile -from typing import Any, Optional +from contextlib import contextmanager +from typing import Any, Generator, Optional from .config import ( MAP_FILE, @@ -16,6 +17,34 @@ from .config import ( from .validation import domain_to_backend +@contextmanager +def file_lock(lock_path: str) -> Generator[None, None, None]: + """Acquire exclusive file lock for atomic operations. + + This context manager provides a consistent locking mechanism for + read-modify-write operations on configuration files to prevent + race conditions during concurrent access. + + Args: + lock_path: Path to the lock file (typically config_file.lock) + + Yields: + None - the lock is held for the duration of the context + + Example: + with file_lock("/path/to/config.json.lock"): + config = load_config() + config["key"] = "value" + save_config(config) + """ + with open(lock_path, 'w') as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + + def atomic_write_file(file_path: str, content: str) -> None: """Write content to file atomically using temp file + rename. @@ -118,17 +147,61 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str def save_map_file(entries: list[tuple[str, str]]) -> None: - """Save entries to separate map files for 2-stage matching. + """Save domain-to-backend entries using 2-stage map routing architecture. - Uses 2-stage matching for performance: - - domains.map: Exact domain matches (used with map_str, O(log n)) - - wildcards.map: Wildcard entries (used with map_dom, O(n)) + This function implements HAProxy's 2-stage domain routing for optimal + performance. Entries are automatically split into two separate map files + based on whether they are exact domains or wildcard patterns. + + 2-Stage Routing Architecture: + Stage 1 - Exact Match (domains.map): + - HAProxy directive: map_str(req.hdr(host),"/path/domains.map") + - Data structure: ebtree (elastic binary tree) + - Lookup complexity: O(log n) + - Use case: Exact domain matches (e.g., "api.example.com") + + Stage 2 - Wildcard Match (wildcards.map): + - HAProxy directive: map_dom(req.hdr(host),"/path/wildcards.map") + - Data structure: Linear suffix search + - Lookup complexity: O(n) where n = number of wildcard entries + - Use case: Wildcard domains (e.g., ".example.com" matches *.example.com) + - Typically small set, so O(n) is acceptable + + Performance Characteristics: + - 1000 exact domains: ~10 comparisons (log2(1000) approx 10) + - 10 wildcard entries: 10 suffix comparisons (acceptable) + - By separating exact and wildcard entries, we avoid O(n) lookup + for the common case (exact domain match) + + HAProxy Configuration Example: + use_backend %[req.hdr(host),lower,map_str(/etc/haproxy/domains.map)] + if { req.hdr(host),lower,map_str(/etc/haproxy/domains.map) -m found } + use_backend %[req.hdr(host),lower,map_dom(/etc/haproxy/wildcards.map)] + if { req.hdr(host),lower,map_dom(/etc/haproxy/wildcards.map) -m found } Args: - entries: List of (domain, backend) tuples to write + entries: List of (domain, backend) tuples to write. + - Exact domains: "api.example.com" -> written to domains.map + - Wildcards: ".example.com" (matches *.example.com) -> written + to wildcards.map Raises: - IOError: If the file cannot be written + IOError: If either map file cannot be written. + + File Formats: + domains.map: + # Exact Domain to Backend mapping (for map_str) + api.example.com pool_1 + www.example.com pool_2 + + wildcards.map: + # Wildcard Domain to Backend mapping (for map_dom) + .example.com pool_3 # Matches *.example.com + .test.org pool_4 # Matches *.test.org + + Note: + Both files are written atomically using temp file + rename to prevent + corruption during concurrent access or system failures. """ # Split into exact and wildcard entries exact_entries, wildcard_entries = split_domain_entries(entries) @@ -170,13 +243,48 @@ def get_domain_backend(domain: str) -> Optional[str]: def is_legacy_backend(backend: str) -> bool: - """Check if backend is a legacy static backend (not a pool). + """Check if backend is a legacy static backend (not a dynamic pool). + + This function distinguishes between two backend naming conventions used + in the HAProxy MCP system: + + Pool Backends (Dynamic): + - Named: pool_1, pool_2, ..., pool_100 + - Pre-configured in haproxy.cfg with 10 server slots each + - Domains are dynamically assigned to available pools via domains.map + - Server slots configured at runtime via Runtime API + - Allows zero-reload domain management + + Legacy Backends (Static): + - Named: {domain}_backend (e.g., "api_example_com_backend") + - Defined statically in haproxy.cfg + - Requires HAProxy reload to add new backends + - Used for domains that were configured before pool-based routing Args: - backend: Backend name to check + backend: Backend name to check (e.g., "pool_5" or "api_example_com_backend"). Returns: - True if this is a legacy backend, False if it's a pool + True if this is a legacy backend (does not start with "pool_"), + False if it's a pool backend. + + Usage Scenarios: + - When listing servers: Determines server naming convention + (pool backends use pool_N_M, legacy use {domain}_M) + - When adding servers: Determines which backend configuration + approach to use + - During migration: Helps identify domains that need migration + from legacy to pool-based routing + + Examples: + >>> is_legacy_backend("pool_5") + False + >>> is_legacy_backend("pool_100") + False + >>> is_legacy_backend("api_example_com_backend") + True + >>> is_legacy_backend("myservice_backend") + True """ return not backend.startswith("pool_") @@ -263,17 +371,12 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non ip: Server IP address http_port: HTTP port """ - lock_path = f"{SERVERS_FILE}.lock" - with open(lock_path, 'w') as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - try: - config = load_servers_config() - if domain not in config: - config[domain] = {} - config[domain][str(slot)] = {"ip": ip, "http_port": http_port} - save_servers_config(config) - finally: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + with file_lock(f"{SERVERS_FILE}.lock"): + config = load_servers_config() + if domain not in config: + config[domain] = {} + config[domain][str(slot)] = {"ip": ip, "http_port": http_port} + save_servers_config(config) def remove_server_from_config(domain: str, slot: int) -> None: @@ -283,18 +386,13 @@ def remove_server_from_config(domain: str, slot: int) -> None: domain: Domain name slot: Server slot to remove """ - lock_path = f"{SERVERS_FILE}.lock" - with open(lock_path, 'w') as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - try: - config = load_servers_config() - if domain in config and str(slot) in config[domain]: - del config[domain][str(slot)] - if not config[domain]: - del config[domain] - save_servers_config(config) - finally: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + with file_lock(f"{SERVERS_FILE}.lock"): + config = load_servers_config() + if domain in config and str(slot) in config[domain]: + del config[domain][str(slot)] + if not config[domain]: + del config[domain] + save_servers_config(config) def remove_domain_from_config(domain: str) -> None: @@ -303,16 +401,11 @@ def remove_domain_from_config(domain: str) -> None: Args: domain: Domain name to remove """ - lock_path = f"{SERVERS_FILE}.lock" - with open(lock_path, 'w') as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - try: - config = load_servers_config() - if domain in config: - del config[domain] - save_servers_config(config) - finally: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + with file_lock(f"{SERVERS_FILE}.lock"): + config = load_servers_config() + if domain in config: + del config[domain] + save_servers_config(config) # Certificate configuration functions @@ -359,16 +452,11 @@ def add_cert_to_config(domain: str) -> None: Args: domain: Domain name to add """ - lock_path = f"{CERTS_FILE}.lock" - with open(lock_path, 'w') as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - try: - domains = load_certs_config() - if domain not in domains: - domains.append(domain) - save_certs_config(domains) - finally: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + with file_lock(f"{CERTS_FILE}.lock"): + domains = load_certs_config() + if domain not in domains: + domains.append(domain) + save_certs_config(domains) def remove_cert_from_config(domain: str) -> None: @@ -377,13 +465,8 @@ def remove_cert_from_config(domain: str) -> None: Args: domain: Domain name to remove """ - lock_path = f"{CERTS_FILE}.lock" - with open(lock_path, 'w') as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - try: - domains = load_certs_config() - if domain in domains: - domains.remove(domain) - save_certs_config(domains) - finally: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + with file_lock(f"{CERTS_FILE}.lock"): + domains = load_certs_config() + if domain in domains: + domains.remove(domain) + save_certs_config(domains) diff --git a/haproxy_mcp/haproxy_client.py b/haproxy_mcp/haproxy_client.py index 9063af9..981caba 100644 --- a/haproxy_mcp/haproxy_client.py +++ b/haproxy_mcp/haproxy_client.py @@ -89,13 +89,45 @@ def haproxy_cmd_checked(command: str) -> str: def _check_response_for_errors(response: str) -> None: - """Check HAProxy response for error indicators. + """Check HAProxy response for error indicators and raise if found. + + HAProxy Runtime API returns plain text responses. Success responses are + typically empty or contain requested data. Error responses contain + specific keywords that indicate the command failed. Args: - response: Response string from HAProxy + response: Response string from HAProxy Runtime API command. Raises: - HaproxyError: If response contains error indicators + HaproxyError: If response contains any error indicator keyword. + + Error Indicators: + - "No such": Resource doesn't exist (e.g., backend, server, map entry) + - "not found": Similar to "No such", resource lookup failed + - "error": General error in command execution + - "failed": Operation could not be completed + - "invalid": Malformed command or invalid parameter value + - "unknown": Unrecognized command or parameter + + Examples: + Successful responses (will NOT raise): + - "" (empty string for successful set commands) + - "1" (map entry ID after successful add) + - Server state data (for show commands) + + Error responses (WILL raise HaproxyError): + - "No such server." - Server doesn't exist in specified backend + - "No such backend." - Backend name not found + - "No such map." - Map file not loaded or doesn't exist + - "Entry not found." - Map entry lookup failed + - "Invalid server state." - Bad state value for set server state + - "unknown keyword 'xyz'" - Unrecognized command parameter + - "failed to allocate memory" - Resource allocation failure + - "'set server' expects :" - Invalid command syntax + + Note: + The check is case-insensitive to catch variations like "Error:", + "ERROR:", "error:" etc. that HAProxy may return. """ error_indicators = ["No such", "not found", "error", "failed", "invalid", "unknown"] if response: diff --git a/haproxy_mcp/pyproject.toml b/haproxy_mcp/pyproject.toml index 1195619..1f8e647 100644 --- a/haproxy_mcp/pyproject.toml +++ b/haproxy_mcp/pyproject.toml @@ -6,3 +6,20 @@ requires-python = ">=3.11" dependencies = [ "mcp[cli]>=1.0.0", ] + +[project.optional-dependencies] +test = [ + "pytest>=8.0.0", + "pytest-cov>=4.1.0", +] + +[tool.pytest.ini_options] +testpaths = ["../tests"] +pythonpath = ["."] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" +filterwarnings = [ + "ignore::DeprecationWarning", +] diff --git a/haproxy_mcp/tools/certificates.py b/haproxy_mcp/tools/certificates.py index ebf76f9..a910a8f 100644 --- a/haproxy_mcp/tools/certificates.py +++ b/haproxy_mcp/tools/certificates.py @@ -7,7 +7,13 @@ from typing import Annotated from pydantic import Field -from ..config import logger, SUBPROCESS_TIMEOUT +from ..config import ( + logger, + SUBPROCESS_TIMEOUT, + CERTS_DIR, + CERTS_DIR_CONTAINER, + ACME_HOME, +) from ..validation import validate_domain from ..haproxy_client import haproxy_cmd from ..file_ops import ( @@ -16,11 +22,8 @@ from ..file_ops import ( remove_cert_from_config, ) -# Certificate paths -ACME_SH = os.path.expanduser("~/.acme.sh/acme.sh") -ACME_HOME = os.path.expanduser("~/.acme.sh") -CERTS_DIR = "/opt/haproxy/certs" -CERTS_DIR_CONTAINER = "/etc/haproxy/certs" +# acme.sh script path (derived from ACME_HOME) +ACME_SH = os.path.join(ACME_HOME, "acme.sh") # Longer timeout for certificate operations (ACME can be slow) CERT_TIMEOUT = 120 diff --git a/haproxy_mcp/tools/domains.py b/haproxy_mcp/tools/domains.py index 1970f64..9a6a7ce 100644 --- a/haproxy_mcp/tools/domains.py +++ b/haproxy_mcp/tools/domains.py @@ -3,7 +3,7 @@ import fcntl import os import subprocess -from typing import Annotated +from typing import Annotated, Optional from pydantic import Field @@ -16,6 +16,7 @@ from ..config import ( StateField, STATE_MIN_COLUMNS, SUBPROCESS_TIMEOUT, + CERTS_DIR, logger, ) from ..exceptions import HaproxyError @@ -31,8 +32,85 @@ from ..file_ops import ( remove_domain_from_config, ) -# Certificate paths -CERTS_DIR = "/opt/haproxy/certs" + +def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]: + """Find an available pool backend from the pool list. + + Iterates through pool_1 to pool_N and returns the first pool + that is not currently in use. + + Args: + entries: List of (domain, backend) tuples from the map file. + used_pools: Set of pool names already in use. + + Returns: + Available pool name (e.g., "pool_5") or None if all pools are in use. + """ + for i in range(1, POOL_COUNT + 1): + pool_name = f"pool_{i}" + if pool_name not in used_pools: + return pool_name + return None + + +def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]: + """Check if a domain is a subdomain of an existing registered domain. + + For example, vault.anvil.it.com is a subdomain if anvil.it.com exists. + Subdomains should not have wildcard entries added to avoid conflicts. + + Args: + domain: Domain name to check (e.g., "api.example.com"). + registered_domains: Set of already registered domain names. + + Returns: + Tuple of (is_subdomain, parent_domain or None). + """ + parts = domain.split(".") + for i in range(1, len(parts)): + candidate = ".".join(parts[i:]) + if candidate in registered_domains: + return True, candidate + return False, None + + +def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None: + """Update HAProxy maps via Runtime API. + + Uses 2-stage matching: exact domains go to domains.map, + wildcards go to wildcards.map. + + Args: + domain: Domain name to add. + pool: Pool backend name (e.g., "pool_5"). + is_subdomain: If True, skip adding wildcard entry. + + Raises: + HaproxyError: If HAProxy Runtime API command fails. + """ + haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}") + if not is_subdomain: + haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}") + + +def _rollback_domain_addition( + domain: str, + entries: list[tuple[str, str]] +) -> None: + """Rollback a failed domain addition by removing entries from map file. + + Called when HAProxy Runtime API update fails after the map file + has already been saved. + + Args: + domain: Domain name that was added. + entries: Current list of map entries to rollback from. + """ + rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"] + try: + save_map_file(rollback_entries) + except IOError: + logger.error("Failed to rollback map file after HAProxy error") def check_certificate_coverage(domain: str) -> tuple[bool, str]: @@ -158,41 +236,26 @@ def register_domain_tools(mcp): if domain_entry == domain: return f"Error: Domain {domain} already exists (mapped to {backend})" - # Find available pool (using cached entries) + # Build used pools and registered domains sets used_pools: set[str] = set() registered_domains: set[str] = set() for entry_domain, backend in entries: if backend.startswith("pool_"): used_pools.add(backend) - # Collect non-wildcard domains for subdomain check if not entry_domain.startswith("."): registered_domains.add(entry_domain) - pool = None - for i in range(1, POOL_COUNT + 1): - pool_name = f"pool_{i}" - if pool_name not in used_pools: - pool = pool_name - break + # Find available pool + pool = _find_available_pool(entries, used_pools) if not pool: return f"Error: All {POOL_COUNT} pool backends are in use" # Check if this is a subdomain of an existing domain - # e.g., vault.anvil.it.com is subdomain if anvil.it.com exists - is_subdomain = False - parent_domain = None - parts = domain.split(".") - for i in range(1, len(parts)): - candidate = ".".join(parts[i:]) - if candidate in registered_domains: - is_subdomain = True - parent_domain = candidate - break + is_subdomain, parent_domain = _check_subdomain(domain, registered_domains) try: # Save to disk first (atomic write for persistence) entries.append((domain, pool)) - # Only add wildcard for root domains, not subdomains if not is_subdomain: entries.append((f".{domain}", pool)) try: @@ -200,42 +263,29 @@ def register_domain_tools(mcp): except IOError as e: return f"Error: Failed to save map file: {e}" - # Then update HAProxy maps via Runtime API - # 2-stage matching: exact domains go to domains.map, wildcards go to wildcards.map + # Update HAProxy maps via Runtime API try: - haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}") - if not is_subdomain: - haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}") + _update_haproxy_maps(domain, pool, is_subdomain) except HaproxyError as e: - # Rollback: remove the domain we just added from entries and re-save - rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"] - try: - save_map_file(rollback_entries) - except IOError: - logger.error("Failed to rollback map file after HAProxy error") + _rollback_domain_addition(domain, entries) return f"Error: Failed to update HAProxy map: {e}" # If IP provided, add server to slot 1 if ip: - # Save server config to disk first add_server_to_config(domain, 1, ip, http_port) - try: server = f"{pool}_1" haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}") haproxy_cmd(f"set server {pool}/{server} state ready") except HaproxyError as e: - # Rollback server config on failure remove_server_from_config(domain, 1) return f"Domain {domain} added to {pool} but server config failed: {e}" - result = f"Domain {domain} added to {pool} with server {ip}:{http_port}" - if is_subdomain: - result += f" (subdomain of {parent_domain}, no wildcard)" else: result = f"Domain {domain} added to {pool} (no servers configured)" - if is_subdomain: - result += f" (subdomain of {parent_domain}, no wildcard)" + + if is_subdomain: + result += f" (subdomain of {parent_domain}, no wildcard)" # Check certificate coverage cert_covered, cert_info = check_certificate_coverage(domain) diff --git a/haproxy_mcp/tools/servers.py b/haproxy_mcp/tools/servers.py index 9e3db41..5a73f54 100644 --- a/haproxy_mcp/tools/servers.py +++ b/haproxy_mcp/tools/servers.py @@ -13,6 +13,7 @@ from ..config import ( StateField, StatField, STATE_MIN_COLUMNS, + logger, ) from ..exceptions import HaproxyError from ..validation import validate_domain, validate_ip, validate_backend_name @@ -252,6 +253,7 @@ def register_server_tools(mcp): added = [] errors = [] failed_slots = [] + successfully_added_slots = [] try: for server_config in validated_servers: @@ -260,19 +262,43 @@ def register_server_tools(mcp): http_port = server_config["http_port"] try: configure_server_slot(backend, server_prefix, slot, ip, http_port) + successfully_added_slots.append(slot) added.append(f"slot {slot}: {ip}:{http_port}") except HaproxyError as e: failed_slots.append(slot) errors.append(f"slot {slot}: {e}") except Exception as e: - # Rollback all saved configs on unexpected error + # Rollback only successfully added configs on unexpected error + for slot in successfully_added_slots: + try: + remove_server_from_config(domain, slot) + except Exception as rollback_error: + logger.error( + "Failed to rollback server config for %s slot %d: %s", + domain, slot, rollback_error + ) + # Also rollback configs that weren't yet processed for server_config in validated_servers: - remove_server_from_config(domain, server_config["slot"]) + slot = server_config["slot"] + if slot not in successfully_added_slots: + try: + remove_server_from_config(domain, slot) + except Exception as rollback_error: + logger.error( + "Failed to rollback server config for %s slot %d: %s", + domain, slot, rollback_error + ) return f"Error: {e}" # Rollback failed slots from config for slot in failed_slots: - remove_server_from_config(domain, slot) + try: + remove_server_from_config(domain, slot) + except Exception as rollback_error: + logger.error( + "Failed to rollback server config for %s slot %d: %s", + domain, slot, rollback_error + ) # Build result message result_parts = [] diff --git a/haproxy_mcp/uv.lock b/haproxy_mcp/uv.lock index 9d291cb..e654fbb 100644 --- a/haproxy_mcp/uv.lock +++ b/haproxy_mcp/uv.lock @@ -133,6 +133,98 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "coverage" +version = "7.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/49/349848445b0e53660e258acbcc9b0d014895b6739237920886672240f84b/coverage-7.13.2.tar.gz", hash = "sha256:044c6951ec37146b72a50cc81ef02217d27d4c3640efd2640311393cbbf143d3", size = 826523, upload-time = "2026-01-25T13:00:04.889Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/01/abca50583a8975bb6e1c59eff67ed8e48bb127c07dad5c28d9e96ccc09ec/coverage-7.13.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:060ebf6f2c51aff5ba38e1f43a2095e087389b1c69d559fde6049a4b0001320e", size = 218971, upload-time = "2026-01-25T12:57:36.953Z" }, + { url = "https://files.pythonhosted.org/packages/eb/0e/b6489f344d99cd1e5b4d5e1be52dfd3f8a3dc5112aa6c33948da8cabad4e/coverage-7.13.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1ea8ca9db5e7469cd364552985e15911548ea5b69c48a17291f0cac70484b2e", size = 219473, upload-time = "2026-01-25T12:57:38.934Z" }, + { url = "https://files.pythonhosted.org/packages/17/11/db2f414915a8e4ec53f60b17956c27f21fb68fcf20f8a455ce7c2ccec638/coverage-7.13.2-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b780090d15fd58f07cf2011943e25a5f0c1c894384b13a216b6c86c8a8a7c508", size = 249896, upload-time = "2026-01-25T12:57:40.365Z" }, + { url = "https://files.pythonhosted.org/packages/80/06/0823fe93913663c017e508e8810c998c8ebd3ec2a5a85d2c3754297bdede/coverage-7.13.2-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:88a800258d83acb803c38175b4495d293656d5fac48659c953c18e5f539a274b", size = 251810, upload-time = "2026-01-25T12:57:42.045Z" }, + { url = "https://files.pythonhosted.org/packages/61/dc/b151c3cc41b28cdf7f0166c5fa1271cbc305a8ec0124cce4b04f74791a18/coverage-7.13.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6326e18e9a553e674d948536a04a80d850a5eeefe2aae2e6d7cf05d54046c01b", size = 253920, upload-time = "2026-01-25T12:57:44.026Z" }, + { url = "https://files.pythonhosted.org/packages/2d/35/e83de0556e54a4729a2b94ea816f74ce08732e81945024adee46851c2264/coverage-7.13.2-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:59562de3f797979e1ff07c587e2ac36ba60ca59d16c211eceaa579c266c5022f", size = 250025, upload-time = "2026-01-25T12:57:45.624Z" }, + { url = "https://files.pythonhosted.org/packages/39/67/af2eb9c3926ce3ea0d58a0d2516fcbdacf7a9fc9559fe63076beaf3f2596/coverage-7.13.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:27ba1ed6f66b0e2d61bfa78874dffd4f8c3a12f8e2b5410e515ab345ba7bc9c3", size = 251612, upload-time = "2026-01-25T12:57:47.713Z" }, + { url = "https://files.pythonhosted.org/packages/26/62/5be2e25f3d6c711d23b71296f8b44c978d4c8b4e5b26871abfc164297502/coverage-7.13.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8be48da4d47cc68754ce643ea50b3234557cbefe47c2f120495e7bd0a2756f2b", size = 249670, upload-time = "2026-01-25T12:57:49.378Z" }, + { url = "https://files.pythonhosted.org/packages/b3/51/400d1b09a8344199f9b6a6fc1868005d766b7ea95e7882e494fa862ca69c/coverage-7.13.2-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:2a47a4223d3361b91176aedd9d4e05844ca67d7188456227b6bf5e436630c9a1", size = 249395, upload-time = "2026-01-25T12:57:50.86Z" }, + { url = "https://files.pythonhosted.org/packages/e0/36/f02234bc6e5230e2f0a63fd125d0a2093c73ef20fdf681c7af62a140e4e7/coverage-7.13.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c6f141b468740197d6bd38f2b26ade124363228cc3f9858bd9924ab059e00059", size = 250298, upload-time = "2026-01-25T12:57:52.287Z" }, + { url = "https://files.pythonhosted.org/packages/b0/06/713110d3dd3151b93611c9cbfc65c15b4156b44f927fced49ac0b20b32a4/coverage-7.13.2-cp311-cp311-win32.whl", hash = "sha256:89567798404af067604246e01a49ef907d112edf2b75ef814b1364d5ce267031", size = 221485, upload-time = "2026-01-25T12:57:53.876Z" }, + { url = "https://files.pythonhosted.org/packages/16/0c/3ae6255fa1ebcb7dec19c9a59e85ef5f34566d1265c70af5b2fc981da834/coverage-7.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:21dd57941804ae2ac7e921771a5e21bbf9aabec317a041d164853ad0a96ce31e", size = 222421, upload-time = "2026-01-25T12:57:55.433Z" }, + { url = "https://files.pythonhosted.org/packages/b5/37/fabc3179af4d61d89ea47bd04333fec735cd5e8b59baad44fed9fc4170d7/coverage-7.13.2-cp311-cp311-win_arm64.whl", hash = "sha256:10758e0586c134a0bafa28f2d37dd2cdb5e4a90de25c0fc0c77dabbad46eca28", size = 221088, upload-time = "2026-01-25T12:57:57.41Z" }, + { url = "https://files.pythonhosted.org/packages/46/39/e92a35f7800222d3f7b2cbb7bbc3b65672ae8d501cb31801b2d2bd7acdf1/coverage-7.13.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f106b2af193f965d0d3234f3f83fc35278c7fb935dfbde56ae2da3dd2c03b84d", size = 219142, upload-time = "2026-01-25T12:58:00.448Z" }, + { url = "https://files.pythonhosted.org/packages/45/7a/8bf9e9309c4c996e65c52a7c5a112707ecdd9fbaf49e10b5a705a402bbb4/coverage-7.13.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:78f45d21dc4d5d6bd29323f0320089ef7eae16e4bef712dff79d184fa7330af3", size = 219503, upload-time = "2026-01-25T12:58:02.451Z" }, + { url = "https://files.pythonhosted.org/packages/87/93/17661e06b7b37580923f3f12406ac91d78aeed293fb6da0b69cc7957582f/coverage-7.13.2-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:fae91dfecd816444c74531a9c3d6ded17a504767e97aa674d44f638107265b99", size = 251006, upload-time = "2026-01-25T12:58:04.059Z" }, + { url = "https://files.pythonhosted.org/packages/12/f0/f9e59fb8c310171497f379e25db060abef9fa605e09d63157eebec102676/coverage-7.13.2-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:264657171406c114787b441484de620e03d8f7202f113d62fcd3d9688baa3e6f", size = 253750, upload-time = "2026-01-25T12:58:05.574Z" }, + { url = "https://files.pythonhosted.org/packages/e5/b1/1935e31add2232663cf7edd8269548b122a7d100047ff93475dbaaae673e/coverage-7.13.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae47d8dcd3ded0155afbb59c62bd8ab07ea0fd4902e1c40567439e6db9dcaf2f", size = 254862, upload-time = "2026-01-25T12:58:07.647Z" }, + { url = "https://files.pythonhosted.org/packages/af/59/b5e97071ec13df5f45da2b3391b6cdbec78ba20757bc92580a5b3d5fa53c/coverage-7.13.2-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8a0b33e9fd838220b007ce8f299114d406c1e8edb21336af4c97a26ecfd185aa", size = 251420, upload-time = "2026-01-25T12:58:09.309Z" }, + { url = "https://files.pythonhosted.org/packages/3f/75/9495932f87469d013dc515fb0ce1aac5fa97766f38f6b1a1deb1ee7b7f3a/coverage-7.13.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b3becbea7f3ce9a2d4d430f223ec15888e4deb31395840a79e916368d6004cce", size = 252786, upload-time = "2026-01-25T12:58:10.909Z" }, + { url = "https://files.pythonhosted.org/packages/6a/59/af550721f0eb62f46f7b8cb7e6f1860592189267b1c411a4e3a057caacee/coverage-7.13.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f819c727a6e6eeb8711e4ce63d78c620f69630a2e9d53bc95ca5379f57b6ba94", size = 250928, upload-time = "2026-01-25T12:58:12.449Z" }, + { url = "https://files.pythonhosted.org/packages/9b/b1/21b4445709aae500be4ab43bbcfb4e53dc0811c3396dcb11bf9f23fd0226/coverage-7.13.2-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:4f7b71757a3ab19f7ba286e04c181004c1d61be921795ee8ba6970fd0ec91da5", size = 250496, upload-time = "2026-01-25T12:58:14.047Z" }, + { url = "https://files.pythonhosted.org/packages/ba/b1/0f5d89dfe0392990e4f3980adbde3eb34885bc1effb2dc369e0bf385e389/coverage-7.13.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b7fc50d2afd2e6b4f6f2f403b70103d280a8e0cb35320cbbe6debcda02a1030b", size = 252373, upload-time = "2026-01-25T12:58:15.976Z" }, + { url = "https://files.pythonhosted.org/packages/01/c9/0cf1a6a57a9968cc049a6b896693faa523c638a5314b1fc374eb2b2ac904/coverage-7.13.2-cp312-cp312-win32.whl", hash = "sha256:292250282cf9bcf206b543d7608bda17ca6fc151f4cbae949fc7e115112fbd41", size = 221696, upload-time = "2026-01-25T12:58:17.517Z" }, + { url = "https://files.pythonhosted.org/packages/4d/05/d7540bf983f09d32803911afed135524570f8c47bb394bf6206c1dc3a786/coverage-7.13.2-cp312-cp312-win_amd64.whl", hash = "sha256:eeea10169fac01549a7921d27a3e517194ae254b542102267bef7a93ed38c40e", size = 222504, upload-time = "2026-01-25T12:58:19.115Z" }, + { url = "https://files.pythonhosted.org/packages/15/8b/1a9f037a736ced0a12aacf6330cdaad5008081142a7070bc58b0f7930cbc/coverage-7.13.2-cp312-cp312-win_arm64.whl", hash = "sha256:2a5b567f0b635b592c917f96b9a9cb3dbd4c320d03f4bf94e9084e494f2e8894", size = 221120, upload-time = "2026-01-25T12:58:21.334Z" }, + { url = "https://files.pythonhosted.org/packages/a7/f0/3d3eac7568ab6096ff23791a526b0048a1ff3f49d0e236b2af6fb6558e88/coverage-7.13.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ed75de7d1217cf3b99365d110975f83af0528c849ef5180a12fd91b5064df9d6", size = 219168, upload-time = "2026-01-25T12:58:23.376Z" }, + { url = "https://files.pythonhosted.org/packages/a3/a6/f8b5cfeddbab95fdef4dcd682d82e5dcff7a112ced57a959f89537ee9995/coverage-7.13.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97e596de8fa9bada4d88fde64a3f4d37f1b6131e4faa32bad7808abc79887ddc", size = 219537, upload-time = "2026-01-25T12:58:24.932Z" }, + { url = "https://files.pythonhosted.org/packages/7b/e6/8d8e6e0c516c838229d1e41cadcec91745f4b1031d4db17ce0043a0423b4/coverage-7.13.2-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:68c86173562ed4413345410c9480a8d64864ac5e54a5cda236748031e094229f", size = 250528, upload-time = "2026-01-25T12:58:26.567Z" }, + { url = "https://files.pythonhosted.org/packages/8e/78/befa6640f74092b86961f957f26504c8fba3d7da57cc2ab7407391870495/coverage-7.13.2-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7be4d613638d678b2b3773b8f687537b284d7074695a43fe2fbbfc0e31ceaed1", size = 253132, upload-time = "2026-01-25T12:58:28.251Z" }, + { url = "https://files.pythonhosted.org/packages/9d/10/1630db1edd8ce675124a2ee0f7becc603d2bb7b345c2387b4b95c6907094/coverage-7.13.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7f63ce526a96acd0e16c4af8b50b64334239550402fb1607ce6a584a6d62ce9", size = 254374, upload-time = "2026-01-25T12:58:30.294Z" }, + { url = "https://files.pythonhosted.org/packages/ed/1d/0d9381647b1e8e6d310ac4140be9c428a0277330991e0c35bdd751e338a4/coverage-7.13.2-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:406821f37f864f968e29ac14c3fccae0fec9fdeba48327f0341decf4daf92d7c", size = 250762, upload-time = "2026-01-25T12:58:32.036Z" }, + { url = "https://files.pythonhosted.org/packages/43/e4/5636dfc9a7c871ee8776af83ee33b4c26bc508ad6cee1e89b6419a366582/coverage-7.13.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ee68e5a4e3e5443623406b905db447dceddffee0dceb39f4e0cd9ec2a35004b5", size = 252502, upload-time = "2026-01-25T12:58:33.961Z" }, + { url = "https://files.pythonhosted.org/packages/02/2a/7ff2884d79d420cbb2d12fed6fff727b6d0ef27253140d3cdbbd03187ee0/coverage-7.13.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2ee0e58cca0c17dd9c6c1cdde02bb705c7b3fbfa5f3b0b5afeda20d4ebff8ef4", size = 250463, upload-time = "2026-01-25T12:58:35.529Z" }, + { url = "https://files.pythonhosted.org/packages/91/c0/ba51087db645b6c7261570400fc62c89a16278763f36ba618dc8657a187b/coverage-7.13.2-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:6e5bbb5018bf76a56aabdb64246b5288d5ae1b7d0dd4d0534fe86df2c2992d1c", size = 250288, upload-time = "2026-01-25T12:58:37.226Z" }, + { url = "https://files.pythonhosted.org/packages/03/07/44e6f428551c4d9faf63ebcefe49b30e5c89d1be96f6a3abd86a52da9d15/coverage-7.13.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a55516c68ef3e08e134e818d5e308ffa6b1337cc8b092b69b24287bf07d38e31", size = 252063, upload-time = "2026-01-25T12:58:38.821Z" }, + { url = "https://files.pythonhosted.org/packages/c2/67/35b730ad7e1859dd57e834d1bc06080d22d2f87457d53f692fce3f24a5a9/coverage-7.13.2-cp313-cp313-win32.whl", hash = "sha256:5b20211c47a8abf4abc3319d8ce2464864fa9f30c5fcaf958a3eed92f4f1fef8", size = 221716, upload-time = "2026-01-25T12:58:40.484Z" }, + { url = "https://files.pythonhosted.org/packages/0d/82/e5fcf5a97c72f45fc14829237a6550bf49d0ab882ac90e04b12a69db76b4/coverage-7.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:14f500232e521201cf031549fb1ebdfc0a40f401cf519157f76c397e586c3beb", size = 222522, upload-time = "2026-01-25T12:58:43.247Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f1/25d7b2f946d239dd2d6644ca2cc060d24f97551e2af13b6c24c722ae5f97/coverage-7.13.2-cp313-cp313-win_arm64.whl", hash = "sha256:9779310cb5a9778a60c899f075a8514c89fa6d10131445c2207fc893e0b14557", size = 221145, upload-time = "2026-01-25T12:58:45Z" }, + { url = "https://files.pythonhosted.org/packages/9e/f7/080376c029c8f76fadfe43911d0daffa0cbdc9f9418a0eead70c56fb7f4b/coverage-7.13.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:e64fa5a1e41ce5df6b547cbc3d3699381c9e2c2c369c67837e716ed0f549d48e", size = 219861, upload-time = "2026-01-25T12:58:46.586Z" }, + { url = "https://files.pythonhosted.org/packages/42/11/0b5e315af5ab35f4c4a70e64d3314e4eec25eefc6dec13be3a7d5ffe8ac5/coverage-7.13.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b01899e82a04085b6561eb233fd688474f57455e8ad35cd82286463ba06332b7", size = 220207, upload-time = "2026-01-25T12:58:48.277Z" }, + { url = "https://files.pythonhosted.org/packages/b2/0c/0874d0318fb1062117acbef06a09cf8b63f3060c22265adaad24b36306b7/coverage-7.13.2-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:838943bea48be0e2768b0cf7819544cdedc1bbb2f28427eabb6eb8c9eb2285d3", size = 261504, upload-time = "2026-01-25T12:58:49.904Z" }, + { url = "https://files.pythonhosted.org/packages/83/5e/1cd72c22ecb30751e43a72f40ba50fcef1b7e93e3ea823bd9feda8e51f9a/coverage-7.13.2-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:93d1d25ec2b27e90bcfef7012992d1f5121b51161b8bffcda756a816cf13c2c3", size = 263582, upload-time = "2026-01-25T12:58:51.582Z" }, + { url = "https://files.pythonhosted.org/packages/9b/da/8acf356707c7a42df4d0657020308e23e5a07397e81492640c186268497c/coverage-7.13.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:93b57142f9621b0d12349c43fc7741fe578e4bc914c1e5a54142856cfc0bf421", size = 266008, upload-time = "2026-01-25T12:58:53.234Z" }, + { url = "https://files.pythonhosted.org/packages/41/41/ea1730af99960309423c6ea8d6a4f1fa5564b2d97bd1d29dda4b42611f04/coverage-7.13.2-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f06799ae1bdfff7ccb8665d75f8291c69110ba9585253de254688aa8a1ccc6c5", size = 260762, upload-time = "2026-01-25T12:58:55.372Z" }, + { url = "https://files.pythonhosted.org/packages/22/fa/02884d2080ba71db64fdc127b311db60e01fe6ba797d9c8363725e39f4d5/coverage-7.13.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:7f9405ab4f81d490811b1d91c7a20361135a2df4c170e7f0b747a794da5b7f23", size = 263571, upload-time = "2026-01-25T12:58:57.52Z" }, + { url = "https://files.pythonhosted.org/packages/d2/6b/4083aaaeba9b3112f55ac57c2ce7001dc4d8fa3fcc228a39f09cc84ede27/coverage-7.13.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:f9ab1d5b86f8fbc97a5b3cd6280a3fd85fef3b028689d8a2c00918f0d82c728c", size = 261200, upload-time = "2026-01-25T12:58:59.255Z" }, + { url = "https://files.pythonhosted.org/packages/e9/d2/aea92fa36d61955e8c416ede9cf9bf142aa196f3aea214bb67f85235a050/coverage-7.13.2-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:f674f59712d67e841525b99e5e2b595250e39b529c3bda14764e4f625a3fa01f", size = 260095, upload-time = "2026-01-25T12:59:01.066Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ae/04ffe96a80f107ea21b22b2367175c621da920063260a1c22f9452fd7866/coverage-7.13.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c6cadac7b8ace1ba9144feb1ae3cb787a6065ba6d23ffc59a934b16406c26573", size = 262284, upload-time = "2026-01-25T12:59:02.802Z" }, + { url = "https://files.pythonhosted.org/packages/1c/7a/6f354dcd7dfc41297791d6fb4e0d618acb55810bde2c1fd14b3939e05c2b/coverage-7.13.2-cp313-cp313t-win32.whl", hash = "sha256:14ae4146465f8e6e6253eba0cccd57423e598a4cb925958b240c805300918343", size = 222389, upload-time = "2026-01-25T12:59:04.563Z" }, + { url = "https://files.pythonhosted.org/packages/8d/d5/080ad292a4a3d3daf411574be0a1f56d6dee2c4fdf6b005342be9fac807f/coverage-7.13.2-cp313-cp313t-win_amd64.whl", hash = "sha256:9074896edd705a05769e3de0eac0a8388484b503b68863dd06d5e473f874fd47", size = 223450, upload-time = "2026-01-25T12:59:06.677Z" }, + { url = "https://files.pythonhosted.org/packages/88/96/df576fbacc522e9fb8d1c4b7a7fc62eb734be56e2cba1d88d2eabe08ea3f/coverage-7.13.2-cp313-cp313t-win_arm64.whl", hash = "sha256:69e526e14f3f854eda573d3cf40cffd29a1a91c684743d904c33dbdcd0e0f3e7", size = 221707, upload-time = "2026-01-25T12:59:08.363Z" }, + { url = "https://files.pythonhosted.org/packages/55/53/1da9e51a0775634b04fcc11eb25c002fc58ee4f92ce2e8512f94ac5fc5bf/coverage-7.13.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:387a825f43d680e7310e6f325b2167dd093bc8ffd933b83e9aa0983cf6e0a2ef", size = 219213, upload-time = "2026-01-25T12:59:11.909Z" }, + { url = "https://files.pythonhosted.org/packages/46/35/b3caac3ebbd10230fea5a33012b27d19e999a17c9285c4228b4b2e35b7da/coverage-7.13.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f0d7fea9d8e5d778cd5a9e8fc38308ad688f02040e883cdc13311ef2748cb40f", size = 219549, upload-time = "2026-01-25T12:59:13.638Z" }, + { url = "https://files.pythonhosted.org/packages/76/9c/e1cf7def1bdc72c1907e60703983a588f9558434a2ff94615747bd73c192/coverage-7.13.2-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e080afb413be106c95c4ee96b4fffdc9e2fa56a8bbf90b5c0918e5c4449412f5", size = 250586, upload-time = "2026-01-25T12:59:15.808Z" }, + { url = "https://files.pythonhosted.org/packages/ba/49/f54ec02ed12be66c8d8897270505759e057b0c68564a65c429ccdd1f139e/coverage-7.13.2-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a7fc042ba3c7ce25b8a9f097eb0f32a5ce1ccdb639d9eec114e26def98e1f8a4", size = 253093, upload-time = "2026-01-25T12:59:17.491Z" }, + { url = "https://files.pythonhosted.org/packages/fb/5e/aaf86be3e181d907e23c0f61fccaeb38de8e6f6b47aed92bf57d8fc9c034/coverage-7.13.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d0ba505e021557f7f8173ee8cd6b926373d8653e5ff7581ae2efce1b11ef4c27", size = 254446, upload-time = "2026-01-25T12:59:19.752Z" }, + { url = "https://files.pythonhosted.org/packages/28/c8/a5fa01460e2d75b0c853b392080d6829d3ca8b5ab31e158fa0501bc7c708/coverage-7.13.2-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7de326f80e3451bd5cc7239ab46c73ddb658fe0b7649476bc7413572d36cd548", size = 250615, upload-time = "2026-01-25T12:59:21.928Z" }, + { url = "https://files.pythonhosted.org/packages/86/0b/6d56315a55f7062bb66410732c24879ccb2ec527ab6630246de5fe45a1df/coverage-7.13.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:abaea04f1e7e34841d4a7b343904a3f59481f62f9df39e2cd399d69a187a9660", size = 252452, upload-time = "2026-01-25T12:59:23.592Z" }, + { url = "https://files.pythonhosted.org/packages/30/19/9bc550363ebc6b0ea121977ee44d05ecd1e8bf79018b8444f1028701c563/coverage-7.13.2-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:9f93959ee0c604bccd8e0697be21de0887b1f73efcc3aa73a3ec0fd13feace92", size = 250418, upload-time = "2026-01-25T12:59:25.392Z" }, + { url = "https://files.pythonhosted.org/packages/1f/53/580530a31ca2f0cc6f07a8f2ab5460785b02bb11bdf815d4c4d37a4c5169/coverage-7.13.2-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:13fe81ead04e34e105bf1b3c9f9cdf32ce31736ee5d90a8d2de02b9d3e1bcb82", size = 250231, upload-time = "2026-01-25T12:59:27.888Z" }, + { url = "https://files.pythonhosted.org/packages/e2/42/dd9093f919dc3088cb472893651884bd675e3df3d38a43f9053656dca9a2/coverage-7.13.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d6d16b0f71120e365741bca2cb473ca6fe38930bc5431c5e850ba949f708f892", size = 251888, upload-time = "2026-01-25T12:59:29.636Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a6/0af4053e6e819774626e133c3d6f70fae4d44884bfc4b126cb647baee8d3/coverage-7.13.2-cp314-cp314-win32.whl", hash = "sha256:9b2f4714bb7d99ba3790ee095b3b4ac94767e1347fe424278a0b10acb3ff04fe", size = 221968, upload-time = "2026-01-25T12:59:31.424Z" }, + { url = "https://files.pythonhosted.org/packages/c4/cc/5aff1e1f80d55862442855517bb8ad8ad3a68639441ff6287dde6a58558b/coverage-7.13.2-cp314-cp314-win_amd64.whl", hash = "sha256:e4121a90823a063d717a96e0a0529c727fb31ea889369a0ee3ec00ed99bf6859", size = 222783, upload-time = "2026-01-25T12:59:33.118Z" }, + { url = "https://files.pythonhosted.org/packages/de/20/09abafb24f84b3292cc658728803416c15b79f9ee5e68d25238a895b07d9/coverage-7.13.2-cp314-cp314-win_arm64.whl", hash = "sha256:6873f0271b4a15a33e7590f338d823f6f66f91ed147a03938d7ce26efd04eee6", size = 221348, upload-time = "2026-01-25T12:59:34.939Z" }, + { url = "https://files.pythonhosted.org/packages/b6/60/a3820c7232db63be060e4019017cd3426751c2699dab3c62819cdbcea387/coverage-7.13.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:f61d349f5b7cd95c34017f1927ee379bfbe9884300d74e07cf630ccf7a610c1b", size = 219950, upload-time = "2026-01-25T12:59:36.624Z" }, + { url = "https://files.pythonhosted.org/packages/fd/37/e4ef5975fdeb86b1e56db9a82f41b032e3d93a840ebaf4064f39e770d5c5/coverage-7.13.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a43d34ce714f4ca674c0d90beb760eb05aad906f2c47580ccee9da8fe8bfb417", size = 220209, upload-time = "2026-01-25T12:59:38.339Z" }, + { url = "https://files.pythonhosted.org/packages/54/df/d40e091d00c51adca1e251d3b60a8b464112efa3004949e96a74d7c19a64/coverage-7.13.2-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:bff1b04cb9d4900ce5c56c4942f047dc7efe57e2608cb7c3c8936e9970ccdbee", size = 261576, upload-time = "2026-01-25T12:59:40.446Z" }, + { url = "https://files.pythonhosted.org/packages/c5/44/5259c4bed54e3392e5c176121af9f71919d96dde853386e7730e705f3520/coverage-7.13.2-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6ae99e4560963ad8e163e819e5d77d413d331fd00566c1e0856aa252303552c1", size = 263704, upload-time = "2026-01-25T12:59:42.346Z" }, + { url = "https://files.pythonhosted.org/packages/16/bd/ae9f005827abcbe2c70157459ae86053971c9fa14617b63903abbdce26d9/coverage-7.13.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e79a8c7d461820257d9aa43716c4efc55366d7b292e46b5b37165be1d377405d", size = 266109, upload-time = "2026-01-25T12:59:44.073Z" }, + { url = "https://files.pythonhosted.org/packages/a2/c0/8e279c1c0f5b1eaa3ad9b0fb7a5637fc0379ea7d85a781c0fe0bb3cfc2ab/coverage-7.13.2-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:060ee84f6a769d40c492711911a76811b4befb6fba50abb450371abb720f5bd6", size = 260686, upload-time = "2026-01-25T12:59:45.804Z" }, + { url = "https://files.pythonhosted.org/packages/b2/47/3a8112627e9d863e7cddd72894171c929e94491a597811725befdcd76bce/coverage-7.13.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:3bca209d001fd03ea2d978f8a4985093240a355c93078aee3f799852c23f561a", size = 263568, upload-time = "2026-01-25T12:59:47.929Z" }, + { url = "https://files.pythonhosted.org/packages/92/bc/7ea367d84afa3120afc3ce6de294fd2dcd33b51e2e7fbe4bbfd200f2cb8c/coverage-7.13.2-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:6b8092aa38d72f091db61ef83cb66076f18f02da3e1a75039a4f218629600e04", size = 261174, upload-time = "2026-01-25T12:59:49.717Z" }, + { url = "https://files.pythonhosted.org/packages/33/b7/f1092dcecb6637e31cc2db099581ee5c61a17647849bae6b8261a2b78430/coverage-7.13.2-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:4a3158dc2dcce5200d91ec28cd315c999eebff355437d2765840555d765a6e5f", size = 260017, upload-time = "2026-01-25T12:59:51.463Z" }, + { url = "https://files.pythonhosted.org/packages/2b/cd/f3d07d4b95fbe1a2ef0958c15da614f7e4f557720132de34d2dc3aa7e911/coverage-7.13.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3973f353b2d70bd9796cc12f532a05945232ccae966456c8ed7034cb96bbfd6f", size = 262337, upload-time = "2026-01-25T12:59:53.407Z" }, + { url = "https://files.pythonhosted.org/packages/e0/db/b0d5b2873a07cb1e06a55d998697c0a5a540dcefbf353774c99eb3874513/coverage-7.13.2-cp314-cp314t-win32.whl", hash = "sha256:79f6506a678a59d4ded048dc72f1859ebede8ec2b9a2d509ebe161f01c2879d3", size = 222749, upload-time = "2026-01-25T12:59:56.316Z" }, + { url = "https://files.pythonhosted.org/packages/e5/2f/838a5394c082ac57d85f57f6aba53093b30d9089781df72412126505716f/coverage-7.13.2-cp314-cp314t-win_amd64.whl", hash = "sha256:196bfeabdccc5a020a57d5a368c681e3a6ceb0447d153aeccc1ab4d70a5032ba", size = 223857, upload-time = "2026-01-25T12:59:58.201Z" }, + { url = "https://files.pythonhosted.org/packages/44/d4/b608243e76ead3a4298824b50922b89ef793e50069ce30316a65c1b4d7ef/coverage-7.13.2-cp314-cp314t-win_arm64.whl", hash = "sha256:69269ab58783e090bfbf5b916ab3d188126e22d6070bbfc93098fdd474ef937c", size = 221881, upload-time = "2026-01-25T13:00:00.449Z" }, + { url = "https://files.pythonhosted.org/packages/d2/db/d291e30fdf7ea617a335531e72294e0c723356d7fdde8fba00610a76bda9/coverage-7.13.2-py3-none-any.whl", hash = "sha256:40ce1ea1e25125556d8e76bd0b61500839a07944cc287ac21d5626f3e620cad5", size = 210943, upload-time = "2026-01-25T13:00:02.388Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + [[package]] name = "cryptography" version = "46.0.4" @@ -209,8 +301,19 @@ dependencies = [ { name = "mcp", extra = ["cli"] }, ] +[package.optional-dependencies] +test = [ + { name = "pytest" }, + { name = "pytest-cov" }, +] + [package.metadata] -requires-dist = [{ name = "mcp", extras = ["cli"], specifier = ">=1.0.0" }] +requires-dist = [ + { name = "mcp", extras = ["cli"], specifier = ">=1.0.0" }, + { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" }, + { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" }, +] +provides-extras = ["test"] [[package]] name = "httpcore" @@ -258,6 +361,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -337,6 +449,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "packaging" +version = "26.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -495,6 +625,36 @@ crypto = [ { name = "cryptography" }, ] +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + +[[package]] +name = "pytest-cov" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, +] + [[package]] name = "python-dotenv" version = "1.2.1" @@ -702,6 +862,60 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] +[[package]] +name = "tomli" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/30/31573e9457673ab10aa432461bee537ce6cef177667deca369efb79df071/tomli-2.4.0.tar.gz", hash = "sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c", size = 17477, upload-time = "2026-01-11T11:22:38.165Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/d9/3dc2289e1f3b32eb19b9785b6a006b28ee99acb37d1d47f78d4c10e28bf8/tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867", size = 153663, upload-time = "2026-01-11T11:21:45.27Z" }, + { url = "https://files.pythonhosted.org/packages/51/32/ef9f6845e6b9ca392cd3f64f9ec185cc6f09f0a2df3db08cbe8809d1d435/tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9", size = 148469, upload-time = "2026-01-11T11:21:46.873Z" }, + { url = "https://files.pythonhosted.org/packages/d6/c2/506e44cce89a8b1b1e047d64bd495c22c9f71f21e05f380f1a950dd9c217/tomli-2.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:551e321c6ba03b55676970b47cb1b73f14a0a4dce6a3e1a9458fd6d921d72e95", size = 236039, upload-time = "2026-01-11T11:21:48.503Z" }, + { url = "https://files.pythonhosted.org/packages/b3/40/e1b65986dbc861b7e986e8ec394598187fa8aee85b1650b01dd925ca0be8/tomli-2.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5e3f639a7a8f10069d0e15408c0b96a2a828cfdec6fca05296ebcdcc28ca7c76", size = 243007, upload-time = "2026-01-11T11:21:49.456Z" }, + { url = "https://files.pythonhosted.org/packages/9c/6f/6e39ce66b58a5b7ae572a0f4352ff40c71e8573633deda43f6a379d56b3e/tomli-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1b168f2731796b045128c45982d3a4874057626da0e2ef1fdd722848b741361d", size = 240875, upload-time = "2026-01-11T11:21:50.755Z" }, + { url = "https://files.pythonhosted.org/packages/aa/ad/cb089cb190487caa80204d503c7fd0f4d443f90b95cf4ef5cf5aa0f439b0/tomli-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:133e93646ec4300d651839d382d63edff11d8978be23da4cc106f5a18b7d0576", size = 246271, upload-time = "2026-01-11T11:21:51.81Z" }, + { url = "https://files.pythonhosted.org/packages/0b/63/69125220e47fd7a3a27fd0de0c6398c89432fec41bc739823bcc66506af6/tomli-2.4.0-cp311-cp311-win32.whl", hash = "sha256:b6c78bdf37764092d369722d9946cb65b8767bfa4110f902a1b2542d8d173c8a", size = 96770, upload-time = "2026-01-11T11:21:52.647Z" }, + { url = "https://files.pythonhosted.org/packages/1e/0d/a22bb6c83f83386b0008425a6cd1fa1c14b5f3dd4bad05e98cf3dbbf4a64/tomli-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:d3d1654e11d724760cdb37a3d7691f0be9db5fbdaef59c9f532aabf87006dbaa", size = 107626, upload-time = "2026-01-11T11:21:53.459Z" }, + { url = "https://files.pythonhosted.org/packages/2f/6d/77be674a3485e75cacbf2ddba2b146911477bd887dda9d8c9dfb2f15e871/tomli-2.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:cae9c19ed12d4e8f3ebf46d1a75090e4c0dc16271c5bce1c833ac168f08fb614", size = 94842, upload-time = "2026-01-11T11:21:54.831Z" }, + { url = "https://files.pythonhosted.org/packages/3c/43/7389a1869f2f26dba52404e1ef13b4784b6b37dac93bac53457e3ff24ca3/tomli-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:920b1de295e72887bafa3ad9f7a792f811847d57ea6b1215154030cf131f16b1", size = 154894, upload-time = "2026-01-11T11:21:56.07Z" }, + { url = "https://files.pythonhosted.org/packages/e9/05/2f9bf110b5294132b2edf13fe6ca6ae456204f3d749f623307cbb7a946f2/tomli-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7d6d9a4aee98fac3eab4952ad1d73aee87359452d1c086b5ceb43ed02ddb16b8", size = 149053, upload-time = "2026-01-11T11:21:57.467Z" }, + { url = "https://files.pythonhosted.org/packages/e8/41/1eda3ca1abc6f6154a8db4d714a4d35c4ad90adc0bcf700657291593fbf3/tomli-2.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36b9d05b51e65b254ea6c2585b59d2c4cb91c8a3d91d0ed0f17591a29aaea54a", size = 243481, upload-time = "2026-01-11T11:21:58.661Z" }, + { url = "https://files.pythonhosted.org/packages/d2/6d/02ff5ab6c8868b41e7d4b987ce2b5f6a51d3335a70aa144edd999e055a01/tomli-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1c8a885b370751837c029ef9bc014f27d80840e48bac415f3412e6593bbc18c1", size = 251720, upload-time = "2026-01-11T11:22:00.178Z" }, + { url = "https://files.pythonhosted.org/packages/7b/57/0405c59a909c45d5b6f146107c6d997825aa87568b042042f7a9c0afed34/tomli-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8768715ffc41f0008abe25d808c20c3d990f42b6e2e58305d5da280ae7d1fa3b", size = 247014, upload-time = "2026-01-11T11:22:01.238Z" }, + { url = "https://files.pythonhosted.org/packages/2c/0e/2e37568edd944b4165735687cbaf2fe3648129e440c26d02223672ee0630/tomli-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b438885858efd5be02a9a133caf5812b8776ee0c969fea02c45e8e3f296ba51", size = 251820, upload-time = "2026-01-11T11:22:02.727Z" }, + { url = "https://files.pythonhosted.org/packages/5a/1c/ee3b707fdac82aeeb92d1a113f803cf6d0f37bdca0849cb489553e1f417a/tomli-2.4.0-cp312-cp312-win32.whl", hash = "sha256:0408e3de5ec77cc7f81960c362543cbbd91ef883e3138e81b729fc3eea5b9729", size = 97712, upload-time = "2026-01-11T11:22:03.777Z" }, + { url = "https://files.pythonhosted.org/packages/69/13/c07a9177d0b3bab7913299b9278845fc6eaaca14a02667c6be0b0a2270c8/tomli-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:685306e2cc7da35be4ee914fd34ab801a6acacb061b6a7abca922aaf9ad368da", size = 108296, upload-time = "2026-01-11T11:22:04.86Z" }, + { url = "https://files.pythonhosted.org/packages/18/27/e267a60bbeeee343bcc279bb9e8fbed0cbe224bc7b2a3dc2975f22809a09/tomli-2.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:5aa48d7c2356055feef06a43611fc401a07337d5b006be13a30f6c58f869e3c3", size = 94553, upload-time = "2026-01-11T11:22:05.854Z" }, + { url = "https://files.pythonhosted.org/packages/34/91/7f65f9809f2936e1f4ce6268ae1903074563603b2a2bd969ebbda802744f/tomli-2.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:84d081fbc252d1b6a982e1870660e7330fb8f90f676f6e78b052ad4e64714bf0", size = 154915, upload-time = "2026-01-11T11:22:06.703Z" }, + { url = "https://files.pythonhosted.org/packages/20/aa/64dd73a5a849c2e8f216b755599c511badde80e91e9bc2271baa7b2cdbb1/tomli-2.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9a08144fa4cba33db5255f9b74f0b89888622109bd2776148f2597447f92a94e", size = 149038, upload-time = "2026-01-11T11:22:07.56Z" }, + { url = "https://files.pythonhosted.org/packages/9e/8a/6d38870bd3d52c8d1505ce054469a73f73a0fe62c0eaf5dddf61447e32fa/tomli-2.4.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c73add4bb52a206fd0c0723432db123c0c75c280cbd67174dd9d2db228ebb1b4", size = 242245, upload-time = "2026-01-11T11:22:08.344Z" }, + { url = "https://files.pythonhosted.org/packages/59/bb/8002fadefb64ab2669e5b977df3f5e444febea60e717e755b38bb7c41029/tomli-2.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fb2945cbe303b1419e2706e711b7113da57b7db31ee378d08712d678a34e51e", size = 250335, upload-time = "2026-01-11T11:22:09.951Z" }, + { url = "https://files.pythonhosted.org/packages/a5/3d/4cdb6f791682b2ea916af2de96121b3cb1284d7c203d97d92d6003e91c8d/tomli-2.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bbb1b10aa643d973366dc2cb1ad94f99c1726a02343d43cbc011edbfac579e7c", size = 245962, upload-time = "2026-01-11T11:22:11.27Z" }, + { url = "https://files.pythonhosted.org/packages/f2/4a/5f25789f9a460bd858ba9756ff52d0830d825b458e13f754952dd15fb7bb/tomli-2.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4cbcb367d44a1f0c2be408758b43e1ffb5308abe0ea222897d6bfc8e8281ef2f", size = 250396, upload-time = "2026-01-11T11:22:12.325Z" }, + { url = "https://files.pythonhosted.org/packages/aa/2f/b73a36fea58dfa08e8b3a268750e6853a6aac2a349241a905ebd86f3047a/tomli-2.4.0-cp313-cp313-win32.whl", hash = "sha256:7d49c66a7d5e56ac959cb6fc583aff0651094ec071ba9ad43df785abc2320d86", size = 97530, upload-time = "2026-01-11T11:22:13.865Z" }, + { url = "https://files.pythonhosted.org/packages/3b/af/ca18c134b5d75de7e8dc551c5234eaba2e8e951f6b30139599b53de9c187/tomli-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:3cf226acb51d8f1c394c1b310e0e0e61fecdd7adcb78d01e294ac297dd2e7f87", size = 108227, upload-time = "2026-01-11T11:22:15.224Z" }, + { url = "https://files.pythonhosted.org/packages/22/c3/b386b832f209fee8073c8138ec50f27b4460db2fdae9ffe022df89a57f9b/tomli-2.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:d20b797a5c1ad80c516e41bc1fb0443ddb5006e9aaa7bda2d71978346aeb9132", size = 94748, upload-time = "2026-01-11T11:22:16.009Z" }, + { url = "https://files.pythonhosted.org/packages/f3/c4/84047a97eb1004418bc10bdbcfebda209fca6338002eba2dc27cc6d13563/tomli-2.4.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:26ab906a1eb794cd4e103691daa23d95c6919cc2fa9160000ac02370cc9dd3f6", size = 154725, upload-time = "2026-01-11T11:22:17.269Z" }, + { url = "https://files.pythonhosted.org/packages/a8/5d/d39038e646060b9d76274078cddf146ced86dc2b9e8bbf737ad5983609a0/tomli-2.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:20cedb4ee43278bc4f2fee6cb50daec836959aadaf948db5172e776dd3d993fc", size = 148901, upload-time = "2026-01-11T11:22:18.287Z" }, + { url = "https://files.pythonhosted.org/packages/73/e5/383be1724cb30f4ce44983d249645684a48c435e1cd4f8b5cded8a816d3c/tomli-2.4.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:39b0b5d1b6dd03684b3fb276407ebed7090bbec989fa55838c98560c01113b66", size = 243375, upload-time = "2026-01-11T11:22:19.154Z" }, + { url = "https://files.pythonhosted.org/packages/31/f0/bea80c17971c8d16d3cc109dc3585b0f2ce1036b5f4a8a183789023574f2/tomli-2.4.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a26d7ff68dfdb9f87a016ecfd1e1c2bacbe3108f4e0f8bcd2228ef9a766c787d", size = 250639, upload-time = "2026-01-11T11:22:20.168Z" }, + { url = "https://files.pythonhosted.org/packages/2c/8f/2853c36abbb7608e3f945d8a74e32ed3a74ee3a1f468f1ffc7d1cb3abba6/tomli-2.4.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:20ffd184fb1df76a66e34bd1b36b4a4641bd2b82954befa32fe8163e79f1a702", size = 246897, upload-time = "2026-01-11T11:22:21.544Z" }, + { url = "https://files.pythonhosted.org/packages/49/f0/6c05e3196ed5337b9fe7ea003e95fd3819a840b7a0f2bf5a408ef1dad8ed/tomli-2.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:75c2f8bbddf170e8effc98f5e9084a8751f8174ea6ccf4fca5398436e0320bc8", size = 254697, upload-time = "2026-01-11T11:22:23.058Z" }, + { url = "https://files.pythonhosted.org/packages/f3/f5/2922ef29c9f2951883525def7429967fc4d8208494e5ab524234f06b688b/tomli-2.4.0-cp314-cp314-win32.whl", hash = "sha256:31d556d079d72db7c584c0627ff3a24c5d3fb4f730221d3444f3efb1b2514776", size = 98567, upload-time = "2026-01-11T11:22:24.033Z" }, + { url = "https://files.pythonhosted.org/packages/7b/31/22b52e2e06dd2a5fdbc3ee73226d763b184ff21fc24e20316a44ccc4d96b/tomli-2.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:43e685b9b2341681907759cf3a04e14d7104b3580f808cfde1dfdb60ada85475", size = 108556, upload-time = "2026-01-11T11:22:25.378Z" }, + { url = "https://files.pythonhosted.org/packages/48/3d/5058dff3255a3d01b705413f64f4306a141a8fd7a251e5a495e3f192a998/tomli-2.4.0-cp314-cp314-win_arm64.whl", hash = "sha256:3d895d56bd3f82ddd6faaff993c275efc2ff38e52322ea264122d72729dca2b2", size = 96014, upload-time = "2026-01-11T11:22:26.138Z" }, + { url = "https://files.pythonhosted.org/packages/b8/4e/75dab8586e268424202d3a1997ef6014919c941b50642a1682df43204c22/tomli-2.4.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:5b5807f3999fb66776dbce568cc9a828544244a8eb84b84b9bafc080c99597b9", size = 163339, upload-time = "2026-01-11T11:22:27.143Z" }, + { url = "https://files.pythonhosted.org/packages/06/e3/b904d9ab1016829a776d97f163f183a48be6a4deb87304d1e0116a349519/tomli-2.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c084ad935abe686bd9c898e62a02a19abfc9760b5a79bc29644463eaf2840cb0", size = 159490, upload-time = "2026-01-11T11:22:28.399Z" }, + { url = "https://files.pythonhosted.org/packages/e3/5a/fc3622c8b1ad823e8ea98a35e3c632ee316d48f66f80f9708ceb4f2a0322/tomli-2.4.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f2e3955efea4d1cfbcb87bc321e00dc08d2bcb737fd1d5e398af111d86db5df", size = 269398, upload-time = "2026-01-11T11:22:29.345Z" }, + { url = "https://files.pythonhosted.org/packages/fd/33/62bd6152c8bdd4c305ad9faca48f51d3acb2df1f8791b1477d46ff86e7f8/tomli-2.4.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e0fe8a0b8312acf3a88077a0802565cb09ee34107813bba1c7cd591fa6cfc8d", size = 276515, upload-time = "2026-01-11T11:22:30.327Z" }, + { url = "https://files.pythonhosted.org/packages/4b/ff/ae53619499f5235ee4211e62a8d7982ba9e439a0fb4f2f351a93d67c1dd2/tomli-2.4.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:413540dce94673591859c4c6f794dfeaa845e98bf35d72ed59636f869ef9f86f", size = 273806, upload-time = "2026-01-11T11:22:32.56Z" }, + { url = "https://files.pythonhosted.org/packages/47/71/cbca7787fa68d4d0a9f7072821980b39fbb1b6faeb5f5cf02f4a5559fa28/tomli-2.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0dc56fef0e2c1c470aeac5b6ca8cc7b640bb93e92d9803ddaf9ea03e198f5b0b", size = 281340, upload-time = "2026-01-11T11:22:33.505Z" }, + { url = "https://files.pythonhosted.org/packages/f5/00/d595c120963ad42474cf6ee7771ad0d0e8a49d0f01e29576ee9195d9ecdf/tomli-2.4.0-cp314-cp314t-win32.whl", hash = "sha256:d878f2a6707cc9d53a1be1414bbb419e629c3d6e67f69230217bb663e76b5087", size = 108106, upload-time = "2026-01-11T11:22:34.451Z" }, + { url = "https://files.pythonhosted.org/packages/de/69/9aa0c6a505c2f80e519b43764f8b4ba93b5a0bbd2d9a9de6e2b24271b9a5/tomli-2.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:2add28aacc7425117ff6364fe9e06a183bb0251b03f986df0e78e974047571fd", size = 120504, upload-time = "2026-01-11T11:22:35.764Z" }, + { url = "https://files.pythonhosted.org/packages/b3/9f/f1668c281c58cfae01482f7114a4b88d345e4c140386241a1a24dcc9e7bc/tomli-2.4.0-cp314-cp314t-win_arm64.whl", hash = "sha256:2b1e3b80e1d5e52e40e9b924ec43d81570f0e7d09d11081b797bc4692765a3d4", size = 99561, upload-time = "2026-01-11T11:22:36.624Z" }, + { url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" }, +] + [[package]] name = "typer" version = "0.21.1" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2e0174e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""HAProxy MCP Server test suite.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..07752e5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,345 @@ +"""Shared pytest fixtures for HAProxy MCP Server tests.""" + +import json +import os +import sys +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +# Add the parent directory to sys.path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class MockSocket: + """Mock socket for testing HAProxy client communication.""" + + def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""): + """Initialize mock socket. + + Args: + responses: Dict mapping command prefixes to responses + default_response: Response for commands not in responses dict + """ + self.responses = responses or {} + self.default_response = default_response + self.sent_commands: list[str] = [] + self._closed = False + self._response_buffer = b"" + + def connect(self, address: tuple[str, int]) -> None: + """Mock connect - does nothing.""" + pass + + def settimeout(self, timeout: float) -> None: + """Mock settimeout - does nothing.""" + pass + + def setblocking(self, blocking: bool) -> None: + """Mock setblocking - does nothing.""" + pass + + def sendall(self, data: bytes) -> None: + """Mock sendall - stores sent command.""" + command = data.decode().strip() + self.sent_commands.append(command) + # Prepare response for this command + response = self.default_response + for prefix, resp in self.responses.items(): + if command.startswith(prefix): + response = resp + break + self._response_buffer = response.encode() + + def shutdown(self, how: int) -> None: + """Mock shutdown - does nothing.""" + pass + + def recv(self, bufsize: int) -> bytes: + """Mock recv - returns prepared response.""" + if self._response_buffer: + data = self._response_buffer[:bufsize] + self._response_buffer = self._response_buffer[bufsize:] + return data + return b"" + + def close(self) -> None: + """Mock close.""" + self._closed = True + + def fileno(self) -> int: + """Mock fileno for select.""" + return 999 + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class HAProxyResponseBuilder: + """Helper class to build HAProxy-style responses for tests.""" + + @staticmethod + def servers_state(servers: list[dict[str, Any]]) -> str: + """Build a 'show servers state' response. + + Args: + servers: List of server dicts with keys: + - be_id: Backend ID (int) + - be_name: Backend name + - srv_id: Server ID (int) + - srv_name: Server name + - srv_addr: Server address (IP) + - srv_op_state: Operational state (int) + - srv_admin_state: Admin state (int) + - srv_port: Server port + + Returns: + HAProxy-formatted server state string + """ + lines = ["1"] # Version line + lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord") + for srv in servers: + # Fill in defaults for minimal state line (need 19+ columns) + line_parts = [ + str(srv.get("be_id", 1)), + srv.get("be_name", "pool_1"), + str(srv.get("srv_id", 1)), + srv.get("srv_name", "pool_1_1"), + srv.get("srv_addr", "0.0.0.0"), + str(srv.get("srv_op_state", 2)), + str(srv.get("srv_admin_state", 0)), + "1", # srv_uweight + "1", # srv_iweight + "100", # srv_time_since_last_change + "6", # srv_check_status + "3", # srv_check_result + "4", # srv_check_health + "6", # srv_check_state + "0", # srv_agent_state + "0", # bk_f_forced_id + "0", # srv_f_forced_id + "-", # srv_fqdn + str(srv.get("srv_port", 0)), # srv_port + ] + lines.append(" ".join(line_parts)) + return "\n".join(lines) + + @staticmethod + def stat_csv(entries: list[dict[str, Any]]) -> str: + """Build a 'show stat' CSV response. + + Args: + entries: List of stat dicts with keys: + - pxname: Proxy name + - svname: Server name (or FRONTEND/BACKEND) + - scur: Current sessions (optional, default 0) + - status: Status (UP/DOWN/MAINT) + - weight: Weight (optional, default 1) + - check_status: Check status (optional) + + Returns: + HAProxy-formatted CSV stat string + """ + lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"] + for entry in entries: + # Build CSV row with proper field positions + # Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ... + # status(17), weight(18), ..., check_status(36) + row = [""] * 50 + row[0] = entry.get("pxname", "pool_1") + row[1] = entry.get("svname", "pool_1_1") + row[4] = str(entry.get("scur", 0)) # SCUR + row[5] = str(entry.get("smax", 0)) # SMAX + row[17] = entry.get("status", "UP") # STATUS + row[18] = str(entry.get("weight", 1)) # WEIGHT + row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS + lines.append(",".join(row)) + return "\n".join(lines) + + @staticmethod + def info(version: str = "3.3.2", uptime: int = 3600) -> str: + """Build a 'show info' response. + + Args: + version: HAProxy version string + uptime: Uptime in seconds + + Returns: + HAProxy-formatted info string + """ + return f"""Name: HAProxy +Version: {version} +Release_date: 2024/01/01 +Nbthread: 4 +Nbproc: 1 +Process_num: 1 +Pid: 1 +Uptime: 1h0m0s +Uptime_sec: {uptime} +Memmax_MB: 0 +PoolAlloc_MB: 0 +PoolUsed_MB: 0 +PoolFailed: 0 +Ulimit-n: 200015 +Maxsock: 200015 +Maxconn: 100000 +Hard_maxconn: 100000 +CurrConns: 5 +CumConns: 1000 +CumReq: 5000""" + + @staticmethod + def map_show(entries: list[tuple[str, str]]) -> str: + """Build a 'show map' response. + + Args: + entries: List of (key, value) tuples + + Returns: + HAProxy-formatted map show string + """ + lines = [] + for i, (key, value) in enumerate(entries): + lines.append(f"0x{i:08x} {key} {value}") + return "\n".join(lines) + + +@pytest.fixture +def mock_socket_class(): + """Fixture that returns MockSocket class for custom configuration.""" + return MockSocket + + +@pytest.fixture +def response_builder(): + """Fixture that returns HAProxyResponseBuilder class.""" + return HAProxyResponseBuilder + + +@pytest.fixture +def mock_haproxy_socket(mock_socket_class, response_builder): + """Fixture providing a pre-configured mock socket with common responses.""" + responses = { + "show info": response_builder.info(), + "show servers state": response_builder.servers_state([]), + "show stat": response_builder.stat_csv([]), + "show map": response_builder.map_show([]), + "show backend": "pool_1\npool_2\npool_3", + "add map": "", + "del map": "", + "set server": "", + } + return mock_socket_class(responses=responses) + + +@pytest.fixture +def temp_config_dir(tmp_path): + """Fixture providing a temporary directory with config files.""" + # Create config files + map_file = tmp_path / "domains.map" + map_file.write_text("# Domain to Backend mapping\n") + + wildcards_file = tmp_path / "wildcards.map" + wildcards_file.write_text("# Wildcard Domain mapping\n") + + servers_file = tmp_path / "servers.json" + servers_file.write_text("{}") + + certs_file = tmp_path / "certificates.json" + certs_file.write_text('{"domains": []}') + + state_file = tmp_path / "servers.state" + state_file.write_text("") + + return { + "dir": tmp_path, + "map_file": str(map_file), + "wildcards_file": str(wildcards_file), + "servers_file": str(servers_file), + "certs_file": str(certs_file), + "state_file": str(state_file), + } + + +@pytest.fixture +def patch_config_paths(temp_config_dir): + """Fixture that patches config module paths to use temporary directory.""" + with patch.multiple( + "haproxy_mcp.config", + MAP_FILE=temp_config_dir["map_file"], + WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], + SERVERS_FILE=temp_config_dir["servers_file"], + CERTS_FILE=temp_config_dir["certs_file"], + STATE_FILE=temp_config_dir["state_file"], + ): + # Also patch file_ops module which imports these + with patch.multiple( + "haproxy_mcp.file_ops", + MAP_FILE=temp_config_dir["map_file"], + WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], + SERVERS_FILE=temp_config_dir["servers_file"], + CERTS_FILE=temp_config_dir["certs_file"], + ): + yield temp_config_dir + + +@pytest.fixture +def mock_subprocess(): + """Fixture that mocks subprocess.run for external command testing.""" + with patch("subprocess.run") as mock_run: + # Default to successful command + mock_run.return_value = MagicMock( + returncode=0, + stdout="", + stderr="", + ) + yield mock_run + + +@pytest.fixture +def mock_socket_module(mock_haproxy_socket): + """Fixture that patches socket module to use mock socket.""" + + def create_socket(*args, **kwargs): + return mock_haproxy_socket + + with patch("socket.socket", side_effect=create_socket): + yield mock_haproxy_socket + + +@pytest.fixture +def mock_select(): + """Fixture that patches select.select for socket recv loops.""" + with patch("select.select") as mock_sel: + # Default: socket is ready immediately + mock_sel.return_value = ([True], [], []) + yield mock_sel + + +@pytest.fixture +def sample_servers_config(): + """Sample servers.json content for testing.""" + return { + "example.com": { + "1": {"ip": "10.0.0.1", "http_port": 80}, + "2": {"ip": "10.0.0.2", "http_port": 80}, + }, + "api.example.com": { + "1": {"ip": "10.0.0.10", "http_port": 8080}, + }, + } + + +@pytest.fixture +def sample_map_entries(): + """Sample domains.map entries for testing.""" + return [ + ("example.com", "pool_1"), + (".example.com", "pool_1"), + ("api.example.com", "pool_2"), + (".api.example.com", "pool_2"), + ] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..36770ff --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for HAProxy MCP Server.""" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..ba5e49b --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for HAProxy MCP Server.""" diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..ec3b985 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,198 @@ +"""Unit tests for config module.""" + +import os +import re +from unittest.mock import patch + +import pytest + +from haproxy_mcp.config import ( + DOMAIN_PATTERN, + BACKEND_NAME_PATTERN, + NON_ALNUM_PATTERN, + StatField, + StateField, + POOL_COUNT, + MAX_SLOTS, + MAX_RESPONSE_SIZE, + SOCKET_TIMEOUT, + MAX_BULK_SERVERS, +) + + +class TestDomainPattern: + """Tests for DOMAIN_PATTERN regex.""" + + def test_simple_domain(self): + """Match simple domain.""" + assert DOMAIN_PATTERN.match("example.com") is not None + + def test_subdomain(self): + """Match subdomain.""" + assert DOMAIN_PATTERN.match("api.example.com") is not None + + def test_deep_subdomain(self): + """Match deep subdomain.""" + assert DOMAIN_PATTERN.match("a.b.c.d.example.com") is not None + + def test_hyphenated_domain(self): + """Match domain with hyphens.""" + assert DOMAIN_PATTERN.match("my-api.example-site.com") is not None + + def test_numeric_labels(self): + """Match domain with numeric labels.""" + assert DOMAIN_PATTERN.match("api123.example.com") is not None + + def test_invalid_starts_with_hyphen(self): + """Reject domain starting with hyphen.""" + assert DOMAIN_PATTERN.match("-example.com") is None + + def test_invalid_ends_with_hyphen(self): + """Reject label ending with hyphen.""" + assert DOMAIN_PATTERN.match("example-.com") is None + + def test_invalid_underscore(self): + """Reject domain with underscore.""" + assert DOMAIN_PATTERN.match("my_api.example.com") is None + + +class TestBackendNamePattern: + """Tests for BACKEND_NAME_PATTERN regex.""" + + def test_pool_name(self): + """Match pool backend names.""" + assert BACKEND_NAME_PATTERN.match("pool_1") is not None + assert BACKEND_NAME_PATTERN.match("pool_100") is not None + + def test_alphanumeric(self): + """Match alphanumeric names.""" + assert BACKEND_NAME_PATTERN.match("backend123") is not None + + def test_underscore(self): + """Match names with underscores.""" + assert BACKEND_NAME_PATTERN.match("my_backend") is not None + + def test_hyphen(self): + """Match names with hyphens.""" + assert BACKEND_NAME_PATTERN.match("my-backend") is not None + + def test_mixed(self): + """Match mixed character names.""" + assert BACKEND_NAME_PATTERN.match("api_example-com_backend") is not None + + def test_invalid_dot(self): + """Reject names with dots.""" + assert BACKEND_NAME_PATTERN.match("my.backend") is None + + def test_invalid_special_chars(self): + """Reject names with special characters.""" + assert BACKEND_NAME_PATTERN.match("my@backend") is None + assert BACKEND_NAME_PATTERN.match("my/backend") is None + + +class TestNonAlnumPattern: + """Tests for NON_ALNUM_PATTERN regex.""" + + def test_replace_dots(self): + """Replace dots.""" + result = NON_ALNUM_PATTERN.sub("_", "example.com") + assert result == "example_com" + + def test_replace_hyphens(self): + """Replace hyphens.""" + result = NON_ALNUM_PATTERN.sub("_", "my-api") + assert result == "my_api" + + def test_preserve_alphanumeric(self): + """Preserve alphanumeric characters.""" + result = NON_ALNUM_PATTERN.sub("_", "abc123") + assert result == "abc123" + + def test_complex_replacement(self): + """Complex domain replacement.""" + result = NON_ALNUM_PATTERN.sub("_", "api.my-site.example.com") + assert result == "api_my_site_example_com" + + +class TestStatField: + """Tests for StatField constants.""" + + def test_field_indices(self): + """Verify stat field indices.""" + assert StatField.PXNAME == 0 + assert StatField.SVNAME == 1 + assert StatField.SCUR == 4 + assert StatField.SMAX == 6 + assert StatField.STATUS == 17 + assert StatField.WEIGHT == 18 + assert StatField.CHECK_STATUS == 36 + + +class TestStateField: + """Tests for StateField constants.""" + + def test_field_indices(self): + """Verify state field indices.""" + assert StateField.BE_ID == 0 + assert StateField.BE_NAME == 1 + assert StateField.SRV_ID == 2 + assert StateField.SRV_NAME == 3 + assert StateField.SRV_ADDR == 4 + assert StateField.SRV_OP_STATE == 5 + assert StateField.SRV_ADMIN_STATE == 6 + assert StateField.SRV_PORT == 18 + + +class TestConfigConstants: + """Tests for configuration constants.""" + + def test_pool_count(self): + """Pool count has expected value.""" + assert POOL_COUNT == 100 + + def test_max_slots(self): + """Max slots has expected value.""" + assert MAX_SLOTS == 10 + + def test_max_response_size(self): + """Max response size is reasonable.""" + assert MAX_RESPONSE_SIZE == 10 * 1024 * 1024 # 10 MB + + def test_socket_timeout(self): + """Socket timeout is reasonable.""" + assert SOCKET_TIMEOUT == 5 + + def test_max_bulk_servers(self): + """Max bulk servers is reasonable.""" + assert MAX_BULK_SERVERS == 10 + + +class TestEnvironmentVariables: + """Tests for environment variable configuration.""" + + def test_default_mcp_host(self): + """Default MCP host is 0.0.0.0.""" + # Import fresh to get defaults + with patch.dict(os.environ, {}, clear=True): + # Re-import to test defaults + from importlib import reload + import haproxy_mcp.config as config + reload(config) + # Note: Due to Python's module caching, this test verifies the + # default values are what we expect from the source code + assert config.MCP_HOST == "0.0.0.0" + + def test_default_mcp_port(self): + """Default MCP port is 8000.""" + from haproxy_mcp.config import MCP_PORT + assert MCP_PORT == 8000 + + def test_default_haproxy_host(self): + """Default HAProxy host is localhost.""" + from haproxy_mcp.config import HAPROXY_HOST + assert HAPROXY_HOST == "localhost" + + def test_default_haproxy_port(self): + """Default HAProxy port is 9999.""" + from haproxy_mcp.config import HAPROXY_PORT + assert HAPROXY_PORT == 9999 diff --git a/tests/unit/test_file_ops.py b/tests/unit/test_file_ops.py new file mode 100644 index 0000000..49e4ccd --- /dev/null +++ b/tests/unit/test_file_ops.py @@ -0,0 +1,498 @@ +"""Unit tests for file_ops module.""" + +import json +import os +from unittest.mock import patch + +import pytest + +from haproxy_mcp.file_ops import ( + atomic_write_file, + get_map_contents, + save_map_file, + get_domain_backend, + split_domain_entries, + is_legacy_backend, + get_legacy_backend_name, + get_backend_and_prefix, + load_servers_config, + save_servers_config, + add_server_to_config, + remove_server_from_config, + remove_domain_from_config, + load_certs_config, + save_certs_config, + add_cert_to_config, + remove_cert_from_config, +) + + +class TestAtomicWriteFile: + """Tests for atomic_write_file function.""" + + def test_write_new_file(self, tmp_path): + """Write to a new file.""" + file_path = str(tmp_path / "test.txt") + content = "Hello, World!" + + atomic_write_file(file_path, content) + + assert os.path.exists(file_path) + with open(file_path) as f: + assert f.read() == content + + def test_overwrite_existing_file(self, tmp_path): + """Overwrite an existing file.""" + file_path = str(tmp_path / "test.txt") + with open(file_path, "w") as f: + f.write("Old content") + + atomic_write_file(file_path, "New content") + + with open(file_path) as f: + assert f.read() == "New content" + + def test_preserves_directory(self, tmp_path): + """Writing does not create intermediate directories.""" + file_path = str(tmp_path / "subdir" / "test.txt") + + with pytest.raises(IOError): + atomic_write_file(file_path, "content") + + def test_unicode_content(self, tmp_path): + """Unicode content is properly written.""" + file_path = str(tmp_path / "unicode.txt") + content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese + + atomic_write_file(file_path, content) + + with open(file_path, encoding="utf-8") as f: + assert f.read() == content + + def test_multiline_content(self, tmp_path): + """Multi-line content is properly written.""" + file_path = str(tmp_path / "multiline.txt") + content = "line1\nline2\nline3" + + atomic_write_file(file_path, content) + + with open(file_path) as f: + assert f.read() == content + + +class TestGetMapContents: + """Tests for get_map_contents function.""" + + def test_read_map_file(self, patch_config_paths): + """Read entries from map file.""" + # Write test content to map file + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + f.write("api.example.com pool_2\n") + + entries = get_map_contents() + + assert ("example.com", "pool_1") in entries + assert ("api.example.com", "pool_2") in entries + + def test_read_both_map_files(self, patch_config_paths): + """Read entries from both domains.map and wildcards.map.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + with open(patch_config_paths["wildcards_file"], "w") as f: + f.write(".example.com pool_1\n") + + entries = get_map_contents() + + assert ("example.com", "pool_1") in entries + assert (".example.com", "pool_1") in entries + + def test_skip_comments(self, patch_config_paths): + """Comments are skipped.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("# This is a comment\n") + f.write("example.com pool_1\n") + f.write("# Another comment\n") + + entries = get_map_contents() + + assert len(entries) == 1 + assert entries[0] == ("example.com", "pool_1") + + def test_skip_empty_lines(self, patch_config_paths): + """Empty lines are skipped.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("\n") + f.write("example.com pool_1\n") + f.write("\n") + f.write("api.example.com pool_2\n") + + entries = get_map_contents() + + assert len(entries) == 2 + + def test_file_not_found(self, patch_config_paths): + """Missing file returns empty list.""" + os.unlink(patch_config_paths["map_file"]) + os.unlink(patch_config_paths["wildcards_file"]) + + entries = get_map_contents() + + assert entries == [] + + +class TestSplitDomainEntries: + """Tests for split_domain_entries function.""" + + def test_split_entries(self): + """Split entries into exact and wildcard.""" + entries = [ + ("example.com", "pool_1"), + (".example.com", "pool_1"), + ("api.example.com", "pool_2"), + (".api.example.com", "pool_2"), + ] + + exact, wildcards = split_domain_entries(entries) + + assert len(exact) == 2 + assert len(wildcards) == 2 + assert ("example.com", "pool_1") in exact + assert (".example.com", "pool_1") in wildcards + + def test_empty_entries(self): + """Empty entries returns empty lists.""" + exact, wildcards = split_domain_entries([]) + + assert exact == [] + assert wildcards == [] + + def test_all_exact(self): + """All exact entries.""" + entries = [ + ("example.com", "pool_1"), + ("api.example.com", "pool_2"), + ] + + exact, wildcards = split_domain_entries(entries) + + assert len(exact) == 2 + assert len(wildcards) == 0 + + +class TestSaveMapFile: + """Tests for save_map_file function.""" + + def test_save_entries(self, patch_config_paths): + """Save entries to separate map files.""" + entries = [ + ("example.com", "pool_1"), + (".example.com", "pool_1"), + ] + + save_map_file(entries) + + # Check exact domains file + with open(patch_config_paths["map_file"]) as f: + content = f.read() + assert "example.com pool_1" in content + + # Check wildcards file + with open(patch_config_paths["wildcards_file"]) as f: + content = f.read() + assert ".example.com pool_1" in content + + def test_sorted_output(self, patch_config_paths): + """Entries are sorted in output.""" + entries = [ + ("z.example.com", "pool_3"), + ("a.example.com", "pool_1"), + ("m.example.com", "pool_2"), + ] + + save_map_file(entries) + + with open(patch_config_paths["map_file"]) as f: + lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] + + assert lines[0] == "a.example.com pool_1" + assert lines[1] == "m.example.com pool_2" + assert lines[2] == "z.example.com pool_3" + + +class TestGetDomainBackend: + """Tests for get_domain_backend function.""" + + def test_find_existing_domain(self, patch_config_paths): + """Find backend for existing domain.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + backend = get_domain_backend("example.com") + + assert backend == "pool_1" + + def test_domain_not_found(self, patch_config_paths): + """Non-existent domain returns None.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + backend = get_domain_backend("other.com") + + assert backend is None + + +class TestIsLegacyBackend: + """Tests for is_legacy_backend function.""" + + def test_pool_backend(self): + """Pool backend is not legacy.""" + assert is_legacy_backend("pool_1") is False + assert is_legacy_backend("pool_100") is False + + def test_legacy_backend(self): + """Non-pool backend is legacy.""" + assert is_legacy_backend("api_example_com_backend") is True + assert is_legacy_backend("static_backend") is True + + +class TestGetLegacyBackendName: + """Tests for get_legacy_backend_name function.""" + + def test_convert_domain(self): + """Convert domain to legacy backend name.""" + result = get_legacy_backend_name("api.example.com") + assert result == "api_example_com_backend" + + +class TestGetBackendAndPrefix: + """Tests for get_backend_and_prefix function.""" + + def test_pool_backend(self, patch_config_paths): + """Pool backend returns pool-based prefix.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_5\n") + + backend, prefix = get_backend_and_prefix("example.com") + + assert backend == "pool_5" + assert prefix == "pool_5" + + def test_unknown_domain_uses_legacy(self, patch_config_paths): + """Unknown domain uses legacy backend naming.""" + backend, prefix = get_backend_and_prefix("unknown.com") + + assert backend == "unknown_com_backend" + assert prefix == "unknown_com" + + +class TestLoadServersConfig: + """Tests for load_servers_config function.""" + + def test_load_existing_config(self, patch_config_paths, sample_servers_config): + """Load existing config file.""" + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(sample_servers_config, f) + + config = load_servers_config() + + assert "example.com" in config + assert config["example.com"]["1"]["ip"] == "10.0.0.1" + + def test_file_not_found(self, patch_config_paths): + """Missing file returns empty dict.""" + os.unlink(patch_config_paths["servers_file"]) + + config = load_servers_config() + + assert config == {} + + def test_invalid_json(self, patch_config_paths): + """Invalid JSON returns empty dict.""" + with open(patch_config_paths["servers_file"], "w") as f: + f.write("not valid json {{{") + + config = load_servers_config() + + assert config == {} + + +class TestSaveServersConfig: + """Tests for save_servers_config function.""" + + def test_save_config(self, patch_config_paths): + """Save config to file.""" + config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}} + + save_servers_config(config) + + with open(patch_config_paths["servers_file"]) as f: + loaded = json.load(f) + assert loaded == config + + +class TestAddServerToConfig: + """Tests for add_server_to_config function.""" + + def test_add_to_empty_config(self, patch_config_paths): + """Add server to empty config.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + + config = load_servers_config() + assert config["example.com"]["1"]["ip"] == "10.0.0.1" + assert config["example.com"]["1"]["http_port"] == 80 + + def test_add_to_existing_domain(self, patch_config_paths): + """Add server to domain with existing servers.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) + + config = load_servers_config() + assert "1" in config["example.com"] + assert "2" in config["example.com"] + + def test_overwrite_existing_slot(self, patch_config_paths): + """Overwrite existing slot.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 1, "10.0.0.99", 8080) + + config = load_servers_config() + assert config["example.com"]["1"]["ip"] == "10.0.0.99" + assert config["example.com"]["1"]["http_port"] == 8080 + + +class TestRemoveServerFromConfig: + """Tests for remove_server_from_config function.""" + + def test_remove_existing_server(self, patch_config_paths): + """Remove existing server.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) + + remove_server_from_config("example.com", 1) + + config = load_servers_config() + assert "1" not in config["example.com"] + assert "2" in config["example.com"] + + def test_remove_last_server_removes_domain(self, patch_config_paths): + """Removing last server removes domain entry.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + + remove_server_from_config("example.com", 1) + + config = load_servers_config() + assert "example.com" not in config + + def test_remove_nonexistent_server(self, patch_config_paths): + """Removing non-existent server is a no-op.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + + remove_server_from_config("example.com", 99) # Non-existent slot + + config = load_servers_config() + assert "1" in config["example.com"] + + +class TestRemoveDomainFromConfig: + """Tests for remove_domain_from_config function.""" + + def test_remove_existing_domain(self, patch_config_paths): + """Remove existing domain.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("other.com", 1, "10.0.0.2", 80) + + remove_domain_from_config("example.com") + + config = load_servers_config() + assert "example.com" not in config + assert "other.com" in config + + def test_remove_nonexistent_domain(self, patch_config_paths): + """Removing non-existent domain is a no-op.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + + remove_domain_from_config("other.com") # Non-existent + + config = load_servers_config() + assert "example.com" in config + + +class TestLoadCertsConfig: + """Tests for load_certs_config function.""" + + def test_load_existing_config(self, patch_config_paths): + """Load existing certs config.""" + with open(patch_config_paths["certs_file"], "w") as f: + json.dump({"domains": ["example.com", "other.com"]}, f) + + domains = load_certs_config() + + assert "example.com" in domains + assert "other.com" in domains + + def test_file_not_found(self, patch_config_paths): + """Missing file returns empty list.""" + os.unlink(patch_config_paths["certs_file"]) + + domains = load_certs_config() + + assert domains == [] + + +class TestSaveCertsConfig: + """Tests for save_certs_config function.""" + + def test_save_domains(self, patch_config_paths): + """Save domains to certs config.""" + save_certs_config(["z.com", "a.com"]) + + with open(patch_config_paths["certs_file"]) as f: + data = json.load(f) + + # Should be sorted + assert data["domains"] == ["a.com", "z.com"] + + +class TestAddCertToConfig: + """Tests for add_cert_to_config function.""" + + def test_add_new_cert(self, patch_config_paths): + """Add new certificate domain.""" + add_cert_to_config("example.com") + + domains = load_certs_config() + assert "example.com" in domains + + def test_add_duplicate_cert(self, patch_config_paths): + """Adding duplicate cert is a no-op.""" + add_cert_to_config("example.com") + add_cert_to_config("example.com") + + domains = load_certs_config() + assert domains.count("example.com") == 1 + + +class TestRemoveCertFromConfig: + """Tests for remove_cert_from_config function.""" + + def test_remove_existing_cert(self, patch_config_paths): + """Remove existing certificate domain.""" + add_cert_to_config("example.com") + add_cert_to_config("other.com") + + remove_cert_from_config("example.com") + + domains = load_certs_config() + assert "example.com" not in domains + assert "other.com" in domains + + def test_remove_nonexistent_cert(self, patch_config_paths): + """Removing non-existent cert is a no-op.""" + add_cert_to_config("example.com") + + remove_cert_from_config("other.com") # Non-existent + + domains = load_certs_config() + assert "example.com" in domains diff --git a/tests/unit/test_haproxy_client.py b/tests/unit/test_haproxy_client.py new file mode 100644 index 0000000..1159033 --- /dev/null +++ b/tests/unit/test_haproxy_client.py @@ -0,0 +1,279 @@ +"""Unit tests for haproxy_client module.""" + +import socket +from unittest.mock import patch, MagicMock + +import pytest + +from haproxy_mcp.haproxy_client import ( + haproxy_cmd, + haproxy_cmd_checked, + haproxy_cmd_batch, + reload_haproxy, +) +from haproxy_mcp.exceptions import HaproxyError + + +class TestHaproxyCmd: + """Tests for haproxy_cmd function.""" + + def test_successful_command(self, mock_socket_class, mock_select): + """Successful command execution returns response.""" + mock_sock = mock_socket_class( + responses={"show info": "Version: 3.3.2\nUptime_sec: 3600"} + ) + + with patch("socket.socket", return_value=mock_sock): + result = haproxy_cmd("show info") + + assert "Version: 3.3.2" in result + assert "show info" in mock_sock.sent_commands + + def test_empty_response(self, mock_socket_class, mock_select): + """Command with empty response returns empty string.""" + mock_sock = mock_socket_class(default_response="") + + with patch("socket.socket", return_value=mock_sock): + result = haproxy_cmd("set server pool_1/pool_1_1 state ready") + + assert result == "" + + def test_connection_refused_error(self, mock_select): + """Connection refused raises HaproxyError.""" + with patch("socket.socket") as mock_socket: + mock_socket.return_value.__enter__ = MagicMock(side_effect=ConnectionRefusedError()) + mock_socket.return_value.__exit__ = MagicMock(return_value=False) + + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd("show info") + + assert "Connection refused" in str(exc_info.value) + + def test_socket_timeout_error(self, mock_select): + """Socket timeout raises HaproxyError.""" + with patch("socket.socket") as mock_socket: + mock_socket.return_value.__enter__ = MagicMock(side_effect=socket.timeout()) + mock_socket.return_value.__exit__ = MagicMock(return_value=False) + + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd("show info") + + assert "timeout" in str(exc_info.value).lower() + + def test_unicode_decode_error(self, mock_socket_class, mock_select): + """Invalid UTF-8 response raises HaproxyError.""" + # Create a mock that returns invalid UTF-8 bytes + class BadUtf8Socket(mock_socket_class): + def sendall(self, data): + self.sent_commands.append(data.decode().strip()) + self._response_buffer = b"\xff\xfe" # Invalid UTF-8 + + mock_sock = BadUtf8Socket() + + with patch("socket.socket", return_value=mock_sock): + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd("show info") + + assert "UTF-8" in str(exc_info.value) + + def test_multiline_response(self, mock_socket_class, mock_select): + """Multi-line response is properly returned.""" + multi_line = "pool_1\npool_2\npool_3" + mock_sock = mock_socket_class(responses={"show backend": multi_line}) + + with patch("socket.socket", return_value=mock_sock): + result = haproxy_cmd("show backend") + + assert "pool_1" in result + assert "pool_2" in result + assert "pool_3" in result + + +class TestHaproxyCmdChecked: + """Tests for haproxy_cmd_checked function.""" + + def test_successful_command(self, mock_socket_class, mock_select): + """Successful command returns response.""" + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + result = haproxy_cmd_checked("set server pool_1/pool_1_1 state ready") + + assert result == "" + + def test_error_response_no_such(self, mock_socket_class, mock_select): + """Response containing 'No such' raises HaproxyError.""" + mock_sock = mock_socket_class( + responses={"set server": "No such server."} + ) + + with patch("socket.socket", return_value=mock_sock): + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd_checked("set server pool_99/pool_99_1 state ready") + + assert "No such" in str(exc_info.value) + + def test_error_response_not_found(self, mock_socket_class, mock_select): + """Response containing 'not found' raises HaproxyError.""" + mock_sock = mock_socket_class( + responses={"del map": "Backend not found."} + ) + + with patch("socket.socket", return_value=mock_sock): + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd_checked("del map /path example.com") + + assert "not found" in str(exc_info.value) + + def test_error_response_error(self, mock_socket_class, mock_select): + """Response containing 'error' raises HaproxyError.""" + mock_sock = mock_socket_class( + responses={"set server": "error: invalid state"} + ) + + with patch("socket.socket", return_value=mock_sock): + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd_checked("set server pool_1/pool_1_1 state invalid") + + assert "error" in str(exc_info.value).lower() + + def test_error_response_failed(self, mock_socket_class, mock_select): + """Response containing 'failed' raises HaproxyError.""" + mock_sock = mock_socket_class( + responses={"set server": "Command failed"} + ) + + with patch("socket.socket", return_value=mock_sock): + with pytest.raises(HaproxyError) as exc_info: + haproxy_cmd_checked("set server pool_1/pool_1_1 addr bad") + + assert "failed" in str(exc_info.value).lower() + + +class TestHaproxyCmdBatch: + """Tests for haproxy_cmd_batch function.""" + + def test_empty_commands(self): + """Empty command list returns empty list.""" + result = haproxy_cmd_batch([]) + assert result == [] + + def test_single_command(self, mock_socket_class, mock_select): + """Single command uses haproxy_cmd_checked.""" + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + result = haproxy_cmd_batch(["set server pool_1/pool_1_1 state ready"]) + + assert len(result) == 1 + + def test_multiple_commands(self, mock_socket_class, mock_select): + """Multiple commands are executed separately.""" + # Each command gets its own socket connection + call_count = 0 + def create_mock_socket(*args, **kwargs): + nonlocal call_count + call_count += 1 + return mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", side_effect=create_mock_socket): + result = haproxy_cmd_batch([ + "set server pool_1/pool_1_1 addr 10.0.0.1 port 80", + "set server pool_1/pool_1_1 state ready", + ]) + + assert len(result) == 2 + assert call_count == 2 # One connection per command + + def test_error_in_batch_raises(self, mock_socket_class, mock_select): + """Error in batch command raises immediately.""" + mock_sock = mock_socket_class( + responses={ + "set server pool_1/pool_1_1 addr": "", + "set server pool_1/pool_1_1 state": "No such server", + } + ) + + call_count = 0 + def create_socket(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_socket_class(responses={"set server": ""}) + else: + return mock_socket_class(responses={"set server": "No such server"}) + + with patch("socket.socket", side_effect=create_socket): + with pytest.raises(HaproxyError): + haproxy_cmd_batch([ + "set server pool_1/pool_1_1 addr 10.0.0.1 port 80", + "set server pool_1/pool_1_1 state ready", + ]) + + +class TestReloadHaproxy: + """Tests for reload_haproxy function.""" + + def test_successful_reload(self, mock_subprocess): + """Successful reload returns (True, 'OK').""" + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + + success, message = reload_haproxy() + + assert success is True + assert message == "OK" + + def test_validation_failure(self, mock_subprocess): + """Config validation failure returns (False, error).""" + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="[ALERT] Invalid configuration" + ) + + success, message = reload_haproxy() + + assert success is False + assert "validation failed" in message.lower() + assert "Invalid configuration" in message + + def test_reload_failure(self, mock_subprocess): + """Reload command failure returns (False, error).""" + # First call (validation) succeeds, second call (reload) fails + mock_subprocess.side_effect = [ + MagicMock(returncode=0, stdout="", stderr=""), + MagicMock(returncode=1, stdout="", stderr="Container not found"), + ] + + success, message = reload_haproxy() + + assert success is False + assert "Reload failed" in message + + def test_podman_not_found(self, mock_subprocess): + """Podman not found returns (False, error).""" + mock_subprocess.side_effect = FileNotFoundError() + + success, message = reload_haproxy() + + assert success is False + assert "podman" in message.lower() + + def test_subprocess_timeout(self, mock_subprocess): + """Subprocess timeout returns (False, error).""" + import subprocess + mock_subprocess.side_effect = subprocess.TimeoutExpired("podman", 30) + + success, message = reload_haproxy() + + assert success is False + assert "timed out" in message.lower() + + def test_os_error(self, mock_subprocess): + """OS error returns (False, error).""" + mock_subprocess.side_effect = OSError("Permission denied") + + success, message = reload_haproxy() + + assert success is False + assert "OS error" in message diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..1282755 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,130 @@ +"""Unit tests for utils module.""" + +import pytest + +from haproxy_mcp.utils import parse_stat_csv + + +class TestParseStatCsv: + """Tests for parse_stat_csv function.""" + + def test_parse_valid_csv(self, response_builder): + """Parse valid HAProxy stat CSV output.""" + csv = response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "scur": 5, "status": "UP", "weight": 1, "check_status": "L4OK"}, + {"pxname": "pool_1", "svname": "pool_1_2", "scur": 3, "status": "UP", "weight": 1, "check_status": "L4OK"}, + ]) + + results = list(parse_stat_csv(csv)) + + assert len(results) == 2 + assert results[0]["pxname"] == "pool_1" + assert results[0]["svname"] == "pool_1_1" + assert results[0]["scur"] == "5" + assert results[0]["status"] == "UP" + assert results[0]["weight"] == "1" + assert results[0]["check_status"] == "L4OK" + + def test_parse_empty_output(self): + """Parse empty output returns no results.""" + results = list(parse_stat_csv("")) + assert results == [] + + def test_parse_header_only(self): + """Parse output with only header returns no results.""" + csv = "# pxname,svname,qcur,qmax,scur,smax,..." + results = list(parse_stat_csv(csv)) + assert results == [] + + def test_skip_comment_lines(self): + """Comment lines are skipped.""" + csv = """# This is a comment +# Another comment +pool_1,pool_1_1,0,0,5,10,0,0,0,0,0,0,0,0,0,0,0,UP,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK,""" + + results = list(parse_stat_csv(csv)) + assert len(results) == 1 + assert results[0]["pxname"] == "pool_1" + + def test_skip_empty_lines(self): + """Empty lines are skipped.""" + csv = """ +pool_1,pool_1_1,0,0,5,10,0,0,0,0,0,0,0,0,0,0,0,UP,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK, + +pool_1,pool_1_2,0,0,3,10,0,0,0,0,0,0,0,0,0,0,0,UP,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK, +""" + + results = list(parse_stat_csv(csv)) + assert len(results) == 2 + + def test_parse_down_status(self, response_builder): + """Parse server with DOWN status.""" + csv = response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "DOWN", "check_status": "L4TOUT"}, + ]) + + results = list(parse_stat_csv(csv)) + + assert len(results) == 1 + assert results[0]["status"] == "DOWN" + assert results[0]["check_status"] == "L4TOUT" + + def test_parse_maint_status(self, response_builder): + """Parse server with MAINT status.""" + csv = response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "MAINT"}, + ]) + + results = list(parse_stat_csv(csv)) + + assert len(results) == 1 + assert results[0]["status"] == "MAINT" + + def test_parse_multiple_backends(self, response_builder): + """Parse output with multiple backends.""" + csv = response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"}, + {"pxname": "pool_2", "svname": "pool_2_1", "status": "UP"}, + {"pxname": "pool_3", "svname": "pool_3_1", "status": "DOWN"}, + ]) + + results = list(parse_stat_csv(csv)) + + assert len(results) == 3 + assert results[0]["pxname"] == "pool_1" + assert results[1]["pxname"] == "pool_2" + assert results[2]["pxname"] == "pool_3" + + def test_parse_frontend_backend_rows(self): + """Frontend and BACKEND rows are included.""" + csv = """pool_1,FRONTEND,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,UP,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,, +pool_1,pool_1_1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,UP,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK, +pool_1,BACKEND,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,UP,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,,""" + + results = list(parse_stat_csv(csv)) + + # All rows with enough columns are returned + assert len(results) == 3 + svnames = [r["svname"] for r in results] + assert "FRONTEND" in svnames + assert "pool_1_1" in svnames + assert "BACKEND" in svnames + + def test_parse_insufficient_columns(self): + """Rows with insufficient columns are skipped.""" + csv = "pool_1,pool_1_1,0,0,5" # Only 5 columns, need more than 17 + + results = list(parse_stat_csv(csv)) + assert results == [] + + def test_generator_is_lazy(self, response_builder): + """Verify parse_stat_csv returns a generator (lazy evaluation).""" + csv = response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"}, + ]) + + result = parse_stat_csv(csv) + + # Should return a generator, not a list + import types + assert isinstance(result, types.GeneratorType) diff --git a/tests/unit/test_validation.py b/tests/unit/test_validation.py new file mode 100644 index 0000000..86312c8 --- /dev/null +++ b/tests/unit/test_validation.py @@ -0,0 +1,275 @@ +"""Unit tests for validation module.""" + +import pytest + +from haproxy_mcp.validation import ( + validate_domain, + validate_ip, + validate_port, + validate_backend_name, + domain_to_backend, +) + + +class TestValidateDomain: + """Tests for validate_domain function.""" + + def test_valid_simple_domain(self): + """Valid simple domain.""" + assert validate_domain("example.com") is True + + def test_valid_subdomain(self): + """Valid subdomain.""" + assert validate_domain("api.example.com") is True + + def test_valid_deep_subdomain(self): + """Valid deep subdomain.""" + assert validate_domain("a.b.c.example.com") is True + + def test_valid_domain_with_numbers(self): + """Valid domain with numbers.""" + assert validate_domain("api123.example.com") is True + + def test_valid_domain_with_hyphen(self): + """Valid domain with hyphens.""" + assert validate_domain("my-api.example-site.com") is True + + def test_valid_single_char_labels(self): + """Valid domain with single character labels.""" + assert validate_domain("a.b.c") is True + + def test_valid_max_label_length(self): + """Valid domain with max label length (63 chars).""" + label = "a" * 63 + assert validate_domain(f"{label}.com") is True + + def test_invalid_empty_domain(self): + """Empty domain is invalid.""" + assert validate_domain("") is False + + def test_invalid_none_domain(self): + """None domain is invalid.""" + assert validate_domain(None) is False + + def test_invalid_starts_with_hyphen(self): + """Domain starting with hyphen is invalid.""" + assert validate_domain("-example.com") is False + + def test_invalid_ends_with_hyphen(self): + """Domain label ending with hyphen is invalid.""" + assert validate_domain("example-.com") is False + + def test_invalid_double_dot(self): + """Domain with double dot is invalid.""" + assert validate_domain("example..com") is False + + def test_invalid_starts_with_dot(self): + """Domain starting with dot is invalid.""" + assert validate_domain(".example.com") is False + + def test_invalid_special_characters(self): + """Domain with special characters is invalid.""" + assert validate_domain("example@.com") is False + assert validate_domain("example!.com") is False + assert validate_domain("example$.com") is False + + def test_invalid_underscore(self): + """Domain with underscore is invalid.""" + assert validate_domain("my_api.example.com") is False + + def test_invalid_too_long(self): + """Domain exceeding 253 chars is invalid.""" + long_domain = "a" * 254 + assert validate_domain(long_domain) is False + + def test_invalid_label_too_long(self): + """Domain label exceeding 63 chars is invalid.""" + label = "a" * 64 + assert validate_domain(f"{label}.com") is False + + def test_valid_numeric_domain(self): + """Domain with all numeric label is valid.""" + assert validate_domain("123.example.com") is True + + def test_invalid_only_dots(self): + """Domain with only dots is invalid.""" + assert validate_domain("...") is False + + +class TestValidateIP: + """Tests for validate_ip function.""" + + def test_valid_ipv4(self): + """Valid IPv4 address.""" + assert validate_ip("192.168.1.1") is True + assert validate_ip("10.0.0.1") is True + assert validate_ip("255.255.255.255") is True + assert validate_ip("0.0.0.0") is True + + def test_valid_ipv6(self): + """Valid IPv6 address.""" + assert validate_ip("::1") is True + assert validate_ip("2001:db8::1") is True + assert validate_ip("fe80::1") is True + assert validate_ip("2001:0db8:0000:0000:0000:0000:0000:0001") is True + + def test_invalid_empty_string(self): + """Empty string is invalid by default.""" + assert validate_ip("") is False + + def test_valid_empty_string_when_allowed(self): + """Empty string is valid when allow_empty=True.""" + assert validate_ip("", allow_empty=True) is True + + def test_invalid_none(self): + """None is invalid.""" + assert validate_ip(None) is False + + def test_invalid_hostname(self): + """Hostname is not a valid IP.""" + assert validate_ip("example.com") is False + + def test_invalid_ipv4_out_of_range(self): + """IPv4 with octets out of range is invalid.""" + assert validate_ip("256.1.1.1") is False + assert validate_ip("1.1.1.300") is False + + def test_invalid_ipv4_format(self): + """Invalid IPv4 format.""" + assert validate_ip("192.168.1") is False + assert validate_ip("192.168.1.1.1") is False + + def test_invalid_ipv6_format(self): + """Invalid IPv6 format.""" + assert validate_ip("2001:db8:::1") is False + assert validate_ip("gggg::1") is False + + def test_invalid_mixed_format(self): + """Mixed invalid format.""" + assert validate_ip("192.168.1.1:8080") is False + + +class TestValidatePort: + """Tests for validate_port function.""" + + def test_valid_port_min(self): + """Valid minimum port.""" + assert validate_port("1") is True + + def test_valid_port_max(self): + """Valid maximum port.""" + assert validate_port("65535") is True + + def test_valid_port_common(self): + """Valid common ports.""" + assert validate_port("80") is True + assert validate_port("443") is True + assert validate_port("8080") is True + + def test_invalid_port_zero(self): + """Port 0 is invalid.""" + assert validate_port("0") is False + + def test_invalid_port_negative(self): + """Negative port is invalid.""" + assert validate_port("-1") is False + + def test_invalid_port_too_high(self): + """Port above 65535 is invalid.""" + assert validate_port("65536") is False + + def test_invalid_port_empty(self): + """Empty port is invalid.""" + assert validate_port("") is False + + def test_invalid_port_none(self): + """None port is invalid.""" + assert validate_port(None) is False + + def test_invalid_port_not_numeric(self): + """Non-numeric port is invalid.""" + assert validate_port("abc") is False + assert validate_port("80a") is False + + def test_invalid_port_float(self): + """Float port is invalid.""" + assert validate_port("80.5") is False + + +class TestValidateBackendName: + """Tests for validate_backend_name function.""" + + def test_valid_pool_name(self): + """Valid pool backend names.""" + assert validate_backend_name("pool_1") is True + assert validate_backend_name("pool_100") is True + + def test_valid_alphanumeric(self): + """Valid alphanumeric names.""" + assert validate_backend_name("backend1") is True + assert validate_backend_name("my_backend") is True + assert validate_backend_name("my-backend") is True + + def test_valid_mixed(self): + """Valid mixed character names.""" + assert validate_backend_name("api_example_com_backend") is True + assert validate_backend_name("my-api-backend-1") is True + + def test_invalid_empty(self): + """Empty name is invalid.""" + assert validate_backend_name("") is False + + def test_invalid_none(self): + """None name is invalid.""" + assert validate_backend_name(None) is False + + def test_invalid_special_chars(self): + """Names with special characters are invalid.""" + assert validate_backend_name("backend@1") is False + assert validate_backend_name("my.backend") is False + assert validate_backend_name("my/backend") is False + assert validate_backend_name("my backend") is False + + def test_invalid_too_long(self): + """Name exceeding 255 chars is invalid.""" + long_name = "a" * 256 + assert validate_backend_name(long_name) is False + + def test_valid_max_length(self): + """Name at exactly 255 chars is valid.""" + max_name = "a" * 255 + assert validate_backend_name(max_name) is True + + +class TestDomainToBackend: + """Tests for domain_to_backend function.""" + + def test_simple_domain(self): + """Simple domain conversion.""" + assert domain_to_backend("example.com") == "example_com" + + def test_subdomain(self): + """Subdomain conversion.""" + assert domain_to_backend("api.example.com") == "api_example_com" + + def test_domain_with_hyphens(self): + """Domain with hyphens.""" + result = domain_to_backend("my-api.example.com") + assert result == "my_api_example_com" + + def test_complex_domain(self): + """Complex domain conversion.""" + result = domain_to_backend("a.b.c.example-site.com") + assert result == "a_b_c_example_site_com" + + def test_already_simple(self): + """Domain that's already mostly valid.""" + result = domain_to_backend("example123") + assert result == "example123" + + def test_invalid_result_raises(self): + """Invalid conversion result raises ValueError.""" + # This should never happen with real domains, but test the safeguard + with pytest.raises(ValueError): + # Mock a case where conversion would fail + domain_to_backend("") diff --git a/tests/unit/tools/__init__.py b/tests/unit/tools/__init__.py new file mode 100644 index 0000000..a0e7cab --- /dev/null +++ b/tests/unit/tools/__init__.py @@ -0,0 +1 @@ +"""Unit tests for HAProxy MCP tools.""" diff --git a/tests/unit/tools/test_certificates.py b/tests/unit/tools/test_certificates.py new file mode 100644 index 0000000..8c2272e --- /dev/null +++ b/tests/unit/tools/test_certificates.py @@ -0,0 +1,1198 @@ +"""Unit tests for certificate management tools.""" + +import json +import os +from unittest.mock import patch, MagicMock + +import pytest + + +class TestGetPemPaths: + """Tests for get_pem_paths function.""" + + def test_get_pem_paths(self): + """Get correct PEM paths.""" + from haproxy_mcp.tools.certificates import get_pem_paths + + host_path, container_path = get_pem_paths("example.com") + + assert host_path == "/opt/haproxy/certs/example.com.pem" + assert container_path == "/etc/haproxy/certs/example.com.pem" + + +class TestLoadCertToHaproxy: + """Tests for load_cert_to_haproxy function.""" + + def test_load_cert_file_not_found(self, tmp_path): + """Fail when PEM file doesn't exist.""" + from haproxy_mcp.tools.certificates import load_cert_to_haproxy + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + success, msg = load_cert_to_haproxy("example.com") + + assert success is False + assert "not found" in msg.lower() + + def test_load_cert_new_cert(self, tmp_path, mock_socket_class, mock_select): + """Load new certificate into HAProxy.""" + # Create PEM file + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", # No existing cert + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import load_cert_to_haproxy + + success, msg = load_cert_to_haproxy("example.com") + + assert success is True + assert msg == "added" + + def test_load_cert_update_existing(self, tmp_path, mock_socket_class, mock_select): + """Update existing certificate in HAProxy.""" + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----") + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "/etc/haproxy/certs/example.com.pem", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import load_cert_to_haproxy + + success, msg = load_cert_to_haproxy("example.com") + + assert success is True + assert msg == "updated" + + +class TestUnloadCertFromHaproxy: + """Tests for unload_cert_from_haproxy function.""" + + def test_unload_cert_not_loaded(self, mock_socket_class, mock_select): + """Unload cert that's not loaded.""" + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import unload_cert_from_haproxy + + success, msg = unload_cert_from_haproxy("example.com") + + assert success is True + assert msg == "not loaded" + + def test_unload_cert_success(self, mock_socket_class, mock_select): + """Unload certificate successfully.""" + mock_sock = mock_socket_class(responses={ + "show ssl cert": "/etc/haproxy/certs/example.com.pem", + "del ssl cert": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import unload_cert_from_haproxy + + success, msg = unload_cert_from_haproxy("example.com") + + assert success is True + assert msg == "unloaded" + + +class TestRestoreCertificates: + """Tests for restore_certificates function.""" + + def test_restore_no_certificates(self, patch_config_paths): + """No certificates to restore.""" + from haproxy_mcp.tools.certificates import restore_certificates + + result = restore_certificates() + + assert result == 0 + + def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select): + """Restore certificates successfully.""" + # Save config + with open(patch_config_paths["certs_file"], "w") as f: + json.dump({"domains": ["example.com"]}, f) + + # Create PEM + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert content") + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import restore_certificates + + result = restore_certificates() + + assert result == 1 + + +class TestHaproxyListCerts: + """Tests for haproxy_list_certs tool function.""" + + def test_list_certs_no_acme(self, mock_subprocess): + """acme.sh not found.""" + mock_subprocess.side_effect = FileNotFoundError() + + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_list_certs"]() + + assert "not found" in result.lower() + + def test_list_certs_empty(self, mock_subprocess): + """No certificates found.""" + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Main_Domain KeyLength SAN_Domains CA Created Renew\n", + stderr="" + ) + + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_list_certs"]() + + assert "No certificates" in result + + def test_list_certs_success(self, mock_subprocess, mock_socket_class, mock_select, tmp_path): + """List certificates successfully.""" + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Main_Domain KeyLength SAN_Domains CA Created Renew\nexample.com ec-256 *.example.com Google 2024-01-01T00:00:00Z 2024-03-01T00:00:00Z\n", + stderr="" + ) + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_list_certs"]() + + assert "example.com" in result + + +class TestHaproxyCertInfo: + """Tests for haproxy_cert_info tool function.""" + + def test_cert_info_invalid_domain(self): + """Reject invalid domain format.""" + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_cert_info"](domain="-invalid") + + assert "Error" in result + assert "Invalid domain" in result + + def test_cert_info_not_found(self, tmp_path): + """Certificate not found.""" + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_cert_info"](domain="example.com") + + assert "Error" in result + assert "not found" in result.lower() + + def test_cert_info_success(self, tmp_path, mock_subprocess, mock_socket_class, mock_select): + """Get certificate info successfully.""" + # Create PEM file + pem_file = tmp_path / "example.com.pem" + pem_file.write_text("cert content") + + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT", + stderr="" + ) + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "/etc/haproxy/certs/example.com.pem", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_cert_info"](domain="example.com") + + assert "example.com" in result + assert "Loaded in HAProxy: Yes" in result + + +class TestHaproxyIssueCert: + """Tests for haproxy_issue_cert tool function.""" + + def test_issue_cert_invalid_domain(self): + """Reject invalid domain format.""" + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_issue_cert"](domain="-invalid", wildcard=True) + + assert "Error" in result + assert "Invalid domain" in result + + def test_issue_cert_no_cf_token(self, tmp_path): + """Fail when CF_Token is not set.""" + with patch.dict(os.environ, {}, clear=True): + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + with patch("os.path.exists", return_value=False): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True) + + assert "CF_Token" in result + + def test_issue_cert_already_exists(self, tmp_path): + """Fail when certificate already exists.""" + cert_dir = tmp_path / "example.com_ecc" + cert_dir.mkdir() + + with patch.dict(os.environ, {"CF_Token": "test_token"}): + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True) + + assert "already exists" in result + + +class TestHaproxyRenewCert: + """Tests for haproxy_renew_cert tool function.""" + + def test_renew_cert_invalid_domain(self): + """Reject invalid domain format.""" + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_cert"](domain="-invalid", force=False) + + assert "Error" in result + assert "Invalid domain" in result + + def test_renew_cert_not_found(self, tmp_path): + """Fail when certificate doesn't exist.""" + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_cert"](domain="example.com", force=False) + + assert "Error" in result + assert "No certificate found" in result + + def test_renew_cert_not_due(self, tmp_path, mock_subprocess): + """Certificate not due for renewal.""" + cert_dir = tmp_path / "example.com_ecc" + cert_dir.mkdir() + + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Skip, Next renewal time is: ...\n", + stderr="Not yet due for renewal" + ) + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_cert"](domain="example.com", force=False) + + assert "not due for renewal" in result + + +class TestHaproxyRenewAllCerts: + """Tests for haproxy_renew_all_certs tool function.""" + + def test_renew_all_no_renewals(self, mock_subprocess): + """No certificates due for renewal.""" + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Checking: example.com\nSkip, Next renewal time...", + stderr="" + ) + + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_all_certs"]() + + assert "No certificates due" in result or "checked" in result + + +class TestHaproxyDeleteCert: + """Tests for haproxy_delete_cert tool function.""" + + def test_delete_cert_invalid_domain(self): + """Reject invalid domain format.""" + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_delete_cert"](domain="-invalid") + + assert "Error" in result + assert "Invalid domain" in result + + def test_delete_cert_not_found(self, tmp_path): + """Fail when certificate doesn't exist.""" + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_delete_cert"](domain="example.com") + + assert "Error" in result + assert "No certificate found" in result + + def test_delete_cert_success(self, tmp_path, mock_subprocess, mock_socket_class, mock_select, patch_config_paths): + """Delete certificate successfully.""" + # Create cert dir and PEM + cert_dir = tmp_path / "acme" / "example.com_ecc" + cert_dir.mkdir(parents=True) + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert") + + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_delete_cert"](domain="example.com") + + assert "Deleted" in result + + +class TestHaproxyLoadCert: + """Tests for haproxy_load_cert tool function.""" + + def test_load_cert_invalid_domain(self): + """Reject invalid domain format.""" + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_load_cert"](domain="-invalid") + + assert "Error" in result + assert "Invalid domain" in result + + def test_load_cert_not_found(self, tmp_path): + """Fail when PEM file doesn't exist.""" + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_load_cert"](domain="example.com") + + assert "Error" in result + assert "not found" in result.lower() + + def test_load_cert_success(self, tmp_path, mock_socket_class, mock_select, patch_config_paths): + """Load certificate successfully.""" + pem_file = tmp_path / "example.com.pem" + pem_file.write_text("cert content") + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_load_cert"](domain="example.com") + + assert "loaded" in result.lower() + assert "example.com" in result + + +class TestHaproxyIssueCertTimeout: + """Tests for haproxy_issue_cert timeout scenarios.""" + + def test_issue_cert_acme_timeout(self, tmp_path, mock_subprocess): + """Handle acme.sh timeout during certificate issuance.""" + import subprocess + mock_subprocess.side_effect = subprocess.TimeoutExpired("acme.sh", 120) + + with patch.dict(os.environ, {"CF_Token": "test_token"}): + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_issue_cert"]( + domain="example.com", + wildcard=True + ) + + assert "timed out" in result.lower() + + def test_issue_cert_acme_failure(self, tmp_path, mock_subprocess): + """Handle acme.sh failure during certificate issuance.""" + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="DNS verification failed" + ) + + with patch.dict(os.environ, {"CF_Token": "test_token"}): + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_issue_cert"]( + domain="example.com", + wildcard=True + ) + + assert "Error" in result + assert "DNS verification failed" in result + + def test_issue_cert_success(self, tmp_path, mock_subprocess, mock_socket_class, mock_select, patch_config_paths): + """Successfully issue a certificate.""" + # Create certs directory + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Cert success", + stderr="" + ) + + # Create PEM file (simulating acme.sh reloadcmd) + pem_file = certs_dir / "example.com.pem" + + def create_pem_file(*args, **kwargs): + pem_file.write_text("cert content") + return MagicMock(returncode=0, stdout="", stderr="") + + mock_subprocess.side_effect = create_pem_file + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch.dict(os.environ, {"CF_Token": "test_token"}): + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_issue_cert"]( + domain="example.com", + wildcard=True + ) + + assert "issued" in result.lower() or "loaded" in result.lower() + + +class TestHaproxyRenewCertTimeout: + """Tests for haproxy_renew_cert timeout scenarios.""" + + def test_renew_cert_timeout(self, tmp_path, mock_subprocess): + """Handle acme.sh timeout during certificate renewal.""" + import subprocess + cert_dir = tmp_path / "example.com_ecc" + cert_dir.mkdir() + + mock_subprocess.side_effect = subprocess.TimeoutExpired("acme.sh", 120) + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_cert"]( + domain="example.com", + force=True + ) + + assert "timed out" in result.lower() + + def test_renew_cert_success(self, tmp_path, mock_subprocess, mock_socket_class, mock_select, patch_config_paths): + """Successfully renew a certificate.""" + cert_dir = tmp_path / "example.com_ecc" + cert_dir.mkdir() + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert content") + + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Cert success", + stderr="" + ) + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_cert"]( + domain="example.com", + force=True + ) + + assert "renewed" in result.lower() + + +class TestHaproxyRenewAllCertsMultiple: + """Tests for haproxy_renew_all_certs with multiple certificates.""" + + def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path): + """Renew multiple certificates successfully.""" + # Write config with multiple domains + with open(patch_config_paths["certs_file"], "w") as f: + json.dump({"domains": ["example.com", "example.org"]}, f) + + # Create PEM files + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + (certs_dir / "example.com.pem").write_text("cert1") + (certs_dir / "example.org.pem").write_text("cert2") + + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="Cert success\nCert success", # Two successful renewals + stderr="" + ) + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_all_certs"]() + + assert "Renewed 2" in result + assert "reloaded" in result.lower() + + def test_renew_all_certs_timeout(self, mock_subprocess): + """Handle timeout during renewal cron.""" + import subprocess + mock_subprocess.side_effect = subprocess.TimeoutExpired("acme.sh", 360) + + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_all_certs"]() + + assert "timed out" in result.lower() + + def test_renew_all_certs_error(self, mock_subprocess): + """Handle error during renewal cron.""" + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="ACME server error" + ) + + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_renew_all_certs"]() + + assert "Error" in result or "ACME server error" in result + + +class TestHaproxyDeleteCertPartialFailure: + """Tests for haproxy_delete_cert partial failure scenarios.""" + + def test_delete_cert_haproxy_unload_failure(self, tmp_path, mock_subprocess, mock_socket_class, mock_select, patch_config_paths): + """Handle HAProxy unload failure during certificate deletion.""" + # Create cert dir and PEM + cert_dir = tmp_path / "acme" / "example.com_ecc" + cert_dir.mkdir(parents=True) + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert") + + # Mock acme.sh removal success + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + + # Mock HAProxy to fail on unload + mock_sock = mock_socket_class(responses={ + "show ssl cert": "/etc/haproxy/certs/example.com.pem", + "del ssl cert": "error: unable to delete certificate", + }) + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_delete_cert"](domain="example.com") + + # Should still delete acme.sh and PEM even if HAProxy unload fails + assert "Deleted" in result or "acme.sh" in result + + def test_delete_cert_acme_removal_failure(self, tmp_path, mock_subprocess, mock_socket_class, mock_select, patch_config_paths): + """Handle acme.sh removal failure during certificate deletion.""" + # Create cert dir and PEM + cert_dir = tmp_path / "acme" / "example.com_ecc" + cert_dir.mkdir(parents=True) + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert") + + # Mock acme.sh removal failure + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="Failed to remove certificate" + ) + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", # Not loaded + }) + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_delete_cert"](domain="example.com") + + # Should report partial success (PEM deleted) and error (acme.sh failed) + assert "Deleted" in result or "PEM" in result + assert "Errors" in result or "acme.sh" in result + + def test_delete_cert_pem_removal_failure(self, tmp_path, mock_subprocess, mock_socket_class, mock_select, patch_config_paths): + """Handle PEM file removal failure during certificate deletion.""" + # Create cert dir but make PEM read-only + cert_dir = tmp_path / "acme" / "example.com_ecc" + cert_dir.mkdir(parents=True) + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert") + + # Mock acme.sh removal success + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", # Not loaded + }) + + # Mock os.remove to fail + def mock_remove(path): + if "example.com.pem" in str(path): + raise PermissionError("Permission denied") + raise FileNotFoundError() + + with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")): + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + with patch("os.remove", side_effect=mock_remove): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_delete_cert"](domain="example.com") + + # Should report partial success (acme.sh deleted) and error (PEM failed) + assert "Deleted" in result + assert "Errors" in result or "Permission" in result + + +class TestLoadCertToHaproxyError: + """Tests for load_cert_to_haproxy error handling.""" + + def test_load_cert_exception_handling(self, tmp_path): + """Handle exception during certificate loading.""" + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + pem_file = certs_dir / "example.com.pem" + pem_file.write_text("cert content") + + # Mock haproxy_cmd to raise exception + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("haproxy_mcp.tools.certificates.haproxy_cmd", side_effect=Exception("Connection failed")): + from haproxy_mcp.tools.certificates import load_cert_to_haproxy + + success, msg = load_cert_to_haproxy("example.com") + + # load_cert_to_haproxy catches exceptions and returns False, msg + assert success is False + assert "Connection failed" in msg + + +class TestUnloadCertFromHaproxyError: + """Tests for unload_cert_from_haproxy error handling.""" + + def test_unload_cert_haproxy_error(self, mock_socket_class, mock_select): + """Handle HAProxy command error during certificate unloading.""" + # Mock HAProxy to return error on del ssl cert + mock_sock = mock_socket_class(responses={ + "show ssl cert": "/etc/haproxy/certs/example.com.pem", + "del ssl cert": "error: certificate in use", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import unload_cert_from_haproxy + + success, msg = unload_cert_from_haproxy("example.com") + + # unload_cert_from_haproxy catches exceptions and returns False, msg + # or may return True "unloaded" since the mock doesn't raise exception + assert success is True or "error" in msg.lower() + + +class TestRestoreCertificatesFailure: + """Tests for restore_certificates failure scenarios.""" + + def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select): + """Handle partial failure when restoring certificates.""" + # Save config with multiple domains + with open(patch_config_paths["certs_file"], "w") as f: + json.dump({"domains": ["example.com", "missing.com"]}, f) + + # Create only one PEM file + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + (certs_dir / "example.com.pem").write_text("cert content") + + mock_sock = mock_socket_class(responses={ + "show ssl cert": "", + "new ssl cert": "", + "set ssl cert": "", + "commit ssl cert": "", + }) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.certificates import restore_certificates + + result = restore_certificates() + + # Should restore 1 (example.com exists), skip 1 (missing.com doesn't exist) + assert result == 1 + + +class TestHaproxyListCertsTimeout: + """Tests for haproxy_list_certs timeout scenarios.""" + + def test_list_certs_timeout(self, mock_subprocess): + """Handle timeout during certificate listing.""" + import subprocess + mock_subprocess.side_effect = subprocess.TimeoutExpired("acme.sh", 30) + + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_list_certs"]() + + assert "timed out" in result.lower() + + +class TestHaproxyCertInfoTimeout: + """Tests for haproxy_cert_info timeout scenarios.""" + + def test_cert_info_timeout(self, tmp_path, mock_subprocess): + """Handle timeout during certificate info retrieval.""" + import subprocess + pem_file = tmp_path / "example.com.pem" + pem_file.write_text("cert content") + + mock_subprocess.side_effect = subprocess.TimeoutExpired("openssl", 30) + + with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(tmp_path)): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_certificate_tools(mcp) + + result = registered_tools["haproxy_cert_info"](domain="example.com") + + assert "timed out" in result.lower() diff --git a/tests/unit/tools/test_configuration.py b/tests/unit/tools/test_configuration.py new file mode 100644 index 0000000..f2b4dc5 --- /dev/null +++ b/tests/unit/tools/test_configuration.py @@ -0,0 +1,749 @@ +"""Unit tests for configuration management tools.""" + +import json +from unittest.mock import patch, MagicMock + +import pytest + + +class TestRestoreServersFromConfig: + """Tests for restore_servers_from_config function.""" + + def test_restore_empty_config(self, patch_config_paths): + """No servers to restore when config is empty.""" + from haproxy_mcp.tools.configuration import restore_servers_from_config + + result = restore_servers_from_config() + + assert result == 0 + + def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): + """Restore servers successfully.""" + # Write config and map + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(sample_servers_config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + f.write("api.example.com pool_2\n") + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import restore_servers_from_config + + result = restore_servers_from_config() + + # example.com has 2 servers, api.example.com has 1 + assert result == 3 + + def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths): + """Skip domains not in map file.""" + config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}} + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import restore_servers_from_config + + result = restore_servers_from_config() + + assert result == 0 + + def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths): + """Skip servers with empty IP.""" + config = {"example.com": {"1": {"ip": "", "http_port": 80}}} + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import restore_servers_from_config + + result = restore_servers_from_config() + + assert result == 0 + + +class TestStartupRestore: + """Tests for startup_restore function.""" + + def test_startup_restore_haproxy_not_ready(self, mock_select): + """Skip restore if HAProxy is not ready.""" + call_count = 0 + + def raise_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + with patch("haproxy_mcp.tools.configuration.STARTUP_RETRY_COUNT", 2): + from haproxy_mcp.tools.configuration import startup_restore + + startup_restore() + + # Should have tried multiple times + assert call_count >= 2 + + def test_startup_restore_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Successfully restore servers and certificates on startup.""" + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + "set server": "", + "show ssl cert": "", + }) + + with patch("socket.socket", return_value=mock_sock): + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0): + with patch("haproxy_mcp.tools.certificates.restore_certificates", return_value=0): + from haproxy_mcp.tools.configuration import startup_restore + + startup_restore() + + # No assertions needed - just verify no exceptions + + +class TestHaproxyReload: + """Tests for haproxy_reload tool function.""" + + def test_reload_success(self, mock_socket_class, mock_select, mock_subprocess, response_builder): + """Reload HAProxy successfully.""" + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=5): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_reload"]() + + assert "reloaded successfully" in result + assert "5 servers restored" in result + + def test_reload_validation_failure(self, mock_subprocess): + """Reload fails on config validation error.""" + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="Configuration error" + ) + + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_reload"]() + + assert "validation failed" in result.lower() or "Configuration error" in result + + +class TestHaproxyCheckConfig: + """Tests for haproxy_check_config tool function.""" + + def test_check_config_valid(self, mock_subprocess): + """Configuration is valid.""" + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_check_config"]() + + assert "valid" in result.lower() + + def test_check_config_invalid(self, mock_subprocess): + """Configuration has errors.""" + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="[ALERT] Syntax error" + ) + + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_check_config"]() + + assert "error" in result.lower() + assert "Syntax error" in result + + def test_check_config_timeout(self, mock_subprocess): + """Configuration check times out.""" + import subprocess + mock_subprocess.side_effect = subprocess.TimeoutExpired("podman", 30) + + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_check_config"]() + + assert "timed out" in result.lower() + + def test_check_config_podman_not_found(self, mock_subprocess): + """Podman not found.""" + mock_subprocess.side_effect = FileNotFoundError() + + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_check_config"]() + + assert "podman" in result.lower() + assert "not found" in result.lower() + + +class TestHaproxySaveState: + """Tests for haproxy_save_state tool function.""" + + def test_save_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Save state successfully.""" + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + ]), + }) + + with patch("haproxy_mcp.tools.configuration.STATE_FILE", patch_config_paths["state_file"]): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_save_state"]() + + assert "saved" in result.lower() + + def test_save_state_haproxy_error(self, mock_select): + """Handle HAProxy connection error.""" + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_save_state"]() + + assert "Error" in result + + +class TestHaproxyRestoreState: + """Tests for haproxy_restore_state tool function.""" + + def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): + """Restore state successfully.""" + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(sample_servers_config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + f.write("api.example.com pool_2\n") + + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_restore_state"]() + + assert "restored" in result.lower() + assert "3 servers" in result + + def test_restore_state_no_servers(self, patch_config_paths): + """No servers to restore.""" + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_restore_state"]() + + assert "No servers to restore" in result + + +class TestRestoreServersFromConfigBatchFailure: + """Tests for restore_servers_from_config batch failure and fallback.""" + + def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths): + """Fall back to individual commands when batch fails.""" + # Create config with servers + config = { + "example.com": { + "1": {"ip": "10.0.0.1", "http_port": 80}, + "2": {"ip": "10.0.0.2", "http_port": 80}, + } + } + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + # Track call count to simulate batch failure then individual success + call_count = [0] + + class BatchFailMockSocket: + def __init__(self): + self.sent_commands = [] + self._response_buffer = b"" + self._closed = False + + def connect(self, address): + pass + + def settimeout(self, timeout): + pass + + def setblocking(self, blocking): + pass + + def sendall(self, data): + command = data.decode().strip() + self.sent_commands.append(command) + call_count[0] += 1 + # First batch call fails (contains multiple commands) + if call_count[0] == 1 and "\n" in data.decode(): + self._response_buffer = b"error: batch command failed" + else: + self._response_buffer = b"" + + def shutdown(self, how): + pass + + def recv(self, bufsize): + if self._response_buffer: + data = self._response_buffer[:bufsize] + self._response_buffer = self._response_buffer[bufsize:] + return data + return b"" + + def close(self): + self._closed = True + + def fileno(self): + return 999 + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + mock_sock = BatchFailMockSocket() + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import restore_servers_from_config + from haproxy_mcp.exceptions import HaproxyError + + # Mock batch to raise error + with patch("haproxy_mcp.tools.configuration.haproxy_cmd_batch") as mock_batch: + # First call (batch) fails, subsequent calls succeed + mock_batch.side_effect = [ + HaproxyError("Batch failed"), # Initial batch fails + None, # Individual server 1 succeeds + None, # Individual server 2 succeeds + ] + + result = restore_servers_from_config() + + # Should have restored servers via individual commands + assert result == 2 + + def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths): + """Skip servers with invalid slot number.""" + config = { + "example.com": { + "invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot + "1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot + } + } + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import restore_servers_from_config + + result = restore_servers_from_config() + + # Should only restore the valid server + assert result == 1 + + def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog): + """Skip servers with invalid port.""" + import logging + config = { + "example.com": { + "1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port + "2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port + } + } + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={"set server": ""}) + + with patch("socket.socket", return_value=mock_sock): + with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): + from haproxy_mcp.tools.configuration import restore_servers_from_config + + result = restore_servers_from_config() + + # Should only restore the valid server + assert result == 1 + + +class TestStartupRestoreFailures: + """Tests for startup_restore failure scenarios.""" + + def test_startup_restore_haproxy_timeout(self, mock_select): + """Skip restore if HAProxy doesn't become ready in time.""" + from haproxy_mcp.exceptions import HaproxyError + + # Mock haproxy_cmd to always fail + with patch("haproxy_mcp.tools.configuration.haproxy_cmd", side_effect=HaproxyError("Connection refused")): + with patch("haproxy_mcp.tools.configuration.STARTUP_RETRY_COUNT", 2): + with patch("time.sleep", return_value=None): + from haproxy_mcp.tools.configuration import startup_restore + + # Should not raise, just log warning + startup_restore() + + def test_startup_restore_server_restore_failure(self, mock_socket_class, mock_select, patch_config_paths, response_builder, caplog): + """Handle server restore failure during startup.""" + import logging + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + }) + + with patch("socket.socket", return_value=mock_sock): + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Disk error")): + with patch("haproxy_mcp.tools.certificates.restore_certificates", return_value=0): + with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): + from haproxy_mcp.tools.configuration import startup_restore + + startup_restore() + + # Should have logged the failure + assert any("Failed to restore servers" in record.message for record in caplog.records) + + def test_startup_restore_certificate_failure(self, mock_socket_class, mock_select, patch_config_paths, response_builder, caplog): + """Handle certificate restore failure during startup.""" + import logging + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + }) + + with patch("socket.socket", return_value=mock_sock): + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0): + with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")): + with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): + from haproxy_mcp.tools.configuration import startup_restore + + startup_restore() + + # Should have logged the failure + assert any("Failed to restore certificates" in record.message for record in caplog.records) + + +class TestHaproxyReloadFailures: + """Tests for haproxy_reload failure scenarios.""" + + def test_reload_haproxy_not_responding_after_reload(self, mock_subprocess, response_builder): + """Handle HAProxy not responding after reload.""" + from haproxy_mcp.exceptions import HaproxyError + + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + + # Mock haproxy_cmd to fail after reload + with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")): + with patch("haproxy_mcp.tools.configuration.haproxy_cmd", side_effect=HaproxyError("Not responding")): + with patch("haproxy_mcp.tools.configuration.STARTUP_RETRY_COUNT", 2): + with patch("time.sleep", return_value=None): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_reload"]() + + assert "not responding" in result.lower() + + def test_reload_server_restore_failure(self, mock_subprocess, mock_socket_class, mock_select, response_builder): + """Handle server restore failure after reload.""" + mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="") + + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + }) + + with patch("socket.socket", return_value=mock_sock): + with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")): + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")): + with patch("time.sleep", return_value=None): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_reload"]() + + assert "reloaded" in result.lower() + assert "failed" in result.lower() + + +class TestHaproxySaveStateFailures: + """Tests for haproxy_save_state failure scenarios.""" + + def test_save_state_io_error(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Handle IO error when saving state.""" + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + ]), + }) + + with patch("haproxy_mcp.tools.configuration.STATE_FILE", patch_config_paths["state_file"]): + with patch("socket.socket", return_value=mock_sock): + with patch("haproxy_mcp.tools.configuration.atomic_write_file", side_effect=IOError("Disk full")): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_save_state"]() + + assert "Error" in result + + +class TestHaproxyRestoreStateFailures: + """Tests for haproxy_restore_state failure scenarios.""" + + def test_restore_state_haproxy_error(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): + """Handle HAProxy error when restoring state.""" + from haproxy_mcp.exceptions import HaproxyError + + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(sample_servers_config, f) + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + f.write("api.example.com pool_2\n") + + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_restore_state"]() + + assert "Error" in result + + def test_restore_state_os_error(self, patch_config_paths): + """Handle OS error when restoring state.""" + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("File not found")): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_restore_state"]() + + assert "Error" in result + + def test_restore_state_value_error(self, patch_config_paths): + """Handle ValueError when restoring state.""" + with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=ValueError("Invalid config")): + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_restore_state"]() + + assert "Error" in result + + +class TestHaproxyCheckConfigOSError: + """Tests for haproxy_check_config OS error handling.""" + + def test_check_config_os_error(self, mock_subprocess): + """Handle OS error during config check.""" + mock_subprocess.side_effect = OSError("Permission denied") + + from haproxy_mcp.tools.configuration import register_config_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_config_tools(mcp) + + result = registered_tools["haproxy_check_config"]() + + assert "Error" in result + assert "OS error" in result diff --git a/tests/unit/tools/test_domains.py b/tests/unit/tools/test_domains.py new file mode 100644 index 0000000..6c03f07 --- /dev/null +++ b/tests/unit/tools/test_domains.py @@ -0,0 +1,476 @@ +"""Unit tests for domain management tools.""" + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from haproxy_mcp.exceptions import HaproxyError + + +class TestHaproxyListDomains: + """Tests for haproxy_list_domains tool function.""" + + def test_list_empty_domains(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """List domains when none configured.""" + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([]), + }) + + with patch("socket.socket", return_value=mock_sock): + # Import here to get patched config + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_list_domains"](include_wildcards=False) + + assert result == "No domains configured" + + def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """List domains with configured servers.""" + # Write map file + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_list_domains"](include_wildcards=False) + + assert "example.com" in result + assert "pool_1" in result + assert "10.0.0.1" in result + + def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """List domains excluding wildcards by default.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + with open(patch_config_paths["wildcards_file"], "w") as f: + f.write(".example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_list_domains"](include_wildcards=False) + + assert "example.com" in result + assert ".example.com" not in result + + def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """List domains including wildcards when requested.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + with open(patch_config_paths["wildcards_file"], "w") as f: + f.write(".example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_list_domains"](include_wildcards=True) + + assert "example.com" in result + assert ".example.com" in result + + +class TestHaproxyAddDomain: + """Tests for haproxy_add_domain tool function.""" + + def test_add_domain_invalid_format(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain="-invalid.com", + ip="", + http_port=80 + ) + + assert "Error" in result + assert "Invalid domain" in result + + def test_add_domain_invalid_ip(self, patch_config_paths): + """Reject invalid IP address.""" + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain="example.com", + ip="not-an-ip", + http_port=80 + ) + + assert "Error" in result + assert "Invalid IP" in result + + def test_add_domain_invalid_port(self, patch_config_paths): + """Reject invalid port.""" + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain="example.com", + ip="10.0.0.1", + http_port=70000 + ) + + assert "Error" in result + assert "Port" in result + + def test_add_domain_starts_with_dot(self, patch_config_paths): + """Reject domain starting with dot (wildcard).""" + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain=".example.com", + ip="", + http_port=80 + ) + + assert "Error" in result + assert "cannot start with '.'" in result + + def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Reject adding domain that already exists.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain="example.com", + ip="", + http_port=80 + ) + + assert "Error" in result + assert "already exists" in result + + def test_add_domain_success_without_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess): + """Successfully add domain without IP.""" + mock_sock = mock_socket_class(responses={ + "add map": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain="newdomain.com", + ip="", + http_port=80 + ) + + assert "newdomain.com" in result + assert "pool_1" in result + assert "no servers configured" in result + + def test_add_domain_success_with_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess): + """Successfully add domain with IP.""" + mock_sock = mock_socket_class(responses={ + "add map": "", + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_add_domain"]( + domain="newdomain.com", + ip="10.0.0.1", + http_port=8080 + ) + + assert "newdomain.com" in result + assert "pool_1" in result + assert "10.0.0.1:8080" in result + + +class TestHaproxyRemoveDomain: + """Tests for haproxy_remove_domain tool function.""" + + def test_remove_domain_invalid_format(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_remove_domain"](domain="-invalid.com") + + assert "Error" in result + assert "Invalid domain" in result + + def test_remove_domain_not_found(self, patch_config_paths): + """Reject removing domain that doesn't exist.""" + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_remove_domain"](domain="nonexistent.com") + + assert "Error" in result + assert "not found" in result + + def test_remove_legacy_domain_rejected(self, patch_config_paths): + """Reject removing legacy (non-pool) domain.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com legacy_backend\n") + + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_remove_domain"](domain="example.com") + + assert "Error" in result + assert "legacy" in result.lower() + + def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Successfully remove domain.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + with open(patch_config_paths["wildcards_file"], "w") as f: + f.write(".example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "del map": "", + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.domains import register_domain_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_domain_tools(mcp) + + result = registered_tools["haproxy_remove_domain"](domain="example.com") + + assert "example.com" in result + assert "removed" in result.lower() + + +class TestCheckCertificateCoverage: + """Tests for check_certificate_coverage function.""" + + def test_no_cert_directory(self, tmp_path): + """No certificate coverage when directory doesn't exist.""" + from haproxy_mcp.tools.domains import check_certificate_coverage + + with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(tmp_path / "nonexistent")): + covered, info = check_certificate_coverage("example.com") + + assert covered is False + assert "not found" in info.lower() + + def test_exact_cert_match(self, tmp_path): + """Exact certificate match.""" + from haproxy_mcp.tools.domains import check_certificate_coverage + + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + (certs_dir / "example.com.pem").write_text("cert content") + + with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)): + covered, info = check_certificate_coverage("example.com") + + assert covered is True + assert info == "example.com" + + def test_wildcard_cert_coverage(self, tmp_path, mock_subprocess): + """Wildcard certificate covers subdomain.""" + from haproxy_mcp.tools.domains import check_certificate_coverage + + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + (certs_dir / "example.com.pem").write_text("cert content") + + # Mock openssl output showing wildcard SAN + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="X509v3 Subject Alternative Name:\n DNS:example.com, DNS:*.example.com" + ) + + with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)): + covered, info = check_certificate_coverage("api.example.com") + + assert covered is True + assert "wildcard" in info + + def test_no_matching_cert(self, tmp_path): + """No matching certificate.""" + from haproxy_mcp.tools.domains import check_certificate_coverage + + certs_dir = tmp_path / "certs" + certs_dir.mkdir() + + with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)): + covered, info = check_certificate_coverage("example.com") + + assert covered is False + assert "No matching" in info diff --git a/tests/unit/tools/test_health.py b/tests/unit/tools/test_health.py new file mode 100644 index 0000000..28ca715 --- /dev/null +++ b/tests/unit/tools/test_health.py @@ -0,0 +1,433 @@ +"""Unit tests for health check tools.""" + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from haproxy_mcp.exceptions import HaproxyError + + +class TestHaproxyHealth: + """Tests for haproxy_health tool function.""" + + def test_health_all_ok(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess): + """Health check returns healthy when all components are OK.""" + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(version="3.3.2", uptime=3600), + }) + + mock_subprocess.return_value = MagicMock( + returncode=0, + stdout="running" + ) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_health"]() + result = json.loads(result_str) + + assert result["status"] == "healthy" + assert result["components"]["mcp"]["status"] == "ok" + assert result["components"]["haproxy"]["status"] == "ok" + assert result["components"]["haproxy"]["version"] == "3.3.2" + + def test_health_haproxy_error(self, mock_socket_class, mock_select, patch_config_paths, mock_subprocess): + """Health check returns degraded when HAProxy is unreachable.""" + + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_health"]() + result = json.loads(result_str) + + assert result["status"] == "degraded" + assert result["components"]["haproxy"]["status"] == "error" + + def test_health_missing_config_files(self, mock_socket_class, mock_select, tmp_path, response_builder, mock_subprocess): + """Health check returns degraded when config files are missing.""" + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + }) + + mock_subprocess.return_value = MagicMock(returncode=0, stdout="running") + + # Use paths that don't exist + with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")): + with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")): + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_health"]() + result = json.loads(result_str) + + assert result["status"] == "degraded" + assert result["components"]["config_files"]["status"] == "warning" + + def test_health_container_not_running(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess): + """Health check returns unhealthy when container is not running.""" + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(), + }) + + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="No such container" + ) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_health"]() + result = json.loads(result_str) + + assert result["status"] == "unhealthy" + assert result["components"]["container"]["status"] == "error" + + +class TestHaproxyDomainHealth: + """Tests for haproxy_domain_health tool function.""" + + def test_domain_health_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_domain_health"](domain="-invalid") + result = json.loads(result_str) + + assert "error" in result + assert "Invalid domain" in result["error"] + + def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Domain health returns healthy when all servers are UP.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + {"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "10.0.0.2", "srv_port": 80}, + ]), + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "check_status": "L4OK"}, + {"pxname": "pool_1", "svname": "pool_1_2", "status": "UP", "check_status": "L4OK"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_domain_health"](domain="example.com") + result = json.loads(result_str) + + assert result["status"] == "healthy" + assert result["healthy_count"] == 2 + assert result["total_count"] == 2 + + def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Domain health returns degraded when some servers are DOWN.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + {"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "10.0.0.2", "srv_port": 80}, + ]), + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"}, + {"pxname": "pool_1", "svname": "pool_1_2", "status": "DOWN"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_domain_health"](domain="example.com") + result = json.loads(result_str) + + assert result["status"] == "degraded" + assert result["healthy_count"] == 1 + assert result["total_count"] == 2 + + def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Domain health returns down when all servers are DOWN.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + ]), + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "DOWN"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_domain_health"](domain="example.com") + result = json.loads(result_str) + + assert result["status"] == "down" + assert result["healthy_count"] == 0 + assert result["total_count"] == 1 + + def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Domain health returns no_servers when no servers configured.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "0.0.0.0", "srv_port": 0}, + ]), + "show stat": response_builder.stat_csv([]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result_str = registered_tools["haproxy_domain_health"](domain="example.com") + result = json.loads(result_str) + + assert result["status"] == "no_servers" + assert result["total_count"] == 0 + + +class TestHaproxyGetServerHealth: + """Tests for haproxy_get_server_health tool function.""" + + def test_get_server_health_invalid_backend(self, patch_config_paths): + """Reject invalid backend name.""" + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result = registered_tools["haproxy_get_server_health"](backend="invalid@name") + + assert "Error" in result + assert "Invalid backend" in result + + def test_get_server_health_all_backends(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Get health for all backends.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "weight": 1, "check_status": "L4OK"}, + {"pxname": "pool_2", "svname": "pool_2_1", "status": "DOWN", "weight": 1, "check_status": "L4TOUT"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result = registered_tools["haproxy_get_server_health"](backend="") + + assert "pool_1" in result + assert "pool_2" in result + assert "UP" in result + assert "DOWN" in result + + def test_get_server_health_filter_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Get health for specific backend.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"}, + {"pxname": "pool_1", "svname": "pool_1_2", "status": "UP"}, + {"pxname": "pool_2", "svname": "pool_2_1", "status": "DOWN"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result = registered_tools["haproxy_get_server_health"](backend="pool_1") + + assert "pool_1" in result + assert "pool_2" not in result + + def test_get_server_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """No servers returns appropriate message.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "FRONTEND", "status": "OPEN"}, + {"pxname": "pool_1", "svname": "BACKEND", "status": "UP"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result = registered_tools["haproxy_get_server_health"](backend="") + + assert "No servers found" in result + + def test_get_server_health_haproxy_error(self, mock_socket_class, mock_select, patch_config_paths): + """HAProxy error returns error message.""" + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.health import register_health_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_health_tools(mcp) + + result = registered_tools["haproxy_get_server_health"](backend="") + + assert "Error" in result diff --git a/tests/unit/tools/test_monitoring.py b/tests/unit/tools/test_monitoring.py new file mode 100644 index 0000000..d148835 --- /dev/null +++ b/tests/unit/tools/test_monitoring.py @@ -0,0 +1,325 @@ +"""Unit tests for monitoring tools.""" + +from unittest.mock import patch, MagicMock + +import pytest + + +class TestHaproxyStats: + """Tests for haproxy_stats tool function.""" + + def test_stats_success(self, mock_socket_class, mock_select, response_builder): + """Get HAProxy stats successfully.""" + mock_sock = mock_socket_class(responses={ + "show info": response_builder.info(version="3.3.2", uptime=3600), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_stats"]() + + assert "Version" in result + assert "3.3.2" in result + + def test_stats_haproxy_error(self, mock_select): + """Handle HAProxy connection error.""" + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_stats"]() + + assert "Error" in result + + +class TestHaproxyBackends: + """Tests for haproxy_backends tool function.""" + + def test_backends_success(self, mock_socket_class, mock_select): + """List backends successfully.""" + mock_sock = mock_socket_class(responses={ + "show backend": "pool_1\npool_2\npool_3", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_backends"]() + + assert "Backends" in result + assert "pool_1" in result + assert "pool_2" in result + assert "pool_3" in result + + def test_backends_haproxy_error(self, mock_select): + """Handle HAProxy connection error.""" + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_backends"]() + + assert "Error" in result + + +class TestHaproxyListFrontends: + """Tests for haproxy_list_frontends tool function.""" + + def test_list_frontends_success(self, mock_socket_class, mock_select, response_builder): + """List frontends successfully.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "http_front", "svname": "FRONTEND", "status": "OPEN", "scur": 10}, + {"pxname": "https_front", "svname": "FRONTEND", "status": "OPEN", "scur": 50}, + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "scur": 5}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_list_frontends"]() + + assert "Frontends" in result + assert "http_front" in result + assert "https_front" in result + # pool_1 is not a FRONTEND + assert "pool_1_1" not in result + + def test_list_frontends_no_frontends(self, mock_socket_class, mock_select, response_builder): + """No frontends found.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_list_frontends"]() + + assert "No frontends found" in result + + def test_list_frontends_haproxy_error(self, mock_select): + """Handle HAProxy connection error.""" + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_list_frontends"]() + + assert "Error" in result + + +class TestHaproxyGetConnections: + """Tests for haproxy_get_connections tool function.""" + + def test_get_connections_all_backends(self, mock_socket_class, mock_select, response_builder): + """Get connections for all backends.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "FRONTEND", "status": "OPEN", "scur": 10, "smax": 100}, + {"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "scur": 5}, + {"pxname": "pool_1", "svname": "BACKEND", "status": "UP", "scur": 10, "smax": 100}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_get_connections"](backend="") + + assert "pool_1" in result + assert "FRONTEND" in result or "BACKEND" in result + assert "connections" in result + + def test_get_connections_filter_backend(self, mock_socket_class, mock_select, response_builder): + """Filter connections by backend.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([ + {"pxname": "pool_1", "svname": "FRONTEND", "status": "OPEN", "scur": 10, "smax": 100}, + {"pxname": "pool_2", "svname": "FRONTEND", "status": "OPEN", "scur": 20, "smax": 200}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_get_connections"](backend="pool_1") + + assert "pool_1" in result + assert "pool_2" not in result + + def test_get_connections_invalid_backend(self): + """Reject invalid backend name.""" + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_get_connections"](backend="invalid@name") + + assert "Error" in result + assert "Invalid backend" in result + + def test_get_connections_no_data(self, mock_socket_class, mock_select, response_builder): + """No connection data found.""" + mock_sock = mock_socket_class(responses={ + "show stat": response_builder.stat_csv([]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_get_connections"](backend="") + + assert "No connection data" in result + + def test_get_connections_haproxy_error(self, mock_select): + """Handle HAProxy connection error.""" + def raise_error(*args, **kwargs): + raise ConnectionRefusedError() + + with patch("socket.socket", side_effect=raise_error): + from haproxy_mcp.tools.monitoring import register_monitoring_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_monitoring_tools(mcp) + + result = registered_tools["haproxy_get_connections"](backend="") + + assert "Error" in result diff --git a/tests/unit/tools/test_servers.py b/tests/unit/tools/test_servers.py new file mode 100644 index 0000000..f370122 --- /dev/null +++ b/tests/unit/tools/test_servers.py @@ -0,0 +1,1350 @@ +"""Unit tests for server management tools.""" + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from haproxy_mcp.exceptions import HaproxyError + + +class TestHaproxyListServers: + """Tests for haproxy_list_servers tool function.""" + + def test_list_servers_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_list_servers"](domain="-invalid") + + assert "Error" in result + assert "Invalid domain" in result + + def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """List servers for domain with no servers.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "0.0.0.0", "srv_port": 0}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_list_servers"](domain="example.com") + + assert "pool_1" in result + assert "disabled" in result + + def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """List servers with active servers.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + {"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "10.0.0.2", "srv_port": 80}, + ]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_list_servers"](domain="example.com") + + assert "10.0.0.1" in result + assert "10.0.0.2" in result + assert "active" in result + + +class TestHaproxyAddServer: + """Tests for haproxy_add_server tool function.""" + + def test_add_server_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="-invalid", + slot=1, + ip="10.0.0.1", + http_port=80 + ) + + assert "Error" in result + assert "Invalid domain" in result + + def test_add_server_empty_ip(self, patch_config_paths): + """Reject empty IP address.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=1, + ip="", + http_port=80 + ) + + assert "Error" in result + assert "IP address is required" in result + + def test_add_server_invalid_ip(self, patch_config_paths): + """Reject invalid IP address.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=1, + ip="not-an-ip", + http_port=80 + ) + + assert "Error" in result + assert "Invalid IP" in result + + def test_add_server_invalid_port(self, patch_config_paths): + """Reject invalid port.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=1, + ip="10.0.0.1", + http_port=70000 + ) + + assert "Error" in result + assert "Port" in result + + def test_add_server_invalid_slot(self, patch_config_paths): + """Reject invalid slot number.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=99, # > MAX_SLOTS + ip="10.0.0.1", + http_port=80 + ) + + assert "Error" in result + assert "Slot" in result + + def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Successfully add server.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=1, + ip="10.0.0.1", + http_port=8080 + ) + + assert "example.com" in result + assert "slot 1" in result + assert "10.0.0.1:8080" in result + + def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Auto-select slot when slot=0.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + {"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "0.0.0.0", "srv_port": 0}, + ]), + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=0, # Auto-select + ip="10.0.0.2", + http_port=80 + ) + + assert "slot 2" in result # First available slot + + +class TestHaproxyAddServers: + """Tests for haproxy_add_servers tool function.""" + + def test_add_servers_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="-invalid", + servers='[{"slot": 1, "ip": "10.0.0.1"}]' + ) + + assert "Error" in result + assert "Invalid domain" in result + + def test_add_servers_invalid_json(self, patch_config_paths): + """Reject invalid JSON.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='not valid json' + ) + + assert "Error" in result + assert "Invalid JSON" in result + + def test_add_servers_not_array(self, patch_config_paths): + """Reject non-array JSON.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='{"slot": 1, "ip": "10.0.0.1"}' + ) + + assert "Error" in result + assert "must be a JSON array" in result + + def test_add_servers_empty_array(self, patch_config_paths): + """Reject empty array.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[]' + ) + + assert "Error" in result + assert "empty" in result + + def test_add_servers_duplicate_slots(self, patch_config_paths): + """Reject duplicate slot numbers.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[{"slot": 1, "ip": "10.0.0.1"}, {"slot": 1, "ip": "10.0.0.2"}]' + ) + + assert "Error" in result + assert "Duplicate" in result + + def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Successfully add multiple servers.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[{"slot": 1, "ip": "10.0.0.1"}, {"slot": 2, "ip": "10.0.0.2"}]' + ) + + assert "Added 2 servers" in result + assert "slot 1" in result + assert "slot 2" in result + + +class TestHaproxyRemoveServer: + """Tests for haproxy_remove_server tool function.""" + + def test_remove_server_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_remove_server"]( + domain="-invalid", + slot=1 + ) + + assert "Error" in result + assert "Invalid domain" in result + + def test_remove_server_invalid_slot(self, patch_config_paths): + """Reject invalid slot number.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_remove_server"]( + domain="example.com", + slot=99 + ) + + assert "Error" in result + assert "Slot" in result + + def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Successfully remove server.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_remove_server"]( + domain="example.com", + slot=1 + ) + + assert "Removed" in result + assert "slot 1" in result + + +class TestHaproxySetServerState: + """Tests for haproxy_set_server_state tool function.""" + + def test_set_state_invalid_backend(self, patch_config_paths): + """Reject invalid backend name.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_state"]( + backend="invalid@backend", + server="pool_1_1", + state="ready" + ) + + assert "Error" in result + assert "Invalid backend" in result + + def test_set_state_invalid_state(self, patch_config_paths): + """Reject invalid state.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_state"]( + backend="pool_1", + server="pool_1_1", + state="invalid" + ) + + assert "Error" in result + assert "state must be" in result + + def test_set_state_success(self, mock_socket_class, mock_select, patch_config_paths): + """Successfully set server state.""" + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_state"]( + backend="pool_1", + server="pool_1_1", + state="maint" + ) + + assert "pool_1/pool_1_1" in result + assert "maint" in result + + +class TestHaproxySetServerWeight: + """Tests for haproxy_set_server_weight tool function.""" + + def test_set_weight_invalid_weight(self, patch_config_paths): + """Reject invalid weight.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_weight"]( + backend="pool_1", + server="pool_1_1", + weight=300 # > 256 + ) + + assert "Error" in result + assert "weight" in result + + def test_set_weight_success(self, mock_socket_class, mock_select, patch_config_paths): + """Successfully set server weight.""" + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_weight"]( + backend="pool_1", + server="pool_1_1", + weight=2 + ) + + assert "pool_1/pool_1_1" in result + assert "2" in result + + +class TestConfigureServerSlot: + """Tests for configure_server_slot helper function.""" + + def test_configure_slot(self, mock_socket_class, mock_select): + """Configure server slot sends correct commands.""" + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import configure_server_slot + + result = configure_server_slot("pool_1", "pool_1", 1, "10.0.0.1", 8080) + + assert result == "pool_1_1" + # Verify commands were sent + assert len(mock_sock.sent_commands) == 2 + assert "addr 10.0.0.1 port 8080" in mock_sock.sent_commands[0] + assert "state ready" in mock_sock.sent_commands[1] + + +class TestHaproxyAddServersRollback: + """Tests for haproxy_add_servers rollback functionality.""" + + def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths): + """Rollback only failed slots when HAProxy error occurs.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + # Mock configure_server_slot to fail on second slot + call_count = [0] + + def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): + call_count[0] += 1 + if slot == 2: + raise HaproxyError("HAProxy command failed: server not found") + return f"{server_prefix}_{slot}" + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + with patch( + "haproxy_mcp.tools.servers.configure_server_slot", + side_effect=mock_configure_server_slot + ): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[{"slot": 1, "ip": "10.0.0.1"}, {"slot": 2, "ip": "10.0.0.2"}]' + ) + + # First server should be added, second should fail + assert "Added 1 server" in result + assert "Failed to add 1 server" in result + assert "slot 1" in result # Successfully added + assert "slot 2" in result # Failed + + # Verify servers.json only has successfully added server + with open(patch_config_paths["servers_file"], "r") as f: + config = json.load(f) + assert "example.com" in config + assert "1" in config["example.com"] # Successfully added stays + assert "2" not in config["example.com"] # Failed one was rolled back + + def test_add_servers_unexpected_error_rollback_only_successful( + self, mock_socket_class, mock_select, patch_config_paths + ): + """Rollback only successfully added servers on unexpected error.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + # Track which servers were configured + configured_slots = [] + + # Mock socket that succeeds first then throws unexpected error + original_configure = None + + def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): + if slot == 2: + # Simulate unexpected error (not HaproxyError) + raise RuntimeError("Unexpected system error") + configured_slots.append(slot) + return f"{server_prefix}_{slot}" + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + with patch( + "haproxy_mcp.tools.servers.configure_server_slot", + side_effect=mock_configure_server_slot + ): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[{"slot": 1, "ip": "10.0.0.1"}, {"slot": 2, "ip": "10.0.0.2"}, {"slot": 3, "ip": "10.0.0.3"}]' + ) + + # Should return error + assert "Error" in result + assert "Unexpected system error" in result + + # Verify servers.json is empty (all rolled back) + with open(patch_config_paths["servers_file"], "r") as f: + config = json.load(f) + assert config == {} or "example.com" not in config or config.get("example.com") == {} + + def test_add_servers_rollback_failure_logged( + self, mock_socket_class, mock_select, patch_config_paths, caplog + ): + """Log rollback failures during error recovery.""" + import logging + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): + if slot == 2: + raise RuntimeError("Unexpected error") + return f"{server_prefix}_{slot}" + + def mock_remove_server_from_config(domain, slot): + raise IOError("Disk full") + + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + with patch( + "haproxy_mcp.tools.servers.configure_server_slot", + side_effect=mock_configure_server_slot + ): + with patch( + "haproxy_mcp.tools.servers.remove_server_from_config", + side_effect=mock_remove_server_from_config + ): + with caplog.at_level(logging.ERROR, logger="haproxy_mcp"): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[{"slot": 1, "ip": "10.0.0.1"}, {"slot": 2, "ip": "10.0.0.2"}]' + ) + + # Should return error from the original failure + assert "Error" in result + assert "Unexpected error" in result + + # Should have logged rollback failures + assert any("Failed to rollback" in record.message for record in caplog.records) + + +class TestHaproxyAddServerAutoSlot: + """Additional tests for auto-slot selection in haproxy_add_server.""" + + def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Auto-select slot fails when all slots are in use.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + # Build response with all 10 slots used + servers = [] + for i in range(1, 11): # MAX_SLOTS = 10 + servers.append({ + "be_name": "pool_1", + "srv_name": f"pool_1_{i}", + "srv_addr": f"10.0.0.{i}", + "srv_port": 80 + }) + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state(servers), + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=0, # Auto-select + ip="10.0.0.100", + http_port=80 + ) + + assert "Error" in result + assert "No available slots" in result + + def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Negative slot number triggers auto-selection.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "0.0.0.0", "srv_port": 0}, + ]), + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_add_server"]( + domain="example.com", + slot=-1, # Negative triggers auto-select + ip="10.0.0.1", + http_port=80 + ) + + assert "slot 1" in result # First available slot + + +class TestHaproxyAddServersPartialFailure: + """Tests for partial failure scenarios in haproxy_add_servers.""" + + def test_add_servers_validation_error_per_server(self, patch_config_paths): + """Handle validation errors for individual servers.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + # Mix of valid and invalid servers + result = registered_tools["haproxy_add_servers"]( + domain="example.com", + servers='[{"slot": 1, "ip": "10.0.0.1"}, {"slot": "invalid"}, {"slot": 2}]' + ) + + assert "Validation errors" in result + assert "Server 2" in result # Invalid slot type + assert "Server 3" in result # Missing IP + + +class TestHaproxyWaitDrain: + """Tests for haproxy_wait_drain tool function.""" + + def test_wait_drain_success(self, patch_config_paths): + """Successfully wait for connections to drain.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + # Mock haproxy_cmd to return 0 connections + with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd: + mock_cmd.return_value = "# pxname,svname,qcur,qmax,scur,smax\npool_1,pool_1_1,0,0,0,0\n" + # Provide enough time values: start_time, loop check, elapsed calculation + with patch("time.time", side_effect=[0, 0.1, 0.2]): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_wait_drain"]( + domain="example.com", + timeout=30 + ) + + assert "drained" in result.lower() + + def test_wait_drain_timeout(self, patch_config_paths): + """Timeout when connections don't drain.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing + time_iter = iter(time_values) + + # Mock haproxy_cmd to return active connections + with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd: + mock_cmd.return_value = "# pxname,svname,qcur,qmax,scur,smax\npool_1,pool_1_1,0,0,5,10\n" + with patch("time.time", side_effect=lambda: next(time_iter)): + with patch("time.sleep", return_value=None): # Don't actually sleep + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_wait_drain"]( + domain="example.com", + timeout=2 # Short timeout + ) + + assert "Timeout" in result + + def test_wait_drain_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_wait_drain"]( + domain="-invalid", + timeout=30 + ) + + assert "Error" in result + assert "Invalid domain" in result + + def test_wait_drain_invalid_timeout(self, patch_config_paths): + """Reject invalid timeout values.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + # Test timeout > 300 + result = registered_tools["haproxy_wait_drain"]( + domain="example.com", + timeout=500 + ) + assert "Error" in result + assert "Timeout must be" in result + + # Test timeout < 1 + result = registered_tools["haproxy_wait_drain"]( + domain="example.com", + timeout=0 + ) + assert "Error" in result + + def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths): + """Error when domain not found in map.""" + # Empty map file - domain not configured + with open(patch_config_paths["map_file"], "w") as f: + f.write("") + + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_wait_drain"]( + domain="unknown.com", + timeout=30 + ) + + assert "Error" in result + + +class TestHaproxySetServerWeightBoundary: + """Tests for haproxy_set_server_weight boundary values.""" + + def test_set_weight_zero(self, mock_socket_class, mock_select, patch_config_paths): + """Set server weight to 0 (disabled).""" + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_weight"]( + backend="pool_1", + server="pool_1_1", + weight=0 + ) + + assert "pool_1/pool_1_1" in result + assert "0" in result + + def test_set_weight_max(self, mock_socket_class, mock_select, patch_config_paths): + """Set server weight to maximum 256.""" + mock_sock = mock_socket_class(responses={ + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_weight"]( + backend="pool_1", + server="pool_1_1", + weight=256 + ) + + assert "pool_1/pool_1_1" in result + assert "256" in result + + def test_set_weight_negative(self, patch_config_paths): + """Reject negative weight.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_server_weight"]( + backend="pool_1", + server="pool_1_1", + weight=-1 + ) + + assert "Error" in result + assert "weight" in result + + +class TestHaproxySetDomainState: + """Tests for haproxy_set_domain_state tool function.""" + + def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Set all servers of a domain to a state.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, + {"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "10.0.0.2", "srv_port": 80}, + ]), + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_domain_state"]( + domain="example.com", + state="maint" + ) + + assert "Set 2 servers" in result + assert "maint" in result + + def test_set_domain_state_invalid_domain(self, patch_config_paths): + """Reject invalid domain format.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_domain_state"]( + domain="-invalid", + state="ready" + ) + + assert "Error" in result + assert "Invalid domain" in result + + def test_set_domain_state_invalid_state(self, patch_config_paths): + """Reject invalid state value.""" + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_domain_state"]( + domain="example.com", + state="invalid" + ) + + assert "Error" in result + assert "must be" in result + + def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """No active servers found for domain.""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + # All servers have 0.0.0.0 address (not configured) + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([ + {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "0.0.0.0", "srv_port": 0}, + ]), + "set server": "", + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_domain_state"]( + domain="example.com", + state="ready" + ) + + assert "No active servers found" in result + + def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder): + """Handle domain not found in map - shows no active servers.""" + # Empty map file + with open(patch_config_paths["map_file"], "w") as f: + f.write("") + + # Mock should show no servers for unknown domain's backend + mock_sock = mock_socket_class(responses={ + "show servers state": response_builder.servers_state([]), + }) + + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.servers import register_server_tools + mcp = MagicMock() + registered_tools = {} + + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator + + mcp.tool = capture_tool + register_server_tools(mcp) + + result = registered_tools["haproxy_set_domain_state"]( + domain="unknown.com", + state="ready" + ) + + # When domain is not in map, get_backend_and_prefix raises ValueError + # which is caught and returns Error + assert "Error" in result or "No active servers" in result