refactor: migrate data storage from JSON/map files to SQLite

Replace servers.json, certificates.json, and map file parsing with
SQLite (WAL mode) as single source of truth. HAProxy map files are
now generated from SQLite via sync_map_files().

Key changes:
- Add db.py with schema, connection management, and JSON migration
- Add DB_FILE config constant
- Delegate file_ops.py functions to db.py
- Refactor domains.py to use file_ops instead of direct list manipulation
- Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError)
- Add DB health check in health.py
- Init DB on startup in server.py and __main__.py
- Update all 359 tests to use SQLite-backed functions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kappa
2026-02-08 11:07:29 +09:00
parent 05bff61b85
commit cf554f3f89
19 changed files with 1525 additions and 564 deletions

3
.gitignore vendored
View File

@@ -16,6 +16,9 @@ run/
data/
*.state
*.lock
*.db
*.db-wal
*.db-shm
# Python
__pycache__/

View File

@@ -1,8 +1,10 @@
"""Entry point for running haproxy_mcp as a module."""
from .db import init_db
from .server import mcp
from .tools.configuration import startup_restore
if __name__ == "__main__":
init_db()
startup_restore()
mcp.run(transport="streamable-http")

View File

@@ -31,6 +31,7 @@ WILDCARDS_MAP_FILE: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE", "/opt/haproxy/
WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map")
SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json")
CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json")
DB_FILE: str = os.getenv("HAPROXY_DB_FILE", "/opt/haproxy/conf/haproxy_mcp.db")
# Certificate paths
CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs")

577
haproxy_mcp/db.py Normal file
View File

@@ -0,0 +1,577 @@
"""SQLite database operations for HAProxy MCP Server.
Single source of truth for all persistent configuration data.
Replaces JSON files (servers.json, certificates.json) and map files
(domains.map, wildcards.map) with a single SQLite database.
Map files are generated from the database via sync_map_files() for
HAProxy to consume directly.
"""
import json
import os
import sqlite3
import threading
from typing import Any, Optional
from .config import (
DB_FILE,
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
POOL_COUNT,
REMOTE_MODE,
logger,
)
SCHEMA_VERSION = 1
_local = threading.local()
def get_connection() -> sqlite3.Connection:
"""Get a thread-local SQLite connection with WAL mode.
Returns:
sqlite3.Connection configured for concurrent access.
"""
conn = getattr(_local, "connection", None)
if conn is not None:
return conn
conn = sqlite3.connect(DB_FILE, timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
_local.connection = conn
return conn
def close_connection() -> None:
"""Close the thread-local SQLite connection."""
conn = getattr(_local, "connection", None)
if conn is not None:
conn.close()
_local.connection = None
def init_db() -> None:
"""Initialize database schema and run migration if needed.
Creates tables if they don't exist, then checks for existing
JSON/map files to migrate data from.
"""
conn = get_connection()
cur = conn.cursor()
cur.executescript("""
CREATE TABLE IF NOT EXISTS domains (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
backend TEXT NOT NULL,
is_wildcard INTEGER NOT NULL DEFAULT 0,
shares_with TEXT DEFAULT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend);
CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with);
CREATE TABLE IF NOT EXISTS servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL,
slot INTEGER NOT NULL,
ip TEXT NOT NULL,
http_port INTEGER NOT NULL DEFAULT 80,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
UNIQUE(domain, slot)
);
CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain);
CREATE TABLE IF NOT EXISTS certificates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL,
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
""")
conn.commit()
# Check schema version
cur.execute("SELECT MAX(version) FROM schema_version")
row = cur.fetchone()
current_version = row[0] if row[0] is not None else 0
if current_version < SCHEMA_VERSION:
# First-time setup: try migrating from JSON files
if current_version == 0:
migrate_from_json()
cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,))
conn.commit()
logger.info("Database initialized (schema v%d)", SCHEMA_VERSION)
def migrate_from_json() -> None:
"""Migrate data from JSON/map files to SQLite.
Reads domains.map, wildcards.map, servers.json, and certificates.json,
imports their data into the database within a single transaction,
then renames the original files to .bak.
This function is idempotent (uses INSERT OR IGNORE).
"""
conn = get_connection()
# Collect data from existing files
map_entries: list[tuple[str, str]] = []
servers_config: dict[str, Any] = {}
cert_domains: list[str] = []
# Read map files
for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]:
try:
content = _read_file_for_migration(map_path)
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
map_entries.append((parts[0], parts[1]))
except FileNotFoundError:
logger.debug("Map file not found for migration: %s", map_path)
# Read servers.json
try:
content = _read_file_for_migration(SERVERS_FILE)
servers_config = json.loads(content)
except FileNotFoundError:
logger.debug("Servers config not found for migration: %s", SERVERS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt servers config during migration: %s", e)
# Read certificates.json
try:
content = _read_file_for_migration(CERTS_FILE)
data = json.loads(content)
cert_domains = data.get("domains", [])
except FileNotFoundError:
logger.debug("Certs config not found for migration: %s", CERTS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt certs config during migration: %s", e)
if not map_entries and not servers_config and not cert_domains:
logger.info("No existing data to migrate")
return
# Import into database in a single transaction
try:
with conn:
# Import domains from map files
for domain, backend in map_entries:
is_wildcard = 1 if domain.startswith(".") else 0
conn.execute(
"INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)",
(domain, backend, is_wildcard),
)
# Import servers and shares_with from servers.json
for domain, slots in servers_config.items():
shares_with = slots.get("_shares")
if shares_with:
# Update domain's shares_with field
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
continue
for slot_str, server_info in slots.items():
if slot_str.startswith("_"):
continue
try:
slot = int(slot_str)
ip = server_info.get("ip", "")
http_port = int(server_info.get("http_port", 80))
if ip:
conn.execute(
"INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)",
(domain, slot, ip, http_port),
)
except (ValueError, TypeError) as e:
logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e)
# Import certificates
for cert_domain in cert_domains:
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(cert_domain,),
)
migrated_domains = len(set(d for d, _ in map_entries))
migrated_servers = sum(
1 for d, slots in servers_config.items()
if "_shares" not in slots
for s in slots
if not s.startswith("_")
)
migrated_certs = len(cert_domains)
logger.info(
"Migrated from JSON: %d domain entries, %d servers, %d certificates",
migrated_domains, migrated_servers, migrated_certs,
)
# Rename original files to .bak
for path in [SERVERS_FILE, CERTS_FILE]:
_backup_file(path)
except sqlite3.Error as e:
logger.error("Migration failed: %s", e)
raise
def _read_file_for_migration(file_path: str) -> str:
"""Read a file for migration purposes (local only, no locking needed)."""
if REMOTE_MODE:
from .ssh_ops import remote_read_file
return remote_read_file(file_path)
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def _backup_file(file_path: str) -> None:
"""Rename a file to .bak if it exists."""
if REMOTE_MODE:
return # Don't rename remote files
try:
if os.path.exists(file_path):
bak_path = f"{file_path}.bak"
os.rename(file_path, bak_path)
logger.info("Backed up %s -> %s", file_path, bak_path)
except OSError as e:
logger.warning("Failed to backup %s: %s", file_path, e)
# --- Domain data access ---
def db_get_map_contents() -> list[tuple[str, str]]:
"""Get all domain-to-backend mappings.
Returns:
List of (domain, backend) tuples including wildcards.
"""
conn = get_connection()
cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain")
return [(row["domain"], row["backend"]) for row in cur.fetchall()]
def db_get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a specific domain.
Args:
domain: Domain name to look up.
Returns:
Backend name if found, None otherwise.
"""
conn = get_connection()
cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,))
row = cur.fetchone()
return row["backend"] if row else None
def db_add_domain(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to the database.
Args:
domain: Domain name (e.g., "example.com" or ".example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO domains (domain, backend, is_wildcard, shares_with)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain) DO UPDATE SET
backend = excluded.backend,
shares_with = excluded.shares_with,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, backend, 1 if is_wildcard else 0, shares_with),
)
conn.commit()
def db_remove_domain(domain: str) -> None:
"""Remove a domain (exact + wildcard) from the database.
Args:
domain: Base domain name (without leading dot).
"""
conn = get_connection()
conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}"))
conn.commit()
def db_find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
conn = get_connection()
cur = conn.execute(
"SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0"
)
used_pools = {row["backend"] for row in cur.fetchall()}
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def db_get_domains_sharing_pool(pool: str) -> list[str]:
"""Get all non-wildcard domains using a specific pool.
Args:
pool: Pool name (e.g., "pool_5").
Returns:
List of domain names.
"""
conn = get_connection()
cur = conn.execute(
"SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0",
(pool,),
)
return [row["domain"] for row in cur.fetchall()]
# --- Server data access ---
def db_load_servers_config() -> dict[str, Any]:
"""Load servers configuration in the legacy dict format.
Returns:
Dictionary compatible with the old servers.json format:
{"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}}
"""
conn = get_connection()
config: dict[str, Any] = {}
# Load server entries
cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot")
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain][str(row["slot"])] = {
"ip": row["ip"],
"http_port": row["http_port"],
}
# Load shared domain references
cur = conn.execute(
"SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0"
)
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain]["_shares"] = row["shares_with"]
return config
def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add or update a server entry.
Args:
domain: Domain name.
slot: Server slot number.
ip: Server IP address.
http_port: HTTP port.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO servers (domain, slot, ip, http_port)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain, slot) DO UPDATE SET
ip = excluded.ip,
http_port = excluded.http_port,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, slot, ip, http_port),
)
conn.commit()
def db_remove_server(domain: str, slot: int) -> None:
"""Remove a server entry.
Args:
domain: Domain name.
slot: Server slot number.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot))
conn.commit()
def db_remove_domain_servers(domain: str) -> None:
"""Remove all server entries for a domain.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ?", (domain,))
conn.commit()
def db_add_shared_domain(domain: str, shares_with: str) -> None:
"""Mark a domain as sharing a pool with another domain.
Updates the shares_with field in the domains table.
Args:
domain: The new domain.
shares_with: The existing domain to share with.
"""
conn = get_connection()
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
conn.commit()
def db_get_shared_domain(domain: str) -> Optional[str]:
"""Get the parent domain that this domain shares a pool with.
Args:
domain: Domain name.
Returns:
Parent domain name if sharing, None otherwise.
"""
conn = get_connection()
cur = conn.execute(
"SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0",
(domain,),
)
row = cur.fetchone()
if row and row["shares_with"]:
return row["shares_with"]
return None
def db_is_shared_domain(domain: str) -> bool:
"""Check if a domain is sharing another domain's pool.
Args:
domain: Domain name.
Returns:
True if domain has shares_with reference.
"""
return db_get_shared_domain(domain) is not None
# --- Certificate data access ---
def db_load_certs() -> list[str]:
"""Load certificate domain list.
Returns:
Sorted list of domain names with certificates.
"""
conn = get_connection()
cur = conn.execute("SELECT domain FROM certificates ORDER BY domain")
return [row["domain"] for row in cur.fetchall()]
def db_add_cert(domain: str) -> None:
"""Add a domain to the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(domain,),
)
conn.commit()
def db_remove_cert(domain: str) -> None:
"""Remove a domain from the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,))
conn.commit()
# --- Map file synchronization ---
def sync_map_files() -> None:
"""Regenerate HAProxy map files from database.
Writes domains.map (exact matches) and wildcards.map (wildcard entries)
from the current database state. Uses atomic writes.
"""
from .file_ops import atomic_write_file
conn = get_connection()
# Fetch exact domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain"
)
exact_entries = cur.fetchall()
# Fetch wildcard domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain"
)
wildcard_entries = cur.fetchall()
# Write exact domains map
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for row in exact_entries:
exact_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Write wildcards map
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for row in wildcard_entries:
wildcard_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
logger.debug("Map files synced: %d exact, %d wildcard entries",
len(exact_entries), len(wildcard_entries))

View File

@@ -1,7 +1,11 @@
"""File I/O operations for HAProxy MCP Server."""
"""File I/O operations for HAProxy MCP Server.
Most data access is now delegated to db.py (SQLite).
This module retains atomic file writes, map file I/O for HAProxy,
and provides backward-compatible function signatures.
"""
import fcntl
import json
import os
import tempfile
from contextlib import contextmanager
@@ -10,8 +14,6 @@ from typing import Any, Generator, Optional
from .config import (
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
REMOTE_MODE,
logger,
)
@@ -138,16 +140,13 @@ def _read_file(file_path: str) -> str:
def get_map_contents() -> list[tuple[str, str]]:
"""Read both domains.map and wildcards.map and return combined entries.
"""Get all domain-to-backend mappings from SQLite.
Returns:
List of (domain, backend) tuples from both map files
List of (domain, backend) tuples including wildcards.
"""
# Read exact domains
entries = _read_map_file(MAP_FILE)
# Read wildcards and append
entries.extend(_read_map_file(WILDCARDS_MAP_FILE))
return entries
from .db import db_get_map_contents
return db_get_map_contents()
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
@@ -170,44 +169,21 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str
def save_map_file(entries: list[tuple[str, str]]) -> None:
"""Save domain-to-backend entries to map files.
"""Sync map files from the database.
Splits entries into two files for 2-stage routing:
- domains.map: Exact matches (map_str, O(log n))
- wildcards.map: Wildcard entries starting with "." (map_dom, O(n))
Args:
entries: List of (domain, backend) tuples.
Regenerates domains.map and wildcards.map from the current
database state. The entries parameter is ignored (kept for
backward compatibility during transition).
Raises:
IOError: If map files cannot be written.
"""
# Split into exact and wildcard entries
exact_entries, wildcard_entries = split_domain_entries(entries)
# Save exact domains (for map_str - fast O(log n) lookup)
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for domain, backend in sorted(exact_entries):
exact_lines.append(f"{domain} {backend}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Save wildcards (for map_dom - O(n) but small set)
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for domain, backend in sorted(wildcard_entries):
wildcard_lines.append(f"{domain} {backend}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
from .db import sync_map_files
sync_map_files()
def get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a domain from domains.map.
"""Look up the backend for a domain from SQLite (O(1)).
Args:
domain: The domain to look up
@@ -215,10 +191,8 @@ def get_domain_backend(domain: str) -> Optional[str]:
Returns:
Backend name if found, None otherwise
"""
for map_domain, backend in get_map_contents():
if map_domain == domain:
return backend
return None
from .db import db_get_domain_backend
return db_get_domain_backend(domain)
def is_legacy_backend(backend: str) -> bool:
@@ -273,34 +247,17 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]:
def load_servers_config() -> dict[str, Any]:
"""Load servers configuration from JSON file.
"""Load servers configuration from SQLite.
Returns:
Dictionary with server configurations
Dictionary with server configurations (legacy format compatible).
"""
try:
content = _read_file(SERVERS_FILE)
return json.loads(content)
except FileNotFoundError:
return {}
except json.JSONDecodeError as e:
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
return {}
def save_servers_config(config: dict[str, Any]) -> None:
"""Save servers configuration to JSON file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args:
config: Dictionary with server configurations
"""
atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2))
from .db import db_load_servers_config
return db_load_servers_config()
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add server configuration to persistent storage with file locking.
"""Add server configuration to persistent storage.
Args:
domain: Domain name
@@ -308,41 +265,29 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non
ip: Server IP address
http_port: HTTP port
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain not in config:
config[domain] = {}
config[domain][str(slot)] = {"ip": ip, "http_port": http_port}
save_servers_config(config)
from .db import db_add_server
db_add_server(domain, slot, ip, http_port)
def remove_server_from_config(domain: str, slot: int) -> None:
"""Remove server configuration from persistent storage with file locking.
"""Remove server configuration from persistent storage.
Args:
domain: Domain name
slot: Server slot to remove
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain in config and str(slot) in config[domain]:
del config[domain][str(slot)]
if not config[domain]:
del config[domain]
save_servers_config(config)
from .db import db_remove_server
db_remove_server(domain, slot)
def remove_domain_from_config(domain: str) -> None:
"""Remove domain from persistent config with file locking.
"""Remove domain from persistent config (servers + domain entry).
Args:
domain: Domain name to remove
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain in config:
del config[domain]
save_servers_config(config)
from .db import db_remove_domain_servers
db_remove_domain_servers(domain)
def get_shared_domain(domain: str) -> Optional[str]:
@@ -354,9 +299,8 @@ def get_shared_domain(domain: str) -> Optional[str]:
Returns:
Parent domain name if sharing, None otherwise
"""
config = load_servers_config()
domain_config = config.get(domain, {})
return domain_config.get("_shares")
from .db import db_get_shared_domain
return db_get_shared_domain(domain)
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
@@ -366,10 +310,8 @@ def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
domain: New domain name
shares_with: Existing domain to share pool with
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
config[domain] = {"_shares": shares_with}
save_servers_config(config)
from .db import db_add_shared_domain
db_add_shared_domain(domain, shares_with)
def get_domains_sharing_pool(pool: str) -> list[str]:
@@ -381,11 +323,8 @@ def get_domains_sharing_pool(pool: str) -> list[str]:
Returns:
List of domain names using this pool
"""
domains = []
for domain, backend in get_map_contents():
if backend == pool and not domain.startswith("."):
domains.append(domain)
return domains
from .db import db_get_domains_sharing_pool
return db_get_domains_sharing_pool(pool)
def is_shared_domain(domain: str) -> bool:
@@ -397,37 +336,20 @@ def is_shared_domain(domain: str) -> bool:
Returns:
True if domain has _shares reference, False otherwise
"""
config = load_servers_config()
domain_config = config.get(domain, {})
return "_shares" in domain_config
from .db import db_is_shared_domain
return db_is_shared_domain(domain)
# Certificate configuration functions
def load_certs_config() -> list[str]:
"""Load certificate domain list from JSON file.
"""Load certificate domain list from SQLite.
Returns:
List of domain names
Sorted list of domain names.
"""
try:
content = _read_file(CERTS_FILE)
data = json.loads(content)
return data.get("domains", [])
except FileNotFoundError:
return []
except json.JSONDecodeError as e:
logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e)
return []
def save_certs_config(domains: list[str]) -> None:
"""Save certificate domain list to JSON file atomically.
Args:
domains: List of domain names
"""
atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2))
from .db import db_load_certs
return db_load_certs()
def add_cert_to_config(domain: str) -> None:
@@ -436,11 +358,8 @@ def add_cert_to_config(domain: str) -> None:
Args:
domain: Domain name to add
"""
with file_lock(f"{CERTS_FILE}.lock"):
domains = load_certs_config()
if domain not in domains:
domains.append(domain)
save_certs_config(domains)
from .db import db_add_cert
db_add_cert(domain)
def remove_cert_from_config(domain: str) -> None:
@@ -449,8 +368,45 @@ def remove_cert_from_config(domain: str) -> None:
Args:
domain: Domain name to remove
"""
with file_lock(f"{CERTS_FILE}.lock"):
domains = load_certs_config()
if domain in domains:
domains.remove(domain)
save_certs_config(domains)
from .db import db_remove_cert
db_remove_cert(domain)
# Domain map helper functions (used by domains.py)
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to SQLite and sync map files.
Args:
domain: Domain name (e.g., "example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
from .db import db_add_domain, sync_map_files
db_add_domain(domain, backend, is_wildcard, shares_with)
sync_map_files()
def remove_domain_from_map(domain: str) -> None:
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
Args:
domain: Base domain name (without leading dot).
"""
from .db import db_remove_domain, sync_map_files
db_remove_domain(domain)
sync_map_files()
def find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
from .db import db_find_available_pool
return db_find_available_pool()

View File

@@ -2,6 +2,7 @@
import socket
import select
import subprocess
import time
from .config import (
@@ -161,7 +162,7 @@ def reload_haproxy() -> tuple[bool, str]:
if result.returncode != 0:
return False, f"Reload failed: {result.stderr}"
return True, "OK"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds"
except FileNotFoundError:
return False, "ssh/podman command not found"

View File

@@ -21,6 +21,7 @@ Environment Variables:
from mcp.server.fastmcp import FastMCP
from .config import MCP_HOST, MCP_PORT
from .db import init_db
from .tools import register_all_tools
from .tools.configuration import startup_restore
@@ -32,5 +33,6 @@ register_all_tools(mcp)
if __name__ == "__main__":
init_db()
startup_restore()
mcp.run(transport="streamable-http")

View File

@@ -1,6 +1,7 @@
"""Certificate management tools for HAProxy MCP Server."""
import os
import subprocess
from datetime import datetime
from typing import Annotated
@@ -152,7 +153,7 @@ def _haproxy_list_certs_impl() -> str:
certs.append(f"{domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
return "\n\n".join(certs) if certs else "No certificates found"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Command timed out"
except FileNotFoundError:
return "Error: acme.sh not found"
@@ -203,7 +204,7 @@ def _haproxy_cert_info_impl(domain: str) -> str:
result.stdout.strip()
]
return "\n".join(info)
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Command timed out"
except OSError as e:
logger.error("Error getting certificate info for %s: %s", domain, e)
@@ -250,7 +251,7 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
else:
return f"Certificate issued but PEM file not created. Check {host_path}"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except OSError as e:
logger.error("Error issuing certificate for %s: %s", domain, e)
@@ -289,7 +290,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
else:
return f"Error renewing certificate:\n{output}"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except FileNotFoundError:
return "Error: acme.sh not found"
@@ -323,7 +324,7 @@ def _haproxy_renew_all_certs_impl() -> str:
else:
return "Renewal check completed"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Renewal cron timed out"
except FileNotFoundError:
return "Error: acme.sh not found"

View File

@@ -1,5 +1,6 @@
"""Configuration management tools for HAProxy MCP Server."""
import subprocess
import time
from ..config import (
@@ -177,7 +178,7 @@ def register_config_tools(mcp):
if result.returncode == 0:
return "Configuration is valid"
return f"Configuration errors:\n{result.stderr}"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds"
except FileNotFoundError:
return "Error: ssh/podman command not found"

View File

@@ -6,7 +6,6 @@ from typing import Annotated, Optional
from pydantic import Field
from ..config import (
MAP_FILE,
MAP_FILE_CONTAINER,
WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT,
@@ -22,7 +21,6 @@ from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd
from ..file_ops import (
get_map_contents,
save_map_file,
get_domain_backend,
is_legacy_backend,
add_server_to_config,
@@ -31,30 +29,13 @@ from ..file_ops import (
add_shared_domain_to_config,
get_domains_sharing_pool,
is_shared_domain,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
)
from ..utils import parse_servers_state, disable_server_slot
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
"""Find an available pool backend from the pool list.
Iterates through pool_1 to pool_N and returns the first pool
that is not currently in use.
Args:
entries: List of (domain, backend) tuples from the map file.
used_pools: Set of pool names already in use.
Returns:
Available pool name (e.g., "pool_5") or None if all pools are in use.
"""
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
"""Check if a domain is a subdomain of an existing registered domain.
@@ -95,24 +76,19 @@ def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
def _rollback_domain_addition(
domain: str,
entries: list[tuple[str, str]]
) -> None:
"""Rollback a failed domain addition by removing entries from map file.
def _rollback_domain_addition(domain: str) -> None:
"""Rollback a failed domain addition by removing from SQLite + map files.
Called when HAProxy Runtime API update fails after the map file
has already been saved.
Called when HAProxy Runtime API update fails after the domain
has already been saved to the database.
Args:
domain: Domain name that was added.
entries: Current list of map entries to rollback from.
"""
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
try:
save_map_file(rollback_entries)
except IOError:
logger.error("Failed to rollback map file after HAProxy error")
remove_domain_from_map(domain)
except (IOError, Exception):
logger.error("Failed to rollback domain %s after HAProxy error", domain)
def _file_exists(path: str) -> bool:
@@ -242,23 +218,17 @@ def register_domain_tools(mcp):
if share_with and ip:
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
# Use file locking for the entire pool allocation operation
from ..file_ops import file_lock
with file_lock(f"{MAP_FILE}.lock"):
# Read map contents once for both existence check and pool lookup
# Read current entries for existence check and subdomain detection
entries = get_map_contents()
# Check if domain already exists (using cached entries)
# Check if domain already exists
for domain_entry, backend in entries:
if domain_entry == domain:
return f"Error: Domain {domain} already exists (mapped to {backend})"
# Build used pools and registered domains sets
used_pools: set[str] = set()
# Build registered domains set for subdomain check
registered_domains: set[str] = set()
for entry_domain, backend in entries:
if backend.startswith("pool_"):
used_pools.add(backend)
for entry_domain, _ in entries:
if not entry_domain.startswith("."):
registered_domains.add(entry_domain)
@@ -271,8 +241,8 @@ def register_domain_tools(mcp):
return f"Error: Cannot share with legacy backend {share_backend}"
pool = share_backend
else:
# Find available pool
pool = _find_available_pool(entries, used_pools)
# Find available pool (SQLite query, O(1))
pool = find_available_pool()
if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use"
@@ -280,20 +250,19 @@ def register_domain_tools(mcp):
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
try:
# Save to disk first (atomic write for persistence)
entries.append((domain, pool))
if not is_subdomain:
entries.append((f".{domain}", pool))
# Save to SQLite + sync map files (atomic via SQLite transaction)
try:
save_map_file(entries)
except IOError as e:
return f"Error: Failed to save map file: {e}"
add_domain_to_map(domain, pool)
if not is_subdomain:
add_domain_to_map(f".{domain}", pool, is_wildcard=True)
except (IOError, Exception) as e:
return f"Error: Failed to save domain: {e}"
# Update HAProxy maps via Runtime API
try:
_update_haproxy_maps(domain, pool, is_subdomain)
except HaproxyError as e:
_rollback_domain_addition(domain, entries)
_rollback_domain_addition(domain)
return f"Error: Failed to update HAProxy map: {e}"
# Handle server configuration based on mode
@@ -355,10 +324,8 @@ def register_domain_tools(mcp):
domains_using_pool = get_domains_sharing_pool(backend)
other_domains = [d for d in domains_using_pool if d != domain]
# Save to disk first (atomic write for persistence)
entries = get_map_contents()
new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
save_map_file(new_entries)
# Remove from SQLite + sync map files
remove_domain_from_map(domain)
# Remove from persistent server config
remove_domain_from_config(domain)

View File

@@ -10,6 +10,7 @@ from pydantic import Field
from ..config import (
MAP_FILE,
SERVERS_FILE,
DB_FILE,
HAPROXY_CONTAINER,
)
from ..exceptions import HaproxyError
@@ -88,7 +89,7 @@ def register_health_tools(mcp):
# Check configuration files
files_ok = True
file_status: dict[str, str] = {}
for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]:
for name, path in [("map_file", MAP_FILE), ("db_file", DB_FILE)]:
exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path)
if exists:
file_status[name] = "ok"

View File

@@ -255,6 +255,8 @@ def temp_config_dir(tmp_path):
state_file = tmp_path / "servers.state"
state_file.write_text("")
db_file = tmp_path / "haproxy_mcp.db"
return {
"dir": tmp_path,
"map_file": str(map_file),
@@ -262,12 +264,15 @@ def temp_config_dir(tmp_path):
"servers_file": str(servers_file),
"certs_file": str(certs_file),
"state_file": str(state_file),
"db_file": str(db_file),
}
@pytest.fixture
def patch_config_paths(temp_config_dir):
"""Fixture that patches config module paths to use temporary directory."""
from haproxy_mcp.db import close_connection, init_db
with patch.multiple(
"haproxy_mcp.config",
MAP_FILE=temp_config_dir["map_file"],
@@ -275,16 +280,34 @@ def patch_config_paths(temp_config_dir):
SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"],
STATE_FILE=temp_config_dir["state_file"],
DB_FILE=temp_config_dir["db_file"],
):
# Also patch file_ops module which imports these
with patch.multiple(
"haproxy_mcp.file_ops",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
):
# Patch db module which imports these
with patch.multiple(
"haproxy_mcp.db",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"],
DB_FILE=temp_config_dir["db_file"],
):
# Patch health module which imports MAP_FILE and DB_FILE
with patch.multiple(
"haproxy_mcp.tools.health",
MAP_FILE=temp_config_dir["map_file"],
DB_FILE=temp_config_dir["db_file"],
):
# Close any existing connection and initialize fresh DB
close_connection()
init_db()
yield temp_config_dir
close_connection()
@pytest.fixture

433
tests/unit/test_db.py Normal file
View File

@@ -0,0 +1,433 @@
"""Unit tests for db module (SQLite database operations)."""
import json
import os
import sqlite3
from unittest.mock import patch
import pytest
from haproxy_mcp.db import (
get_connection,
close_connection,
init_db,
migrate_from_json,
db_get_map_contents,
db_get_domain_backend,
db_add_domain,
db_remove_domain,
db_find_available_pool,
db_get_domains_sharing_pool,
db_load_servers_config,
db_add_server,
db_remove_server,
db_remove_domain_servers,
db_add_shared_domain,
db_get_shared_domain,
db_is_shared_domain,
db_load_certs,
db_add_cert,
db_remove_cert,
sync_map_files,
SCHEMA_VERSION,
)
class TestConnectionManagement:
"""Tests for database connection management."""
def test_get_connection(self, patch_config_paths):
"""Get a connection returns a valid SQLite connection."""
conn = get_connection()
assert conn is not None
# Verify WAL mode
result = conn.execute("PRAGMA journal_mode").fetchone()
assert result[0] == "wal"
def test_connection_is_thread_local(self, patch_config_paths):
"""Same thread gets same connection."""
conn1 = get_connection()
conn2 = get_connection()
assert conn1 is conn2
def test_close_connection(self, patch_config_paths):
"""Close connection clears thread-local."""
conn1 = get_connection()
close_connection()
conn2 = get_connection()
assert conn1 is not conn2
class TestInitDb:
"""Tests for database initialization."""
def test_creates_tables(self, patch_config_paths):
"""init_db creates all required tables."""
conn = get_connection()
# Check tables exist
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
).fetchall()
table_names = [t["name"] for t in tables]
assert "domains" in table_names
assert "servers" in table_names
assert "certificates" in table_names
assert "schema_version" in table_names
def test_schema_version_recorded(self, patch_config_paths):
"""Schema version is recorded."""
conn = get_connection()
cur = conn.execute("SELECT MAX(version) FROM schema_version")
version = cur.fetchone()[0]
assert version == SCHEMA_VERSION
def test_idempotent(self, patch_config_paths):
"""Calling init_db twice is safe."""
# init_db is already called by patch_config_paths
# Calling it again should not raise
init_db()
conn = get_connection()
cur = conn.execute("SELECT COUNT(*) FROM schema_version")
# May have 1 or 2 entries but should not fail
assert cur.fetchone()[0] >= 1
class TestMigrateFromJson:
"""Tests for JSON to SQLite migration."""
def test_migrate_map_files(self, patch_config_paths):
"""Migrate domain entries from map files."""
# Write test map files
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
migrate_from_json()
entries = db_get_map_contents()
assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries
assert (".example.com", "pool_1") in entries
def test_migrate_servers_json(self, patch_config_paths):
"""Migrate server entries from servers.json."""
config = {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": 80},
"2": {"ip": "10.0.0.2", "http_port": 8080},
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
migrate_from_json()
result = db_load_servers_config()
assert "example.com" in result
assert result["example.com"]["1"]["ip"] == "10.0.0.1"
assert result["example.com"]["2"]["http_port"] == 8080
def test_migrate_shared_domains(self, patch_config_paths):
"""Migrate shared domain references."""
# First add the domain to map so it exists in DB
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("www.example.com pool_1\n")
config = {
"www.example.com": {"_shares": "example.com"},
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
migrate_from_json()
assert db_get_shared_domain("www.example.com") == "example.com"
def test_migrate_certificates(self, patch_config_paths):
"""Migrate certificate entries."""
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "api.example.com"]}, f)
migrate_from_json()
certs = db_load_certs()
assert "example.com" in certs
assert "api.example.com" in certs
def test_migrate_idempotent(self, patch_config_paths):
"""Migration is idempotent (INSERT OR IGNORE)."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["servers_file"], "w") as f:
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
migrate_from_json()
migrate_from_json() # Should not fail
entries = db_get_map_contents()
assert len([d for d, _ in entries if d == "example.com"]) == 1
def test_migrate_empty_files(self, patch_config_paths):
"""Migration with no existing data does nothing."""
os.unlink(patch_config_paths["map_file"])
os.unlink(patch_config_paths["wildcards_file"])
os.unlink(patch_config_paths["servers_file"])
os.unlink(patch_config_paths["certs_file"])
migrate_from_json() # Should not fail
def test_backup_files_after_migration(self, patch_config_paths):
"""Original JSON files are backed up after migration."""
with open(patch_config_paths["servers_file"], "w") as f:
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com"]}, f)
# Write map file so migration has data
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
migrate_from_json()
assert os.path.exists(f"{patch_config_paths['servers_file']}.bak")
assert os.path.exists(f"{patch_config_paths['certs_file']}.bak")
class TestDomainOperations:
"""Tests for domain data access functions."""
def test_add_and_get_domain(self, patch_config_paths):
"""Add domain and retrieve its backend."""
db_add_domain("example.com", "pool_1")
assert db_get_domain_backend("example.com") == "pool_1"
def test_get_nonexistent_domain(self, patch_config_paths):
"""Non-existent domain returns None."""
assert db_get_domain_backend("nonexistent.com") is None
def test_remove_domain(self, patch_config_paths):
"""Remove domain removes both exact and wildcard."""
db_add_domain("example.com", "pool_1")
db_add_domain(".example.com", "pool_1", is_wildcard=True)
db_remove_domain("example.com")
assert db_get_domain_backend("example.com") is None
entries = db_get_map_contents()
assert (".example.com", "pool_1") not in entries
def test_find_available_pool(self, patch_config_paths):
"""Find first available pool."""
db_add_domain("a.com", "pool_1")
db_add_domain("b.com", "pool_2")
pool = db_find_available_pool()
assert pool == "pool_3"
def test_find_available_pool_empty(self, patch_config_paths):
"""First pool available when none used."""
pool = db_find_available_pool()
assert pool == "pool_1"
def test_get_domains_sharing_pool(self, patch_config_paths):
"""Get non-wildcard domains using a pool."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_domain(".example.com", "pool_1", is_wildcard=True)
domains = db_get_domains_sharing_pool("pool_1")
assert "example.com" in domains
assert "www.example.com" in domains
assert ".example.com" not in domains
def test_update_domain_backend(self, patch_config_paths):
"""Updating a domain changes its backend."""
db_add_domain("example.com", "pool_1")
db_add_domain("example.com", "pool_5") # Update
assert db_get_domain_backend("example.com") == "pool_5"
class TestServerOperations:
"""Tests for server data access functions."""
def test_add_and_load_server(self, patch_config_paths):
"""Add server and load config."""
db_add_server("example.com", 1, "10.0.0.1", 80)
config = db_load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["1"]["http_port"] == 80
def test_update_server(self, patch_config_paths):
"""Update existing server slot."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 1, "10.0.0.99", 8080)
config = db_load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
assert config["example.com"]["1"]["http_port"] == 8080
def test_remove_server(self, patch_config_paths):
"""Remove a server slot."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 2, "10.0.0.2", 80)
db_remove_server("example.com", 1)
config = db_load_servers_config()
assert "1" not in config.get("example.com", {})
assert "2" in config["example.com"]
def test_remove_domain_servers(self, patch_config_paths):
"""Remove all servers for a domain."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 2, "10.0.0.2", 80)
db_add_server("other.com", 1, "10.0.0.3", 80)
db_remove_domain_servers("example.com")
config = db_load_servers_config()
assert config.get("example.com", {}).get("1") is None
assert "other.com" in config
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty dict."""
config = db_load_servers_config()
assert config == {}
class TestSharedDomainOperations:
"""Tests for shared domain functions."""
def test_add_and_get_shared(self, patch_config_paths):
"""Add shared domain reference."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
assert db_get_shared_domain("www.example.com") == "example.com"
def test_is_shared(self, patch_config_paths):
"""Check if domain is shared."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
assert db_is_shared_domain("www.example.com") is True
assert db_is_shared_domain("example.com") is False
def test_not_shared(self, patch_config_paths):
"""Non-shared domain returns None."""
db_add_domain("example.com", "pool_1")
assert db_get_shared_domain("example.com") is None
assert db_is_shared_domain("example.com") is False
def test_shared_in_load_config(self, patch_config_paths):
"""Shared domain appears in load_servers_config."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
config = db_load_servers_config()
assert config["www.example.com"]["_shares"] == "example.com"
class TestCertificateOperations:
"""Tests for certificate data access functions."""
def test_add_and_load_cert(self, patch_config_paths):
"""Add and load certificate."""
db_add_cert("example.com")
certs = db_load_certs()
assert "example.com" in certs
def test_add_duplicate(self, patch_config_paths):
"""Adding duplicate cert is a no-op."""
db_add_cert("example.com")
db_add_cert("example.com")
certs = db_load_certs()
assert certs.count("example.com") == 1
def test_remove_cert(self, patch_config_paths):
"""Remove a certificate."""
db_add_cert("example.com")
db_add_cert("other.com")
db_remove_cert("example.com")
certs = db_load_certs()
assert "example.com" not in certs
assert "other.com" in certs
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty list."""
certs = db_load_certs()
assert certs == []
def test_sorted_output(self, patch_config_paths):
"""Certificates are returned sorted."""
db_add_cert("z.com")
db_add_cert("a.com")
db_add_cert("m.com")
certs = db_load_certs()
assert certs == ["a.com", "m.com", "z.com"]
class TestSyncMapFiles:
"""Tests for sync_map_files function."""
def test_sync_exact_domains(self, patch_config_paths):
"""Sync writes exact domains to domains.map."""
db_add_domain("example.com", "pool_1")
db_add_domain("api.example.com", "pool_2")
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "example.com pool_1" in content
assert "api.example.com pool_2" in content
def test_sync_wildcards(self, patch_config_paths):
"""Sync writes wildcards to wildcards.map."""
db_add_domain(".example.com", "pool_1", is_wildcard=True)
sync_map_files()
with open(patch_config_paths["wildcards_file"]) as f:
content = f.read()
assert ".example.com pool_1" in content
def test_sync_empty(self, patch_config_paths):
"""Sync with no domains writes headers only."""
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "Exact Domain" in content
# No domain entries
lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")]
assert len(lines) == 0
def test_sync_sorted(self, patch_config_paths):
"""Sync output is sorted."""
db_add_domain("z.com", "pool_3")
db_add_domain("a.com", "pool_1")
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
assert lines[0] == "a.com pool_1"
assert lines[1] == "z.com pool_3"

View File

@@ -1,6 +1,5 @@
"""Unit tests for file_ops module."""
"""Unit tests for file_ops module (SQLite-backed)."""
import json
import os
from unittest.mock import patch
@@ -16,14 +15,19 @@ from haproxy_mcp.file_ops import (
get_legacy_backend_name,
get_backend_and_prefix,
load_servers_config,
save_servers_config,
add_server_to_config,
remove_server_from_config,
remove_domain_from_config,
load_certs_config,
save_certs_config,
add_cert_to_config,
remove_cert_from_config,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
add_shared_domain_to_config,
get_shared_domain,
is_shared_domain,
get_domains_sharing_pool,
)
@@ -62,7 +66,7 @@ class TestAtomicWriteFile:
def test_unicode_content(self, tmp_path):
"""Unicode content is properly written."""
file_path = str(tmp_path / "unicode.txt")
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese
content = "Hello, \u4e16\u754c!"
atomic_write_file(file_path, content)
@@ -81,66 +85,33 @@ class TestAtomicWriteFile:
class TestGetMapContents:
"""Tests for get_map_contents function."""
"""Tests for get_map_contents function (SQLite-backed)."""
def test_read_map_file(self, patch_config_paths):
"""Read entries from map file."""
# Write test content to map file
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
def test_empty_db(self, patch_config_paths):
"""Empty database returns empty list."""
entries = get_map_contents()
assert entries == []
def test_read_domains(self, patch_config_paths):
"""Read entries from database."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("api.example.com", "pool_2")
entries = get_map_contents()
assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries
def test_read_both_map_files(self, patch_config_paths):
"""Read entries from both domains.map and wildcards.map."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
def test_read_with_wildcards(self, patch_config_paths):
"""Read entries including wildcards."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
entries = get_map_contents()
assert ("example.com", "pool_1") in entries
assert (".example.com", "pool_1") in entries
def test_skip_comments(self, patch_config_paths):
"""Comments are skipped."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("# This is a comment\n")
f.write("example.com pool_1\n")
f.write("# Another comment\n")
entries = get_map_contents()
assert len(entries) == 1
assert entries[0] == ("example.com", "pool_1")
def test_skip_empty_lines(self, patch_config_paths):
"""Empty lines are skipped."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("\n")
f.write("example.com pool_1\n")
f.write("\n")
f.write("api.example.com pool_2\n")
entries = get_map_contents()
assert len(entries) == 2
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty list."""
os.unlink(patch_config_paths["map_file"])
os.unlink(patch_config_paths["wildcards_file"])
entries = get_map_contents()
assert entries == []
class TestSplitDomainEntries:
"""Tests for split_domain_entries function."""
@@ -182,36 +153,30 @@ class TestSplitDomainEntries:
class TestSaveMapFile:
"""Tests for save_map_file function."""
"""Tests for save_map_file function (syncs from DB to map files)."""
def test_save_entries(self, patch_config_paths):
"""Save entries to separate map files."""
entries = [
("example.com", "pool_1"),
(".example.com", "pool_1"),
]
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
save_map_file(entries)
save_map_file([]) # Entries param ignored, syncs from DB
# Check exact domains file
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "example.com pool_1" in content
# Check wildcards file
with open(patch_config_paths["wildcards_file"]) as f:
content = f.read()
assert ".example.com pool_1" in content
def test_sorted_output(self, patch_config_paths):
"""Entries are sorted in output."""
entries = [
("z.example.com", "pool_3"),
("a.example.com", "pool_1"),
("m.example.com", "pool_2"),
]
add_domain_to_map("z.example.com", "pool_3")
add_domain_to_map("a.example.com", "pool_1")
add_domain_to_map("m.example.com", "pool_2")
save_map_file(entries)
save_map_file([])
with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
@@ -222,12 +187,11 @@ class TestSaveMapFile:
class TestGetDomainBackend:
"""Tests for get_domain_backend function."""
"""Tests for get_domain_backend function (SQLite-backed)."""
def test_find_existing_domain(self, patch_config_paths):
"""Find backend for existing domain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
backend = get_domain_backend("example.com")
@@ -235,8 +199,7 @@ class TestGetDomainBackend:
def test_domain_not_found(self, patch_config_paths):
"""Non-existent domain returns None."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
backend = get_domain_backend("other.com")
@@ -271,8 +234,7 @@ class TestGetBackendAndPrefix:
def test_pool_backend(self, patch_config_paths):
"""Pool backend returns pool-based prefix."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_5\n")
add_domain_to_map("example.com", "pool_5")
backend, prefix = get_backend_and_prefix("example.com")
@@ -288,48 +250,33 @@ class TestGetBackendAndPrefix:
class TestLoadServersConfig:
"""Tests for load_servers_config function."""
"""Tests for load_servers_config function (SQLite-backed)."""
def test_load_existing_config(self, patch_config_paths, sample_servers_config):
"""Load existing config file."""
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(sample_servers_config, f)
def test_load_empty_config(self, patch_config_paths):
"""Empty database returns empty dict."""
config = load_servers_config()
assert config == {}
def test_load_with_servers(self, patch_config_paths):
"""Load config with server entries."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
config = load_servers_config()
assert "example.com" in config
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["2"]["ip"] == "10.0.0.2"
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty dict."""
os.unlink(patch_config_paths["servers_file"])
def test_load_with_shared_domain(self, patch_config_paths):
"""Load config with shared domain reference."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
config = load_servers_config()
assert config == {}
def test_invalid_json(self, patch_config_paths):
"""Invalid JSON returns empty dict."""
with open(patch_config_paths["servers_file"], "w") as f:
f.write("not valid json {{{")
config = load_servers_config()
assert config == {}
class TestSaveServersConfig:
"""Tests for save_servers_config function."""
def test_save_config(self, patch_config_paths):
"""Save config to file."""
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
save_servers_config(config)
with open(patch_config_paths["servers_file"]) as f:
loaded = json.load(f)
assert loaded == config
assert config["www.example.com"]["_shares"] == "example.com"
class TestAddServerToConfig:
@@ -373,17 +320,18 @@ class TestRemoveServerFromConfig:
remove_server_from_config("example.com", 1)
config = load_servers_config()
assert "1" not in config["example.com"]
assert "1" not in config.get("example.com", {})
assert "2" in config["example.com"]
def test_remove_last_server_removes_domain(self, patch_config_paths):
"""Removing last server removes domain entry."""
"""Removing last server removes domain entry from servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_server_from_config("example.com", 1)
config = load_servers_config()
assert "example.com" not in config
# Domain may or may not exist (no servers = no entry)
assert config.get("example.com", {}).get("1") is None
def test_remove_nonexistent_server(self, patch_config_paths):
"""Removing non-existent server is a no-op."""
@@ -399,14 +347,14 @@ class TestRemoveDomainFromConfig:
"""Tests for remove_domain_from_config function."""
def test_remove_existing_domain(self, patch_config_paths):
"""Remove existing domain."""
"""Remove existing domain's servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("other.com", 1, "10.0.0.2", 80)
remove_domain_from_config("example.com")
config = load_servers_config()
assert "example.com" not in config
assert config.get("example.com", {}).get("1") is None
assert "other.com" in config
def test_remove_nonexistent_domain(self, patch_config_paths):
@@ -420,40 +368,23 @@ class TestRemoveDomainFromConfig:
class TestLoadCertsConfig:
"""Tests for load_certs_config function."""
"""Tests for load_certs_config function (SQLite-backed)."""
def test_load_existing_config(self, patch_config_paths):
"""Load existing certs config."""
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "other.com"]}, f)
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty list."""
domains = load_certs_config()
assert domains == []
def test_load_with_certs(self, patch_config_paths):
"""Load certs from database."""
add_cert_to_config("example.com")
add_cert_to_config("other.com")
domains = load_certs_config()
assert "example.com" in domains
assert "other.com" in domains
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty list."""
os.unlink(patch_config_paths["certs_file"])
domains = load_certs_config()
assert domains == []
class TestSaveCertsConfig:
"""Tests for save_certs_config function."""
def test_save_domains(self, patch_config_paths):
"""Save domains to certs config."""
save_certs_config(["z.com", "a.com"])
with open(patch_config_paths["certs_file"]) as f:
data = json.load(f)
# Should be sorted
assert data["domains"] == ["a.com", "z.com"]
class TestAddCertToConfig:
"""Tests for add_cert_to_config function."""
@@ -496,3 +427,87 @@ class TestRemoveCertFromConfig:
domains = load_certs_config()
assert "example.com" in domains
class TestAddDomainToMap:
"""Tests for add_domain_to_map function."""
def test_add_domain(self, patch_config_paths):
"""Add a domain and verify map files are synced."""
add_domain_to_map("example.com", "pool_1")
assert get_domain_backend("example.com") == "pool_1"
with open(patch_config_paths["map_file"]) as f:
assert "example.com pool_1" in f.read()
def test_add_wildcard(self, patch_config_paths):
"""Add a wildcard domain."""
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
entries = get_map_contents()
assert (".example.com", "pool_1") in entries
class TestRemoveDomainFromMap:
"""Tests for remove_domain_from_map function."""
def test_remove_domain(self, patch_config_paths):
"""Remove a domain and its wildcard."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
remove_domain_from_map("example.com")
assert get_domain_backend("example.com") is None
entries = get_map_contents()
assert (".example.com", "pool_1") not in entries
class TestFindAvailablePool:
"""Tests for find_available_pool function."""
def test_first_pool_available(self, patch_config_paths):
"""When no domains exist, pool_1 is returned."""
pool = find_available_pool()
assert pool == "pool_1"
def test_skip_used_pools(self, patch_config_paths):
"""Used pools are skipped."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("other.com", "pool_2")
pool = find_available_pool()
assert pool == "pool_3"
class TestSharedDomains:
"""Tests for shared domain functions."""
def test_get_shared_domain(self, patch_config_paths):
"""Get parent domain for shared domain."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
assert get_shared_domain("www.example.com") == "example.com"
def test_is_shared_domain(self, patch_config_paths):
"""Check if domain is shared."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
assert is_shared_domain("www.example.com") is True
assert is_shared_domain("example.com") is False
def test_get_domains_sharing_pool(self, patch_config_paths):
"""Get all domains using a pool."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
domains = get_domains_sharing_pool("pool_1")
assert "example.com" in domains
assert "www.example.com" in domains
assert ".example.com" not in domains # Wildcards excluded

View File

@@ -6,6 +6,8 @@ from unittest.mock import patch, MagicMock
import pytest
from haproxy_mcp.file_ops import add_cert_to_config
class TestGetPemPaths:
"""Tests for get_pem_paths function."""
@@ -127,8 +129,7 @@ class TestRestoreCertificates:
def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
"""Restore certificates successfully."""
# Save config
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com"]}, f)
add_cert_to_config("example.com")
# Create PEM
certs_dir = tmp_path / "certs"
@@ -283,12 +284,18 @@ class TestHaproxyCertInfo:
pem_file = tmp_path / "example.com.pem"
pem_file.write_text("cert content")
mock_subprocess.return_value = MagicMock(
def subprocess_side_effect(*args, **kwargs):
cmd = args[0] if args else kwargs.get("args", [])
if isinstance(cmd, list) and "stat" in cmd:
return MagicMock(returncode=0, stdout="1704067200", stderr="")
return MagicMock(
returncode=0,
stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT",
stderr=""
)
mock_subprocess.side_effect = subprocess_side_effect
mock_sock = mock_socket_class(responses={
"show ssl cert": "/etc/haproxy/certs/example.com.pem",
})
@@ -337,10 +344,18 @@ class TestHaproxyIssueCert:
assert "Error" in result
assert "Invalid domain" in result
def test_issue_cert_no_cf_token(self, tmp_path):
def test_issue_cert_no_cf_token(self, tmp_path, mock_subprocess):
"""Fail when CF_Token is not set."""
acme_sh = str(tmp_path / "acme.sh")
mock_subprocess.return_value = MagicMock(
returncode=1,
stdout="",
stderr="CF_Token is not set. Please export CF_Token environment variable.",
)
with patch.dict(os.environ, {}, clear=True):
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)):
with patch("haproxy_mcp.tools.certificates.ACME_SH", acme_sh):
with patch("os.path.exists", return_value=False):
from haproxy_mcp.tools.certificates import register_certificate_tools
mcp = MagicMock()
@@ -845,8 +860,8 @@ class TestHaproxyRenewAllCertsMultiple:
def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path):
"""Renew multiple certificates successfully."""
# Write config with multiple domains
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "example.org"]}, f)
add_cert_to_config("example.com")
add_cert_to_config("example.org")
# Create PEM files
certs_dir = tmp_path / "certs"
@@ -1038,16 +1053,18 @@ class TestHaproxyDeleteCertPartialFailure:
"show ssl cert": "", # Not loaded
})
# Mock os.remove to fail
def mock_remove(path):
if "example.com.pem" in str(path):
raise PermissionError("Permission denied")
raise FileNotFoundError()
# Mock subprocess to succeed for acme.sh remove but fail for rm (PEM removal)
def subprocess_side_effect(*args, **kwargs):
cmd = args[0] if args else kwargs.get("args", [])
if isinstance(cmd, list) and cmd[0] == "rm":
return MagicMock(returncode=1, stdout="", stderr="Permission denied")
return MagicMock(returncode=0, stdout="", stderr="")
mock_subprocess.side_effect = subprocess_side_effect
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")):
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
with patch("socket.socket", return_value=mock_sock):
with patch("os.remove", side_effect=mock_remove):
from haproxy_mcp.tools.certificates import register_certificate_tools
mcp = MagicMock()
registered_tools = {}
@@ -1118,8 +1135,8 @@ class TestRestoreCertificatesFailure:
def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
"""Handle partial failure when restoring certificates."""
# Save config with multiple domains
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "missing.com"]}, f)
add_cert_to_config("example.com")
add_cert_to_config("missing.com")
# Create only one PEM file
certs_dir = tmp_path / "certs"

View File

@@ -5,6 +5,8 @@ from unittest.mock import patch, MagicMock
import pytest
from haproxy_mcp.file_ops import add_domain_to_map, add_server_to_config
class TestRestoreServersFromConfig:
"""Tests for restore_servers_from_config function."""
@@ -19,12 +21,12 @@ class TestRestoreServersFromConfig:
def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
"""Restore servers successfully."""
# Write config and map
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(sample_servers_config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
# Add domains and servers to database
add_domain_to_map("example.com", "pool_1")
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
add_domain_to_map("api.example.com", "pool_2")
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
mock_sock = mock_socket_class(responses={
"set server": "",
@@ -40,9 +42,8 @@ class TestRestoreServersFromConfig:
def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths):
"""Skip domains not in map file."""
config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
# Add server for unknown.com but no map entry (simulates missing domain)
add_server_to_config("unknown.com", 1, "10.0.0.1", 80)
mock_sock = mock_socket_class(responses={"set server": ""})
@@ -55,11 +56,9 @@ class TestRestoreServersFromConfig:
def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths):
"""Skip servers with empty IP."""
config = {"example.com": {"1": {"ip": "", "http_port": 80}}}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
# Add domain to map and server with empty IP (will be skipped during restore)
add_domain_to_map("example.com", "pool_1")
add_server_to_config("example.com", 1, "", 80)
mock_sock = mock_socket_class(responses={"set server": ""})
@@ -321,11 +320,12 @@ class TestHaproxyRestoreState:
def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
"""Restore state successfully."""
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(sample_servers_config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
# Add domains and servers to database
add_domain_to_map("example.com", "pool_1")
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
add_domain_to_map("api.example.com", "pool_2")
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
mock_sock = mock_socket_class(responses={"set server": ""})
@@ -373,17 +373,10 @@ class TestRestoreServersFromConfigBatchFailure:
def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths):
"""Fall back to individual commands when batch fails."""
# Create config with servers
config = {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": 80},
"2": {"ip": "10.0.0.2", "http_port": 80},
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
# Add domain and servers to database
add_domain_to_map("example.com", "pool_1")
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
# Track call count to simulate batch failure then individual success
call_count = [0]
@@ -457,17 +450,17 @@ class TestRestoreServersFromConfigBatchFailure:
def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths):
"""Skip servers with invalid slot number."""
# Add domain to map
add_domain_to_map("example.com", "pool_1")
# Mock load_servers_config to return config with invalid slot
config = {
"example.com": {
"invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot
"1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
mock_sock = mock_socket_class(responses={"set server": ""})
with patch("socket.socket", return_value=mock_sock):
@@ -481,17 +474,17 @@ class TestRestoreServersFromConfigBatchFailure:
def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog):
"""Skip servers with invalid port."""
import logging
# Add domain to map
add_domain_to_map("example.com", "pool_1")
# Mock load_servers_config to return config with invalid port
config = {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port
"2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
mock_sock = mock_socket_class(responses={"set server": ""})
with patch("socket.socket", return_value=mock_sock):
@@ -658,11 +651,12 @@ class TestHaproxyRestoreStateFailures:
"""Handle HAProxy error when restoring state."""
from haproxy_mcp.exceptions import HaproxyError
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(sample_servers_config, f)
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
# Add domains and servers to database
add_domain_to_map("example.com", "pool_1")
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
add_domain_to_map("api.example.com", "pool_2")
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")):
from haproxy_mcp.tools.configuration import register_config_tools

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
import pytest
from haproxy_mcp.exceptions import HaproxyError
from haproxy_mcp.file_ops import add_domain_to_map
class TestHaproxyListDomains:
@@ -38,9 +39,8 @@ class TestHaproxyListDomains:
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List domains with configured servers."""
# Write map file
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
# Add domain to DB
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -70,10 +70,8 @@ class TestHaproxyListDomains:
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List domains excluding wildcards by default."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([]),
@@ -100,10 +98,8 @@ class TestHaproxyListDomains:
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List domains including wildcards when requested."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([]),
@@ -230,8 +226,7 @@ class TestHaproxyAddDomain:
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Reject adding domain that already exists."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
from haproxy_mcp.tools.domains import register_domain_tools
mcp = MagicMock()
@@ -362,8 +357,7 @@ class TestHaproxyRemoveDomain:
def test_remove_legacy_domain_rejected(self, patch_config_paths):
"""Reject removing legacy (non-pool) domain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com legacy_backend\n")
add_domain_to_map("example.com", "legacy_backend")
from haproxy_mcp.tools.domains import register_domain_tools
mcp = MagicMock()
@@ -385,10 +379,8 @@ class TestHaproxyRemoveDomain:
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully remove domain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
mock_sock = mock_socket_class(responses={
"del map": "",

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
import pytest
from haproxy_mcp.exceptions import HaproxyError
from haproxy_mcp.file_ops import add_domain_to_map
class TestHaproxyHealth:
@@ -80,7 +81,7 @@ class TestHaproxyHealth:
# Use paths that don't exist
with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")):
with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")):
with patch("haproxy_mcp.tools.health.DB_FILE", str(tmp_path / "nonexistent.db")):
with patch("socket.socket", return_value=mock_sock):
from haproxy_mcp.tools.health import register_health_tools
mcp = MagicMock()
@@ -160,8 +161,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns healthy when all servers are UP."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -197,8 +197,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns degraded when some servers are DOWN."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -234,8 +233,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns down when all servers are DOWN."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -269,8 +267,7 @@ class TestHaproxyDomainHealth:
def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Domain health returns no_servers when no servers configured."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([

View File

@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
import pytest
from haproxy_mcp.exceptions import HaproxyError
from haproxy_mcp.file_ops import add_domain_to_map, load_servers_config
class TestHaproxyListServers:
@@ -33,8 +34,7 @@ class TestHaproxyListServers:
def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List servers for domain with no servers."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -63,8 +63,7 @@ class TestHaproxyListServers:
def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""List servers with active servers."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -224,8 +223,7 @@ class TestHaproxyAddServer:
def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully add server."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"set server": "",
@@ -258,8 +256,7 @@ class TestHaproxyAddServer:
def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Auto-select slot when slot=0."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -413,8 +410,7 @@ class TestHaproxyAddServers:
def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully add multiple servers."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"set server": "",
@@ -495,8 +491,7 @@ class TestHaproxyRemoveServer:
def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Successfully remove server."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"set server": "",
@@ -689,8 +684,7 @@ class TestHaproxyAddServersRollback:
def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths):
"""Rollback only failed slots when HAProxy error occurs."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
# Mock configure_server_slot to fail on second slot
call_count = [0]
@@ -735,18 +729,16 @@ class TestHaproxyAddServersRollback:
assert "slot 2" in result # Failed
# Verify servers.json only has successfully added server
with open(patch_config_paths["servers_file"], "r") as f:
config = json.load(f)
config = load_servers_config()
assert "example.com" in config
assert "1" in config["example.com"] # Successfully added stays
assert "2" not in config["example.com"] # Failed one was rolled back
assert "2" not in config.get("example.com", {}) # Failed one was rolled back
def test_add_servers_unexpected_error_rollback_only_successful(
self, mock_socket_class, mock_select, patch_config_paths
):
"""Rollback only successfully added servers on unexpected error."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
# Track which servers were configured
configured_slots = []
@@ -793,8 +785,7 @@ class TestHaproxyAddServersRollback:
assert "Unexpected system error" in result
# Verify servers.json is empty (all rolled back)
with open(patch_config_paths["servers_file"], "r") as f:
config = json.load(f)
config = load_servers_config()
assert config == {} or "example.com" not in config or config.get("example.com") == {}
def test_add_servers_rollback_failure_logged(
@@ -802,8 +793,7 @@ class TestHaproxyAddServersRollback:
):
"""Log rollback failures during error recovery."""
import logging
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
if slot == 2:
@@ -858,8 +848,7 @@ class TestHaproxyAddServerAutoSlot:
def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Auto-select slot fails when all slots are in use."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
# Build response with all 10 slots used
servers = []
@@ -902,8 +891,7 @@ class TestHaproxyAddServerAutoSlot:
def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Negative slot number triggers auto-selection."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -970,8 +958,7 @@ class TestHaproxyWaitDrain:
def test_wait_drain_success(self, patch_config_paths):
"""Successfully wait for connections to drain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
# Mock haproxy_cmd to return 0 connections
with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd:
@@ -1000,8 +987,7 @@ class TestHaproxyWaitDrain:
def test_wait_drain_timeout(self, patch_config_paths):
"""Timeout when connections don't drain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing
time_iter = iter(time_values)
@@ -1087,9 +1073,6 @@ class TestHaproxyWaitDrain:
def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths):
"""Error when domain not found in map."""
# Empty map file - domain not configured
with open(patch_config_paths["map_file"], "w") as f:
f.write("")
from haproxy_mcp.tools.servers import register_server_tools
mcp = MagicMock()
registered_tools = {}
@@ -1202,8 +1185,7 @@ class TestHaproxySetDomainState:
def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Set all servers of a domain to a state."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([
@@ -1283,8 +1265,7 @@ class TestHaproxySetDomainState:
def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""No active servers found for domain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
add_domain_to_map("example.com", "pool_1")
# All servers have 0.0.0.0 address (not configured)
mock_sock = mock_socket_class(responses={
@@ -1318,9 +1299,6 @@ class TestHaproxySetDomainState:
def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
"""Handle domain not found in map - shows no active servers."""
# Empty map file
with open(patch_config_paths["map_file"], "w") as f:
f.write("")
# Mock should show no servers for unknown domain's backend
mock_sock = mock_socket_class(responses={
"show servers state": response_builder.servers_state([]),