refactor: migrate data storage from JSON/map files to SQLite

Replace servers.json, certificates.json, and map file parsing with
SQLite (WAL mode) as single source of truth. HAProxy map files are
now generated from SQLite via sync_map_files().

Key changes:
- Add db.py with schema, connection management, and JSON migration
- Add DB_FILE config constant
- Delegate file_ops.py functions to db.py
- Refactor domains.py to use file_ops instead of direct list manipulation
- Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError)
- Add DB health check in health.py
- Init DB on startup in server.py and __main__.py
- Update all 359 tests to use SQLite-backed functions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kappa
2026-02-08 11:07:29 +09:00
parent 05bff61b85
commit cf554f3f89
19 changed files with 1525 additions and 564 deletions

3
.gitignore vendored
View File

@@ -16,6 +16,9 @@ run/
data/ data/
*.state *.state
*.lock *.lock
*.db
*.db-wal
*.db-shm
# Python # Python
__pycache__/ __pycache__/

View File

@@ -1,8 +1,10 @@
"""Entry point for running haproxy_mcp as a module.""" """Entry point for running haproxy_mcp as a module."""
from .db import init_db
from .server import mcp from .server import mcp
from .tools.configuration import startup_restore from .tools.configuration import startup_restore
if __name__ == "__main__": if __name__ == "__main__":
init_db()
startup_restore() startup_restore()
mcp.run(transport="streamable-http") mcp.run(transport="streamable-http")

View File

@@ -31,6 +31,7 @@ WILDCARDS_MAP_FILE: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE", "/opt/haproxy/
WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map") WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map")
SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json") SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json")
CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json") CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json")
DB_FILE: str = os.getenv("HAPROXY_DB_FILE", "/opt/haproxy/conf/haproxy_mcp.db")
# Certificate paths # Certificate paths
CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs") CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs")

577
haproxy_mcp/db.py Normal file
View File

@@ -0,0 +1,577 @@
"""SQLite database operations for HAProxy MCP Server.
Single source of truth for all persistent configuration data.
Replaces JSON files (servers.json, certificates.json) and map files
(domains.map, wildcards.map) with a single SQLite database.
Map files are generated from the database via sync_map_files() for
HAProxy to consume directly.
"""
import json
import os
import sqlite3
import threading
from typing import Any, Optional
from .config import (
DB_FILE,
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
POOL_COUNT,
REMOTE_MODE,
logger,
)
SCHEMA_VERSION = 1
_local = threading.local()
def get_connection() -> sqlite3.Connection:
"""Get a thread-local SQLite connection with WAL mode.
Returns:
sqlite3.Connection configured for concurrent access.
"""
conn = getattr(_local, "connection", None)
if conn is not None:
return conn
conn = sqlite3.connect(DB_FILE, timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
_local.connection = conn
return conn
def close_connection() -> None:
"""Close the thread-local SQLite connection."""
conn = getattr(_local, "connection", None)
if conn is not None:
conn.close()
_local.connection = None
def init_db() -> None:
"""Initialize database schema and run migration if needed.
Creates tables if they don't exist, then checks for existing
JSON/map files to migrate data from.
"""
conn = get_connection()
cur = conn.cursor()
cur.executescript("""
CREATE TABLE IF NOT EXISTS domains (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
backend TEXT NOT NULL,
is_wildcard INTEGER NOT NULL DEFAULT 0,
shares_with TEXT DEFAULT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend);
CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with);
CREATE TABLE IF NOT EXISTS servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL,
slot INTEGER NOT NULL,
ip TEXT NOT NULL,
http_port INTEGER NOT NULL DEFAULT 80,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
UNIQUE(domain, slot)
);
CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain);
CREATE TABLE IF NOT EXISTS certificates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL,
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
""")
conn.commit()
# Check schema version
cur.execute("SELECT MAX(version) FROM schema_version")
row = cur.fetchone()
current_version = row[0] if row[0] is not None else 0
if current_version < SCHEMA_VERSION:
# First-time setup: try migrating from JSON files
if current_version == 0:
migrate_from_json()
cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,))
conn.commit()
logger.info("Database initialized (schema v%d)", SCHEMA_VERSION)
def migrate_from_json() -> None:
"""Migrate data from JSON/map files to SQLite.
Reads domains.map, wildcards.map, servers.json, and certificates.json,
imports their data into the database within a single transaction,
then renames the original files to .bak.
This function is idempotent (uses INSERT OR IGNORE).
"""
conn = get_connection()
# Collect data from existing files
map_entries: list[tuple[str, str]] = []
servers_config: dict[str, Any] = {}
cert_domains: list[str] = []
# Read map files
for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]:
try:
content = _read_file_for_migration(map_path)
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
map_entries.append((parts[0], parts[1]))
except FileNotFoundError:
logger.debug("Map file not found for migration: %s", map_path)
# Read servers.json
try:
content = _read_file_for_migration(SERVERS_FILE)
servers_config = json.loads(content)
except FileNotFoundError:
logger.debug("Servers config not found for migration: %s", SERVERS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt servers config during migration: %s", e)
# Read certificates.json
try:
content = _read_file_for_migration(CERTS_FILE)
data = json.loads(content)
cert_domains = data.get("domains", [])
except FileNotFoundError:
logger.debug("Certs config not found for migration: %s", CERTS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt certs config during migration: %s", e)
if not map_entries and not servers_config and not cert_domains:
logger.info("No existing data to migrate")
return
# Import into database in a single transaction
try:
with conn:
# Import domains from map files
for domain, backend in map_entries:
is_wildcard = 1 if domain.startswith(".") else 0
conn.execute(
"INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)",
(domain, backend, is_wildcard),
)
# Import servers and shares_with from servers.json
for domain, slots in servers_config.items():
shares_with = slots.get("_shares")
if shares_with:
# Update domain's shares_with field
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
continue
for slot_str, server_info in slots.items():
if slot_str.startswith("_"):
continue
try:
slot = int(slot_str)
ip = server_info.get("ip", "")
http_port = int(server_info.get("http_port", 80))
if ip:
conn.execute(
"INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)",
(domain, slot, ip, http_port),
)
except (ValueError, TypeError) as e:
logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e)
# Import certificates
for cert_domain in cert_domains:
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(cert_domain,),
)
migrated_domains = len(set(d for d, _ in map_entries))
migrated_servers = sum(
1 for d, slots in servers_config.items()
if "_shares" not in slots
for s in slots
if not s.startswith("_")
)
migrated_certs = len(cert_domains)
logger.info(
"Migrated from JSON: %d domain entries, %d servers, %d certificates",
migrated_domains, migrated_servers, migrated_certs,
)
# Rename original files to .bak
for path in [SERVERS_FILE, CERTS_FILE]:
_backup_file(path)
except sqlite3.Error as e:
logger.error("Migration failed: %s", e)
raise
def _read_file_for_migration(file_path: str) -> str:
"""Read a file for migration purposes (local only, no locking needed)."""
if REMOTE_MODE:
from .ssh_ops import remote_read_file
return remote_read_file(file_path)
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def _backup_file(file_path: str) -> None:
"""Rename a file to .bak if it exists."""
if REMOTE_MODE:
return # Don't rename remote files
try:
if os.path.exists(file_path):
bak_path = f"{file_path}.bak"
os.rename(file_path, bak_path)
logger.info("Backed up %s -> %s", file_path, bak_path)
except OSError as e:
logger.warning("Failed to backup %s: %s", file_path, e)
# --- Domain data access ---
def db_get_map_contents() -> list[tuple[str, str]]:
"""Get all domain-to-backend mappings.
Returns:
List of (domain, backend) tuples including wildcards.
"""
conn = get_connection()
cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain")
return [(row["domain"], row["backend"]) for row in cur.fetchall()]
def db_get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a specific domain.
Args:
domain: Domain name to look up.
Returns:
Backend name if found, None otherwise.
"""
conn = get_connection()
cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,))
row = cur.fetchone()
return row["backend"] if row else None
def db_add_domain(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to the database.
Args:
domain: Domain name (e.g., "example.com" or ".example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO domains (domain, backend, is_wildcard, shares_with)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain) DO UPDATE SET
backend = excluded.backend,
shares_with = excluded.shares_with,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, backend, 1 if is_wildcard else 0, shares_with),
)
conn.commit()
def db_remove_domain(domain: str) -> None:
"""Remove a domain (exact + wildcard) from the database.
Args:
domain: Base domain name (without leading dot).
"""
conn = get_connection()
conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}"))
conn.commit()
def db_find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
conn = get_connection()
cur = conn.execute(
"SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0"
)
used_pools = {row["backend"] for row in cur.fetchall()}
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def db_get_domains_sharing_pool(pool: str) -> list[str]:
"""Get all non-wildcard domains using a specific pool.
Args:
pool: Pool name (e.g., "pool_5").
Returns:
List of domain names.
"""
conn = get_connection()
cur = conn.execute(
"SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0",
(pool,),
)
return [row["domain"] for row in cur.fetchall()]
# --- Server data access ---
def db_load_servers_config() -> dict[str, Any]:
"""Load servers configuration in the legacy dict format.
Returns:
Dictionary compatible with the old servers.json format:
{"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}}
"""
conn = get_connection()
config: dict[str, Any] = {}
# Load server entries
cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot")
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain][str(row["slot"])] = {
"ip": row["ip"],
"http_port": row["http_port"],
}
# Load shared domain references
cur = conn.execute(
"SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0"
)
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain]["_shares"] = row["shares_with"]
return config
def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add or update a server entry.
Args:
domain: Domain name.
slot: Server slot number.
ip: Server IP address.
http_port: HTTP port.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO servers (domain, slot, ip, http_port)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain, slot) DO UPDATE SET
ip = excluded.ip,
http_port = excluded.http_port,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, slot, ip, http_port),
)
conn.commit()
def db_remove_server(domain: str, slot: int) -> None:
"""Remove a server entry.
Args:
domain: Domain name.
slot: Server slot number.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot))
conn.commit()
def db_remove_domain_servers(domain: str) -> None:
"""Remove all server entries for a domain.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ?", (domain,))
conn.commit()
def db_add_shared_domain(domain: str, shares_with: str) -> None:
"""Mark a domain as sharing a pool with another domain.
Updates the shares_with field in the domains table.
Args:
domain: The new domain.
shares_with: The existing domain to share with.
"""
conn = get_connection()
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
conn.commit()
def db_get_shared_domain(domain: str) -> Optional[str]:
"""Get the parent domain that this domain shares a pool with.
Args:
domain: Domain name.
Returns:
Parent domain name if sharing, None otherwise.
"""
conn = get_connection()
cur = conn.execute(
"SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0",
(domain,),
)
row = cur.fetchone()
if row and row["shares_with"]:
return row["shares_with"]
return None
def db_is_shared_domain(domain: str) -> bool:
"""Check if a domain is sharing another domain's pool.
Args:
domain: Domain name.
Returns:
True if domain has shares_with reference.
"""
return db_get_shared_domain(domain) is not None
# --- Certificate data access ---
def db_load_certs() -> list[str]:
"""Load certificate domain list.
Returns:
Sorted list of domain names with certificates.
"""
conn = get_connection()
cur = conn.execute("SELECT domain FROM certificates ORDER BY domain")
return [row["domain"] for row in cur.fetchall()]
def db_add_cert(domain: str) -> None:
"""Add a domain to the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(domain,),
)
conn.commit()
def db_remove_cert(domain: str) -> None:
"""Remove a domain from the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,))
conn.commit()
# --- Map file synchronization ---
def sync_map_files() -> None:
"""Regenerate HAProxy map files from database.
Writes domains.map (exact matches) and wildcards.map (wildcard entries)
from the current database state. Uses atomic writes.
"""
from .file_ops import atomic_write_file
conn = get_connection()
# Fetch exact domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain"
)
exact_entries = cur.fetchall()
# Fetch wildcard domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain"
)
wildcard_entries = cur.fetchall()
# Write exact domains map
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for row in exact_entries:
exact_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Write wildcards map
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for row in wildcard_entries:
wildcard_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
logger.debug("Map files synced: %d exact, %d wildcard entries",
len(exact_entries), len(wildcard_entries))

View File

@@ -1,7 +1,11 @@
"""File I/O operations for HAProxy MCP Server.""" """File I/O operations for HAProxy MCP Server.
Most data access is now delegated to db.py (SQLite).
This module retains atomic file writes, map file I/O for HAProxy,
and provides backward-compatible function signatures.
"""
import fcntl import fcntl
import json
import os import os
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
@@ -10,8 +14,6 @@ from typing import Any, Generator, Optional
from .config import ( from .config import (
MAP_FILE, MAP_FILE,
WILDCARDS_MAP_FILE, WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
REMOTE_MODE, REMOTE_MODE,
logger, logger,
) )
@@ -138,16 +140,13 @@ def _read_file(file_path: str) -> str:
def get_map_contents() -> list[tuple[str, str]]: def get_map_contents() -> list[tuple[str, str]]:
"""Read both domains.map and wildcards.map and return combined entries. """Get all domain-to-backend mappings from SQLite.
Returns: Returns:
List of (domain, backend) tuples from both map files List of (domain, backend) tuples including wildcards.
""" """
# Read exact domains from .db import db_get_map_contents
entries = _read_map_file(MAP_FILE) return db_get_map_contents()
# Read wildcards and append
entries.extend(_read_map_file(WILDCARDS_MAP_FILE))
return entries
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]: def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
@@ -170,44 +169,21 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str
def save_map_file(entries: list[tuple[str, str]]) -> None: def save_map_file(entries: list[tuple[str, str]]) -> None:
"""Save domain-to-backend entries to map files. """Sync map files from the database.
Splits entries into two files for 2-stage routing: Regenerates domains.map and wildcards.map from the current
- domains.map: Exact matches (map_str, O(log n)) database state. The entries parameter is ignored (kept for
- wildcards.map: Wildcard entries starting with "." (map_dom, O(n)) backward compatibility during transition).
Args:
entries: List of (domain, backend) tuples.
Raises: Raises:
IOError: If map files cannot be written. IOError: If map files cannot be written.
""" """
# Split into exact and wildcard entries from .db import sync_map_files
exact_entries, wildcard_entries = split_domain_entries(entries) sync_map_files()
# Save exact domains (for map_str - fast O(log n) lookup)
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for domain, backend in sorted(exact_entries):
exact_lines.append(f"{domain} {backend}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Save wildcards (for map_dom - O(n) but small set)
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for domain, backend in sorted(wildcard_entries):
wildcard_lines.append(f"{domain} {backend}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
def get_domain_backend(domain: str) -> Optional[str]: def get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a domain from domains.map. """Look up the backend for a domain from SQLite (O(1)).
Args: Args:
domain: The domain to look up domain: The domain to look up
@@ -215,10 +191,8 @@ def get_domain_backend(domain: str) -> Optional[str]:
Returns: Returns:
Backend name if found, None otherwise Backend name if found, None otherwise
""" """
for map_domain, backend in get_map_contents(): from .db import db_get_domain_backend
if map_domain == domain: return db_get_domain_backend(domain)
return backend
return None
def is_legacy_backend(backend: str) -> bool: def is_legacy_backend(backend: str) -> bool:
@@ -273,34 +247,17 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]:
def load_servers_config() -> dict[str, Any]: def load_servers_config() -> dict[str, Any]:
"""Load servers configuration from JSON file. """Load servers configuration from SQLite.
Returns: Returns:
Dictionary with server configurations Dictionary with server configurations (legacy format compatible).
""" """
try: from .db import db_load_servers_config
content = _read_file(SERVERS_FILE) return db_load_servers_config()
return json.loads(content)
except FileNotFoundError:
return {}
except json.JSONDecodeError as e:
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
return {}
def save_servers_config(config: dict[str, Any]) -> None:
"""Save servers configuration to JSON file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args:
config: Dictionary with server configurations
"""
atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2))
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None: def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add server configuration to persistent storage with file locking. """Add server configuration to persistent storage.
Args: Args:
domain: Domain name domain: Domain name
@@ -308,41 +265,29 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non
ip: Server IP address ip: Server IP address
http_port: HTTP port http_port: HTTP port
""" """
with file_lock(f"{SERVERS_FILE}.lock"): from .db import db_add_server
config = load_servers_config() db_add_server(domain, slot, ip, http_port)
if domain not in config:
config[domain] = {}
config[domain][str(slot)] = {"ip": ip, "http_port": http_port}
save_servers_config(config)
def remove_server_from_config(domain: str, slot: int) -> None: def remove_server_from_config(domain: str, slot: int) -> None:
"""Remove server configuration from persistent storage with file locking. """Remove server configuration from persistent storage.
Args: Args:
domain: Domain name domain: Domain name
slot: Server slot to remove slot: Server slot to remove
""" """
with file_lock(f"{SERVERS_FILE}.lock"): from .db import db_remove_server
config = load_servers_config() db_remove_server(domain, slot)
if domain in config and str(slot) in config[domain]:
del config[domain][str(slot)]
if not config[domain]:
del config[domain]
save_servers_config(config)
def remove_domain_from_config(domain: str) -> None: def remove_domain_from_config(domain: str) -> None:
"""Remove domain from persistent config with file locking. """Remove domain from persistent config (servers + domain entry).
Args: Args:
domain: Domain name to remove domain: Domain name to remove
""" """
with file_lock(f"{SERVERS_FILE}.lock"): from .db import db_remove_domain_servers
config = load_servers_config() db_remove_domain_servers(domain)
if domain in config:
del config[domain]
save_servers_config(config)
def get_shared_domain(domain: str) -> Optional[str]: def get_shared_domain(domain: str) -> Optional[str]:
@@ -354,9 +299,8 @@ def get_shared_domain(domain: str) -> Optional[str]:
Returns: Returns:
Parent domain name if sharing, None otherwise Parent domain name if sharing, None otherwise
""" """
config = load_servers_config() from .db import db_get_shared_domain
domain_config = config.get(domain, {}) return db_get_shared_domain(domain)
return domain_config.get("_shares")
def add_shared_domain_to_config(domain: str, shares_with: str) -> None: def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
@@ -366,10 +310,8 @@ def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
domain: New domain name domain: New domain name
shares_with: Existing domain to share pool with shares_with: Existing domain to share pool with
""" """
with file_lock(f"{SERVERS_FILE}.lock"): from .db import db_add_shared_domain
config = load_servers_config() db_add_shared_domain(domain, shares_with)
config[domain] = {"_shares": shares_with}
save_servers_config(config)
def get_domains_sharing_pool(pool: str) -> list[str]: def get_domains_sharing_pool(pool: str) -> list[str]:
@@ -381,11 +323,8 @@ def get_domains_sharing_pool(pool: str) -> list[str]:
Returns: Returns:
List of domain names using this pool List of domain names using this pool
""" """
domains = [] from .db import db_get_domains_sharing_pool
for domain, backend in get_map_contents(): return db_get_domains_sharing_pool(pool)
if backend == pool and not domain.startswith("."):
domains.append(domain)
return domains
def is_shared_domain(domain: str) -> bool: def is_shared_domain(domain: str) -> bool:
@@ -397,37 +336,20 @@ def is_shared_domain(domain: str) -> bool:
Returns: Returns:
True if domain has _shares reference, False otherwise True if domain has _shares reference, False otherwise
""" """
config = load_servers_config() from .db import db_is_shared_domain
domain_config = config.get(domain, {}) return db_is_shared_domain(domain)
return "_shares" in domain_config
# Certificate configuration functions # Certificate configuration functions
def load_certs_config() -> list[str]: def load_certs_config() -> list[str]:
"""Load certificate domain list from JSON file. """Load certificate domain list from SQLite.
Returns: Returns:
List of domain names Sorted list of domain names.
""" """
try: from .db import db_load_certs
content = _read_file(CERTS_FILE) return db_load_certs()
data = json.loads(content)
return data.get("domains", [])
except FileNotFoundError:
return []
except json.JSONDecodeError as e:
logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e)
return []
def save_certs_config(domains: list[str]) -> None:
"""Save certificate domain list to JSON file atomically.
Args:
domains: List of domain names
"""
atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2))
def add_cert_to_config(domain: str) -> None: def add_cert_to_config(domain: str) -> None:
@@ -436,11 +358,8 @@ def add_cert_to_config(domain: str) -> None:
Args: Args:
domain: Domain name to add domain: Domain name to add
""" """
with file_lock(f"{CERTS_FILE}.lock"): from .db import db_add_cert
domains = load_certs_config() db_add_cert(domain)
if domain not in domains:
domains.append(domain)
save_certs_config(domains)
def remove_cert_from_config(domain: str) -> None: def remove_cert_from_config(domain: str) -> None:
@@ -449,8 +368,45 @@ def remove_cert_from_config(domain: str) -> None:
Args: Args:
domain: Domain name to remove domain: Domain name to remove
""" """
with file_lock(f"{CERTS_FILE}.lock"): from .db import db_remove_cert
domains = load_certs_config() db_remove_cert(domain)
if domain in domains:
domains.remove(domain)
save_certs_config(domains) # Domain map helper functions (used by domains.py)
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to SQLite and sync map files.
Args:
domain: Domain name (e.g., "example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
from .db import db_add_domain, sync_map_files
db_add_domain(domain, backend, is_wildcard, shares_with)
sync_map_files()
def remove_domain_from_map(domain: str) -> None:
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
Args:
domain: Base domain name (without leading dot).
"""
from .db import db_remove_domain, sync_map_files
db_remove_domain(domain)
sync_map_files()
def find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
from .db import db_find_available_pool
return db_find_available_pool()

View File

@@ -2,6 +2,7 @@
import socket import socket
import select import select
import subprocess
import time import time
from .config import ( from .config import (
@@ -161,7 +162,7 @@ def reload_haproxy() -> tuple[bool, str]:
if result.returncode != 0: if result.returncode != 0:
return False, f"Reload failed: {result.stderr}" return False, f"Reload failed: {result.stderr}"
return True, "OK" return True, "OK"
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds" return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds"
except FileNotFoundError: except FileNotFoundError:
return False, "ssh/podman command not found" return False, "ssh/podman command not found"

View File

@@ -21,6 +21,7 @@ Environment Variables:
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from .config import MCP_HOST, MCP_PORT from .config import MCP_HOST, MCP_PORT
from .db import init_db
from .tools import register_all_tools from .tools import register_all_tools
from .tools.configuration import startup_restore from .tools.configuration import startup_restore
@@ -32,5 +33,6 @@ register_all_tools(mcp)
if __name__ == "__main__": if __name__ == "__main__":
init_db()
startup_restore() startup_restore()
mcp.run(transport="streamable-http") mcp.run(transport="streamable-http")

View File

@@ -1,6 +1,7 @@
"""Certificate management tools for HAProxy MCP Server.""" """Certificate management tools for HAProxy MCP Server."""
import os import os
import subprocess
from datetime import datetime from datetime import datetime
from typing import Annotated from typing import Annotated
@@ -152,7 +153,7 @@ def _haproxy_list_certs_impl() -> str:
certs.append(f"{domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}") certs.append(f"{domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
return "\n\n".join(certs) if certs else "No certificates found" return "\n\n".join(certs) if certs else "No certificates found"
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Command timed out" return "Error: Command timed out"
except FileNotFoundError: except FileNotFoundError:
return "Error: acme.sh not found" return "Error: acme.sh not found"
@@ -203,7 +204,7 @@ def _haproxy_cert_info_impl(domain: str) -> str:
result.stdout.strip() result.stdout.strip()
] ]
return "\n".join(info) return "\n".join(info)
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Command timed out" return "Error: Command timed out"
except OSError as e: except OSError as e:
logger.error("Error getting certificate info for %s: %s", domain, e) logger.error("Error getting certificate info for %s: %s", domain, e)
@@ -250,7 +251,7 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
else: else:
return f"Certificate issued but PEM file not created. Check {host_path}" return f"Certificate issued but PEM file not created. Check {host_path}"
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s" return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except OSError as e: except OSError as e:
logger.error("Error issuing certificate for %s: %s", domain, e) logger.error("Error issuing certificate for %s: %s", domain, e)
@@ -289,7 +290,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
else: else:
return f"Error renewing certificate:\n{output}" return f"Error renewing certificate:\n{output}"
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s" return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except FileNotFoundError: except FileNotFoundError:
return "Error: acme.sh not found" return "Error: acme.sh not found"
@@ -323,7 +324,7 @@ def _haproxy_renew_all_certs_impl() -> str:
else: else:
return "Renewal check completed" return "Renewal check completed"
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Renewal cron timed out" return "Error: Renewal cron timed out"
except FileNotFoundError: except FileNotFoundError:
return "Error: acme.sh not found" return "Error: acme.sh not found"

View File

@@ -1,5 +1,6 @@
"""Configuration management tools for HAProxy MCP Server.""" """Configuration management tools for HAProxy MCP Server."""
import subprocess
import time import time
from ..config import ( from ..config import (
@@ -177,7 +178,7 @@ def register_config_tools(mcp):
if result.returncode == 0: if result.returncode == 0:
return "Configuration is valid" return "Configuration is valid"
return f"Configuration errors:\n{result.stderr}" return f"Configuration errors:\n{result.stderr}"
except TimeoutError: except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds" return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds"
except FileNotFoundError: except FileNotFoundError:
return "Error: ssh/podman command not found" return "Error: ssh/podman command not found"

View File

@@ -6,7 +6,6 @@ from typing import Annotated, Optional
from pydantic import Field from pydantic import Field
from ..config import ( from ..config import (
MAP_FILE,
MAP_FILE_CONTAINER, MAP_FILE_CONTAINER,
WILDCARDS_MAP_FILE_CONTAINER, WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT, POOL_COUNT,
@@ -22,7 +21,6 @@ from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd from ..haproxy_client import haproxy_cmd
from ..file_ops import ( from ..file_ops import (
get_map_contents, get_map_contents,
save_map_file,
get_domain_backend, get_domain_backend,
is_legacy_backend, is_legacy_backend,
add_server_to_config, add_server_to_config,
@@ -31,30 +29,13 @@ from ..file_ops import (
add_shared_domain_to_config, add_shared_domain_to_config,
get_domains_sharing_pool, get_domains_sharing_pool,
is_shared_domain, is_shared_domain,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
) )
from ..utils import parse_servers_state, disable_server_slot from ..utils import parse_servers_state, disable_server_slot
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
"""Find an available pool backend from the pool list.
Iterates through pool_1 to pool_N and returns the first pool
that is not currently in use.
Args:
entries: List of (domain, backend) tuples from the map file.
used_pools: Set of pool names already in use.
Returns:
Available pool name (e.g., "pool_5") or None if all pools are in use.
"""
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]: def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
"""Check if a domain is a subdomain of an existing registered domain. """Check if a domain is a subdomain of an existing registered domain.
@@ -95,24 +76,19 @@ def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}") haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
def _rollback_domain_addition( def _rollback_domain_addition(domain: str) -> None:
domain: str, """Rollback a failed domain addition by removing from SQLite + map files.
entries: list[tuple[str, str]]
) -> None:
"""Rollback a failed domain addition by removing entries from map file.
Called when HAProxy Runtime API update fails after the map file Called when HAProxy Runtime API update fails after the domain
has already been saved. has already been saved to the database.
Args: Args:
domain: Domain name that was added. domain: Domain name that was added.
entries: Current list of map entries to rollback from.
""" """
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
try: try:
save_map_file(rollback_entries) remove_domain_from_map(domain)
except IOError: except (IOError, Exception):
logger.error("Failed to rollback map file after HAProxy error") logger.error("Failed to rollback domain %s after HAProxy error", domain)
def _file_exists(path: str) -> bool: def _file_exists(path: str) -> bool:
@@ -242,23 +218,17 @@ def register_domain_tools(mcp):
if share_with and ip: if share_with and ip:
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)" return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
# Use file locking for the entire pool allocation operation # Read current entries for existence check and subdomain detection
from ..file_ops import file_lock
with file_lock(f"{MAP_FILE}.lock"):
# Read map contents once for both existence check and pool lookup
entries = get_map_contents() entries = get_map_contents()
# Check if domain already exists (using cached entries) # Check if domain already exists
for domain_entry, backend in entries: for domain_entry, backend in entries:
if domain_entry == domain: if domain_entry == domain:
return f"Error: Domain {domain} already exists (mapped to {backend})" return f"Error: Domain {domain} already exists (mapped to {backend})"
# Build used pools and registered domains sets # Build registered domains set for subdomain check
used_pools: set[str] = set()
registered_domains: set[str] = set() registered_domains: set[str] = set()
for entry_domain, backend in entries: for entry_domain, _ in entries:
if backend.startswith("pool_"):
used_pools.add(backend)
if not entry_domain.startswith("."): if not entry_domain.startswith("."):
registered_domains.add(entry_domain) registered_domains.add(entry_domain)
@@ -271,8 +241,8 @@ def register_domain_tools(mcp):
return f"Error: Cannot share with legacy backend {share_backend}" return f"Error: Cannot share with legacy backend {share_backend}"
pool = share_backend pool = share_backend
else: else:
# Find available pool # Find available pool (SQLite query, O(1))
pool = _find_available_pool(entries, used_pools) pool = find_available_pool()
if not pool: if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use" return f"Error: All {POOL_COUNT} pool backends are in use"
@@ -280,20 +250,19 @@ def register_domain_tools(mcp):
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains) is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
try: try:
# Save to disk first (atomic write for persistence) # Save to SQLite + sync map files (atomic via SQLite transaction)
entries.append((domain, pool))
if not is_subdomain:
entries.append((f".{domain}", pool))
try: try:
save_map_file(entries) add_domain_to_map(domain, pool)
except IOError as e: if not is_subdomain:
return f"Error: Failed to save map file: {e}" add_domain_to_map(f".{domain}", pool, is_wildcard=True)
except (IOError, Exception) as e:
return f"Error: Failed to save domain: {e}"
# Update HAProxy maps via Runtime API # Update HAProxy maps via Runtime API
try: try:
_update_haproxy_maps(domain, pool, is_subdomain) _update_haproxy_maps(domain, pool, is_subdomain)
except HaproxyError as e: except HaproxyError as e:
_rollback_domain_addition(domain, entries) _rollback_domain_addition(domain)
return f"Error: Failed to update HAProxy map: {e}" return f"Error: Failed to update HAProxy map: {e}"
# Handle server configuration based on mode # Handle server configuration based on mode
@@ -355,10 +324,8 @@ def register_domain_tools(mcp):
domains_using_pool = get_domains_sharing_pool(backend) domains_using_pool = get_domains_sharing_pool(backend)
other_domains = [d for d in domains_using_pool if d != domain] other_domains = [d for d in domains_using_pool if d != domain]
# Save to disk first (atomic write for persistence) # Remove from SQLite + sync map files
entries = get_map_contents() remove_domain_from_map(domain)
new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
save_map_file(new_entries)
# Remove from persistent server config # Remove from persistent server config
remove_domain_from_config(domain) remove_domain_from_config(domain)

View File

@@ -10,6 +10,7 @@ from pydantic import Field
from ..config import ( from ..config import (
MAP_FILE, MAP_FILE,
SERVERS_FILE, SERVERS_FILE,
DB_FILE,
HAPROXY_CONTAINER, HAPROXY_CONTAINER,
) )
from ..exceptions import HaproxyError from ..exceptions import HaproxyError
@@ -88,7 +89,7 @@ def register_health_tools(mcp):
# Check configuration files # Check configuration files
files_ok = True files_ok = True
file_status: dict[str, str] = {} file_status: dict[str, str] = {}
for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]: for name, path in [("map_file", MAP_FILE), ("db_file", DB_FILE)]:
exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path) exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path)
if exists: if exists:
file_status[name] = "ok" file_status[name] = "ok"

View File

@@ -255,6 +255,8 @@ def temp_config_dir(tmp_path):
state_file = tmp_path / "servers.state" state_file = tmp_path / "servers.state"
state_file.write_text("") state_file.write_text("")
db_file = tmp_path / "haproxy_mcp.db"
return { return {
"dir": tmp_path, "dir": tmp_path,
"map_file": str(map_file), "map_file": str(map_file),
@@ -262,12 +264,15 @@ def temp_config_dir(tmp_path):
"servers_file": str(servers_file), "servers_file": str(servers_file),
"certs_file": str(certs_file), "certs_file": str(certs_file),
"state_file": str(state_file), "state_file": str(state_file),
"db_file": str(db_file),
} }
@pytest.fixture @pytest.fixture
def patch_config_paths(temp_config_dir): def patch_config_paths(temp_config_dir):
"""Fixture that patches config module paths to use temporary directory.""" """Fixture that patches config module paths to use temporary directory."""
from haproxy_mcp.db import close_connection, init_db
with patch.multiple( with patch.multiple(
"haproxy_mcp.config", "haproxy_mcp.config",
MAP_FILE=temp_config_dir["map_file"], MAP_FILE=temp_config_dir["map_file"],
@@ -275,16 +280,34 @@ def patch_config_paths(temp_config_dir):
SERVERS_FILE=temp_config_dir["servers_file"], SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"], CERTS_FILE=temp_config_dir["certs_file"],
STATE_FILE=temp_config_dir["state_file"], STATE_FILE=temp_config_dir["state_file"],
DB_FILE=temp_config_dir["db_file"],
): ):
# Also patch file_ops module which imports these # Also patch file_ops module which imports these
with patch.multiple( with patch.multiple(
"haproxy_mcp.file_ops", "haproxy_mcp.file_ops",
MAP_FILE=temp_config_dir["map_file"], MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
):
# Patch db module which imports these
with patch.multiple(
"haproxy_mcp.db",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
SERVERS_FILE=temp_config_dir["servers_file"], SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"], CERTS_FILE=temp_config_dir["certs_file"],
DB_FILE=temp_config_dir["db_file"],
): ):
# Patch health module which imports MAP_FILE and DB_FILE
with patch.multiple(
"haproxy_mcp.tools.health",
MAP_FILE=temp_config_dir["map_file"],
DB_FILE=temp_config_dir["db_file"],
):
# Close any existing connection and initialize fresh DB
close_connection()
init_db()
yield temp_config_dir yield temp_config_dir
close_connection()
@pytest.fixture @pytest.fixture

433
tests/unit/test_db.py Normal file
View File

@@ -0,0 +1,433 @@
"""Unit tests for db module (SQLite database operations)."""
import json
import os
import sqlite3
from unittest.mock import patch
import pytest
from haproxy_mcp.db import (
get_connection,
close_connection,
init_db,
migrate_from_json,
db_get_map_contents,
db_get_domain_backend,
db_add_domain,
db_remove_domain,
db_find_available_pool,
db_get_domains_sharing_pool,
db_load_servers_config,
db_add_server,
db_remove_server,
db_remove_domain_servers,
db_add_shared_domain,
db_get_shared_domain,
db_is_shared_domain,
db_load_certs,
db_add_cert,
db_remove_cert,
sync_map_files,
SCHEMA_VERSION,
)
class TestConnectionManagement:
"""Tests for database connection management."""
def test_get_connection(self, patch_config_paths):
"""Get a connection returns a valid SQLite connection."""
conn = get_connection()
assert conn is not None
# Verify WAL mode
result = conn.execute("PRAGMA journal_mode").fetchone()
assert result[0] == "wal"
def test_connection_is_thread_local(self, patch_config_paths):
"""Same thread gets same connection."""
conn1 = get_connection()
conn2 = get_connection()
assert conn1 is conn2
def test_close_connection(self, patch_config_paths):
"""Close connection clears thread-local."""
conn1 = get_connection()
close_connection()
conn2 = get_connection()
assert conn1 is not conn2
class TestInitDb:
"""Tests for database initialization."""
def test_creates_tables(self, patch_config_paths):
"""init_db creates all required tables."""
conn = get_connection()
# Check tables exist
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
).fetchall()
table_names = [t["name"] for t in tables]
assert "domains" in table_names
assert "servers" in table_names
assert "certificates" in table_names
assert "schema_version" in table_names
def test_schema_version_recorded(self, patch_config_paths):
"""Schema version is recorded."""
conn = get_connection()
cur = conn.execute("SELECT MAX(version) FROM schema_version")
version = cur.fetchone()[0]
assert version == SCHEMA_VERSION
def test_idempotent(self, patch_config_paths):
"""Calling init_db twice is safe."""
# init_db is already called by patch_config_paths
# Calling it again should not raise
init_db()
conn = get_connection()
cur = conn.execute("SELECT COUNT(*) FROM schema_version")
# May have 1 or 2 entries but should not fail
assert cur.fetchone()[0] >= 1
class TestMigrateFromJson:
"""Tests for JSON to SQLite migration."""
def test_migrate_map_files(self, patch_config_paths):
"""Migrate domain entries from map files."""
# Write test map files
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
migrate_from_json()
entries = db_get_map_contents()
assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries
assert (".example.com", "pool_1") in entries
def test_migrate_servers_json(self, patch_config_paths):
"""Migrate server entries from servers.json."""
config = {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": 80},
"2": {"ip": "10.0.0.2", "http_port": 8080},
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
migrate_from_json()
result = db_load_servers_config()
assert "example.com" in result
assert result["example.com"]["1"]["ip"] == "10.0.0.1"
assert result["example.com"]["2"]["http_port"] == 8080
def test_migrate_shared_domains(self, patch_config_paths):
"""Migrate shared domain references."""
# First add the domain to map so it exists in DB
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("www.example.com pool_1\n")
config = {
"www.example.com": {"_shares": "example.com"},
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
migrate_from_json()
assert db_get_shared_domain("www.example.com") == "example.com"
def test_migrate_certificates(self, patch_config_paths):
"""Migrate certificate entries."""
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "api.example.com"]}, f)
migrate_from_json()
certs = db_load_certs()
assert "example.com" in certs
assert "api.example.com" in certs
def test_migrate_idempotent(self, patch_config_paths):
"""Migration is idempotent (INSERT OR IGNORE)."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["servers_file"], "w") as f:
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
migrate_from_json()
migrate_from_json() # Should not fail
entries = db_get_map_contents()
assert len([d for d, _ in entries if d == "example.com"]) == 1
def test_migrate_empty_files(self, patch_config_paths):
"""Migration with no existing data does nothing."""
os.unlink(patch_config_paths["map_file"])
os.unlink(patch_config_paths["wildcards_file"])
os.unlink(patch_config_paths["servers_file"])
os.unlink(patch_config_paths["certs_file"])
migrate_from_json() # Should not fail
def test_backup_files_after_migration(self, patch_config_paths):
"""Original JSON files are backed up after migration."""
with open(patch_config_paths["servers_file"], "w") as f:
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com"]}, f)
# Write map file so migration has data
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
migrate_from_json()
assert os.path.exists(f"{patch_config_paths['servers_file']}.bak")
assert os.path.exists(f"{patch_config_paths['certs_file']}.bak")
class TestDomainOperations:
"""Tests for domain data access functions."""
def test_add_and_get_domain(self, patch_config_paths):
"""Add domain and retrieve its backend."""
db_add_domain("example.com", "pool_1")
assert db_get_domain_backend("example.com") == "pool_1"
def test_get_nonexistent_domain(self, patch_config_paths):
"""Non-existent domain returns None."""
assert db_get_domain_backend("nonexistent.com") is None
def test_remove_domain(self, patch_config_paths):
"""Remove domain removes both exact and wildcard."""
db_add_domain("example.com", "pool_1")
db_add_domain(".example.com", "pool_1", is_wildcard=True)
db_remove_domain("example.com")
assert db_get_domain_backend("example.com") is None
entries = db_get_map_contents()
assert (".example.com", "pool_1") not in entries
def test_find_available_pool(self, patch_config_paths):
"""Find first available pool."""
db_add_domain("a.com", "pool_1")
db_add_domain("b.com", "pool_2")
pool = db_find_available_pool()
assert pool == "pool_3"
def test_find_available_pool_empty(self, patch_config_paths):
"""First pool available when none used."""
pool = db_find_available_pool()
assert pool == "pool_1"
def test_get_domains_sharing_pool(self, patch_config_paths):
"""Get non-wildcard domains using a pool."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_domain(".example.com", "pool_1", is_wildcard=True)
domains = db_get_domains_sharing_pool("pool_1")
assert "example.com" in domains
assert "www.example.com" in domains
assert ".example.com" not in domains
def test_update_domain_backend(self, patch_config_paths):
"""Updating a domain changes its backend."""
db_add_domain("example.com", "pool_1")
db_add_domain("example.com", "pool_5") # Update
assert db_get_domain_backend("example.com") == "pool_5"
class TestServerOperations:
"""Tests for server data access functions."""
def test_add_and_load_server(self, patch_config_paths):
"""Add server and load config."""
db_add_server("example.com", 1, "10.0.0.1", 80)
config = db_load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["1"]["http_port"] == 80
def test_update_server(self, patch_config_paths):
"""Update existing server slot."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 1, "10.0.0.99", 8080)
config = db_load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
assert config["example.com"]["1"]["http_port"] == 8080
def test_remove_server(self, patch_config_paths):
"""Remove a server slot."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 2, "10.0.0.2", 80)
db_remove_server("example.com", 1)
config = db_load_servers_config()
assert "1" not in config.get("example.com", {})
assert "2" in config["example.com"]
def test_remove_domain_servers(self, patch_config_paths):
"""Remove all servers for a domain."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 2, "10.0.0.2", 80)
db_add_server("other.com", 1, "10.0.0.3", 80)
db_remove_domain_servers("example.com")
config = db_load_servers_config()
assert config.get("example.com", {}).get("1") is None
assert "other.com" in config
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty dict."""
config = db_load_servers_config()
assert config == {}
class TestSharedDomainOperations:
"""Tests for shared domain functions."""
def test_add_and_get_shared(self, patch_config_paths):
"""Add shared domain reference."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
assert db_get_shared_domain("www.example.com") == "example.com"
def test_is_shared(self, patch_config_paths):
"""Check if domain is shared."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
assert db_is_shared_domain("www.example.com") is True
assert db_is_shared_domain("example.com") is False
def test_not_shared(self, patch_config_paths):
"""Non-shared domain returns None."""
db_add_domain("example.com", "pool_1")
assert db_get_shared_domain("example.com") is None
assert db_is_shared_domain("example.com") is False
def test_shared_in_load_config(self, patch_config_paths):
"""Shared domain appears in load_servers_config."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
config = db_load_servers_config()
assert config["www.example.com"]["_shares"] == "example.com"
class TestCertificateOperations:
"""Tests for certificate data access functions."""
def test_add_and_load_cert(self, patch_config_paths):
"""Add and load certificate."""
db_add_cert("example.com")
certs = db_load_certs()
assert "example.com" in certs
def test_add_duplicate(self, patch_config_paths):
"""Adding duplicate cert is a no-op."""
db_add_cert("example.com")
db_add_cert("example.com")
certs = db_load_certs()
assert certs.count("example.com") == 1
def test_remove_cert(self, patch_config_paths):
"""Remove a certificate."""
db_add_cert("example.com")
db_add_cert("other.com")
db_remove_cert("example.com")
certs = db_load_certs()
assert "example.com" not in certs
assert "other.com" in certs
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty list."""
certs = db_load_certs()
assert certs == []
def test_sorted_output(self, patch_config_paths):
"""Certificates are returned sorted."""
db_add_cert("z.com")
db_add_cert("a.com")
db_add_cert("m.com")
certs = db_load_certs()
assert certs == ["a.com", "m.com", "z.com"]
class TestSyncMapFiles:
"""Tests for sync_map_files function."""
def test_sync_exact_domains(self, patch_config_paths):
"""Sync writes exact domains to domains.map."""
db_add_domain("example.com", "pool_1")
db_add_domain("api.example.com", "pool_2")
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "example.com pool_1" in content
assert "api.example.com pool_2" in content
def test_sync_wildcards(self, patch_config_paths):
"""Sync writes wildcards to wildcards.map."""
db_add_domain(".example.com", "pool_1", is_wildcard=True)
sync_map_files()
with open(patch_config_paths["wildcards_file"]) as f:
content = f.read()
assert ".example.com pool_1" in content
def test_sync_empty(self, patch_config_paths):
"""Sync with no domains writes headers only."""
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "Exact Domain" in content
# No domain entries
lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")]
assert len(lines) == 0
def test_sync_sorted(self, patch_config_paths):
"""Sync output is sorted."""
db_add_domain("z.com", "pool_3")
db_add_domain("a.com", "pool_1")
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
assert lines[0] == "a.com pool_1"
assert lines[1] == "z.com pool_3"

View File

@@ -1,6 +1,5 @@
"""Unit tests for file_ops module.""" """Unit tests for file_ops module (SQLite-backed)."""
import json
import os import os
from unittest.mock import patch from unittest.mock import patch
@@ -16,14 +15,19 @@ from haproxy_mcp.file_ops import (
get_legacy_backend_name, get_legacy_backend_name,
get_backend_and_prefix, get_backend_and_prefix,
load_servers_config, load_servers_config,
save_servers_config,
add_server_to_config, add_server_to_config,
remove_server_from_config, remove_server_from_config,
remove_domain_from_config, remove_domain_from_config,
load_certs_config, load_certs_config,
save_certs_config,
add_cert_to_config, add_cert_to_config,
remove_cert_from_config, remove_cert_from_config,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
add_shared_domain_to_config,
get_shared_domain,
is_shared_domain,
get_domains_sharing_pool,
) )
@@ -62,7 +66,7 @@ class TestAtomicWriteFile:
def test_unicode_content(self, tmp_path): def test_unicode_content(self, tmp_path):
"""Unicode content is properly written.""" """Unicode content is properly written."""
file_path = str(tmp_path / "unicode.txt") file_path = str(tmp_path / "unicode.txt")
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese content = "Hello, \u4e16\u754c!"
atomic_write_file(file_path, content) atomic_write_file(file_path, content)
@@ -81,66 +85,33 @@ class TestAtomicWriteFile:
class TestGetMapContents: class TestGetMapContents:
"""Tests for get_map_contents function.""" """Tests for get_map_contents function (SQLite-backed)."""
def test_read_map_file(self, patch_config_paths): def test_empty_db(self, patch_config_paths):
"""Read entries from map file.""" """Empty database returns empty list."""
# Write test content to map file entries = get_map_contents()
with open(patch_config_paths["map_file"], "w") as f: assert entries == []
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n") def test_read_domains(self, patch_config_paths):
"""Read entries from database."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("api.example.com", "pool_2")
entries = get_map_contents() entries = get_map_contents()
assert ("example.com", "pool_1") in entries assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries assert ("api.example.com", "pool_2") in entries
def test_read_both_map_files(self, patch_config_paths): def test_read_with_wildcards(self, patch_config_paths):
"""Read entries from both domains.map and wildcards.map.""" """Read entries including wildcards."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n") add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
entries = get_map_contents() entries = get_map_contents()
assert ("example.com", "pool_1") in entries assert ("example.com", "pool_1") in entries
assert (".example.com", "pool_1") in entries assert (".example.com", "pool_1") in entries
def test_skip_comments(self, patch_config_paths):
"""Comments are skipped."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("# This is a comment\n")
f.write("example.com pool_1\n")
f.write("# Another comment\n")
entries = get_map_contents()
assert len(entries) == 1
assert entries[0] == ("example.com", "pool_1")
def test_skip_empty_lines(self, patch_config_paths):
"""Empty lines are skipped."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("\n")
f.write("example.com pool_1\n")
f.write("\n")
f.write("api.example.com pool_2\n")
entries = get_map_contents()
assert len(entries) == 2
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty list."""
os.unlink(patch_config_paths["map_file"])
os.unlink(patch_config_paths["wildcards_file"])
entries = get_map_contents()
assert entries == []
class TestSplitDomainEntries: class TestSplitDomainEntries:
"""Tests for split_domain_entries function.""" """Tests for split_domain_entries function."""
@@ -182,36 +153,30 @@ class TestSplitDomainEntries:
class TestSaveMapFile: class TestSaveMapFile:
"""Tests for save_map_file function.""" """Tests for save_map_file function (syncs from DB to map files)."""
def test_save_entries(self, patch_config_paths): def test_save_entries(self, patch_config_paths):
"""Save entries to separate map files.""" """Save entries to separate map files."""
entries = [ add_domain_to_map("example.com", "pool_1")
("example.com", "pool_1"), add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
(".example.com", "pool_1"),
]
save_map_file(entries) save_map_file([]) # Entries param ignored, syncs from DB
# Check exact domains file
with open(patch_config_paths["map_file"]) as f: with open(patch_config_paths["map_file"]) as f:
content = f.read() content = f.read()
assert "example.com pool_1" in content assert "example.com pool_1" in content
# Check wildcards file
with open(patch_config_paths["wildcards_file"]) as f: with open(patch_config_paths["wildcards_file"]) as f:
content = f.read() content = f.read()
assert ".example.com pool_1" in content assert ".example.com pool_1" in content
def test_sorted_output(self, patch_config_paths): def test_sorted_output(self, patch_config_paths):
"""Entries are sorted in output.""" """Entries are sorted in output."""
entries = [ add_domain_to_map("z.example.com", "pool_3")
("z.example.com", "pool_3"), add_domain_to_map("a.example.com", "pool_1")
("a.example.com", "pool_1"), add_domain_to_map("m.example.com", "pool_2")
("m.example.com", "pool_2"),
]
save_map_file(entries) save_map_file([])
with open(patch_config_paths["map_file"]) as f: with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
@@ -222,12 +187,11 @@ class TestSaveMapFile:
class TestGetDomainBackend: class TestGetDomainBackend:
"""Tests for get_domain_backend function.""" """Tests for get_domain_backend function (SQLite-backed)."""
def test_find_existing_domain(self, patch_config_paths): def test_find_existing_domain(self, patch_config_paths):
"""Find backend for existing domain.""" """Find backend for existing domain."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
backend = get_domain_backend("example.com") backend = get_domain_backend("example.com")
@@ -235,8 +199,7 @@ class TestGetDomainBackend:
def test_domain_not_found(self, patch_config_paths): def test_domain_not_found(self, patch_config_paths):
"""Non-existent domain returns None.""" """Non-existent domain returns None."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
backend = get_domain_backend("other.com") backend = get_domain_backend("other.com")
@@ -271,8 +234,7 @@ class TestGetBackendAndPrefix:
def test_pool_backend(self, patch_config_paths): def test_pool_backend(self, patch_config_paths):
"""Pool backend returns pool-based prefix.""" """Pool backend returns pool-based prefix."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_5")
f.write("example.com pool_5\n")
backend, prefix = get_backend_and_prefix("example.com") backend, prefix = get_backend_and_prefix("example.com")
@@ -288,48 +250,33 @@ class TestGetBackendAndPrefix:
class TestLoadServersConfig: class TestLoadServersConfig:
"""Tests for load_servers_config function.""" """Tests for load_servers_config function (SQLite-backed)."""
def test_load_existing_config(self, patch_config_paths, sample_servers_config): def test_load_empty_config(self, patch_config_paths):
"""Load existing config file.""" """Empty database returns empty dict."""
with open(patch_config_paths["servers_file"], "w") as f: config = load_servers_config()
json.dump(sample_servers_config, f) assert config == {}
def test_load_with_servers(self, patch_config_paths):
"""Load config with server entries."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
config = load_servers_config() config = load_servers_config()
assert "example.com" in config assert "example.com" in config
assert config["example.com"]["1"]["ip"] == "10.0.0.1" assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["2"]["ip"] == "10.0.0.2"
def test_file_not_found(self, patch_config_paths): def test_load_with_shared_domain(self, patch_config_paths):
"""Missing file returns empty dict.""" """Load config with shared domain reference."""
os.unlink(patch_config_paths["servers_file"]) add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
config = load_servers_config() config = load_servers_config()
assert config == {} assert config["www.example.com"]["_shares"] == "example.com"
def test_invalid_json(self, patch_config_paths):
"""Invalid JSON returns empty dict."""
with open(patch_config_paths["servers_file"], "w") as f:
f.write("not valid json {{{")
config = load_servers_config()
assert config == {}
class TestSaveServersConfig:
"""Tests for save_servers_config function."""
def test_save_config(self, patch_config_paths):
"""Save config to file."""
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
save_servers_config(config)
with open(patch_config_paths["servers_file"]) as f:
loaded = json.load(f)
assert loaded == config
class TestAddServerToConfig: class TestAddServerToConfig:
@@ -373,17 +320,18 @@ class TestRemoveServerFromConfig:
remove_server_from_config("example.com", 1) remove_server_from_config("example.com", 1)
config = load_servers_config() config = load_servers_config()
assert "1" not in config["example.com"] assert "1" not in config.get("example.com", {})
assert "2" in config["example.com"] assert "2" in config["example.com"]
def test_remove_last_server_removes_domain(self, patch_config_paths): def test_remove_last_server_removes_domain(self, patch_config_paths):
"""Removing last server removes domain entry.""" """Removing last server removes domain entry from servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_server_from_config("example.com", 1) remove_server_from_config("example.com", 1)
config = load_servers_config() config = load_servers_config()
assert "example.com" not in config # Domain may or may not exist (no servers = no entry)
assert config.get("example.com", {}).get("1") is None
def test_remove_nonexistent_server(self, patch_config_paths): def test_remove_nonexistent_server(self, patch_config_paths):
"""Removing non-existent server is a no-op.""" """Removing non-existent server is a no-op."""
@@ -399,14 +347,14 @@ class TestRemoveDomainFromConfig:
"""Tests for remove_domain_from_config function.""" """Tests for remove_domain_from_config function."""
def test_remove_existing_domain(self, patch_config_paths): def test_remove_existing_domain(self, patch_config_paths):
"""Remove existing domain.""" """Remove existing domain's servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("other.com", 1, "10.0.0.2", 80) add_server_to_config("other.com", 1, "10.0.0.2", 80)
remove_domain_from_config("example.com") remove_domain_from_config("example.com")
config = load_servers_config() config = load_servers_config()
assert "example.com" not in config assert config.get("example.com", {}).get("1") is None
assert "other.com" in config assert "other.com" in config
def test_remove_nonexistent_domain(self, patch_config_paths): def test_remove_nonexistent_domain(self, patch_config_paths):
@@ -420,40 +368,23 @@ class TestRemoveDomainFromConfig:
class TestLoadCertsConfig: class TestLoadCertsConfig:
"""Tests for load_certs_config function.""" """Tests for load_certs_config function (SQLite-backed)."""
def test_load_existing_config(self, patch_config_paths): def test_load_empty(self, patch_config_paths):
"""Load existing certs config.""" """Empty database returns empty list."""
with open(patch_config_paths["certs_file"], "w") as f: domains = load_certs_config()
json.dump({"domains": ["example.com", "other.com"]}, f) assert domains == []
def test_load_with_certs(self, patch_config_paths):
"""Load certs from database."""
add_cert_to_config("example.com")
add_cert_to_config("other.com")
domains = load_certs_config() domains = load_certs_config()
assert "example.com" in domains assert "example.com" in domains
assert "other.com" in domains assert "other.com" in domains
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty list."""
os.unlink(patch_config_paths["certs_file"])
domains = load_certs_config()
assert domains == []
class TestSaveCertsConfig:
"""Tests for save_certs_config function."""
def test_save_domains(self, patch_config_paths):
"""Save domains to certs config."""
save_certs_config(["z.com", "a.com"])
with open(patch_config_paths["certs_file"]) as f:
data = json.load(f)
# Should be sorted
assert data["domains"] == ["a.com", "z.com"]
class TestAddCertToConfig: class TestAddCertToConfig:
"""Tests for add_cert_to_config function.""" """Tests for add_cert_to_config function."""
@@ -496,3 +427,87 @@ class TestRemoveCertFromConfig:
domains = load_certs_config() domains = load_certs_config()
assert "example.com" in domains assert "example.com" in domains
class TestAddDomainToMap:
"""Tests for add_domain_to_map function."""
def test_add_domain(self, patch_config_paths):
"""Add a domain and verify map files are synced."""
add_domain_to_map("example.com", "pool_1")
assert get_domain_backend("example.com") == "pool_1"
with open(patch_config_paths["map_file"]) as f:
assert "example.com pool_1" in f.read()
def test_add_wildcard(self, patch_config_paths):
"""Add a wildcard domain."""
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
entries = get_map_contents()
assert (".example.com", "pool_1") in entries
class TestRemoveDomainFromMap:
"""Tests for remove_domain_from_map function."""
def test_remove_domain(self, patch_config_paths):
"""Remove a domain and its wildcard."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
remove_domain_from_map("example.com")
assert get_domain_backend("example.com") is None
entries = get_map_contents()
assert (".example.com", "pool_1") not in entries
class TestFindAvailablePool:
"""Tests for find_available_pool function."""
def test_first_pool_available(self, patch_config_paths):
"""When no domains exist, pool_1 is returned."""
pool = find_available_pool()
assert pool == "pool_1"
def test_skip_used_pools(self, patch_config_paths):
"""Used pools are skipped."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("other.com", "pool_2")
pool = find_available_pool()
assert pool == "pool_3"
class TestSharedDomains:
"""Tests for shared domain functions."""
def test_get_shared_domain(self, patch_config_paths):
"""Get parent domain for shared domain."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
assert get_shared_domain("www.example.com") == "example.com"
def test_is_shared_domain(self, patch_config_paths):
"""Check if domain is shared."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
assert is_shared_domain("www.example.com") is True
assert is_shared_domain("example.com") is False
def test_get_domains_sharing_pool(self, patch_config_paths):
"""Get all domains using a pool."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
domains = get_domains_sharing_pool("pool_1")
assert "example.com" in domains
assert "www.example.com" in domains
assert ".example.com" not in domains # Wildcards excluded

View File

@@ -6,6 +6,8 @@ from unittest.mock import patch, MagicMock
import pytest import pytest
from haproxy_mcp.file_ops import add_cert_to_config
class TestGetPemPaths: class TestGetPemPaths:
"""Tests for get_pem_paths function.""" """Tests for get_pem_paths function."""
@@ -127,8 +129,7 @@ class TestRestoreCertificates:
def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select): def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
"""Restore certificates successfully.""" """Restore certificates successfully."""
# Save config # Save config
with open(patch_config_paths["certs_file"], "w") as f: add_cert_to_config("example.com")
json.dump({"domains": ["example.com"]}, f)
# Create PEM # Create PEM
certs_dir = tmp_path / "certs" certs_dir = tmp_path / "certs"
@@ -283,12 +284,18 @@ class TestHaproxyCertInfo:
pem_file = tmp_path / "example.com.pem" pem_file = tmp_path / "example.com.pem"
pem_file.write_text("cert content") pem_file.write_text("cert content")
mock_subprocess.return_value = MagicMock( def subprocess_side_effect(*args, **kwargs):
cmd = args[0] if args else kwargs.get("args", [])
if isinstance(cmd, list) and "stat" in cmd:
return MagicMock(returncode=0, stdout="1704067200", stderr="")
return MagicMock(
returncode=0, returncode=0,
stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT", stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT",
stderr="" stderr=""
) )
mock_subprocess.side_effect = subprocess_side_effect
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show ssl cert": "/etc/haproxy/certs/example.com.pem", "show ssl cert": "/etc/haproxy/certs/example.com.pem",
}) })
@@ -337,10 +344,18 @@ class TestHaproxyIssueCert:
assert "Error" in result assert "Error" in result
assert "Invalid domain" in result assert "Invalid domain" in result
def test_issue_cert_no_cf_token(self, tmp_path): def test_issue_cert_no_cf_token(self, tmp_path, mock_subprocess):
"""Fail when CF_Token is not set.""" """Fail when CF_Token is not set."""
acme_sh = str(tmp_path / "acme.sh")
mock_subprocess.return_value = MagicMock(
returncode=1,
stdout="",
stderr="CF_Token is not set. Please export CF_Token environment variable.",
)
with patch.dict(os.environ, {}, clear=True): with patch.dict(os.environ, {}, clear=True):
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)): with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)):
with patch("haproxy_mcp.tools.certificates.ACME_SH", acme_sh):
with patch("os.path.exists", return_value=False): with patch("os.path.exists", return_value=False):
from haproxy_mcp.tools.certificates import register_certificate_tools from haproxy_mcp.tools.certificates import register_certificate_tools
mcp = MagicMock() mcp = MagicMock()
@@ -845,8 +860,8 @@ class TestHaproxyRenewAllCertsMultiple:
def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path): def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path):
"""Renew multiple certificates successfully.""" """Renew multiple certificates successfully."""
# Write config with multiple domains # Write config with multiple domains
with open(patch_config_paths["certs_file"], "w") as f: add_cert_to_config("example.com")
json.dump({"domains": ["example.com", "example.org"]}, f) add_cert_to_config("example.org")
# Create PEM files # Create PEM files
certs_dir = tmp_path / "certs" certs_dir = tmp_path / "certs"
@@ -1038,16 +1053,18 @@ class TestHaproxyDeleteCertPartialFailure:
"show ssl cert": "", # Not loaded "show ssl cert": "", # Not loaded
}) })
# Mock os.remove to fail # Mock subprocess to succeed for acme.sh remove but fail for rm (PEM removal)
def mock_remove(path): def subprocess_side_effect(*args, **kwargs):
if "example.com.pem" in str(path): cmd = args[0] if args else kwargs.get("args", [])
raise PermissionError("Permission denied") if isinstance(cmd, list) and cmd[0] == "rm":
raise FileNotFoundError() return MagicMock(returncode=1, stdout="", stderr="Permission denied")
return MagicMock(returncode=0, stdout="", stderr="")
mock_subprocess.side_effect = subprocess_side_effect
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")): with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")):
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)): with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
with patch("os.remove", side_effect=mock_remove):
from haproxy_mcp.tools.certificates import register_certificate_tools from haproxy_mcp.tools.certificates import register_certificate_tools
mcp = MagicMock() mcp = MagicMock()
registered_tools = {} registered_tools = {}
@@ -1118,8 +1135,8 @@ class TestRestoreCertificatesFailure:
def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select): def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
"""Handle partial failure when restoring certificates.""" """Handle partial failure when restoring certificates."""
# Save config with multiple domains # Save config with multiple domains
with open(patch_config_paths["certs_file"], "w") as f: add_cert_to_config("example.com")
json.dump({"domains": ["example.com", "missing.com"]}, f) add_cert_to_config("missing.com")
# Create only one PEM file # Create only one PEM file
certs_dir = tmp_path / "certs" certs_dir = tmp_path / "certs"

View File

@@ -5,6 +5,8 @@ from unittest.mock import patch, MagicMock
import pytest import pytest
from haproxy_mcp.file_ops import add_domain_to_map, add_server_to_config
class TestRestoreServersFromConfig: class TestRestoreServersFromConfig:
"""Tests for restore_servers_from_config function.""" """Tests for restore_servers_from_config function."""
@@ -19,12 +21,12 @@ class TestRestoreServersFromConfig:
def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
"""Restore servers successfully.""" """Restore servers successfully."""
# Write config and map # Add domains and servers to database
with open(patch_config_paths["servers_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
json.dump(sample_servers_config, f) add_server_to_config("example.com", 1, "10.0.0.1", 80)
with open(patch_config_paths["map_file"], "w") as f: add_server_to_config("example.com", 2, "10.0.0.2", 80)
f.write("example.com pool_1\n") add_domain_to_map("api.example.com", "pool_2")
f.write("api.example.com pool_2\n") add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"set server": "", "set server": "",
@@ -40,9 +42,8 @@ class TestRestoreServersFromConfig:
def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths): def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths):
"""Skip domains not in map file.""" """Skip domains not in map file."""
config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}} # Add server for unknown.com but no map entry (simulates missing domain)
with open(patch_config_paths["servers_file"], "w") as f: add_server_to_config("unknown.com", 1, "10.0.0.1", 80)
json.dump(config, f)
mock_sock = mock_socket_class(responses={"set server": ""}) mock_sock = mock_socket_class(responses={"set server": ""})
@@ -55,11 +56,9 @@ class TestRestoreServersFromConfig:
def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths): def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths):
"""Skip servers with empty IP.""" """Skip servers with empty IP."""
config = {"example.com": {"1": {"ip": "", "http_port": 80}}} # Add domain to map and server with empty IP (will be skipped during restore)
with open(patch_config_paths["servers_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
json.dump(config, f) add_server_to_config("example.com", 1, "", 80)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={"set server": ""}) mock_sock = mock_socket_class(responses={"set server": ""})
@@ -321,11 +320,12 @@ class TestHaproxyRestoreState:
def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config): def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
"""Restore state successfully.""" """Restore state successfully."""
with open(patch_config_paths["servers_file"], "w") as f: # Add domains and servers to database
json.dump(sample_servers_config, f) add_domain_to_map("example.com", "pool_1")
with open(patch_config_paths["map_file"], "w") as f: add_server_to_config("example.com", 1, "10.0.0.1", 80)
f.write("example.com pool_1\n") add_server_to_config("example.com", 2, "10.0.0.2", 80)
f.write("api.example.com pool_2\n") add_domain_to_map("api.example.com", "pool_2")
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
mock_sock = mock_socket_class(responses={"set server": ""}) mock_sock = mock_socket_class(responses={"set server": ""})
@@ -373,17 +373,10 @@ class TestRestoreServersFromConfigBatchFailure:
def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths): def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths):
"""Fall back to individual commands when batch fails.""" """Fall back to individual commands when batch fails."""
# Create config with servers # Add domain and servers to database
config = { add_domain_to_map("example.com", "pool_1")
"example.com": { add_server_to_config("example.com", 1, "10.0.0.1", 80)
"1": {"ip": "10.0.0.1", "http_port": 80}, add_server_to_config("example.com", 2, "10.0.0.2", 80)
"2": {"ip": "10.0.0.2", "http_port": 80},
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
# Track call count to simulate batch failure then individual success # Track call count to simulate batch failure then individual success
call_count = [0] call_count = [0]
@@ -457,17 +450,17 @@ class TestRestoreServersFromConfigBatchFailure:
def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths): def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths):
"""Skip servers with invalid slot number.""" """Skip servers with invalid slot number."""
# Add domain to map
add_domain_to_map("example.com", "pool_1")
# Mock load_servers_config to return config with invalid slot
config = { config = {
"example.com": { "example.com": {
"invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot "invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot
"1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot "1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot
} }
} }
with open(patch_config_paths["servers_file"], "w") as f: with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={"set server": ""}) mock_sock = mock_socket_class(responses={"set server": ""})
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
@@ -481,17 +474,17 @@ class TestRestoreServersFromConfigBatchFailure:
def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog): def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog):
"""Skip servers with invalid port.""" """Skip servers with invalid port."""
import logging import logging
# Add domain to map
add_domain_to_map("example.com", "pool_1")
# Mock load_servers_config to return config with invalid port
config = { config = {
"example.com": { "example.com": {
"1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port "1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port
"2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port "2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port
} }
} }
with open(patch_config_paths["servers_file"], "w") as f: with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={"set server": ""}) mock_sock = mock_socket_class(responses={"set server": ""})
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
@@ -658,11 +651,12 @@ class TestHaproxyRestoreStateFailures:
"""Handle HAProxy error when restoring state.""" """Handle HAProxy error when restoring state."""
from haproxy_mcp.exceptions import HaproxyError from haproxy_mcp.exceptions import HaproxyError
with open(patch_config_paths["servers_file"], "w") as f: # Add domains and servers to database
json.dump(sample_servers_config, f) add_domain_to_map("example.com", "pool_1")
with open(patch_config_paths["map_file"], "w") as f: add_server_to_config("example.com", 1, "10.0.0.1", 80)
f.write("example.com pool_1\n") add_server_to_config("example.com", 2, "10.0.0.2", 80)
f.write("api.example.com pool_2\n") add_domain_to_map("api.example.com", "pool_2")
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")): with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")):
from haproxy_mcp.tools.configuration import register_config_tools from haproxy_mcp.tools.configuration import register_config_tools

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
import pytest import pytest
from haproxy_mcp.exceptions import HaproxyError from haproxy_mcp.exceptions import HaproxyError
from haproxy_mcp.file_ops import add_domain_to_map
class TestHaproxyListDomains: class TestHaproxyListDomains:
@@ -38,9 +39,8 @@ class TestHaproxyListDomains:
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List domains with configured servers.""" """List domains with configured servers."""
# Write map file # Add domain to DB
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -70,10 +70,8 @@ class TestHaproxyListDomains:
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List domains excluding wildcards by default.""" """List domains excluding wildcards by default."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n") add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([]), "show servers state": response_builder.servers_state([]),
@@ -100,10 +98,8 @@ class TestHaproxyListDomains:
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List domains including wildcards when requested.""" """List domains including wildcards when requested."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n") add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([]), "show servers state": response_builder.servers_state([]),
@@ -230,8 +226,7 @@ class TestHaproxyAddDomain:
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Reject adding domain that already exists.""" """Reject adding domain that already exists."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
from haproxy_mcp.tools.domains import register_domain_tools from haproxy_mcp.tools.domains import register_domain_tools
mcp = MagicMock() mcp = MagicMock()
@@ -362,8 +357,7 @@ class TestHaproxyRemoveDomain:
def test_remove_legacy_domain_rejected(self, patch_config_paths): def test_remove_legacy_domain_rejected(self, patch_config_paths):
"""Reject removing legacy (non-pool) domain.""" """Reject removing legacy (non-pool) domain."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "legacy_backend")
f.write("example.com legacy_backend\n")
from haproxy_mcp.tools.domains import register_domain_tools from haproxy_mcp.tools.domains import register_domain_tools
mcp = MagicMock() mcp = MagicMock()
@@ -385,10 +379,8 @@ class TestHaproxyRemoveDomain:
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully remove domain.""" """Successfully remove domain."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n") add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"del map": "", "del map": "",

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
import pytest import pytest
from haproxy_mcp.exceptions import HaproxyError from haproxy_mcp.exceptions import HaproxyError
from haproxy_mcp.file_ops import add_domain_to_map
class TestHaproxyHealth: class TestHaproxyHealth:
@@ -80,7 +81,7 @@ class TestHaproxyHealth:
# Use paths that don't exist # Use paths that don't exist
with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")): with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")):
with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")): with patch("haproxy_mcp.tools.health.DB_FILE", str(tmp_path / "nonexistent.db")):
with patch("socket.socket", return_value=mock_sock): with patch("socket.socket", return_value=mock_sock):
from haproxy_mcp.tools.health import register_health_tools from haproxy_mcp.tools.health import register_health_tools
mcp = MagicMock() mcp = MagicMock()
@@ -160,8 +161,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns healthy when all servers are UP.""" """Domain health returns healthy when all servers are UP."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -197,8 +197,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns degraded when some servers are DOWN.""" """Domain health returns degraded when some servers are DOWN."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -234,8 +233,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns down when all servers are DOWN.""" """Domain health returns down when all servers are DOWN."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -269,8 +267,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns no_servers when no servers configured.""" """Domain health returns no_servers when no servers configured."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
import pytest import pytest
from haproxy_mcp.exceptions import HaproxyError from haproxy_mcp.exceptions import HaproxyError
from haproxy_mcp.file_ops import add_domain_to_map, load_servers_config
class TestHaproxyListServers: class TestHaproxyListServers:
@@ -33,8 +34,7 @@ class TestHaproxyListServers:
def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List servers for domain with no servers.""" """List servers for domain with no servers."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -63,8 +63,7 @@ class TestHaproxyListServers:
def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List servers with active servers.""" """List servers with active servers."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -224,8 +223,7 @@ class TestHaproxyAddServer:
def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully add server.""" """Successfully add server."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"set server": "", "set server": "",
@@ -258,8 +256,7 @@ class TestHaproxyAddServer:
def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Auto-select slot when slot=0.""" """Auto-select slot when slot=0."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -413,8 +410,7 @@ class TestHaproxyAddServers:
def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully add multiple servers.""" """Successfully add multiple servers."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"set server": "", "set server": "",
@@ -495,8 +491,7 @@ class TestHaproxyRemoveServer:
def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully remove server.""" """Successfully remove server."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"set server": "", "set server": "",
@@ -689,8 +684,7 @@ class TestHaproxyAddServersRollback:
def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths): def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths):
"""Rollback only failed slots when HAProxy error occurs.""" """Rollback only failed slots when HAProxy error occurs."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
# Mock configure_server_slot to fail on second slot # Mock configure_server_slot to fail on second slot
call_count = [0] call_count = [0]
@@ -735,18 +729,16 @@ class TestHaproxyAddServersRollback:
assert "slot 2" in result # Failed assert "slot 2" in result # Failed
# Verify servers.json only has successfully added server # Verify servers.json only has successfully added server
with open(patch_config_paths["servers_file"], "r") as f: config = load_servers_config()
config = json.load(f)
assert "example.com" in config assert "example.com" in config
assert "1" in config["example.com"] # Successfully added stays assert "1" in config["example.com"] # Successfully added stays
assert "2" not in config["example.com"] # Failed one was rolled back assert "2" not in config.get("example.com", {}) # Failed one was rolled back
def test_add_servers_unexpected_error_rollback_only_successful( def test_add_servers_unexpected_error_rollback_only_successful(
self, mock_socket_class, mock_select, patch_config_paths self, mock_socket_class, mock_select, patch_config_paths
): ):
"""Rollback only successfully added servers on unexpected error.""" """Rollback only successfully added servers on unexpected error."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
# Track which servers were configured # Track which servers were configured
configured_slots = [] configured_slots = []
@@ -793,8 +785,7 @@ class TestHaproxyAddServersRollback:
assert "Unexpected system error" in result assert "Unexpected system error" in result
# Verify servers.json is empty (all rolled back) # Verify servers.json is empty (all rolled back)
with open(patch_config_paths["servers_file"], "r") as f: config = load_servers_config()
config = json.load(f)
assert config == {} or "example.com" not in config or config.get("example.com") == {} assert config == {} or "example.com" not in config or config.get("example.com") == {}
def test_add_servers_rollback_failure_logged( def test_add_servers_rollback_failure_logged(
@@ -802,8 +793,7 @@ class TestHaproxyAddServersRollback:
): ):
"""Log rollback failures during error recovery.""" """Log rollback failures during error recovery."""
import logging import logging
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port): def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2: if slot == 2:
@@ -858,8 +848,7 @@ class TestHaproxyAddServerAutoSlot:
def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Auto-select slot fails when all slots are in use.""" """Auto-select slot fails when all slots are in use."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
# Build response with all 10 slots used # Build response with all 10 slots used
servers = [] servers = []
@@ -902,8 +891,7 @@ class TestHaproxyAddServerAutoSlot:
def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Negative slot number triggers auto-selection.""" """Negative slot number triggers auto-selection."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -970,8 +958,7 @@ class TestHaproxyWaitDrain:
def test_wait_drain_success(self, patch_config_paths): def test_wait_drain_success(self, patch_config_paths):
"""Successfully wait for connections to drain.""" """Successfully wait for connections to drain."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
# Mock haproxy_cmd to return 0 connections # Mock haproxy_cmd to return 0 connections
with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd: with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd:
@@ -1000,8 +987,7 @@ class TestHaproxyWaitDrain:
def test_wait_drain_timeout(self, patch_config_paths): def test_wait_drain_timeout(self, patch_config_paths):
"""Timeout when connections don't drain.""" """Timeout when connections don't drain."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing
time_iter = iter(time_values) time_iter = iter(time_values)
@@ -1087,9 +1073,6 @@ class TestHaproxyWaitDrain:
def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths): def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths):
"""Error when domain not found in map.""" """Error when domain not found in map."""
# Empty map file - domain not configured # Empty map file - domain not configured
with open(patch_config_paths["map_file"], "w") as f:
f.write("")
from haproxy_mcp.tools.servers import register_server_tools from haproxy_mcp.tools.servers import register_server_tools
mcp = MagicMock() mcp = MagicMock()
registered_tools = {} registered_tools = {}
@@ -1202,8 +1185,7 @@ class TestHaproxySetDomainState:
def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Set all servers of a domain to a state.""" """Set all servers of a domain to a state."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([ "show servers state": response_builder.servers_state([
@@ -1283,8 +1265,7 @@ class TestHaproxySetDomainState:
def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""No active servers found for domain.""" """No active servers found for domain."""
with open(patch_config_paths["map_file"], "w") as f: add_domain_to_map("example.com", "pool_1")
f.write("example.com pool_1\n")
# All servers have 0.0.0.0 address (not configured) # All servers have 0.0.0.0 address (not configured)
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
@@ -1318,9 +1299,6 @@ class TestHaproxySetDomainState:
def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder): def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Handle domain not found in map - shows no active servers.""" """Handle domain not found in map - shows no active servers."""
# Empty map file # Empty map file
with open(patch_config_paths["map_file"], "w") as f:
f.write("")
# Mock should show no servers for unknown domain's backend # Mock should show no servers for unknown domain's backend
mock_sock = mock_socket_class(responses={ mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([]), "show servers state": response_builder.servers_state([]),