diff --git a/.gitignore b/.gitignore index 94927c8..79fba10 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,9 @@ run/ data/ *.state *.lock +*.db +*.db-wal +*.db-shm # Python __pycache__/ diff --git a/haproxy_mcp/__main__.py b/haproxy_mcp/__main__.py index 90b2a56..ae5bf0d 100644 --- a/haproxy_mcp/__main__.py +++ b/haproxy_mcp/__main__.py @@ -1,8 +1,10 @@ """Entry point for running haproxy_mcp as a module.""" +from .db import init_db from .server import mcp from .tools.configuration import startup_restore if __name__ == "__main__": + init_db() startup_restore() mcp.run(transport="streamable-http") diff --git a/haproxy_mcp/config.py b/haproxy_mcp/config.py index db199e3..afa4aac 100644 --- a/haproxy_mcp/config.py +++ b/haproxy_mcp/config.py @@ -31,6 +31,7 @@ WILDCARDS_MAP_FILE: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE", "/opt/haproxy/ WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map") SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json") CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json") +DB_FILE: str = os.getenv("HAPROXY_DB_FILE", "/opt/haproxy/conf/haproxy_mcp.db") # Certificate paths CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs") diff --git a/haproxy_mcp/db.py b/haproxy_mcp/db.py new file mode 100644 index 0000000..ec7caa5 --- /dev/null +++ b/haproxy_mcp/db.py @@ -0,0 +1,577 @@ +"""SQLite database operations for HAProxy MCP Server. + +Single source of truth for all persistent configuration data. +Replaces JSON files (servers.json, certificates.json) and map files +(domains.map, wildcards.map) with a single SQLite database. + +Map files are generated from the database via sync_map_files() for +HAProxy to consume directly. +""" + +import json +import os +import sqlite3 +import threading +from typing import Any, Optional + +from .config import ( + DB_FILE, + MAP_FILE, + WILDCARDS_MAP_FILE, + SERVERS_FILE, + CERTS_FILE, + POOL_COUNT, + REMOTE_MODE, + logger, +) + +SCHEMA_VERSION = 1 + +_local = threading.local() + + +def get_connection() -> sqlite3.Connection: + """Get a thread-local SQLite connection with WAL mode. + + Returns: + sqlite3.Connection configured for concurrent access. + """ + conn = getattr(_local, "connection", None) + if conn is not None: + return conn + + conn = sqlite3.connect(DB_FILE, timeout=10) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=5000") + conn.execute("PRAGMA foreign_keys=ON") + conn.row_factory = sqlite3.Row + _local.connection = conn + return conn + + +def close_connection() -> None: + """Close the thread-local SQLite connection.""" + conn = getattr(_local, "connection", None) + if conn is not None: + conn.close() + _local.connection = None + + +def init_db() -> None: + """Initialize database schema and run migration if needed. + + Creates tables if they don't exist, then checks for existing + JSON/map files to migrate data from. + """ + conn = get_connection() + cur = conn.cursor() + + cur.executescript(""" + CREATE TABLE IF NOT EXISTS domains ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT NOT NULL UNIQUE, + backend TEXT NOT NULL, + is_wildcard INTEGER NOT NULL DEFAULT 0, + shares_with TEXT DEFAULT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) + ); + + CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend); + CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with); + + CREATE TABLE IF NOT EXISTS servers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT NOT NULL, + slot INTEGER NOT NULL, + ip TEXT NOT NULL, + http_port INTEGER NOT NULL DEFAULT 80, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + UNIQUE(domain, slot) + ); + + CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain); + + CREATE TABLE IF NOT EXISTS certificates ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) + ); + + CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER NOT NULL, + applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) + ); + """) + conn.commit() + + # Check schema version + cur.execute("SELECT MAX(version) FROM schema_version") + row = cur.fetchone() + current_version = row[0] if row[0] is not None else 0 + + if current_version < SCHEMA_VERSION: + # First-time setup: try migrating from JSON files + if current_version == 0: + migrate_from_json() + cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,)) + conn.commit() + + logger.info("Database initialized (schema v%d)", SCHEMA_VERSION) + + +def migrate_from_json() -> None: + """Migrate data from JSON/map files to SQLite. + + Reads domains.map, wildcards.map, servers.json, and certificates.json, + imports their data into the database within a single transaction, + then renames the original files to .bak. + + This function is idempotent (uses INSERT OR IGNORE). + """ + conn = get_connection() + + # Collect data from existing files + map_entries: list[tuple[str, str]] = [] + servers_config: dict[str, Any] = {} + cert_domains: list[str] = [] + + # Read map files + for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]: + try: + content = _read_file_for_migration(map_path) + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) >= 2: + map_entries.append((parts[0], parts[1])) + except FileNotFoundError: + logger.debug("Map file not found for migration: %s", map_path) + + # Read servers.json + try: + content = _read_file_for_migration(SERVERS_FILE) + servers_config = json.loads(content) + except FileNotFoundError: + logger.debug("Servers config not found for migration: %s", SERVERS_FILE) + except json.JSONDecodeError as e: + logger.warning("Corrupt servers config during migration: %s", e) + + # Read certificates.json + try: + content = _read_file_for_migration(CERTS_FILE) + data = json.loads(content) + cert_domains = data.get("domains", []) + except FileNotFoundError: + logger.debug("Certs config not found for migration: %s", CERTS_FILE) + except json.JSONDecodeError as e: + logger.warning("Corrupt certs config during migration: %s", e) + + if not map_entries and not servers_config and not cert_domains: + logger.info("No existing data to migrate") + return + + # Import into database in a single transaction + try: + with conn: + # Import domains from map files + for domain, backend in map_entries: + is_wildcard = 1 if domain.startswith(".") else 0 + conn.execute( + "INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)", + (domain, backend, is_wildcard), + ) + + # Import servers and shares_with from servers.json + for domain, slots in servers_config.items(): + shares_with = slots.get("_shares") + if shares_with: + # Update domain's shares_with field + conn.execute( + "UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0", + (shares_with, domain), + ) + continue + + for slot_str, server_info in slots.items(): + if slot_str.startswith("_"): + continue + try: + slot = int(slot_str) + ip = server_info.get("ip", "") + http_port = int(server_info.get("http_port", 80)) + if ip: + conn.execute( + "INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)", + (domain, slot, ip, http_port), + ) + except (ValueError, TypeError) as e: + logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e) + + # Import certificates + for cert_domain in cert_domains: + conn.execute( + "INSERT OR IGNORE INTO certificates (domain) VALUES (?)", + (cert_domain,), + ) + + migrated_domains = len(set(d for d, _ in map_entries)) + migrated_servers = sum( + 1 for d, slots in servers_config.items() + if "_shares" not in slots + for s in slots + if not s.startswith("_") + ) + migrated_certs = len(cert_domains) + logger.info( + "Migrated from JSON: %d domain entries, %d servers, %d certificates", + migrated_domains, migrated_servers, migrated_certs, + ) + + # Rename original files to .bak + for path in [SERVERS_FILE, CERTS_FILE]: + _backup_file(path) + + except sqlite3.Error as e: + logger.error("Migration failed: %s", e) + raise + + +def _read_file_for_migration(file_path: str) -> str: + """Read a file for migration purposes (local only, no locking needed).""" + if REMOTE_MODE: + from .ssh_ops import remote_read_file + return remote_read_file(file_path) + with open(file_path, "r", encoding="utf-8") as f: + return f.read() + + +def _backup_file(file_path: str) -> None: + """Rename a file to .bak if it exists.""" + if REMOTE_MODE: + return # Don't rename remote files + try: + if os.path.exists(file_path): + bak_path = f"{file_path}.bak" + os.rename(file_path, bak_path) + logger.info("Backed up %s -> %s", file_path, bak_path) + except OSError as e: + logger.warning("Failed to backup %s: %s", file_path, e) + + +# --- Domain data access --- + +def db_get_map_contents() -> list[tuple[str, str]]: + """Get all domain-to-backend mappings. + + Returns: + List of (domain, backend) tuples including wildcards. + """ + conn = get_connection() + cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain") + return [(row["domain"], row["backend"]) for row in cur.fetchall()] + + +def db_get_domain_backend(domain: str) -> Optional[str]: + """Look up the backend for a specific domain. + + Args: + domain: Domain name to look up. + + Returns: + Backend name if found, None otherwise. + """ + conn = get_connection() + cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,)) + row = cur.fetchone() + return row["backend"] if row else None + + +def db_add_domain(domain: str, backend: str, is_wildcard: bool = False, + shares_with: Optional[str] = None) -> None: + """Add a domain to the database. + + Args: + domain: Domain name (e.g., "example.com" or ".example.com"). + backend: Backend pool name (e.g., "pool_5"). + is_wildcard: Whether this is a wildcard entry. + shares_with: Parent domain if sharing a pool. + """ + conn = get_connection() + conn.execute( + """INSERT INTO domains (domain, backend, is_wildcard, shares_with) + VALUES (?, ?, ?, ?) + ON CONFLICT(domain) DO UPDATE SET + backend = excluded.backend, + shares_with = excluded.shares_with, + updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""", + (domain, backend, 1 if is_wildcard else 0, shares_with), + ) + conn.commit() + + +def db_remove_domain(domain: str) -> None: + """Remove a domain (exact + wildcard) from the database. + + Args: + domain: Base domain name (without leading dot). + """ + conn = get_connection() + conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}")) + conn.commit() + + +def db_find_available_pool() -> Optional[str]: + """Find the first available pool not assigned to any domain. + + Returns: + Pool name (e.g., "pool_5") or None if all pools are in use. + """ + conn = get_connection() + cur = conn.execute( + "SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0" + ) + used_pools = {row["backend"] for row in cur.fetchall()} + + for i in range(1, POOL_COUNT + 1): + pool_name = f"pool_{i}" + if pool_name not in used_pools: + return pool_name + return None + + +def db_get_domains_sharing_pool(pool: str) -> list[str]: + """Get all non-wildcard domains using a specific pool. + + Args: + pool: Pool name (e.g., "pool_5"). + + Returns: + List of domain names. + """ + conn = get_connection() + cur = conn.execute( + "SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0", + (pool,), + ) + return [row["domain"] for row in cur.fetchall()] + + +# --- Server data access --- + +def db_load_servers_config() -> dict[str, Any]: + """Load servers configuration in the legacy dict format. + + Returns: + Dictionary compatible with the old servers.json format: + {"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}} + """ + conn = get_connection() + config: dict[str, Any] = {} + + # Load server entries + cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot") + for row in cur.fetchall(): + domain = row["domain"] + if domain not in config: + config[domain] = {} + config[domain][str(row["slot"])] = { + "ip": row["ip"], + "http_port": row["http_port"], + } + + # Load shared domain references + cur = conn.execute( + "SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0" + ) + for row in cur.fetchall(): + domain = row["domain"] + if domain not in config: + config[domain] = {} + config[domain]["_shares"] = row["shares_with"] + + return config + + +def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None: + """Add or update a server entry. + + Args: + domain: Domain name. + slot: Server slot number. + ip: Server IP address. + http_port: HTTP port. + """ + conn = get_connection() + conn.execute( + """INSERT INTO servers (domain, slot, ip, http_port) + VALUES (?, ?, ?, ?) + ON CONFLICT(domain, slot) DO UPDATE SET + ip = excluded.ip, + http_port = excluded.http_port, + updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""", + (domain, slot, ip, http_port), + ) + conn.commit() + + +def db_remove_server(domain: str, slot: int) -> None: + """Remove a server entry. + + Args: + domain: Domain name. + slot: Server slot number. + """ + conn = get_connection() + conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot)) + conn.commit() + + +def db_remove_domain_servers(domain: str) -> None: + """Remove all server entries for a domain. + + Args: + domain: Domain name. + """ + conn = get_connection() + conn.execute("DELETE FROM servers WHERE domain = ?", (domain,)) + conn.commit() + + +def db_add_shared_domain(domain: str, shares_with: str) -> None: + """Mark a domain as sharing a pool with another domain. + + Updates the shares_with field in the domains table. + + Args: + domain: The new domain. + shares_with: The existing domain to share with. + """ + conn = get_connection() + conn.execute( + "UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0", + (shares_with, domain), + ) + conn.commit() + + +def db_get_shared_domain(domain: str) -> Optional[str]: + """Get the parent domain that this domain shares a pool with. + + Args: + domain: Domain name. + + Returns: + Parent domain name if sharing, None otherwise. + """ + conn = get_connection() + cur = conn.execute( + "SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0", + (domain,), + ) + row = cur.fetchone() + if row and row["shares_with"]: + return row["shares_with"] + return None + + +def db_is_shared_domain(domain: str) -> bool: + """Check if a domain is sharing another domain's pool. + + Args: + domain: Domain name. + + Returns: + True if domain has shares_with reference. + """ + return db_get_shared_domain(domain) is not None + + +# --- Certificate data access --- + +def db_load_certs() -> list[str]: + """Load certificate domain list. + + Returns: + Sorted list of domain names with certificates. + """ + conn = get_connection() + cur = conn.execute("SELECT domain FROM certificates ORDER BY domain") + return [row["domain"] for row in cur.fetchall()] + + +def db_add_cert(domain: str) -> None: + """Add a domain to the certificate list. + + Args: + domain: Domain name. + """ + conn = get_connection() + conn.execute( + "INSERT OR IGNORE INTO certificates (domain) VALUES (?)", + (domain,), + ) + conn.commit() + + +def db_remove_cert(domain: str) -> None: + """Remove a domain from the certificate list. + + Args: + domain: Domain name. + """ + conn = get_connection() + conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,)) + conn.commit() + + +# --- Map file synchronization --- + +def sync_map_files() -> None: + """Regenerate HAProxy map files from database. + + Writes domains.map (exact matches) and wildcards.map (wildcard entries) + from the current database state. Uses atomic writes. + """ + from .file_ops import atomic_write_file + + conn = get_connection() + + # Fetch exact domains + cur = conn.execute( + "SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain" + ) + exact_entries = cur.fetchall() + + # Fetch wildcard domains + cur = conn.execute( + "SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain" + ) + wildcard_entries = cur.fetchall() + + # Write exact domains map + exact_lines = [ + "# Exact Domain to Backend mapping (for map_str)\n", + "# Format: domain backend_name\n", + "# Uses ebtree for O(log n) lookup performance\n\n", + ] + for row in exact_entries: + exact_lines.append(f"{row['domain']} {row['backend']}\n") + atomic_write_file(MAP_FILE, "".join(exact_lines)) + + # Write wildcards map + wildcard_lines = [ + "# Wildcard Domain to Backend mapping (for map_dom)\n", + "# Format: .domain.com backend_name (matches *.domain.com)\n", + "# Uses map_dom for suffix matching\n\n", + ] + for row in wildcard_entries: + wildcard_lines.append(f"{row['domain']} {row['backend']}\n") + atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines)) + + logger.debug("Map files synced: %d exact, %d wildcard entries", + len(exact_entries), len(wildcard_entries)) diff --git a/haproxy_mcp/file_ops.py b/haproxy_mcp/file_ops.py index 847eed7..8aacce5 100644 --- a/haproxy_mcp/file_ops.py +++ b/haproxy_mcp/file_ops.py @@ -1,7 +1,11 @@ -"""File I/O operations for HAProxy MCP Server.""" +"""File I/O operations for HAProxy MCP Server. + +Most data access is now delegated to db.py (SQLite). +This module retains atomic file writes, map file I/O for HAProxy, +and provides backward-compatible function signatures. +""" import fcntl -import json import os import tempfile from contextlib import contextmanager @@ -10,8 +14,6 @@ from typing import Any, Generator, Optional from .config import ( MAP_FILE, WILDCARDS_MAP_FILE, - SERVERS_FILE, - CERTS_FILE, REMOTE_MODE, logger, ) @@ -138,16 +140,13 @@ def _read_file(file_path: str) -> str: def get_map_contents() -> list[tuple[str, str]]: - """Read both domains.map and wildcards.map and return combined entries. + """Get all domain-to-backend mappings from SQLite. Returns: - List of (domain, backend) tuples from both map files + List of (domain, backend) tuples including wildcards. """ - # Read exact domains - entries = _read_map_file(MAP_FILE) - # Read wildcards and append - entries.extend(_read_map_file(WILDCARDS_MAP_FILE)) - return entries + from .db import db_get_map_contents + return db_get_map_contents() def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]: @@ -170,44 +169,21 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str def save_map_file(entries: list[tuple[str, str]]) -> None: - """Save domain-to-backend entries to map files. + """Sync map files from the database. - Splits entries into two files for 2-stage routing: - - domains.map: Exact matches (map_str, O(log n)) - - wildcards.map: Wildcard entries starting with "." (map_dom, O(n)) - - Args: - entries: List of (domain, backend) tuples. + Regenerates domains.map and wildcards.map from the current + database state. The entries parameter is ignored (kept for + backward compatibility during transition). Raises: IOError: If map files cannot be written. """ - # Split into exact and wildcard entries - exact_entries, wildcard_entries = split_domain_entries(entries) - - # Save exact domains (for map_str - fast O(log n) lookup) - exact_lines = [ - "# Exact Domain to Backend mapping (for map_str)\n", - "# Format: domain backend_name\n", - "# Uses ebtree for O(log n) lookup performance\n\n", - ] - for domain, backend in sorted(exact_entries): - exact_lines.append(f"{domain} {backend}\n") - atomic_write_file(MAP_FILE, "".join(exact_lines)) - - # Save wildcards (for map_dom - O(n) but small set) - wildcard_lines = [ - "# Wildcard Domain to Backend mapping (for map_dom)\n", - "# Format: .domain.com backend_name (matches *.domain.com)\n", - "# Uses map_dom for suffix matching\n\n", - ] - for domain, backend in sorted(wildcard_entries): - wildcard_lines.append(f"{domain} {backend}\n") - atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines)) + from .db import sync_map_files + sync_map_files() def get_domain_backend(domain: str) -> Optional[str]: - """Look up the backend for a domain from domains.map. + """Look up the backend for a domain from SQLite (O(1)). Args: domain: The domain to look up @@ -215,10 +191,8 @@ def get_domain_backend(domain: str) -> Optional[str]: Returns: Backend name if found, None otherwise """ - for map_domain, backend in get_map_contents(): - if map_domain == domain: - return backend - return None + from .db import db_get_domain_backend + return db_get_domain_backend(domain) def is_legacy_backend(backend: str) -> bool: @@ -273,34 +247,17 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]: def load_servers_config() -> dict[str, Any]: - """Load servers configuration from JSON file. + """Load servers configuration from SQLite. Returns: - Dictionary with server configurations + Dictionary with server configurations (legacy format compatible). """ - try: - content = _read_file(SERVERS_FILE) - return json.loads(content) - except FileNotFoundError: - return {} - except json.JSONDecodeError as e: - logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e) - return {} - - -def save_servers_config(config: dict[str, Any]) -> None: - """Save servers configuration to JSON file atomically. - - Uses temp file + rename for atomic write to prevent race conditions. - - Args: - config: Dictionary with server configurations - """ - atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2)) + from .db import db_load_servers_config + return db_load_servers_config() def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None: - """Add server configuration to persistent storage with file locking. + """Add server configuration to persistent storage. Args: domain: Domain name @@ -308,41 +265,29 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non ip: Server IP address http_port: HTTP port """ - with file_lock(f"{SERVERS_FILE}.lock"): - config = load_servers_config() - if domain not in config: - config[domain] = {} - config[domain][str(slot)] = {"ip": ip, "http_port": http_port} - save_servers_config(config) + from .db import db_add_server + db_add_server(domain, slot, ip, http_port) def remove_server_from_config(domain: str, slot: int) -> None: - """Remove server configuration from persistent storage with file locking. + """Remove server configuration from persistent storage. Args: domain: Domain name slot: Server slot to remove """ - with file_lock(f"{SERVERS_FILE}.lock"): - config = load_servers_config() - if domain in config and str(slot) in config[domain]: - del config[domain][str(slot)] - if not config[domain]: - del config[domain] - save_servers_config(config) + from .db import db_remove_server + db_remove_server(domain, slot) def remove_domain_from_config(domain: str) -> None: - """Remove domain from persistent config with file locking. + """Remove domain from persistent config (servers + domain entry). Args: domain: Domain name to remove """ - with file_lock(f"{SERVERS_FILE}.lock"): - config = load_servers_config() - if domain in config: - del config[domain] - save_servers_config(config) + from .db import db_remove_domain_servers + db_remove_domain_servers(domain) def get_shared_domain(domain: str) -> Optional[str]: @@ -354,9 +299,8 @@ def get_shared_domain(domain: str) -> Optional[str]: Returns: Parent domain name if sharing, None otherwise """ - config = load_servers_config() - domain_config = config.get(domain, {}) - return domain_config.get("_shares") + from .db import db_get_shared_domain + return db_get_shared_domain(domain) def add_shared_domain_to_config(domain: str, shares_with: str) -> None: @@ -366,10 +310,8 @@ def add_shared_domain_to_config(domain: str, shares_with: str) -> None: domain: New domain name shares_with: Existing domain to share pool with """ - with file_lock(f"{SERVERS_FILE}.lock"): - config = load_servers_config() - config[domain] = {"_shares": shares_with} - save_servers_config(config) + from .db import db_add_shared_domain + db_add_shared_domain(domain, shares_with) def get_domains_sharing_pool(pool: str) -> list[str]: @@ -381,11 +323,8 @@ def get_domains_sharing_pool(pool: str) -> list[str]: Returns: List of domain names using this pool """ - domains = [] - for domain, backend in get_map_contents(): - if backend == pool and not domain.startswith("."): - domains.append(domain) - return domains + from .db import db_get_domains_sharing_pool + return db_get_domains_sharing_pool(pool) def is_shared_domain(domain: str) -> bool: @@ -397,37 +336,20 @@ def is_shared_domain(domain: str) -> bool: Returns: True if domain has _shares reference, False otherwise """ - config = load_servers_config() - domain_config = config.get(domain, {}) - return "_shares" in domain_config + from .db import db_is_shared_domain + return db_is_shared_domain(domain) # Certificate configuration functions def load_certs_config() -> list[str]: - """Load certificate domain list from JSON file. + """Load certificate domain list from SQLite. Returns: - List of domain names + Sorted list of domain names. """ - try: - content = _read_file(CERTS_FILE) - data = json.loads(content) - return data.get("domains", []) - except FileNotFoundError: - return [] - except json.JSONDecodeError as e: - logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e) - return [] - - -def save_certs_config(domains: list[str]) -> None: - """Save certificate domain list to JSON file atomically. - - Args: - domains: List of domain names - """ - atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2)) + from .db import db_load_certs + return db_load_certs() def add_cert_to_config(domain: str) -> None: @@ -436,11 +358,8 @@ def add_cert_to_config(domain: str) -> None: Args: domain: Domain name to add """ - with file_lock(f"{CERTS_FILE}.lock"): - domains = load_certs_config() - if domain not in domains: - domains.append(domain) - save_certs_config(domains) + from .db import db_add_cert + db_add_cert(domain) def remove_cert_from_config(domain: str) -> None: @@ -449,8 +368,45 @@ def remove_cert_from_config(domain: str) -> None: Args: domain: Domain name to remove """ - with file_lock(f"{CERTS_FILE}.lock"): - domains = load_certs_config() - if domain in domains: - domains.remove(domain) - save_certs_config(domains) + from .db import db_remove_cert + db_remove_cert(domain) + + +# Domain map helper functions (used by domains.py) + +def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False, + shares_with: Optional[str] = None) -> None: + """Add a domain to SQLite and sync map files. + + Args: + domain: Domain name (e.g., "example.com"). + backend: Backend pool name (e.g., "pool_5"). + is_wildcard: Whether this is a wildcard entry. + shares_with: Parent domain if sharing a pool. + """ + from .db import db_add_domain, sync_map_files + db_add_domain(domain, backend, is_wildcard, shares_with) + sync_map_files() + + +def remove_domain_from_map(domain: str) -> None: + """Remove a domain (exact + wildcard) from SQLite and sync map files. + + Args: + domain: Base domain name (without leading dot). + """ + from .db import db_remove_domain, sync_map_files + db_remove_domain(domain) + sync_map_files() + + +def find_available_pool() -> Optional[str]: + """Find the first available pool not assigned to any domain. + + Uses SQLite query for O(1) lookup vs previous O(n) list scan. + + Returns: + Pool name (e.g., "pool_5") or None if all pools are in use. + """ + from .db import db_find_available_pool + return db_find_available_pool() diff --git a/haproxy_mcp/haproxy_client.py b/haproxy_mcp/haproxy_client.py index de74758..6735fac 100644 --- a/haproxy_mcp/haproxy_client.py +++ b/haproxy_mcp/haproxy_client.py @@ -2,6 +2,7 @@ import socket import select +import subprocess import time from .config import ( @@ -161,7 +162,7 @@ def reload_haproxy() -> tuple[bool, str]: if result.returncode != 0: return False, f"Reload failed: {result.stderr}" return True, "OK" - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds" except FileNotFoundError: return False, "ssh/podman command not found" diff --git a/haproxy_mcp/server.py b/haproxy_mcp/server.py index d08cd53..38429a7 100644 --- a/haproxy_mcp/server.py +++ b/haproxy_mcp/server.py @@ -21,6 +21,7 @@ Environment Variables: from mcp.server.fastmcp import FastMCP from .config import MCP_HOST, MCP_PORT +from .db import init_db from .tools import register_all_tools from .tools.configuration import startup_restore @@ -32,5 +33,6 @@ register_all_tools(mcp) if __name__ == "__main__": + init_db() startup_restore() mcp.run(transport="streamable-http") diff --git a/haproxy_mcp/tools/certificates.py b/haproxy_mcp/tools/certificates.py index bbc327b..ec75f9d 100644 --- a/haproxy_mcp/tools/certificates.py +++ b/haproxy_mcp/tools/certificates.py @@ -1,6 +1,7 @@ """Certificate management tools for HAProxy MCP Server.""" import os +import subprocess from datetime import datetime from typing import Annotated @@ -152,7 +153,7 @@ def _haproxy_list_certs_impl() -> str: certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}") return "\n\n".join(certs) if certs else "No certificates found" - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return "Error: Command timed out" except FileNotFoundError: return "Error: acme.sh not found" @@ -203,7 +204,7 @@ def _haproxy_cert_info_impl(domain: str) -> str: result.stdout.strip() ] return "\n".join(info) - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return "Error: Command timed out" except OSError as e: logger.error("Error getting certificate info for %s: %s", domain, e) @@ -250,7 +251,7 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str: else: return f"Certificate issued but PEM file not created. Check {host_path}" - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s" except OSError as e: logger.error("Error issuing certificate for %s: %s", domain, e) @@ -289,7 +290,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str: else: return f"Error renewing certificate:\n{output}" - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s" except FileNotFoundError: return "Error: acme.sh not found" @@ -323,7 +324,7 @@ def _haproxy_renew_all_certs_impl() -> str: else: return "Renewal check completed" - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return "Error: Renewal cron timed out" except FileNotFoundError: return "Error: acme.sh not found" diff --git a/haproxy_mcp/tools/configuration.py b/haproxy_mcp/tools/configuration.py index 8d532ea..7e2564d 100644 --- a/haproxy_mcp/tools/configuration.py +++ b/haproxy_mcp/tools/configuration.py @@ -1,5 +1,6 @@ """Configuration management tools for HAProxy MCP Server.""" +import subprocess import time from ..config import ( @@ -177,7 +178,7 @@ def register_config_tools(mcp): if result.returncode == 0: return "Configuration is valid" return f"Configuration errors:\n{result.stderr}" - except TimeoutError: + except (TimeoutError, subprocess.TimeoutExpired): return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds" except FileNotFoundError: return "Error: ssh/podman command not found" diff --git a/haproxy_mcp/tools/domains.py b/haproxy_mcp/tools/domains.py index 59b5256..3a10537 100644 --- a/haproxy_mcp/tools/domains.py +++ b/haproxy_mcp/tools/domains.py @@ -6,7 +6,6 @@ from typing import Annotated, Optional from pydantic import Field from ..config import ( - MAP_FILE, MAP_FILE_CONTAINER, WILDCARDS_MAP_FILE_CONTAINER, POOL_COUNT, @@ -22,7 +21,6 @@ from ..validation import validate_domain, validate_ip, validate_port_int from ..haproxy_client import haproxy_cmd from ..file_ops import ( get_map_contents, - save_map_file, get_domain_backend, is_legacy_backend, add_server_to_config, @@ -31,30 +29,13 @@ from ..file_ops import ( add_shared_domain_to_config, get_domains_sharing_pool, is_shared_domain, + add_domain_to_map, + remove_domain_from_map, + find_available_pool, ) from ..utils import parse_servers_state, disable_server_slot -def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]: - """Find an available pool backend from the pool list. - - Iterates through pool_1 to pool_N and returns the first pool - that is not currently in use. - - Args: - entries: List of (domain, backend) tuples from the map file. - used_pools: Set of pool names already in use. - - Returns: - Available pool name (e.g., "pool_5") or None if all pools are in use. - """ - for i in range(1, POOL_COUNT + 1): - pool_name = f"pool_{i}" - if pool_name not in used_pools: - return pool_name - return None - - def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]: """Check if a domain is a subdomain of an existing registered domain. @@ -95,24 +76,19 @@ def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None: haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}") -def _rollback_domain_addition( - domain: str, - entries: list[tuple[str, str]] -) -> None: - """Rollback a failed domain addition by removing entries from map file. +def _rollback_domain_addition(domain: str) -> None: + """Rollback a failed domain addition by removing from SQLite + map files. - Called when HAProxy Runtime API update fails after the map file - has already been saved. + Called when HAProxy Runtime API update fails after the domain + has already been saved to the database. Args: domain: Domain name that was added. - entries: Current list of map entries to rollback from. """ - rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"] try: - save_map_file(rollback_entries) - except IOError: - logger.error("Failed to rollback map file after HAProxy error") + remove_domain_from_map(domain) + except (IOError, Exception): + logger.error("Failed to rollback domain %s after HAProxy error", domain) def _file_exists(path: str) -> bool: @@ -242,93 +218,86 @@ def register_domain_tools(mcp): if share_with and ip: return "Error: Cannot specify both ip and share_with (shared domains use existing servers)" - # Use file locking for the entire pool allocation operation - from ..file_ops import file_lock - with file_lock(f"{MAP_FILE}.lock"): - # Read map contents once for both existence check and pool lookup - entries = get_map_contents() + # Read current entries for existence check and subdomain detection + entries = get_map_contents() - # Check if domain already exists (using cached entries) - for domain_entry, backend in entries: - if domain_entry == domain: - return f"Error: Domain {domain} already exists (mapped to {backend})" + # Check if domain already exists + for domain_entry, backend in entries: + if domain_entry == domain: + return f"Error: Domain {domain} already exists (mapped to {backend})" - # Build used pools and registered domains sets - used_pools: set[str] = set() - registered_domains: set[str] = set() - for entry_domain, backend in entries: - if backend.startswith("pool_"): - used_pools.add(backend) - if not entry_domain.startswith("."): - registered_domains.add(entry_domain) + # Build registered domains set for subdomain check + registered_domains: set[str] = set() + for entry_domain, _ in entries: + if not entry_domain.startswith("."): + registered_domains.add(entry_domain) - # Handle share_with: reuse existing domain's pool - if share_with: - share_backend = get_domain_backend(share_with) - if not share_backend: - return f"Error: Domain {share_with} not found" - if not share_backend.startswith("pool_"): - return f"Error: Cannot share with legacy backend {share_backend}" - pool = share_backend - else: - # Find available pool - pool = _find_available_pool(entries, used_pools) - if not pool: - return f"Error: All {POOL_COUNT} pool backends are in use" + # Handle share_with: reuse existing domain's pool + if share_with: + share_backend = get_domain_backend(share_with) + if not share_backend: + return f"Error: Domain {share_with} not found" + if not share_backend.startswith("pool_"): + return f"Error: Cannot share with legacy backend {share_backend}" + pool = share_backend + else: + # Find available pool (SQLite query, O(1)) + pool = find_available_pool() + if not pool: + return f"Error: All {POOL_COUNT} pool backends are in use" - # Check if this is a subdomain of an existing domain - is_subdomain, parent_domain = _check_subdomain(domain, registered_domains) + # Check if this is a subdomain of an existing domain + is_subdomain, parent_domain = _check_subdomain(domain, registered_domains) + try: + # Save to SQLite + sync map files (atomic via SQLite transaction) try: - # Save to disk first (atomic write for persistence) - entries.append((domain, pool)) + add_domain_to_map(domain, pool) if not is_subdomain: - entries.append((f".{domain}", pool)) - try: - save_map_file(entries) - except IOError as e: - return f"Error: Failed to save map file: {e}" - - # Update HAProxy maps via Runtime API - try: - _update_haproxy_maps(domain, pool, is_subdomain) - except HaproxyError as e: - _rollback_domain_addition(domain, entries) - return f"Error: Failed to update HAProxy map: {e}" - - # Handle server configuration based on mode - if share_with: - # Save shared domain reference - add_shared_domain_to_config(domain, share_with) - result = f"Domain {domain} added, sharing pool {pool} with {share_with}" - elif ip: - # Add server to slot 1 - add_server_to_config(domain, 1, ip, http_port) - try: - server = f"{pool}_1" - haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}") - haproxy_cmd(f"set server {pool}/{server} state ready") - except HaproxyError as e: - remove_server_from_config(domain, 1) - return f"Domain {domain} added to {pool} but server config failed: {e}" - result = f"Domain {domain} added to {pool} with server {ip}:{http_port}" - else: - result = f"Domain {domain} added to {pool} (no servers configured)" - - if is_subdomain: - result += f" (subdomain of {parent_domain}, no wildcard)" - - # Check certificate coverage - cert_covered, cert_info = check_certificate_coverage(domain) - if cert_covered: - result += f"\nSSL: Using certificate {cert_info}" - else: - result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one." - - return result + add_domain_to_map(f".{domain}", pool, is_wildcard=True) + except (IOError, Exception) as e: + return f"Error: Failed to save domain: {e}" + # Update HAProxy maps via Runtime API + try: + _update_haproxy_maps(domain, pool, is_subdomain) except HaproxyError as e: - return f"Error: {e}" + _rollback_domain_addition(domain) + return f"Error: Failed to update HAProxy map: {e}" + + # Handle server configuration based on mode + if share_with: + # Save shared domain reference + add_shared_domain_to_config(domain, share_with) + result = f"Domain {domain} added, sharing pool {pool} with {share_with}" + elif ip: + # Add server to slot 1 + add_server_to_config(domain, 1, ip, http_port) + try: + server = f"{pool}_1" + haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}") + haproxy_cmd(f"set server {pool}/{server} state ready") + except HaproxyError as e: + remove_server_from_config(domain, 1) + return f"Domain {domain} added to {pool} but server config failed: {e}" + result = f"Domain {domain} added to {pool} with server {ip}:{http_port}" + else: + result = f"Domain {domain} added to {pool} (no servers configured)" + + if is_subdomain: + result += f" (subdomain of {parent_domain}, no wildcard)" + + # Check certificate coverage + cert_covered, cert_info = check_certificate_coverage(domain) + if cert_covered: + result += f"\nSSL: Using certificate {cert_info}" + else: + result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one." + + return result + + except HaproxyError as e: + return f"Error: {e}" @mcp.tool() def haproxy_remove_domain( @@ -355,10 +324,8 @@ def register_domain_tools(mcp): domains_using_pool = get_domains_sharing_pool(backend) other_domains = [d for d in domains_using_pool if d != domain] - # Save to disk first (atomic write for persistence) - entries = get_map_contents() - new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"] - save_map_file(new_entries) + # Remove from SQLite + sync map files + remove_domain_from_map(domain) # Remove from persistent server config remove_domain_from_config(domain) diff --git a/haproxy_mcp/tools/health.py b/haproxy_mcp/tools/health.py index 39208d9..fbea193 100644 --- a/haproxy_mcp/tools/health.py +++ b/haproxy_mcp/tools/health.py @@ -10,6 +10,7 @@ from pydantic import Field from ..config import ( MAP_FILE, SERVERS_FILE, + DB_FILE, HAPROXY_CONTAINER, ) from ..exceptions import HaproxyError @@ -88,7 +89,7 @@ def register_health_tools(mcp): # Check configuration files files_ok = True file_status: dict[str, str] = {} - for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]: + for name, path in [("map_file", MAP_FILE), ("db_file", DB_FILE)]: exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path) if exists: file_status[name] = "ok" diff --git a/tests/conftest.py b/tests/conftest.py index 07752e5..96a288a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -255,6 +255,8 @@ def temp_config_dir(tmp_path): state_file = tmp_path / "servers.state" state_file.write_text("") + db_file = tmp_path / "haproxy_mcp.db" + return { "dir": tmp_path, "map_file": str(map_file), @@ -262,12 +264,15 @@ def temp_config_dir(tmp_path): "servers_file": str(servers_file), "certs_file": str(certs_file), "state_file": str(state_file), + "db_file": str(db_file), } @pytest.fixture def patch_config_paths(temp_config_dir): """Fixture that patches config module paths to use temporary directory.""" + from haproxy_mcp.db import close_connection, init_db + with patch.multiple( "haproxy_mcp.config", MAP_FILE=temp_config_dir["map_file"], @@ -275,16 +280,34 @@ def patch_config_paths(temp_config_dir): SERVERS_FILE=temp_config_dir["servers_file"], CERTS_FILE=temp_config_dir["certs_file"], STATE_FILE=temp_config_dir["state_file"], + DB_FILE=temp_config_dir["db_file"], ): # Also patch file_ops module which imports these with patch.multiple( "haproxy_mcp.file_ops", MAP_FILE=temp_config_dir["map_file"], WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], - SERVERS_FILE=temp_config_dir["servers_file"], - CERTS_FILE=temp_config_dir["certs_file"], ): - yield temp_config_dir + # Patch db module which imports these + with patch.multiple( + "haproxy_mcp.db", + MAP_FILE=temp_config_dir["map_file"], + WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], + SERVERS_FILE=temp_config_dir["servers_file"], + CERTS_FILE=temp_config_dir["certs_file"], + DB_FILE=temp_config_dir["db_file"], + ): + # Patch health module which imports MAP_FILE and DB_FILE + with patch.multiple( + "haproxy_mcp.tools.health", + MAP_FILE=temp_config_dir["map_file"], + DB_FILE=temp_config_dir["db_file"], + ): + # Close any existing connection and initialize fresh DB + close_connection() + init_db() + yield temp_config_dir + close_connection() @pytest.fixture diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py new file mode 100644 index 0000000..f5b8741 --- /dev/null +++ b/tests/unit/test_db.py @@ -0,0 +1,433 @@ +"""Unit tests for db module (SQLite database operations).""" + +import json +import os +import sqlite3 +from unittest.mock import patch + +import pytest + +from haproxy_mcp.db import ( + get_connection, + close_connection, + init_db, + migrate_from_json, + db_get_map_contents, + db_get_domain_backend, + db_add_domain, + db_remove_domain, + db_find_available_pool, + db_get_domains_sharing_pool, + db_load_servers_config, + db_add_server, + db_remove_server, + db_remove_domain_servers, + db_add_shared_domain, + db_get_shared_domain, + db_is_shared_domain, + db_load_certs, + db_add_cert, + db_remove_cert, + sync_map_files, + SCHEMA_VERSION, +) + + +class TestConnectionManagement: + """Tests for database connection management.""" + + def test_get_connection(self, patch_config_paths): + """Get a connection returns a valid SQLite connection.""" + conn = get_connection() + assert conn is not None + # Verify WAL mode + result = conn.execute("PRAGMA journal_mode").fetchone() + assert result[0] == "wal" + + def test_connection_is_thread_local(self, patch_config_paths): + """Same thread gets same connection.""" + conn1 = get_connection() + conn2 = get_connection() + assert conn1 is conn2 + + def test_close_connection(self, patch_config_paths): + """Close connection clears thread-local.""" + conn1 = get_connection() + close_connection() + conn2 = get_connection() + assert conn1 is not conn2 + + +class TestInitDb: + """Tests for database initialization.""" + + def test_creates_tables(self, patch_config_paths): + """init_db creates all required tables.""" + conn = get_connection() + + # Check tables exist + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ).fetchall() + table_names = [t["name"] for t in tables] + + assert "domains" in table_names + assert "servers" in table_names + assert "certificates" in table_names + assert "schema_version" in table_names + + def test_schema_version_recorded(self, patch_config_paths): + """Schema version is recorded.""" + conn = get_connection() + cur = conn.execute("SELECT MAX(version) FROM schema_version") + version = cur.fetchone()[0] + assert version == SCHEMA_VERSION + + def test_idempotent(self, patch_config_paths): + """Calling init_db twice is safe.""" + # init_db is already called by patch_config_paths + # Calling it again should not raise + init_db() + conn = get_connection() + cur = conn.execute("SELECT COUNT(*) FROM schema_version") + # May have 1 or 2 entries but should not fail + assert cur.fetchone()[0] >= 1 + + +class TestMigrateFromJson: + """Tests for JSON to SQLite migration.""" + + def test_migrate_map_files(self, patch_config_paths): + """Migrate domain entries from map files.""" + # Write test map files + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + f.write("api.example.com pool_2\n") + with open(patch_config_paths["wildcards_file"], "w") as f: + f.write(".example.com pool_1\n") + + migrate_from_json() + + entries = db_get_map_contents() + assert ("example.com", "pool_1") in entries + assert ("api.example.com", "pool_2") in entries + assert (".example.com", "pool_1") in entries + + def test_migrate_servers_json(self, patch_config_paths): + """Migrate server entries from servers.json.""" + config = { + "example.com": { + "1": {"ip": "10.0.0.1", "http_port": 80}, + "2": {"ip": "10.0.0.2", "http_port": 8080}, + } + } + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + + migrate_from_json() + + result = db_load_servers_config() + assert "example.com" in result + assert result["example.com"]["1"]["ip"] == "10.0.0.1" + assert result["example.com"]["2"]["http_port"] == 8080 + + def test_migrate_shared_domains(self, patch_config_paths): + """Migrate shared domain references.""" + # First add the domain to map so it exists in DB + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + f.write("www.example.com pool_1\n") + + config = { + "www.example.com": {"_shares": "example.com"}, + } + with open(patch_config_paths["servers_file"], "w") as f: + json.dump(config, f) + + migrate_from_json() + + assert db_get_shared_domain("www.example.com") == "example.com" + + def test_migrate_certificates(self, patch_config_paths): + """Migrate certificate entries.""" + with open(patch_config_paths["certs_file"], "w") as f: + json.dump({"domains": ["example.com", "api.example.com"]}, f) + + migrate_from_json() + + certs = db_load_certs() + assert "example.com" in certs + assert "api.example.com" in certs + + def test_migrate_idempotent(self, patch_config_paths): + """Migration is idempotent (INSERT OR IGNORE).""" + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + with open(patch_config_paths["servers_file"], "w") as f: + json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f) + + migrate_from_json() + migrate_from_json() # Should not fail + + entries = db_get_map_contents() + assert len([d for d, _ in entries if d == "example.com"]) == 1 + + def test_migrate_empty_files(self, patch_config_paths): + """Migration with no existing data does nothing.""" + os.unlink(patch_config_paths["map_file"]) + os.unlink(patch_config_paths["wildcards_file"]) + os.unlink(patch_config_paths["servers_file"]) + os.unlink(patch_config_paths["certs_file"]) + + migrate_from_json() # Should not fail + + def test_backup_files_after_migration(self, patch_config_paths): + """Original JSON files are backed up after migration.""" + with open(patch_config_paths["servers_file"], "w") as f: + json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f) + with open(patch_config_paths["certs_file"], "w") as f: + json.dump({"domains": ["example.com"]}, f) + + # Write map file so migration has data + with open(patch_config_paths["map_file"], "w") as f: + f.write("example.com pool_1\n") + + migrate_from_json() + + assert os.path.exists(f"{patch_config_paths['servers_file']}.bak") + assert os.path.exists(f"{patch_config_paths['certs_file']}.bak") + + +class TestDomainOperations: + """Tests for domain data access functions.""" + + def test_add_and_get_domain(self, patch_config_paths): + """Add domain and retrieve its backend.""" + db_add_domain("example.com", "pool_1") + + assert db_get_domain_backend("example.com") == "pool_1" + + def test_get_nonexistent_domain(self, patch_config_paths): + """Non-existent domain returns None.""" + assert db_get_domain_backend("nonexistent.com") is None + + def test_remove_domain(self, patch_config_paths): + """Remove domain removes both exact and wildcard.""" + db_add_domain("example.com", "pool_1") + db_add_domain(".example.com", "pool_1", is_wildcard=True) + + db_remove_domain("example.com") + + assert db_get_domain_backend("example.com") is None + entries = db_get_map_contents() + assert (".example.com", "pool_1") not in entries + + def test_find_available_pool(self, patch_config_paths): + """Find first available pool.""" + db_add_domain("a.com", "pool_1") + db_add_domain("b.com", "pool_2") + + pool = db_find_available_pool() + assert pool == "pool_3" + + def test_find_available_pool_empty(self, patch_config_paths): + """First pool available when none used.""" + pool = db_find_available_pool() + assert pool == "pool_1" + + def test_get_domains_sharing_pool(self, patch_config_paths): + """Get non-wildcard domains using a pool.""" + db_add_domain("example.com", "pool_1") + db_add_domain("www.example.com", "pool_1") + db_add_domain(".example.com", "pool_1", is_wildcard=True) + + domains = db_get_domains_sharing_pool("pool_1") + assert "example.com" in domains + assert "www.example.com" in domains + assert ".example.com" not in domains + + def test_update_domain_backend(self, patch_config_paths): + """Updating a domain changes its backend.""" + db_add_domain("example.com", "pool_1") + db_add_domain("example.com", "pool_5") # Update + + assert db_get_domain_backend("example.com") == "pool_5" + + +class TestServerOperations: + """Tests for server data access functions.""" + + def test_add_and_load_server(self, patch_config_paths): + """Add server and load config.""" + db_add_server("example.com", 1, "10.0.0.1", 80) + + config = db_load_servers_config() + assert config["example.com"]["1"]["ip"] == "10.0.0.1" + assert config["example.com"]["1"]["http_port"] == 80 + + def test_update_server(self, patch_config_paths): + """Update existing server slot.""" + db_add_server("example.com", 1, "10.0.0.1", 80) + db_add_server("example.com", 1, "10.0.0.99", 8080) + + config = db_load_servers_config() + assert config["example.com"]["1"]["ip"] == "10.0.0.99" + assert config["example.com"]["1"]["http_port"] == 8080 + + def test_remove_server(self, patch_config_paths): + """Remove a server slot.""" + db_add_server("example.com", 1, "10.0.0.1", 80) + db_add_server("example.com", 2, "10.0.0.2", 80) + + db_remove_server("example.com", 1) + + config = db_load_servers_config() + assert "1" not in config.get("example.com", {}) + assert "2" in config["example.com"] + + def test_remove_domain_servers(self, patch_config_paths): + """Remove all servers for a domain.""" + db_add_server("example.com", 1, "10.0.0.1", 80) + db_add_server("example.com", 2, "10.0.0.2", 80) + db_add_server("other.com", 1, "10.0.0.3", 80) + + db_remove_domain_servers("example.com") + + config = db_load_servers_config() + assert config.get("example.com", {}).get("1") is None + assert "other.com" in config + + def test_load_empty(self, patch_config_paths): + """Empty database returns empty dict.""" + config = db_load_servers_config() + assert config == {} + + +class TestSharedDomainOperations: + """Tests for shared domain functions.""" + + def test_add_and_get_shared(self, patch_config_paths): + """Add shared domain reference.""" + db_add_domain("example.com", "pool_1") + db_add_domain("www.example.com", "pool_1") + db_add_shared_domain("www.example.com", "example.com") + + assert db_get_shared_domain("www.example.com") == "example.com" + + def test_is_shared(self, patch_config_paths): + """Check if domain is shared.""" + db_add_domain("example.com", "pool_1") + db_add_domain("www.example.com", "pool_1") + db_add_shared_domain("www.example.com", "example.com") + + assert db_is_shared_domain("www.example.com") is True + assert db_is_shared_domain("example.com") is False + + def test_not_shared(self, patch_config_paths): + """Non-shared domain returns None.""" + db_add_domain("example.com", "pool_1") + + assert db_get_shared_domain("example.com") is None + assert db_is_shared_domain("example.com") is False + + def test_shared_in_load_config(self, patch_config_paths): + """Shared domain appears in load_servers_config.""" + db_add_domain("example.com", "pool_1") + db_add_domain("www.example.com", "pool_1") + db_add_shared_domain("www.example.com", "example.com") + + config = db_load_servers_config() + assert config["www.example.com"]["_shares"] == "example.com" + + +class TestCertificateOperations: + """Tests for certificate data access functions.""" + + def test_add_and_load_cert(self, patch_config_paths): + """Add and load certificate.""" + db_add_cert("example.com") + + certs = db_load_certs() + assert "example.com" in certs + + def test_add_duplicate(self, patch_config_paths): + """Adding duplicate cert is a no-op.""" + db_add_cert("example.com") + db_add_cert("example.com") + + certs = db_load_certs() + assert certs.count("example.com") == 1 + + def test_remove_cert(self, patch_config_paths): + """Remove a certificate.""" + db_add_cert("example.com") + db_add_cert("other.com") + + db_remove_cert("example.com") + + certs = db_load_certs() + assert "example.com" not in certs + assert "other.com" in certs + + def test_load_empty(self, patch_config_paths): + """Empty database returns empty list.""" + certs = db_load_certs() + assert certs == [] + + def test_sorted_output(self, patch_config_paths): + """Certificates are returned sorted.""" + db_add_cert("z.com") + db_add_cert("a.com") + db_add_cert("m.com") + + certs = db_load_certs() + assert certs == ["a.com", "m.com", "z.com"] + + +class TestSyncMapFiles: + """Tests for sync_map_files function.""" + + def test_sync_exact_domains(self, patch_config_paths): + """Sync writes exact domains to domains.map.""" + db_add_domain("example.com", "pool_1") + db_add_domain("api.example.com", "pool_2") + + sync_map_files() + + with open(patch_config_paths["map_file"]) as f: + content = f.read() + assert "example.com pool_1" in content + assert "api.example.com pool_2" in content + + def test_sync_wildcards(self, patch_config_paths): + """Sync writes wildcards to wildcards.map.""" + db_add_domain(".example.com", "pool_1", is_wildcard=True) + + sync_map_files() + + with open(patch_config_paths["wildcards_file"]) as f: + content = f.read() + assert ".example.com pool_1" in content + + def test_sync_empty(self, patch_config_paths): + """Sync with no domains writes headers only.""" + sync_map_files() + + with open(patch_config_paths["map_file"]) as f: + content = f.read() + assert "Exact Domain" in content + # No domain entries + lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")] + assert len(lines) == 0 + + def test_sync_sorted(self, patch_config_paths): + """Sync output is sorted.""" + db_add_domain("z.com", "pool_3") + db_add_domain("a.com", "pool_1") + + sync_map_files() + + with open(patch_config_paths["map_file"]) as f: + lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] + assert lines[0] == "a.com pool_1" + assert lines[1] == "z.com pool_3" diff --git a/tests/unit/test_file_ops.py b/tests/unit/test_file_ops.py index 49e4ccd..af2ae7e 100644 --- a/tests/unit/test_file_ops.py +++ b/tests/unit/test_file_ops.py @@ -1,6 +1,5 @@ -"""Unit tests for file_ops module.""" +"""Unit tests for file_ops module (SQLite-backed).""" -import json import os from unittest.mock import patch @@ -16,14 +15,19 @@ from haproxy_mcp.file_ops import ( get_legacy_backend_name, get_backend_and_prefix, load_servers_config, - save_servers_config, add_server_to_config, remove_server_from_config, remove_domain_from_config, load_certs_config, - save_certs_config, add_cert_to_config, remove_cert_from_config, + add_domain_to_map, + remove_domain_from_map, + find_available_pool, + add_shared_domain_to_config, + get_shared_domain, + is_shared_domain, + get_domains_sharing_pool, ) @@ -62,7 +66,7 @@ class TestAtomicWriteFile: def test_unicode_content(self, tmp_path): """Unicode content is properly written.""" file_path = str(tmp_path / "unicode.txt") - content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese + content = "Hello, \u4e16\u754c!" atomic_write_file(file_path, content) @@ -81,66 +85,33 @@ class TestAtomicWriteFile: class TestGetMapContents: - """Tests for get_map_contents function.""" + """Tests for get_map_contents function (SQLite-backed).""" - def test_read_map_file(self, patch_config_paths): - """Read entries from map file.""" - # Write test content to map file - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - f.write("api.example.com pool_2\n") + def test_empty_db(self, patch_config_paths): + """Empty database returns empty list.""" + entries = get_map_contents() + assert entries == [] + + def test_read_domains(self, patch_config_paths): + """Read entries from database.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map("api.example.com", "pool_2") entries = get_map_contents() assert ("example.com", "pool_1") in entries assert ("api.example.com", "pool_2") in entries - def test_read_both_map_files(self, patch_config_paths): - """Read entries from both domains.map and wildcards.map.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - - with open(patch_config_paths["wildcards_file"], "w") as f: - f.write(".example.com pool_1\n") + def test_read_with_wildcards(self, patch_config_paths): + """Read entries including wildcards.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) entries = get_map_contents() assert ("example.com", "pool_1") in entries assert (".example.com", "pool_1") in entries - def test_skip_comments(self, patch_config_paths): - """Comments are skipped.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("# This is a comment\n") - f.write("example.com pool_1\n") - f.write("# Another comment\n") - - entries = get_map_contents() - - assert len(entries) == 1 - assert entries[0] == ("example.com", "pool_1") - - def test_skip_empty_lines(self, patch_config_paths): - """Empty lines are skipped.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("\n") - f.write("example.com pool_1\n") - f.write("\n") - f.write("api.example.com pool_2\n") - - entries = get_map_contents() - - assert len(entries) == 2 - - def test_file_not_found(self, patch_config_paths): - """Missing file returns empty list.""" - os.unlink(patch_config_paths["map_file"]) - os.unlink(patch_config_paths["wildcards_file"]) - - entries = get_map_contents() - - assert entries == [] - class TestSplitDomainEntries: """Tests for split_domain_entries function.""" @@ -182,36 +153,30 @@ class TestSplitDomainEntries: class TestSaveMapFile: - """Tests for save_map_file function.""" + """Tests for save_map_file function (syncs from DB to map files).""" def test_save_entries(self, patch_config_paths): """Save entries to separate map files.""" - entries = [ - ("example.com", "pool_1"), - (".example.com", "pool_1"), - ] + add_domain_to_map("example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) - save_map_file(entries) + save_map_file([]) # Entries param ignored, syncs from DB - # Check exact domains file with open(patch_config_paths["map_file"]) as f: content = f.read() assert "example.com pool_1" in content - # Check wildcards file with open(patch_config_paths["wildcards_file"]) as f: content = f.read() assert ".example.com pool_1" in content def test_sorted_output(self, patch_config_paths): """Entries are sorted in output.""" - entries = [ - ("z.example.com", "pool_3"), - ("a.example.com", "pool_1"), - ("m.example.com", "pool_2"), - ] + add_domain_to_map("z.example.com", "pool_3") + add_domain_to_map("a.example.com", "pool_1") + add_domain_to_map("m.example.com", "pool_2") - save_map_file(entries) + save_map_file([]) with open(patch_config_paths["map_file"]) as f: lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] @@ -222,12 +187,11 @@ class TestSaveMapFile: class TestGetDomainBackend: - """Tests for get_domain_backend function.""" + """Tests for get_domain_backend function (SQLite-backed).""" def test_find_existing_domain(self, patch_config_paths): """Find backend for existing domain.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") backend = get_domain_backend("example.com") @@ -235,8 +199,7 @@ class TestGetDomainBackend: def test_domain_not_found(self, patch_config_paths): """Non-existent domain returns None.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") backend = get_domain_backend("other.com") @@ -271,8 +234,7 @@ class TestGetBackendAndPrefix: def test_pool_backend(self, patch_config_paths): """Pool backend returns pool-based prefix.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_5\n") + add_domain_to_map("example.com", "pool_5") backend, prefix = get_backend_and_prefix("example.com") @@ -288,48 +250,33 @@ class TestGetBackendAndPrefix: class TestLoadServersConfig: - """Tests for load_servers_config function.""" + """Tests for load_servers_config function (SQLite-backed).""" - def test_load_existing_config(self, patch_config_paths, sample_servers_config): - """Load existing config file.""" - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(sample_servers_config, f) + def test_load_empty_config(self, patch_config_paths): + """Empty database returns empty dict.""" + config = load_servers_config() + assert config == {} + + def test_load_with_servers(self, patch_config_paths): + """Load config with server entries.""" + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) config = load_servers_config() assert "example.com" in config assert config["example.com"]["1"]["ip"] == "10.0.0.1" + assert config["example.com"]["2"]["ip"] == "10.0.0.2" - def test_file_not_found(self, patch_config_paths): - """Missing file returns empty dict.""" - os.unlink(patch_config_paths["servers_file"]) + def test_load_with_shared_domain(self, patch_config_paths): + """Load config with shared domain reference.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map("www.example.com", "pool_1") + add_shared_domain_to_config("www.example.com", "example.com") config = load_servers_config() - assert config == {} - - def test_invalid_json(self, patch_config_paths): - """Invalid JSON returns empty dict.""" - with open(patch_config_paths["servers_file"], "w") as f: - f.write("not valid json {{{") - - config = load_servers_config() - - assert config == {} - - -class TestSaveServersConfig: - """Tests for save_servers_config function.""" - - def test_save_config(self, patch_config_paths): - """Save config to file.""" - config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}} - - save_servers_config(config) - - with open(patch_config_paths["servers_file"]) as f: - loaded = json.load(f) - assert loaded == config + assert config["www.example.com"]["_shares"] == "example.com" class TestAddServerToConfig: @@ -373,17 +320,18 @@ class TestRemoveServerFromConfig: remove_server_from_config("example.com", 1) config = load_servers_config() - assert "1" not in config["example.com"] + assert "1" not in config.get("example.com", {}) assert "2" in config["example.com"] def test_remove_last_server_removes_domain(self, patch_config_paths): - """Removing last server removes domain entry.""" + """Removing last server removes domain entry from servers.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_server_from_config("example.com", 1) config = load_servers_config() - assert "example.com" not in config + # Domain may or may not exist (no servers = no entry) + assert config.get("example.com", {}).get("1") is None def test_remove_nonexistent_server(self, patch_config_paths): """Removing non-existent server is a no-op.""" @@ -399,14 +347,14 @@ class TestRemoveDomainFromConfig: """Tests for remove_domain_from_config function.""" def test_remove_existing_domain(self, patch_config_paths): - """Remove existing domain.""" + """Remove existing domain's servers.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("other.com", 1, "10.0.0.2", 80) remove_domain_from_config("example.com") config = load_servers_config() - assert "example.com" not in config + assert config.get("example.com", {}).get("1") is None assert "other.com" in config def test_remove_nonexistent_domain(self, patch_config_paths): @@ -420,40 +368,23 @@ class TestRemoveDomainFromConfig: class TestLoadCertsConfig: - """Tests for load_certs_config function.""" + """Tests for load_certs_config function (SQLite-backed).""" - def test_load_existing_config(self, patch_config_paths): - """Load existing certs config.""" - with open(patch_config_paths["certs_file"], "w") as f: - json.dump({"domains": ["example.com", "other.com"]}, f) + def test_load_empty(self, patch_config_paths): + """Empty database returns empty list.""" + domains = load_certs_config() + assert domains == [] + + def test_load_with_certs(self, patch_config_paths): + """Load certs from database.""" + add_cert_to_config("example.com") + add_cert_to_config("other.com") domains = load_certs_config() assert "example.com" in domains assert "other.com" in domains - def test_file_not_found(self, patch_config_paths): - """Missing file returns empty list.""" - os.unlink(patch_config_paths["certs_file"]) - - domains = load_certs_config() - - assert domains == [] - - -class TestSaveCertsConfig: - """Tests for save_certs_config function.""" - - def test_save_domains(self, patch_config_paths): - """Save domains to certs config.""" - save_certs_config(["z.com", "a.com"]) - - with open(patch_config_paths["certs_file"]) as f: - data = json.load(f) - - # Should be sorted - assert data["domains"] == ["a.com", "z.com"] - class TestAddCertToConfig: """Tests for add_cert_to_config function.""" @@ -496,3 +427,87 @@ class TestRemoveCertFromConfig: domains = load_certs_config() assert "example.com" in domains + + +class TestAddDomainToMap: + """Tests for add_domain_to_map function.""" + + def test_add_domain(self, patch_config_paths): + """Add a domain and verify map files are synced.""" + add_domain_to_map("example.com", "pool_1") + + assert get_domain_backend("example.com") == "pool_1" + + with open(patch_config_paths["map_file"]) as f: + assert "example.com pool_1" in f.read() + + def test_add_wildcard(self, patch_config_paths): + """Add a wildcard domain.""" + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) + + entries = get_map_contents() + assert (".example.com", "pool_1") in entries + + +class TestRemoveDomainFromMap: + """Tests for remove_domain_from_map function.""" + + def test_remove_domain(self, patch_config_paths): + """Remove a domain and its wildcard.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) + + remove_domain_from_map("example.com") + + assert get_domain_backend("example.com") is None + entries = get_map_contents() + assert (".example.com", "pool_1") not in entries + + +class TestFindAvailablePool: + """Tests for find_available_pool function.""" + + def test_first_pool_available(self, patch_config_paths): + """When no domains exist, pool_1 is returned.""" + pool = find_available_pool() + assert pool == "pool_1" + + def test_skip_used_pools(self, patch_config_paths): + """Used pools are skipped.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map("other.com", "pool_2") + + pool = find_available_pool() + assert pool == "pool_3" + + +class TestSharedDomains: + """Tests for shared domain functions.""" + + def test_get_shared_domain(self, patch_config_paths): + """Get parent domain for shared domain.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map("www.example.com", "pool_1") + add_shared_domain_to_config("www.example.com", "example.com") + + assert get_shared_domain("www.example.com") == "example.com" + + def test_is_shared_domain(self, patch_config_paths): + """Check if domain is shared.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map("www.example.com", "pool_1") + add_shared_domain_to_config("www.example.com", "example.com") + + assert is_shared_domain("www.example.com") is True + assert is_shared_domain("example.com") is False + + def test_get_domains_sharing_pool(self, patch_config_paths): + """Get all domains using a pool.""" + add_domain_to_map("example.com", "pool_1") + add_domain_to_map("www.example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) + + domains = get_domains_sharing_pool("pool_1") + assert "example.com" in domains + assert "www.example.com" in domains + assert ".example.com" not in domains # Wildcards excluded diff --git a/tests/unit/tools/test_certificates.py b/tests/unit/tools/test_certificates.py index 22d89a4..4ff2c72 100644 --- a/tests/unit/tools/test_certificates.py +++ b/tests/unit/tools/test_certificates.py @@ -6,6 +6,8 @@ from unittest.mock import patch, MagicMock import pytest +from haproxy_mcp.file_ops import add_cert_to_config + class TestGetPemPaths: """Tests for get_pem_paths function.""" @@ -127,8 +129,7 @@ class TestRestoreCertificates: def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select): """Restore certificates successfully.""" # Save config - with open(patch_config_paths["certs_file"], "w") as f: - json.dump({"domains": ["example.com"]}, f) + add_cert_to_config("example.com") # Create PEM certs_dir = tmp_path / "certs" @@ -283,11 +284,17 @@ class TestHaproxyCertInfo: pem_file = tmp_path / "example.com.pem" pem_file.write_text("cert content") - mock_subprocess.return_value = MagicMock( - returncode=0, - stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT", - stderr="" - ) + def subprocess_side_effect(*args, **kwargs): + cmd = args[0] if args else kwargs.get("args", []) + if isinstance(cmd, list) and "stat" in cmd: + return MagicMock(returncode=0, stdout="1704067200", stderr="") + return MagicMock( + returncode=0, + stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT", + stderr="" + ) + + mock_subprocess.side_effect = subprocess_side_effect mock_sock = mock_socket_class(responses={ "show ssl cert": "/etc/haproxy/certs/example.com.pem", @@ -337,25 +344,33 @@ class TestHaproxyIssueCert: assert "Error" in result assert "Invalid domain" in result - def test_issue_cert_no_cf_token(self, tmp_path): + def test_issue_cert_no_cf_token(self, tmp_path, mock_subprocess): """Fail when CF_Token is not set.""" + acme_sh = str(tmp_path / "acme.sh") + mock_subprocess.return_value = MagicMock( + returncode=1, + stdout="", + stderr="CF_Token is not set. Please export CF_Token environment variable.", + ) + with patch.dict(os.environ, {}, clear=True): with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): - with patch("os.path.exists", return_value=False): - from haproxy_mcp.tools.certificates import register_certificate_tools - mcp = MagicMock() - registered_tools = {} + with patch("haproxy_mcp.tools.certificates.ACME_SH", acme_sh): + with patch("os.path.exists", return_value=False): + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} - def capture_tool(): - def decorator(func): - registered_tools[func.__name__] = func - return func - return decorator + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator - mcp.tool = capture_tool - register_certificate_tools(mcp) + mcp.tool = capture_tool + register_certificate_tools(mcp) - result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True) + result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True) assert "CF_Token" in result @@ -845,8 +860,8 @@ class TestHaproxyRenewAllCertsMultiple: def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path): """Renew multiple certificates successfully.""" # Write config with multiple domains - with open(patch_config_paths["certs_file"], "w") as f: - json.dump({"domains": ["example.com", "example.org"]}, f) + add_cert_to_config("example.com") + add_cert_to_config("example.org") # Create PEM files certs_dir = tmp_path / "certs" @@ -1038,30 +1053,32 @@ class TestHaproxyDeleteCertPartialFailure: "show ssl cert": "", # Not loaded }) - # Mock os.remove to fail - def mock_remove(path): - if "example.com.pem" in str(path): - raise PermissionError("Permission denied") - raise FileNotFoundError() + # Mock subprocess to succeed for acme.sh remove but fail for rm (PEM removal) + def subprocess_side_effect(*args, **kwargs): + cmd = args[0] if args else kwargs.get("args", []) + if isinstance(cmd, list) and cmd[0] == "rm": + return MagicMock(returncode=1, stdout="", stderr="Permission denied") + return MagicMock(returncode=0, stdout="", stderr="") + + mock_subprocess.side_effect = subprocess_side_effect with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")): with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): with patch("socket.socket", return_value=mock_sock): - with patch("os.remove", side_effect=mock_remove): - from haproxy_mcp.tools.certificates import register_certificate_tools - mcp = MagicMock() - registered_tools = {} + from haproxy_mcp.tools.certificates import register_certificate_tools + mcp = MagicMock() + registered_tools = {} - def capture_tool(): - def decorator(func): - registered_tools[func.__name__] = func - return func - return decorator + def capture_tool(): + def decorator(func): + registered_tools[func.__name__] = func + return func + return decorator - mcp.tool = capture_tool - register_certificate_tools(mcp) + mcp.tool = capture_tool + register_certificate_tools(mcp) - result = registered_tools["haproxy_delete_cert"](domain="example.com") + result = registered_tools["haproxy_delete_cert"](domain="example.com") # Should report partial success (acme.sh deleted) and error (PEM failed) assert "Deleted" in result @@ -1118,8 +1135,8 @@ class TestRestoreCertificatesFailure: def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select): """Handle partial failure when restoring certificates.""" # Save config with multiple domains - with open(patch_config_paths["certs_file"], "w") as f: - json.dump({"domains": ["example.com", "missing.com"]}, f) + add_cert_to_config("example.com") + add_cert_to_config("missing.com") # Create only one PEM file certs_dir = tmp_path / "certs" diff --git a/tests/unit/tools/test_configuration.py b/tests/unit/tools/test_configuration.py index e9997e7..4b461c3 100644 --- a/tests/unit/tools/test_configuration.py +++ b/tests/unit/tools/test_configuration.py @@ -5,6 +5,8 @@ from unittest.mock import patch, MagicMock import pytest +from haproxy_mcp.file_ops import add_domain_to_map, add_server_to_config + class TestRestoreServersFromConfig: """Tests for restore_servers_from_config function.""" @@ -19,12 +21,12 @@ class TestRestoreServersFromConfig: def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): """Restore servers successfully.""" - # Write config and map - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(sample_servers_config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - f.write("api.example.com pool_2\n") + # Add domains and servers to database + add_domain_to_map("example.com", "pool_1") + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) + add_domain_to_map("api.example.com", "pool_2") + add_server_to_config("api.example.com", 1, "10.0.0.10", 8080) mock_sock = mock_socket_class(responses={ "set server": "", @@ -40,9 +42,8 @@ class TestRestoreServersFromConfig: def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths): """Skip domains not in map file.""" - config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}} - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(config, f) + # Add server for unknown.com but no map entry (simulates missing domain) + add_server_to_config("unknown.com", 1, "10.0.0.1", 80) mock_sock = mock_socket_class(responses={"set server": ""}) @@ -55,11 +56,9 @@ class TestRestoreServersFromConfig: def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths): """Skip servers with empty IP.""" - config = {"example.com": {"1": {"ip": "", "http_port": 80}}} - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + # Add domain to map and server with empty IP (will be skipped during restore) + add_domain_to_map("example.com", "pool_1") + add_server_to_config("example.com", 1, "", 80) mock_sock = mock_socket_class(responses={"set server": ""}) @@ -321,11 +320,12 @@ class TestHaproxyRestoreState: def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): """Restore state successfully.""" - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(sample_servers_config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - f.write("api.example.com pool_2\n") + # Add domains and servers to database + add_domain_to_map("example.com", "pool_1") + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) + add_domain_to_map("api.example.com", "pool_2") + add_server_to_config("api.example.com", 1, "10.0.0.10", 8080) mock_sock = mock_socket_class(responses={"set server": ""}) @@ -373,17 +373,10 @@ class TestRestoreServersFromConfigBatchFailure: def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths): """Fall back to individual commands when batch fails.""" - # Create config with servers - config = { - "example.com": { - "1": {"ip": "10.0.0.1", "http_port": 80}, - "2": {"ip": "10.0.0.2", "http_port": 80}, - } - } - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + # Add domain and servers to database + add_domain_to_map("example.com", "pool_1") + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) # Track call count to simulate batch failure then individual success call_count = [0] @@ -457,51 +450,51 @@ class TestRestoreServersFromConfigBatchFailure: def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths): """Skip servers with invalid slot number.""" + # Add domain to map + add_domain_to_map("example.com", "pool_1") + + # Mock load_servers_config to return config with invalid slot config = { "example.com": { "invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot "1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot } } - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config): + mock_sock = mock_socket_class(responses={"set server": ""}) - mock_sock = mock_socket_class(responses={"set server": ""}) + with patch("socket.socket", return_value=mock_sock): + from haproxy_mcp.tools.configuration import restore_servers_from_config - with patch("socket.socket", return_value=mock_sock): - from haproxy_mcp.tools.configuration import restore_servers_from_config + result = restore_servers_from_config() - result = restore_servers_from_config() - - # Should only restore the valid server - assert result == 1 + # Should only restore the valid server + assert result == 1 def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog): """Skip servers with invalid port.""" import logging + # Add domain to map + add_domain_to_map("example.com", "pool_1") + + # Mock load_servers_config to return config with invalid port config = { "example.com": { "1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port "2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port } } - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config): + mock_sock = mock_socket_class(responses={"set server": ""}) - mock_sock = mock_socket_class(responses={"set server": ""}) + with patch("socket.socket", return_value=mock_sock): + with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): + from haproxy_mcp.tools.configuration import restore_servers_from_config - with patch("socket.socket", return_value=mock_sock): - with caplog.at_level(logging.WARNING, logger="haproxy_mcp"): - from haproxy_mcp.tools.configuration import restore_servers_from_config + result = restore_servers_from_config() - result = restore_servers_from_config() - - # Should only restore the valid server - assert result == 1 + # Should only restore the valid server + assert result == 1 class TestStartupRestoreFailures: @@ -658,11 +651,12 @@ class TestHaproxyRestoreStateFailures: """Handle HAProxy error when restoring state.""" from haproxy_mcp.exceptions import HaproxyError - with open(patch_config_paths["servers_file"], "w") as f: - json.dump(sample_servers_config, f) - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - f.write("api.example.com pool_2\n") + # Add domains and servers to database + add_domain_to_map("example.com", "pool_1") + add_server_to_config("example.com", 1, "10.0.0.1", 80) + add_server_to_config("example.com", 2, "10.0.0.2", 80) + add_domain_to_map("api.example.com", "pool_2") + add_server_to_config("api.example.com", 1, "10.0.0.10", 8080) with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")): from haproxy_mcp.tools.configuration import register_config_tools diff --git a/tests/unit/tools/test_domains.py b/tests/unit/tools/test_domains.py index 6c03f07..e8e396f 100644 --- a/tests/unit/tools/test_domains.py +++ b/tests/unit/tools/test_domains.py @@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock import pytest from haproxy_mcp.exceptions import HaproxyError +from haproxy_mcp.file_ops import add_domain_to_map class TestHaproxyListDomains: @@ -38,9 +39,8 @@ class TestHaproxyListDomains: def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains with configured servers.""" - # Write map file - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + # Add domain to DB + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -70,10 +70,8 @@ class TestHaproxyListDomains: def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains excluding wildcards by default.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - with open(patch_config_paths["wildcards_file"], "w") as f: - f.write(".example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([]), @@ -100,10 +98,8 @@ class TestHaproxyListDomains: def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains including wildcards when requested.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - with open(patch_config_paths["wildcards_file"], "w") as f: - f.write(".example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([]), @@ -230,8 +226,7 @@ class TestHaproxyAddDomain: def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Reject adding domain that already exists.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() @@ -362,8 +357,7 @@ class TestHaproxyRemoveDomain: def test_remove_legacy_domain_rejected(self, patch_config_paths): """Reject removing legacy (non-pool) domain.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com legacy_backend\n") + add_domain_to_map("example.com", "legacy_backend") from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() @@ -385,10 +379,8 @@ class TestHaproxyRemoveDomain: def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Successfully remove domain.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") - with open(patch_config_paths["wildcards_file"], "w") as f: - f.write(".example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") + add_domain_to_map(".example.com", "pool_1", is_wildcard=True) mock_sock = mock_socket_class(responses={ "del map": "", diff --git a/tests/unit/tools/test_health.py b/tests/unit/tools/test_health.py index 28ca715..08e00cf 100644 --- a/tests/unit/tools/test_health.py +++ b/tests/unit/tools/test_health.py @@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock import pytest from haproxy_mcp.exceptions import HaproxyError +from haproxy_mcp.file_ops import add_domain_to_map class TestHaproxyHealth: @@ -80,7 +81,7 @@ class TestHaproxyHealth: # Use paths that don't exist with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")): - with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")): + with patch("haproxy_mcp.tools.health.DB_FILE", str(tmp_path / "nonexistent.db")): with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.health import register_health_tools mcp = MagicMock() @@ -160,8 +161,7 @@ class TestHaproxyDomainHealth: def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Domain health returns healthy when all servers are UP.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -197,8 +197,7 @@ class TestHaproxyDomainHealth: def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Domain health returns degraded when some servers are DOWN.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -234,8 +233,7 @@ class TestHaproxyDomainHealth: def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Domain health returns down when all servers are DOWN.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -269,8 +267,7 @@ class TestHaproxyDomainHealth: def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Domain health returns no_servers when no servers configured.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ diff --git a/tests/unit/tools/test_servers.py b/tests/unit/tools/test_servers.py index 26887d3..f259ac5 100644 --- a/tests/unit/tools/test_servers.py +++ b/tests/unit/tools/test_servers.py @@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock import pytest from haproxy_mcp.exceptions import HaproxyError +from haproxy_mcp.file_ops import add_domain_to_map, load_servers_config class TestHaproxyListServers: @@ -33,8 +34,7 @@ class TestHaproxyListServers: def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List servers for domain with no servers.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -63,8 +63,7 @@ class TestHaproxyListServers: def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List servers with active servers.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -224,8 +223,7 @@ class TestHaproxyAddServer: def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Successfully add server.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "set server": "", @@ -258,8 +256,7 @@ class TestHaproxyAddServer: def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Auto-select slot when slot=0.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -413,8 +410,7 @@ class TestHaproxyAddServers: def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Successfully add multiple servers.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "set server": "", @@ -495,8 +491,7 @@ class TestHaproxyRemoveServer: def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Successfully remove server.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "set server": "", @@ -689,8 +684,7 @@ class TestHaproxyAddServersRollback: def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths): """Rollback only failed slots when HAProxy error occurs.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") # Mock configure_server_slot to fail on second slot call_count = [0] @@ -735,18 +729,16 @@ class TestHaproxyAddServersRollback: assert "slot 2" in result # Failed # Verify servers.json only has successfully added server - with open(patch_config_paths["servers_file"], "r") as f: - config = json.load(f) + config = load_servers_config() assert "example.com" in config assert "1" in config["example.com"] # Successfully added stays - assert "2" not in config["example.com"] # Failed one was rolled back + assert "2" not in config.get("example.com", {}) # Failed one was rolled back def test_add_servers_unexpected_error_rollback_only_successful( self, mock_socket_class, mock_select, patch_config_paths ): """Rollback only successfully added servers on unexpected error.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") # Track which servers were configured configured_slots = [] @@ -793,8 +785,7 @@ class TestHaproxyAddServersRollback: assert "Unexpected system error" in result # Verify servers.json is empty (all rolled back) - with open(patch_config_paths["servers_file"], "r") as f: - config = json.load(f) + config = load_servers_config() assert config == {} or "example.com" not in config or config.get("example.com") == {} def test_add_servers_rollback_failure_logged( @@ -802,8 +793,7 @@ class TestHaproxyAddServersRollback: ): """Log rollback failures during error recovery.""" import logging - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): if slot == 2: @@ -858,8 +848,7 @@ class TestHaproxyAddServerAutoSlot: def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Auto-select slot fails when all slots are in use.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") # Build response with all 10 slots used servers = [] @@ -902,8 +891,7 @@ class TestHaproxyAddServerAutoSlot: def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Negative slot number triggers auto-selection.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -970,8 +958,7 @@ class TestHaproxyWaitDrain: def test_wait_drain_success(self, patch_config_paths): """Successfully wait for connections to drain.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") # Mock haproxy_cmd to return 0 connections with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd: @@ -1000,8 +987,7 @@ class TestHaproxyWaitDrain: def test_wait_drain_timeout(self, patch_config_paths): """Timeout when connections don't drain.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing time_iter = iter(time_values) @@ -1087,9 +1073,6 @@ class TestHaproxyWaitDrain: def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths): """Error when domain not found in map.""" # Empty map file - domain not configured - with open(patch_config_paths["map_file"], "w") as f: - f.write("") - from haproxy_mcp.tools.servers import register_server_tools mcp = MagicMock() registered_tools = {} @@ -1202,8 +1185,7 @@ class TestHaproxySetDomainState: def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Set all servers of a domain to a state.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ @@ -1283,8 +1265,7 @@ class TestHaproxySetDomainState: def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """No active servers found for domain.""" - with open(patch_config_paths["map_file"], "w") as f: - f.write("example.com pool_1\n") + add_domain_to_map("example.com", "pool_1") # All servers have 0.0.0.0 address (not configured) mock_sock = mock_socket_class(responses={ @@ -1318,9 +1299,6 @@ class TestHaproxySetDomainState: def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Handle domain not found in map - shows no active servers.""" # Empty map file - with open(patch_config_paths["map_file"], "w") as f: - f.write("") - # Mock should show no servers for unknown domain's backend mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([]),