refactor: migrate data storage from JSON/map files to SQLite
Replace servers.json, certificates.json, and map file parsing with SQLite (WAL mode) as single source of truth. HAProxy map files are now generated from SQLite via sync_map_files(). Key changes: - Add db.py with schema, connection management, and JSON migration - Add DB_FILE config constant - Delegate file_ops.py functions to db.py - Refactor domains.py to use file_ops instead of direct list manipulation - Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError) - Add DB health check in health.py - Init DB on startup in server.py and __main__.py - Update all 359 tests to use SQLite-backed functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -16,6 +16,9 @@ run/
|
|||||||
data/
|
data/
|
||||||
*.state
|
*.state
|
||||||
*.lock
|
*.lock
|
||||||
|
*.db
|
||||||
|
*.db-wal
|
||||||
|
*.db-shm
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""Entry point for running haproxy_mcp as a module."""
|
"""Entry point for running haproxy_mcp as a module."""
|
||||||
|
|
||||||
|
from .db import init_db
|
||||||
from .server import mcp
|
from .server import mcp
|
||||||
from .tools.configuration import startup_restore
|
from .tools.configuration import startup_restore
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
init_db()
|
||||||
startup_restore()
|
startup_restore()
|
||||||
mcp.run(transport="streamable-http")
|
mcp.run(transport="streamable-http")
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ WILDCARDS_MAP_FILE: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE", "/opt/haproxy/
|
|||||||
WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map")
|
WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map")
|
||||||
SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json")
|
SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json")
|
||||||
CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json")
|
CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json")
|
||||||
|
DB_FILE: str = os.getenv("HAPROXY_DB_FILE", "/opt/haproxy/conf/haproxy_mcp.db")
|
||||||
|
|
||||||
# Certificate paths
|
# Certificate paths
|
||||||
CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs")
|
CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs")
|
||||||
|
|||||||
577
haproxy_mcp/db.py
Normal file
577
haproxy_mcp/db.py
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
"""SQLite database operations for HAProxy MCP Server.
|
||||||
|
|
||||||
|
Single source of truth for all persistent configuration data.
|
||||||
|
Replaces JSON files (servers.json, certificates.json) and map files
|
||||||
|
(domains.map, wildcards.map) with a single SQLite database.
|
||||||
|
|
||||||
|
Map files are generated from the database via sync_map_files() for
|
||||||
|
HAProxy to consume directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
DB_FILE,
|
||||||
|
MAP_FILE,
|
||||||
|
WILDCARDS_MAP_FILE,
|
||||||
|
SERVERS_FILE,
|
||||||
|
CERTS_FILE,
|
||||||
|
POOL_COUNT,
|
||||||
|
REMOTE_MODE,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection() -> sqlite3.Connection:
|
||||||
|
"""Get a thread-local SQLite connection with WAL mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sqlite3.Connection configured for concurrent access.
|
||||||
|
"""
|
||||||
|
conn = getattr(_local, "connection", None)
|
||||||
|
if conn is not None:
|
||||||
|
return conn
|
||||||
|
|
||||||
|
conn = sqlite3.connect(DB_FILE, timeout=10)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA busy_timeout=5000")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
_local.connection = conn
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def close_connection() -> None:
|
||||||
|
"""Close the thread-local SQLite connection."""
|
||||||
|
conn = getattr(_local, "connection", None)
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
_local.connection = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_db() -> None:
|
||||||
|
"""Initialize database schema and run migration if needed.
|
||||||
|
|
||||||
|
Creates tables if they don't exist, then checks for existing
|
||||||
|
JSON/map files to migrate data from.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
cur.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS domains (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
domain TEXT NOT NULL UNIQUE,
|
||||||
|
backend TEXT NOT NULL,
|
||||||
|
is_wildcard INTEGER NOT NULL DEFAULT 0,
|
||||||
|
shares_with TEXT DEFAULT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS servers (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
domain TEXT NOT NULL,
|
||||||
|
slot INTEGER NOT NULL,
|
||||||
|
ip TEXT NOT NULL,
|
||||||
|
http_port INTEGER NOT NULL DEFAULT 80,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
UNIQUE(domain, slot)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS certificates (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
domain TEXT NOT NULL UNIQUE,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
version INTEGER NOT NULL,
|
||||||
|
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# Check schema version
|
||||||
|
cur.execute("SELECT MAX(version) FROM schema_version")
|
||||||
|
row = cur.fetchone()
|
||||||
|
current_version = row[0] if row[0] is not None else 0
|
||||||
|
|
||||||
|
if current_version < SCHEMA_VERSION:
|
||||||
|
# First-time setup: try migrating from JSON files
|
||||||
|
if current_version == 0:
|
||||||
|
migrate_from_json()
|
||||||
|
cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
logger.info("Database initialized (schema v%d)", SCHEMA_VERSION)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_from_json() -> None:
|
||||||
|
"""Migrate data from JSON/map files to SQLite.
|
||||||
|
|
||||||
|
Reads domains.map, wildcards.map, servers.json, and certificates.json,
|
||||||
|
imports their data into the database within a single transaction,
|
||||||
|
then renames the original files to .bak.
|
||||||
|
|
||||||
|
This function is idempotent (uses INSERT OR IGNORE).
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
|
||||||
|
# Collect data from existing files
|
||||||
|
map_entries: list[tuple[str, str]] = []
|
||||||
|
servers_config: dict[str, Any] = {}
|
||||||
|
cert_domains: list[str] = []
|
||||||
|
|
||||||
|
# Read map files
|
||||||
|
for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]:
|
||||||
|
try:
|
||||||
|
content = _read_file_for_migration(map_path)
|
||||||
|
for line in content.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
map_entries.append((parts[0], parts[1]))
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("Map file not found for migration: %s", map_path)
|
||||||
|
|
||||||
|
# Read servers.json
|
||||||
|
try:
|
||||||
|
content = _read_file_for_migration(SERVERS_FILE)
|
||||||
|
servers_config = json.loads(content)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("Servers config not found for migration: %s", SERVERS_FILE)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning("Corrupt servers config during migration: %s", e)
|
||||||
|
|
||||||
|
# Read certificates.json
|
||||||
|
try:
|
||||||
|
content = _read_file_for_migration(CERTS_FILE)
|
||||||
|
data = json.loads(content)
|
||||||
|
cert_domains = data.get("domains", [])
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("Certs config not found for migration: %s", CERTS_FILE)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning("Corrupt certs config during migration: %s", e)
|
||||||
|
|
||||||
|
if not map_entries and not servers_config and not cert_domains:
|
||||||
|
logger.info("No existing data to migrate")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Import into database in a single transaction
|
||||||
|
try:
|
||||||
|
with conn:
|
||||||
|
# Import domains from map files
|
||||||
|
for domain, backend in map_entries:
|
||||||
|
is_wildcard = 1 if domain.startswith(".") else 0
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)",
|
||||||
|
(domain, backend, is_wildcard),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import servers and shares_with from servers.json
|
||||||
|
for domain, slots in servers_config.items():
|
||||||
|
shares_with = slots.get("_shares")
|
||||||
|
if shares_with:
|
||||||
|
# Update domain's shares_with field
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
|
||||||
|
(shares_with, domain),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for slot_str, server_info in slots.items():
|
||||||
|
if slot_str.startswith("_"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
slot = int(slot_str)
|
||||||
|
ip = server_info.get("ip", "")
|
||||||
|
http_port = int(server_info.get("http_port", 80))
|
||||||
|
if ip:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)",
|
||||||
|
(domain, slot, ip, http_port),
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e)
|
||||||
|
|
||||||
|
# Import certificates
|
||||||
|
for cert_domain in cert_domains:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
|
||||||
|
(cert_domain,),
|
||||||
|
)
|
||||||
|
|
||||||
|
migrated_domains = len(set(d for d, _ in map_entries))
|
||||||
|
migrated_servers = sum(
|
||||||
|
1 for d, slots in servers_config.items()
|
||||||
|
if "_shares" not in slots
|
||||||
|
for s in slots
|
||||||
|
if not s.startswith("_")
|
||||||
|
)
|
||||||
|
migrated_certs = len(cert_domains)
|
||||||
|
logger.info(
|
||||||
|
"Migrated from JSON: %d domain entries, %d servers, %d certificates",
|
||||||
|
migrated_domains, migrated_servers, migrated_certs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rename original files to .bak
|
||||||
|
for path in [SERVERS_FILE, CERTS_FILE]:
|
||||||
|
_backup_file(path)
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
logger.error("Migration failed: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _read_file_for_migration(file_path: str) -> str:
|
||||||
|
"""Read a file for migration purposes (local only, no locking needed)."""
|
||||||
|
if REMOTE_MODE:
|
||||||
|
from .ssh_ops import remote_read_file
|
||||||
|
return remote_read_file(file_path)
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _backup_file(file_path: str) -> None:
|
||||||
|
"""Rename a file to .bak if it exists."""
|
||||||
|
if REMOTE_MODE:
|
||||||
|
return # Don't rename remote files
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
bak_path = f"{file_path}.bak"
|
||||||
|
os.rename(file_path, bak_path)
|
||||||
|
logger.info("Backed up %s -> %s", file_path, bak_path)
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("Failed to backup %s: %s", file_path, e)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Domain data access ---
|
||||||
|
|
||||||
|
def db_get_map_contents() -> list[tuple[str, str]]:
|
||||||
|
"""Get all domain-to-backend mappings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (domain, backend) tuples including wildcards.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain")
|
||||||
|
return [(row["domain"], row["backend"]) for row in cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def db_get_domain_backend(domain: str) -> Optional[str]:
|
||||||
|
"""Look up the backend for a specific domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend name if found, None otherwise.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
return row["backend"] if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def db_add_domain(domain: str, backend: str, is_wildcard: bool = False,
|
||||||
|
shares_with: Optional[str] = None) -> None:
|
||||||
|
"""Add a domain to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name (e.g., "example.com" or ".example.com").
|
||||||
|
backend: Backend pool name (e.g., "pool_5").
|
||||||
|
is_wildcard: Whether this is a wildcard entry.
|
||||||
|
shares_with: Parent domain if sharing a pool.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO domains (domain, backend, is_wildcard, shares_with)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
ON CONFLICT(domain) DO UPDATE SET
|
||||||
|
backend = excluded.backend,
|
||||||
|
shares_with = excluded.shares_with,
|
||||||
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
|
||||||
|
(domain, backend, 1 if is_wildcard else 0, shares_with),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_remove_domain(domain: str) -> None:
|
||||||
|
"""Remove a domain (exact + wildcard) from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Base domain name (without leading dot).
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}"))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_find_available_pool() -> Optional[str]:
|
||||||
|
"""Find the first available pool not assigned to any domain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pool name (e.g., "pool_5") or None if all pools are in use.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0"
|
||||||
|
)
|
||||||
|
used_pools = {row["backend"] for row in cur.fetchall()}
|
||||||
|
|
||||||
|
for i in range(1, POOL_COUNT + 1):
|
||||||
|
pool_name = f"pool_{i}"
|
||||||
|
if pool_name not in used_pools:
|
||||||
|
return pool_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def db_get_domains_sharing_pool(pool: str) -> list[str]:
|
||||||
|
"""Get all non-wildcard domains using a specific pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: Pool name (e.g., "pool_5").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of domain names.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0",
|
||||||
|
(pool,),
|
||||||
|
)
|
||||||
|
return [row["domain"] for row in cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Server data access ---
|
||||||
|
|
||||||
|
def db_load_servers_config() -> dict[str, Any]:
|
||||||
|
"""Load servers configuration in the legacy dict format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary compatible with the old servers.json format:
|
||||||
|
{"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}}
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
config: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Load server entries
|
||||||
|
cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot")
|
||||||
|
for row in cur.fetchall():
|
||||||
|
domain = row["domain"]
|
||||||
|
if domain not in config:
|
||||||
|
config[domain] = {}
|
||||||
|
config[domain][str(row["slot"])] = {
|
||||||
|
"ip": row["ip"],
|
||||||
|
"http_port": row["http_port"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load shared domain references
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0"
|
||||||
|
)
|
||||||
|
for row in cur.fetchall():
|
||||||
|
domain = row["domain"]
|
||||||
|
if domain not in config:
|
||||||
|
config[domain] = {}
|
||||||
|
config[domain]["_shares"] = row["shares_with"]
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None:
|
||||||
|
"""Add or update a server entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
slot: Server slot number.
|
||||||
|
ip: Server IP address.
|
||||||
|
http_port: HTTP port.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO servers (domain, slot, ip, http_port)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
ON CONFLICT(domain, slot) DO UPDATE SET
|
||||||
|
ip = excluded.ip,
|
||||||
|
http_port = excluded.http_port,
|
||||||
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
|
||||||
|
(domain, slot, ip, http_port),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_remove_server(domain: str, slot: int) -> None:
|
||||||
|
"""Remove a server entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
slot: Server slot number.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_remove_domain_servers(domain: str) -> None:
|
||||||
|
"""Remove all server entries for a domain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute("DELETE FROM servers WHERE domain = ?", (domain,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_add_shared_domain(domain: str, shares_with: str) -> None:
|
||||||
|
"""Mark a domain as sharing a pool with another domain.
|
||||||
|
|
||||||
|
Updates the shares_with field in the domains table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: The new domain.
|
||||||
|
shares_with: The existing domain to share with.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
|
||||||
|
(shares_with, domain),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_get_shared_domain(domain: str) -> Optional[str]:
|
||||||
|
"""Get the parent domain that this domain shares a pool with.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parent domain name if sharing, None otherwise.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0",
|
||||||
|
(domain,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if row and row["shares_with"]:
|
||||||
|
return row["shares_with"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def db_is_shared_domain(domain: str) -> bool:
|
||||||
|
"""Check if a domain is sharing another domain's pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if domain has shares_with reference.
|
||||||
|
"""
|
||||||
|
return db_get_shared_domain(domain) is not None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Certificate data access ---
|
||||||
|
|
||||||
|
def db_load_certs() -> list[str]:
|
||||||
|
"""Load certificate domain list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of domain names with certificates.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute("SELECT domain FROM certificates ORDER BY domain")
|
||||||
|
return [row["domain"] for row in cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def db_add_cert(domain: str) -> None:
|
||||||
|
"""Add a domain to the certificate list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
|
||||||
|
(domain,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def db_remove_cert(domain: str) -> None:
|
||||||
|
"""Remove a domain from the certificate list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name.
|
||||||
|
"""
|
||||||
|
conn = get_connection()
|
||||||
|
conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Map file synchronization ---
|
||||||
|
|
||||||
|
def sync_map_files() -> None:
|
||||||
|
"""Regenerate HAProxy map files from database.
|
||||||
|
|
||||||
|
Writes domains.map (exact matches) and wildcards.map (wildcard entries)
|
||||||
|
from the current database state. Uses atomic writes.
|
||||||
|
"""
|
||||||
|
from .file_ops import atomic_write_file
|
||||||
|
|
||||||
|
conn = get_connection()
|
||||||
|
|
||||||
|
# Fetch exact domains
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain"
|
||||||
|
)
|
||||||
|
exact_entries = cur.fetchall()
|
||||||
|
|
||||||
|
# Fetch wildcard domains
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain"
|
||||||
|
)
|
||||||
|
wildcard_entries = cur.fetchall()
|
||||||
|
|
||||||
|
# Write exact domains map
|
||||||
|
exact_lines = [
|
||||||
|
"# Exact Domain to Backend mapping (for map_str)\n",
|
||||||
|
"# Format: domain backend_name\n",
|
||||||
|
"# Uses ebtree for O(log n) lookup performance\n\n",
|
||||||
|
]
|
||||||
|
for row in exact_entries:
|
||||||
|
exact_lines.append(f"{row['domain']} {row['backend']}\n")
|
||||||
|
atomic_write_file(MAP_FILE, "".join(exact_lines))
|
||||||
|
|
||||||
|
# Write wildcards map
|
||||||
|
wildcard_lines = [
|
||||||
|
"# Wildcard Domain to Backend mapping (for map_dom)\n",
|
||||||
|
"# Format: .domain.com backend_name (matches *.domain.com)\n",
|
||||||
|
"# Uses map_dom for suffix matching\n\n",
|
||||||
|
]
|
||||||
|
for row in wildcard_entries:
|
||||||
|
wildcard_lines.append(f"{row['domain']} {row['backend']}\n")
|
||||||
|
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
|
||||||
|
|
||||||
|
logger.debug("Map files synced: %d exact, %d wildcard entries",
|
||||||
|
len(exact_entries), len(wildcard_entries))
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
"""File I/O operations for HAProxy MCP Server."""
|
"""File I/O operations for HAProxy MCP Server.
|
||||||
|
|
||||||
|
Most data access is now delegated to db.py (SQLite).
|
||||||
|
This module retains atomic file writes, map file I/O for HAProxy,
|
||||||
|
and provides backward-compatible function signatures.
|
||||||
|
"""
|
||||||
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@@ -10,8 +14,6 @@ from typing import Any, Generator, Optional
|
|||||||
from .config import (
|
from .config import (
|
||||||
MAP_FILE,
|
MAP_FILE,
|
||||||
WILDCARDS_MAP_FILE,
|
WILDCARDS_MAP_FILE,
|
||||||
SERVERS_FILE,
|
|
||||||
CERTS_FILE,
|
|
||||||
REMOTE_MODE,
|
REMOTE_MODE,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
@@ -138,16 +140,13 @@ def _read_file(file_path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_map_contents() -> list[tuple[str, str]]:
|
def get_map_contents() -> list[tuple[str, str]]:
|
||||||
"""Read both domains.map and wildcards.map and return combined entries.
|
"""Get all domain-to-backend mappings from SQLite.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (domain, backend) tuples from both map files
|
List of (domain, backend) tuples including wildcards.
|
||||||
"""
|
"""
|
||||||
# Read exact domains
|
from .db import db_get_map_contents
|
||||||
entries = _read_map_file(MAP_FILE)
|
return db_get_map_contents()
|
||||||
# Read wildcards and append
|
|
||||||
entries.extend(_read_map_file(WILDCARDS_MAP_FILE))
|
|
||||||
return entries
|
|
||||||
|
|
||||||
|
|
||||||
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
||||||
@@ -170,44 +169,21 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str
|
|||||||
|
|
||||||
|
|
||||||
def save_map_file(entries: list[tuple[str, str]]) -> None:
|
def save_map_file(entries: list[tuple[str, str]]) -> None:
|
||||||
"""Save domain-to-backend entries to map files.
|
"""Sync map files from the database.
|
||||||
|
|
||||||
Splits entries into two files for 2-stage routing:
|
Regenerates domains.map and wildcards.map from the current
|
||||||
- domains.map: Exact matches (map_str, O(log n))
|
database state. The entries parameter is ignored (kept for
|
||||||
- wildcards.map: Wildcard entries starting with "." (map_dom, O(n))
|
backward compatibility during transition).
|
||||||
|
|
||||||
Args:
|
|
||||||
entries: List of (domain, backend) tuples.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
IOError: If map files cannot be written.
|
IOError: If map files cannot be written.
|
||||||
"""
|
"""
|
||||||
# Split into exact and wildcard entries
|
from .db import sync_map_files
|
||||||
exact_entries, wildcard_entries = split_domain_entries(entries)
|
sync_map_files()
|
||||||
|
|
||||||
# Save exact domains (for map_str - fast O(log n) lookup)
|
|
||||||
exact_lines = [
|
|
||||||
"# Exact Domain to Backend mapping (for map_str)\n",
|
|
||||||
"# Format: domain backend_name\n",
|
|
||||||
"# Uses ebtree for O(log n) lookup performance\n\n",
|
|
||||||
]
|
|
||||||
for domain, backend in sorted(exact_entries):
|
|
||||||
exact_lines.append(f"{domain} {backend}\n")
|
|
||||||
atomic_write_file(MAP_FILE, "".join(exact_lines))
|
|
||||||
|
|
||||||
# Save wildcards (for map_dom - O(n) but small set)
|
|
||||||
wildcard_lines = [
|
|
||||||
"# Wildcard Domain to Backend mapping (for map_dom)\n",
|
|
||||||
"# Format: .domain.com backend_name (matches *.domain.com)\n",
|
|
||||||
"# Uses map_dom for suffix matching\n\n",
|
|
||||||
]
|
|
||||||
for domain, backend in sorted(wildcard_entries):
|
|
||||||
wildcard_lines.append(f"{domain} {backend}\n")
|
|
||||||
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
|
|
||||||
|
|
||||||
|
|
||||||
def get_domain_backend(domain: str) -> Optional[str]:
|
def get_domain_backend(domain: str) -> Optional[str]:
|
||||||
"""Look up the backend for a domain from domains.map.
|
"""Look up the backend for a domain from SQLite (O(1)).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
domain: The domain to look up
|
domain: The domain to look up
|
||||||
@@ -215,10 +191,8 @@ def get_domain_backend(domain: str) -> Optional[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
Backend name if found, None otherwise
|
Backend name if found, None otherwise
|
||||||
"""
|
"""
|
||||||
for map_domain, backend in get_map_contents():
|
from .db import db_get_domain_backend
|
||||||
if map_domain == domain:
|
return db_get_domain_backend(domain)
|
||||||
return backend
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_legacy_backend(backend: str) -> bool:
|
def is_legacy_backend(backend: str) -> bool:
|
||||||
@@ -273,34 +247,17 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
def load_servers_config() -> dict[str, Any]:
|
def load_servers_config() -> dict[str, Any]:
|
||||||
"""Load servers configuration from JSON file.
|
"""Load servers configuration from SQLite.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with server configurations
|
Dictionary with server configurations (legacy format compatible).
|
||||||
"""
|
"""
|
||||||
try:
|
from .db import db_load_servers_config
|
||||||
content = _read_file(SERVERS_FILE)
|
return db_load_servers_config()
|
||||||
return json.loads(content)
|
|
||||||
except FileNotFoundError:
|
|
||||||
return {}
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def save_servers_config(config: dict[str, Any]) -> None:
|
|
||||||
"""Save servers configuration to JSON file atomically.
|
|
||||||
|
|
||||||
Uses temp file + rename for atomic write to prevent race conditions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Dictionary with server configurations
|
|
||||||
"""
|
|
||||||
atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2))
|
|
||||||
|
|
||||||
|
|
||||||
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
|
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
|
||||||
"""Add server configuration to persistent storage with file locking.
|
"""Add server configuration to persistent storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
domain: Domain name
|
domain: Domain name
|
||||||
@@ -308,41 +265,29 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non
|
|||||||
ip: Server IP address
|
ip: Server IP address
|
||||||
http_port: HTTP port
|
http_port: HTTP port
|
||||||
"""
|
"""
|
||||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
from .db import db_add_server
|
||||||
config = load_servers_config()
|
db_add_server(domain, slot, ip, http_port)
|
||||||
if domain not in config:
|
|
||||||
config[domain] = {}
|
|
||||||
config[domain][str(slot)] = {"ip": ip, "http_port": http_port}
|
|
||||||
save_servers_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_server_from_config(domain: str, slot: int) -> None:
|
def remove_server_from_config(domain: str, slot: int) -> None:
|
||||||
"""Remove server configuration from persistent storage with file locking.
|
"""Remove server configuration from persistent storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
domain: Domain name
|
domain: Domain name
|
||||||
slot: Server slot to remove
|
slot: Server slot to remove
|
||||||
"""
|
"""
|
||||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
from .db import db_remove_server
|
||||||
config = load_servers_config()
|
db_remove_server(domain, slot)
|
||||||
if domain in config and str(slot) in config[domain]:
|
|
||||||
del config[domain][str(slot)]
|
|
||||||
if not config[domain]:
|
|
||||||
del config[domain]
|
|
||||||
save_servers_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_domain_from_config(domain: str) -> None:
|
def remove_domain_from_config(domain: str) -> None:
|
||||||
"""Remove domain from persistent config with file locking.
|
"""Remove domain from persistent config (servers + domain entry).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
domain: Domain name to remove
|
domain: Domain name to remove
|
||||||
"""
|
"""
|
||||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
from .db import db_remove_domain_servers
|
||||||
config = load_servers_config()
|
db_remove_domain_servers(domain)
|
||||||
if domain in config:
|
|
||||||
del config[domain]
|
|
||||||
save_servers_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_shared_domain(domain: str) -> Optional[str]:
|
def get_shared_domain(domain: str) -> Optional[str]:
|
||||||
@@ -354,9 +299,8 @@ def get_shared_domain(domain: str) -> Optional[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
Parent domain name if sharing, None otherwise
|
Parent domain name if sharing, None otherwise
|
||||||
"""
|
"""
|
||||||
config = load_servers_config()
|
from .db import db_get_shared_domain
|
||||||
domain_config = config.get(domain, {})
|
return db_get_shared_domain(domain)
|
||||||
return domain_config.get("_shares")
|
|
||||||
|
|
||||||
|
|
||||||
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
|
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
|
||||||
@@ -366,10 +310,8 @@ def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
|
|||||||
domain: New domain name
|
domain: New domain name
|
||||||
shares_with: Existing domain to share pool with
|
shares_with: Existing domain to share pool with
|
||||||
"""
|
"""
|
||||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
from .db import db_add_shared_domain
|
||||||
config = load_servers_config()
|
db_add_shared_domain(domain, shares_with)
|
||||||
config[domain] = {"_shares": shares_with}
|
|
||||||
save_servers_config(config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_domains_sharing_pool(pool: str) -> list[str]:
|
def get_domains_sharing_pool(pool: str) -> list[str]:
|
||||||
@@ -381,11 +323,8 @@ def get_domains_sharing_pool(pool: str) -> list[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
List of domain names using this pool
|
List of domain names using this pool
|
||||||
"""
|
"""
|
||||||
domains = []
|
from .db import db_get_domains_sharing_pool
|
||||||
for domain, backend in get_map_contents():
|
return db_get_domains_sharing_pool(pool)
|
||||||
if backend == pool and not domain.startswith("."):
|
|
||||||
domains.append(domain)
|
|
||||||
return domains
|
|
||||||
|
|
||||||
|
|
||||||
def is_shared_domain(domain: str) -> bool:
|
def is_shared_domain(domain: str) -> bool:
|
||||||
@@ -397,37 +336,20 @@ def is_shared_domain(domain: str) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
True if domain has _shares reference, False otherwise
|
True if domain has _shares reference, False otherwise
|
||||||
"""
|
"""
|
||||||
config = load_servers_config()
|
from .db import db_is_shared_domain
|
||||||
domain_config = config.get(domain, {})
|
return db_is_shared_domain(domain)
|
||||||
return "_shares" in domain_config
|
|
||||||
|
|
||||||
|
|
||||||
# Certificate configuration functions
|
# Certificate configuration functions
|
||||||
|
|
||||||
def load_certs_config() -> list[str]:
|
def load_certs_config() -> list[str]:
|
||||||
"""Load certificate domain list from JSON file.
|
"""Load certificate domain list from SQLite.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of domain names
|
Sorted list of domain names.
|
||||||
"""
|
"""
|
||||||
try:
|
from .db import db_load_certs
|
||||||
content = _read_file(CERTS_FILE)
|
return db_load_certs()
|
||||||
data = json.loads(content)
|
|
||||||
return data.get("domains", [])
|
|
||||||
except FileNotFoundError:
|
|
||||||
return []
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def save_certs_config(domains: list[str]) -> None:
|
|
||||||
"""Save certificate domain list to JSON file atomically.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
domains: List of domain names
|
|
||||||
"""
|
|
||||||
atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2))
|
|
||||||
|
|
||||||
|
|
||||||
def add_cert_to_config(domain: str) -> None:
|
def add_cert_to_config(domain: str) -> None:
|
||||||
@@ -436,11 +358,8 @@ def add_cert_to_config(domain: str) -> None:
|
|||||||
Args:
|
Args:
|
||||||
domain: Domain name to add
|
domain: Domain name to add
|
||||||
"""
|
"""
|
||||||
with file_lock(f"{CERTS_FILE}.lock"):
|
from .db import db_add_cert
|
||||||
domains = load_certs_config()
|
db_add_cert(domain)
|
||||||
if domain not in domains:
|
|
||||||
domains.append(domain)
|
|
||||||
save_certs_config(domains)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cert_from_config(domain: str) -> None:
|
def remove_cert_from_config(domain: str) -> None:
|
||||||
@@ -449,8 +368,45 @@ def remove_cert_from_config(domain: str) -> None:
|
|||||||
Args:
|
Args:
|
||||||
domain: Domain name to remove
|
domain: Domain name to remove
|
||||||
"""
|
"""
|
||||||
with file_lock(f"{CERTS_FILE}.lock"):
|
from .db import db_remove_cert
|
||||||
domains = load_certs_config()
|
db_remove_cert(domain)
|
||||||
if domain in domains:
|
|
||||||
domains.remove(domain)
|
|
||||||
save_certs_config(domains)
|
# Domain map helper functions (used by domains.py)
|
||||||
|
|
||||||
|
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
|
||||||
|
shares_with: Optional[str] = None) -> None:
|
||||||
|
"""Add a domain to SQLite and sync map files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Domain name (e.g., "example.com").
|
||||||
|
backend: Backend pool name (e.g., "pool_5").
|
||||||
|
is_wildcard: Whether this is a wildcard entry.
|
||||||
|
shares_with: Parent domain if sharing a pool.
|
||||||
|
"""
|
||||||
|
from .db import db_add_domain, sync_map_files
|
||||||
|
db_add_domain(domain, backend, is_wildcard, shares_with)
|
||||||
|
sync_map_files()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_domain_from_map(domain: str) -> None:
|
||||||
|
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain: Base domain name (without leading dot).
|
||||||
|
"""
|
||||||
|
from .db import db_remove_domain, sync_map_files
|
||||||
|
db_remove_domain(domain)
|
||||||
|
sync_map_files()
|
||||||
|
|
||||||
|
|
||||||
|
def find_available_pool() -> Optional[str]:
|
||||||
|
"""Find the first available pool not assigned to any domain.
|
||||||
|
|
||||||
|
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pool name (e.g., "pool_5") or None if all pools are in use.
|
||||||
|
"""
|
||||||
|
from .db import db_find_available_pool
|
||||||
|
return db_find_available_pool()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import socket
|
import socket
|
||||||
import select
|
import select
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@@ -161,7 +162,7 @@ def reload_haproxy() -> tuple[bool, str]:
|
|||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
return False, f"Reload failed: {result.stderr}"
|
return False, f"Reload failed: {result.stderr}"
|
||||||
return True, "OK"
|
return True, "OK"
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds"
|
return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds"
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return False, "ssh/podman command not found"
|
return False, "ssh/podman command not found"
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ Environment Variables:
|
|||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
from .config import MCP_HOST, MCP_PORT
|
from .config import MCP_HOST, MCP_PORT
|
||||||
|
from .db import init_db
|
||||||
from .tools import register_all_tools
|
from .tools import register_all_tools
|
||||||
from .tools.configuration import startup_restore
|
from .tools.configuration import startup_restore
|
||||||
|
|
||||||
@@ -32,5 +33,6 @@ register_all_tools(mcp)
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
init_db()
|
||||||
startup_restore()
|
startup_restore()
|
||||||
mcp.run(transport="streamable-http")
|
mcp.run(transport="streamable-http")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Certificate management tools for HAProxy MCP Server."""
|
"""Certificate management tools for HAProxy MCP Server."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
@@ -152,7 +153,7 @@ def _haproxy_list_certs_impl() -> str:
|
|||||||
certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
|
certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
|
||||||
|
|
||||||
return "\n\n".join(certs) if certs else "No certificates found"
|
return "\n\n".join(certs) if certs else "No certificates found"
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return "Error: Command timed out"
|
return "Error: Command timed out"
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return "Error: acme.sh not found"
|
return "Error: acme.sh not found"
|
||||||
@@ -203,7 +204,7 @@ def _haproxy_cert_info_impl(domain: str) -> str:
|
|||||||
result.stdout.strip()
|
result.stdout.strip()
|
||||||
]
|
]
|
||||||
return "\n".join(info)
|
return "\n".join(info)
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return "Error: Command timed out"
|
return "Error: Command timed out"
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error("Error getting certificate info for %s: %s", domain, e)
|
logger.error("Error getting certificate info for %s: %s", domain, e)
|
||||||
@@ -250,7 +251,7 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
|
|||||||
else:
|
else:
|
||||||
return f"Certificate issued but PEM file not created. Check {host_path}"
|
return f"Certificate issued but PEM file not created. Check {host_path}"
|
||||||
|
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
|
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error("Error issuing certificate for %s: %s", domain, e)
|
logger.error("Error issuing certificate for %s: %s", domain, e)
|
||||||
@@ -289,7 +290,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
|
|||||||
else:
|
else:
|
||||||
return f"Error renewing certificate:\n{output}"
|
return f"Error renewing certificate:\n{output}"
|
||||||
|
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
|
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return "Error: acme.sh not found"
|
return "Error: acme.sh not found"
|
||||||
@@ -323,7 +324,7 @@ def _haproxy_renew_all_certs_impl() -> str:
|
|||||||
else:
|
else:
|
||||||
return "Renewal check completed"
|
return "Renewal check completed"
|
||||||
|
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return "Error: Renewal cron timed out"
|
return "Error: Renewal cron timed out"
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return "Error: acme.sh not found"
|
return "Error: acme.sh not found"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Configuration management tools for HAProxy MCP Server."""
|
"""Configuration management tools for HAProxy MCP Server."""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from ..config import (
|
from ..config import (
|
||||||
@@ -177,7 +178,7 @@ def register_config_tools(mcp):
|
|||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return "Configuration is valid"
|
return "Configuration is valid"
|
||||||
return f"Configuration errors:\n{result.stderr}"
|
return f"Configuration errors:\n{result.stderr}"
|
||||||
except TimeoutError:
|
except (TimeoutError, subprocess.TimeoutExpired):
|
||||||
return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds"
|
return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds"
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return "Error: ssh/podman command not found"
|
return "Error: ssh/podman command not found"
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Annotated, Optional
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..config import (
|
from ..config import (
|
||||||
MAP_FILE,
|
|
||||||
MAP_FILE_CONTAINER,
|
MAP_FILE_CONTAINER,
|
||||||
WILDCARDS_MAP_FILE_CONTAINER,
|
WILDCARDS_MAP_FILE_CONTAINER,
|
||||||
POOL_COUNT,
|
POOL_COUNT,
|
||||||
@@ -22,7 +21,6 @@ from ..validation import validate_domain, validate_ip, validate_port_int
|
|||||||
from ..haproxy_client import haproxy_cmd
|
from ..haproxy_client import haproxy_cmd
|
||||||
from ..file_ops import (
|
from ..file_ops import (
|
||||||
get_map_contents,
|
get_map_contents,
|
||||||
save_map_file,
|
|
||||||
get_domain_backend,
|
get_domain_backend,
|
||||||
is_legacy_backend,
|
is_legacy_backend,
|
||||||
add_server_to_config,
|
add_server_to_config,
|
||||||
@@ -31,30 +29,13 @@ from ..file_ops import (
|
|||||||
add_shared_domain_to_config,
|
add_shared_domain_to_config,
|
||||||
get_domains_sharing_pool,
|
get_domains_sharing_pool,
|
||||||
is_shared_domain,
|
is_shared_domain,
|
||||||
|
add_domain_to_map,
|
||||||
|
remove_domain_from_map,
|
||||||
|
find_available_pool,
|
||||||
)
|
)
|
||||||
from ..utils import parse_servers_state, disable_server_slot
|
from ..utils import parse_servers_state, disable_server_slot
|
||||||
|
|
||||||
|
|
||||||
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
|
|
||||||
"""Find an available pool backend from the pool list.
|
|
||||||
|
|
||||||
Iterates through pool_1 to pool_N and returns the first pool
|
|
||||||
that is not currently in use.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entries: List of (domain, backend) tuples from the map file.
|
|
||||||
used_pools: Set of pool names already in use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Available pool name (e.g., "pool_5") or None if all pools are in use.
|
|
||||||
"""
|
|
||||||
for i in range(1, POOL_COUNT + 1):
|
|
||||||
pool_name = f"pool_{i}"
|
|
||||||
if pool_name not in used_pools:
|
|
||||||
return pool_name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
|
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
|
||||||
"""Check if a domain is a subdomain of an existing registered domain.
|
"""Check if a domain is a subdomain of an existing registered domain.
|
||||||
|
|
||||||
@@ -95,24 +76,19 @@ def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
|
|||||||
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
|
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
|
||||||
|
|
||||||
|
|
||||||
def _rollback_domain_addition(
|
def _rollback_domain_addition(domain: str) -> None:
|
||||||
domain: str,
|
"""Rollback a failed domain addition by removing from SQLite + map files.
|
||||||
entries: list[tuple[str, str]]
|
|
||||||
) -> None:
|
|
||||||
"""Rollback a failed domain addition by removing entries from map file.
|
|
||||||
|
|
||||||
Called when HAProxy Runtime API update fails after the map file
|
Called when HAProxy Runtime API update fails after the domain
|
||||||
has already been saved.
|
has already been saved to the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
domain: Domain name that was added.
|
domain: Domain name that was added.
|
||||||
entries: Current list of map entries to rollback from.
|
|
||||||
"""
|
"""
|
||||||
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
|
||||||
try:
|
try:
|
||||||
save_map_file(rollback_entries)
|
remove_domain_from_map(domain)
|
||||||
except IOError:
|
except (IOError, Exception):
|
||||||
logger.error("Failed to rollback map file after HAProxy error")
|
logger.error("Failed to rollback domain %s after HAProxy error", domain)
|
||||||
|
|
||||||
|
|
||||||
def _file_exists(path: str) -> bool:
|
def _file_exists(path: str) -> bool:
|
||||||
@@ -242,93 +218,86 @@ def register_domain_tools(mcp):
|
|||||||
if share_with and ip:
|
if share_with and ip:
|
||||||
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
|
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
|
||||||
|
|
||||||
# Use file locking for the entire pool allocation operation
|
# Read current entries for existence check and subdomain detection
|
||||||
from ..file_ops import file_lock
|
entries = get_map_contents()
|
||||||
with file_lock(f"{MAP_FILE}.lock"):
|
|
||||||
# Read map contents once for both existence check and pool lookup
|
|
||||||
entries = get_map_contents()
|
|
||||||
|
|
||||||
# Check if domain already exists (using cached entries)
|
# Check if domain already exists
|
||||||
for domain_entry, backend in entries:
|
for domain_entry, backend in entries:
|
||||||
if domain_entry == domain:
|
if domain_entry == domain:
|
||||||
return f"Error: Domain {domain} already exists (mapped to {backend})"
|
return f"Error: Domain {domain} already exists (mapped to {backend})"
|
||||||
|
|
||||||
# Build used pools and registered domains sets
|
# Build registered domains set for subdomain check
|
||||||
used_pools: set[str] = set()
|
registered_domains: set[str] = set()
|
||||||
registered_domains: set[str] = set()
|
for entry_domain, _ in entries:
|
||||||
for entry_domain, backend in entries:
|
if not entry_domain.startswith("."):
|
||||||
if backend.startswith("pool_"):
|
registered_domains.add(entry_domain)
|
||||||
used_pools.add(backend)
|
|
||||||
if not entry_domain.startswith("."):
|
|
||||||
registered_domains.add(entry_domain)
|
|
||||||
|
|
||||||
# Handle share_with: reuse existing domain's pool
|
# Handle share_with: reuse existing domain's pool
|
||||||
if share_with:
|
if share_with:
|
||||||
share_backend = get_domain_backend(share_with)
|
share_backend = get_domain_backend(share_with)
|
||||||
if not share_backend:
|
if not share_backend:
|
||||||
return f"Error: Domain {share_with} not found"
|
return f"Error: Domain {share_with} not found"
|
||||||
if not share_backend.startswith("pool_"):
|
if not share_backend.startswith("pool_"):
|
||||||
return f"Error: Cannot share with legacy backend {share_backend}"
|
return f"Error: Cannot share with legacy backend {share_backend}"
|
||||||
pool = share_backend
|
pool = share_backend
|
||||||
else:
|
else:
|
||||||
# Find available pool
|
# Find available pool (SQLite query, O(1))
|
||||||
pool = _find_available_pool(entries, used_pools)
|
pool = find_available_pool()
|
||||||
if not pool:
|
if not pool:
|
||||||
return f"Error: All {POOL_COUNT} pool backends are in use"
|
return f"Error: All {POOL_COUNT} pool backends are in use"
|
||||||
|
|
||||||
# Check if this is a subdomain of an existing domain
|
# Check if this is a subdomain of an existing domain
|
||||||
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
|
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save to SQLite + sync map files (atomic via SQLite transaction)
|
||||||
try:
|
try:
|
||||||
# Save to disk first (atomic write for persistence)
|
add_domain_to_map(domain, pool)
|
||||||
entries.append((domain, pool))
|
|
||||||
if not is_subdomain:
|
if not is_subdomain:
|
||||||
entries.append((f".{domain}", pool))
|
add_domain_to_map(f".{domain}", pool, is_wildcard=True)
|
||||||
try:
|
except (IOError, Exception) as e:
|
||||||
save_map_file(entries)
|
return f"Error: Failed to save domain: {e}"
|
||||||
except IOError as e:
|
|
||||||
return f"Error: Failed to save map file: {e}"
|
|
||||||
|
|
||||||
# Update HAProxy maps via Runtime API
|
|
||||||
try:
|
|
||||||
_update_haproxy_maps(domain, pool, is_subdomain)
|
|
||||||
except HaproxyError as e:
|
|
||||||
_rollback_domain_addition(domain, entries)
|
|
||||||
return f"Error: Failed to update HAProxy map: {e}"
|
|
||||||
|
|
||||||
# Handle server configuration based on mode
|
|
||||||
if share_with:
|
|
||||||
# Save shared domain reference
|
|
||||||
add_shared_domain_to_config(domain, share_with)
|
|
||||||
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
|
|
||||||
elif ip:
|
|
||||||
# Add server to slot 1
|
|
||||||
add_server_to_config(domain, 1, ip, http_port)
|
|
||||||
try:
|
|
||||||
server = f"{pool}_1"
|
|
||||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
|
||||||
haproxy_cmd(f"set server {pool}/{server} state ready")
|
|
||||||
except HaproxyError as e:
|
|
||||||
remove_server_from_config(domain, 1)
|
|
||||||
return f"Domain {domain} added to {pool} but server config failed: {e}"
|
|
||||||
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
|
|
||||||
else:
|
|
||||||
result = f"Domain {domain} added to {pool} (no servers configured)"
|
|
||||||
|
|
||||||
if is_subdomain:
|
|
||||||
result += f" (subdomain of {parent_domain}, no wildcard)"
|
|
||||||
|
|
||||||
# Check certificate coverage
|
|
||||||
cert_covered, cert_info = check_certificate_coverage(domain)
|
|
||||||
if cert_covered:
|
|
||||||
result += f"\nSSL: Using certificate {cert_info}"
|
|
||||||
else:
|
|
||||||
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
# Update HAProxy maps via Runtime API
|
||||||
|
try:
|
||||||
|
_update_haproxy_maps(domain, pool, is_subdomain)
|
||||||
except HaproxyError as e:
|
except HaproxyError as e:
|
||||||
return f"Error: {e}"
|
_rollback_domain_addition(domain)
|
||||||
|
return f"Error: Failed to update HAProxy map: {e}"
|
||||||
|
|
||||||
|
# Handle server configuration based on mode
|
||||||
|
if share_with:
|
||||||
|
# Save shared domain reference
|
||||||
|
add_shared_domain_to_config(domain, share_with)
|
||||||
|
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
|
||||||
|
elif ip:
|
||||||
|
# Add server to slot 1
|
||||||
|
add_server_to_config(domain, 1, ip, http_port)
|
||||||
|
try:
|
||||||
|
server = f"{pool}_1"
|
||||||
|
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
||||||
|
haproxy_cmd(f"set server {pool}/{server} state ready")
|
||||||
|
except HaproxyError as e:
|
||||||
|
remove_server_from_config(domain, 1)
|
||||||
|
return f"Domain {domain} added to {pool} but server config failed: {e}"
|
||||||
|
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
|
||||||
|
else:
|
||||||
|
result = f"Domain {domain} added to {pool} (no servers configured)"
|
||||||
|
|
||||||
|
if is_subdomain:
|
||||||
|
result += f" (subdomain of {parent_domain}, no wildcard)"
|
||||||
|
|
||||||
|
# Check certificate coverage
|
||||||
|
cert_covered, cert_info = check_certificate_coverage(domain)
|
||||||
|
if cert_covered:
|
||||||
|
result += f"\nSSL: Using certificate {cert_info}"
|
||||||
|
else:
|
||||||
|
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HaproxyError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def haproxy_remove_domain(
|
def haproxy_remove_domain(
|
||||||
@@ -355,10 +324,8 @@ def register_domain_tools(mcp):
|
|||||||
domains_using_pool = get_domains_sharing_pool(backend)
|
domains_using_pool = get_domains_sharing_pool(backend)
|
||||||
other_domains = [d for d in domains_using_pool if d != domain]
|
other_domains = [d for d in domains_using_pool if d != domain]
|
||||||
|
|
||||||
# Save to disk first (atomic write for persistence)
|
# Remove from SQLite + sync map files
|
||||||
entries = get_map_contents()
|
remove_domain_from_map(domain)
|
||||||
new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
|
||||||
save_map_file(new_entries)
|
|
||||||
|
|
||||||
# Remove from persistent server config
|
# Remove from persistent server config
|
||||||
remove_domain_from_config(domain)
|
remove_domain_from_config(domain)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pydantic import Field
|
|||||||
from ..config import (
|
from ..config import (
|
||||||
MAP_FILE,
|
MAP_FILE,
|
||||||
SERVERS_FILE,
|
SERVERS_FILE,
|
||||||
|
DB_FILE,
|
||||||
HAPROXY_CONTAINER,
|
HAPROXY_CONTAINER,
|
||||||
)
|
)
|
||||||
from ..exceptions import HaproxyError
|
from ..exceptions import HaproxyError
|
||||||
@@ -88,7 +89,7 @@ def register_health_tools(mcp):
|
|||||||
# Check configuration files
|
# Check configuration files
|
||||||
files_ok = True
|
files_ok = True
|
||||||
file_status: dict[str, str] = {}
|
file_status: dict[str, str] = {}
|
||||||
for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]:
|
for name, path in [("map_file", MAP_FILE), ("db_file", DB_FILE)]:
|
||||||
exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path)
|
exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path)
|
||||||
if exists:
|
if exists:
|
||||||
file_status[name] = "ok"
|
file_status[name] = "ok"
|
||||||
|
|||||||
@@ -255,6 +255,8 @@ def temp_config_dir(tmp_path):
|
|||||||
state_file = tmp_path / "servers.state"
|
state_file = tmp_path / "servers.state"
|
||||||
state_file.write_text("")
|
state_file.write_text("")
|
||||||
|
|
||||||
|
db_file = tmp_path / "haproxy_mcp.db"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"dir": tmp_path,
|
"dir": tmp_path,
|
||||||
"map_file": str(map_file),
|
"map_file": str(map_file),
|
||||||
@@ -262,12 +264,15 @@ def temp_config_dir(tmp_path):
|
|||||||
"servers_file": str(servers_file),
|
"servers_file": str(servers_file),
|
||||||
"certs_file": str(certs_file),
|
"certs_file": str(certs_file),
|
||||||
"state_file": str(state_file),
|
"state_file": str(state_file),
|
||||||
|
"db_file": str(db_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patch_config_paths(temp_config_dir):
|
def patch_config_paths(temp_config_dir):
|
||||||
"""Fixture that patches config module paths to use temporary directory."""
|
"""Fixture that patches config module paths to use temporary directory."""
|
||||||
|
from haproxy_mcp.db import close_connection, init_db
|
||||||
|
|
||||||
with patch.multiple(
|
with patch.multiple(
|
||||||
"haproxy_mcp.config",
|
"haproxy_mcp.config",
|
||||||
MAP_FILE=temp_config_dir["map_file"],
|
MAP_FILE=temp_config_dir["map_file"],
|
||||||
@@ -275,16 +280,34 @@ def patch_config_paths(temp_config_dir):
|
|||||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||||
CERTS_FILE=temp_config_dir["certs_file"],
|
CERTS_FILE=temp_config_dir["certs_file"],
|
||||||
STATE_FILE=temp_config_dir["state_file"],
|
STATE_FILE=temp_config_dir["state_file"],
|
||||||
|
DB_FILE=temp_config_dir["db_file"],
|
||||||
):
|
):
|
||||||
# Also patch file_ops module which imports these
|
# Also patch file_ops module which imports these
|
||||||
with patch.multiple(
|
with patch.multiple(
|
||||||
"haproxy_mcp.file_ops",
|
"haproxy_mcp.file_ops",
|
||||||
MAP_FILE=temp_config_dir["map_file"],
|
MAP_FILE=temp_config_dir["map_file"],
|
||||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
|
||||||
CERTS_FILE=temp_config_dir["certs_file"],
|
|
||||||
):
|
):
|
||||||
yield temp_config_dir
|
# Patch db module which imports these
|
||||||
|
with patch.multiple(
|
||||||
|
"haproxy_mcp.db",
|
||||||
|
MAP_FILE=temp_config_dir["map_file"],
|
||||||
|
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||||
|
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||||
|
CERTS_FILE=temp_config_dir["certs_file"],
|
||||||
|
DB_FILE=temp_config_dir["db_file"],
|
||||||
|
):
|
||||||
|
# Patch health module which imports MAP_FILE and DB_FILE
|
||||||
|
with patch.multiple(
|
||||||
|
"haproxy_mcp.tools.health",
|
||||||
|
MAP_FILE=temp_config_dir["map_file"],
|
||||||
|
DB_FILE=temp_config_dir["db_file"],
|
||||||
|
):
|
||||||
|
# Close any existing connection and initialize fresh DB
|
||||||
|
close_connection()
|
||||||
|
init_db()
|
||||||
|
yield temp_config_dir
|
||||||
|
close_connection()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
433
tests/unit/test_db.py
Normal file
433
tests/unit/test_db.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""Unit tests for db module (SQLite database operations)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haproxy_mcp.db import (
|
||||||
|
get_connection,
|
||||||
|
close_connection,
|
||||||
|
init_db,
|
||||||
|
migrate_from_json,
|
||||||
|
db_get_map_contents,
|
||||||
|
db_get_domain_backend,
|
||||||
|
db_add_domain,
|
||||||
|
db_remove_domain,
|
||||||
|
db_find_available_pool,
|
||||||
|
db_get_domains_sharing_pool,
|
||||||
|
db_load_servers_config,
|
||||||
|
db_add_server,
|
||||||
|
db_remove_server,
|
||||||
|
db_remove_domain_servers,
|
||||||
|
db_add_shared_domain,
|
||||||
|
db_get_shared_domain,
|
||||||
|
db_is_shared_domain,
|
||||||
|
db_load_certs,
|
||||||
|
db_add_cert,
|
||||||
|
db_remove_cert,
|
||||||
|
sync_map_files,
|
||||||
|
SCHEMA_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionManagement:
|
||||||
|
"""Tests for database connection management."""
|
||||||
|
|
||||||
|
def test_get_connection(self, patch_config_paths):
|
||||||
|
"""Get a connection returns a valid SQLite connection."""
|
||||||
|
conn = get_connection()
|
||||||
|
assert conn is not None
|
||||||
|
# Verify WAL mode
|
||||||
|
result = conn.execute("PRAGMA journal_mode").fetchone()
|
||||||
|
assert result[0] == "wal"
|
||||||
|
|
||||||
|
def test_connection_is_thread_local(self, patch_config_paths):
|
||||||
|
"""Same thread gets same connection."""
|
||||||
|
conn1 = get_connection()
|
||||||
|
conn2 = get_connection()
|
||||||
|
assert conn1 is conn2
|
||||||
|
|
||||||
|
def test_close_connection(self, patch_config_paths):
|
||||||
|
"""Close connection clears thread-local."""
|
||||||
|
conn1 = get_connection()
|
||||||
|
close_connection()
|
||||||
|
conn2 = get_connection()
|
||||||
|
assert conn1 is not conn2
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitDb:
|
||||||
|
"""Tests for database initialization."""
|
||||||
|
|
||||||
|
def test_creates_tables(self, patch_config_paths):
|
||||||
|
"""init_db creates all required tables."""
|
||||||
|
conn = get_connection()
|
||||||
|
|
||||||
|
# Check tables exist
|
||||||
|
tables = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||||
|
).fetchall()
|
||||||
|
table_names = [t["name"] for t in tables]
|
||||||
|
|
||||||
|
assert "domains" in table_names
|
||||||
|
assert "servers" in table_names
|
||||||
|
assert "certificates" in table_names
|
||||||
|
assert "schema_version" in table_names
|
||||||
|
|
||||||
|
def test_schema_version_recorded(self, patch_config_paths):
|
||||||
|
"""Schema version is recorded."""
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute("SELECT MAX(version) FROM schema_version")
|
||||||
|
version = cur.fetchone()[0]
|
||||||
|
assert version == SCHEMA_VERSION
|
||||||
|
|
||||||
|
def test_idempotent(self, patch_config_paths):
|
||||||
|
"""Calling init_db twice is safe."""
|
||||||
|
# init_db is already called by patch_config_paths
|
||||||
|
# Calling it again should not raise
|
||||||
|
init_db()
|
||||||
|
conn = get_connection()
|
||||||
|
cur = conn.execute("SELECT COUNT(*) FROM schema_version")
|
||||||
|
# May have 1 or 2 entries but should not fail
|
||||||
|
assert cur.fetchone()[0] >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestMigrateFromJson:
|
||||||
|
"""Tests for JSON to SQLite migration."""
|
||||||
|
|
||||||
|
def test_migrate_map_files(self, patch_config_paths):
|
||||||
|
"""Migrate domain entries from map files."""
|
||||||
|
# Write test map files
|
||||||
|
with open(patch_config_paths["map_file"], "w") as f:
|
||||||
|
f.write("example.com pool_1\n")
|
||||||
|
f.write("api.example.com pool_2\n")
|
||||||
|
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||||
|
f.write(".example.com pool_1\n")
|
||||||
|
|
||||||
|
migrate_from_json()
|
||||||
|
|
||||||
|
entries = db_get_map_contents()
|
||||||
|
assert ("example.com", "pool_1") in entries
|
||||||
|
assert ("api.example.com", "pool_2") in entries
|
||||||
|
assert (".example.com", "pool_1") in entries
|
||||||
|
|
||||||
|
def test_migrate_servers_json(self, patch_config_paths):
|
||||||
|
"""Migrate server entries from servers.json."""
|
||||||
|
config = {
|
||||||
|
"example.com": {
|
||||||
|
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||||
|
"2": {"ip": "10.0.0.2", "http_port": 8080},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open(patch_config_paths["servers_file"], "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
migrate_from_json()
|
||||||
|
|
||||||
|
result = db_load_servers_config()
|
||||||
|
assert "example.com" in result
|
||||||
|
assert result["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||||
|
assert result["example.com"]["2"]["http_port"] == 8080
|
||||||
|
|
||||||
|
def test_migrate_shared_domains(self, patch_config_paths):
|
||||||
|
"""Migrate shared domain references."""
|
||||||
|
# First add the domain to map so it exists in DB
|
||||||
|
with open(patch_config_paths["map_file"], "w") as f:
|
||||||
|
f.write("example.com pool_1\n")
|
||||||
|
f.write("www.example.com pool_1\n")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"www.example.com": {"_shares": "example.com"},
|
||||||
|
}
|
||||||
|
with open(patch_config_paths["servers_file"], "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
migrate_from_json()
|
||||||
|
|
||||||
|
assert db_get_shared_domain("www.example.com") == "example.com"
|
||||||
|
|
||||||
|
def test_migrate_certificates(self, patch_config_paths):
|
||||||
|
"""Migrate certificate entries."""
|
||||||
|
with open(patch_config_paths["certs_file"], "w") as f:
|
||||||
|
json.dump({"domains": ["example.com", "api.example.com"]}, f)
|
||||||
|
|
||||||
|
migrate_from_json()
|
||||||
|
|
||||||
|
certs = db_load_certs()
|
||||||
|
assert "example.com" in certs
|
||||||
|
assert "api.example.com" in certs
|
||||||
|
|
||||||
|
def test_migrate_idempotent(self, patch_config_paths):
|
||||||
|
"""Migration is idempotent (INSERT OR IGNORE)."""
|
||||||
|
with open(patch_config_paths["map_file"], "w") as f:
|
||||||
|
f.write("example.com pool_1\n")
|
||||||
|
with open(patch_config_paths["servers_file"], "w") as f:
|
||||||
|
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
|
||||||
|
|
||||||
|
migrate_from_json()
|
||||||
|
migrate_from_json() # Should not fail
|
||||||
|
|
||||||
|
entries = db_get_map_contents()
|
||||||
|
assert len([d for d, _ in entries if d == "example.com"]) == 1
|
||||||
|
|
||||||
|
def test_migrate_empty_files(self, patch_config_paths):
|
||||||
|
"""Migration with no existing data does nothing."""
|
||||||
|
os.unlink(patch_config_paths["map_file"])
|
||||||
|
os.unlink(patch_config_paths["wildcards_file"])
|
||||||
|
os.unlink(patch_config_paths["servers_file"])
|
||||||
|
os.unlink(patch_config_paths["certs_file"])
|
||||||
|
|
||||||
|
migrate_from_json() # Should not fail
|
||||||
|
|
||||||
|
def test_backup_files_after_migration(self, patch_config_paths):
|
||||||
|
"""Original JSON files are backed up after migration."""
|
||||||
|
with open(patch_config_paths["servers_file"], "w") as f:
|
||||||
|
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
|
||||||
|
with open(patch_config_paths["certs_file"], "w") as f:
|
||||||
|
json.dump({"domains": ["example.com"]}, f)
|
||||||
|
|
||||||
|
# Write map file so migration has data
|
||||||
|
with open(patch_config_paths["map_file"], "w") as f:
|
||||||
|
f.write("example.com pool_1\n")
|
||||||
|
|
||||||
|
migrate_from_json()
|
||||||
|
|
||||||
|
assert os.path.exists(f"{patch_config_paths['servers_file']}.bak")
|
||||||
|
assert os.path.exists(f"{patch_config_paths['certs_file']}.bak")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDomainOperations:
|
||||||
|
"""Tests for domain data access functions."""
|
||||||
|
|
||||||
|
def test_add_and_get_domain(self, patch_config_paths):
|
||||||
|
"""Add domain and retrieve its backend."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
|
||||||
|
assert db_get_domain_backend("example.com") == "pool_1"
|
||||||
|
|
||||||
|
def test_get_nonexistent_domain(self, patch_config_paths):
|
||||||
|
"""Non-existent domain returns None."""
|
||||||
|
assert db_get_domain_backend("nonexistent.com") is None
|
||||||
|
|
||||||
|
def test_remove_domain(self, patch_config_paths):
|
||||||
|
"""Remove domain removes both exact and wildcard."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
|
db_remove_domain("example.com")
|
||||||
|
|
||||||
|
assert db_get_domain_backend("example.com") is None
|
||||||
|
entries = db_get_map_contents()
|
||||||
|
assert (".example.com", "pool_1") not in entries
|
||||||
|
|
||||||
|
def test_find_available_pool(self, patch_config_paths):
|
||||||
|
"""Find first available pool."""
|
||||||
|
db_add_domain("a.com", "pool_1")
|
||||||
|
db_add_domain("b.com", "pool_2")
|
||||||
|
|
||||||
|
pool = db_find_available_pool()
|
||||||
|
assert pool == "pool_3"
|
||||||
|
|
||||||
|
def test_find_available_pool_empty(self, patch_config_paths):
|
||||||
|
"""First pool available when none used."""
|
||||||
|
pool = db_find_available_pool()
|
||||||
|
assert pool == "pool_1"
|
||||||
|
|
||||||
|
def test_get_domains_sharing_pool(self, patch_config_paths):
|
||||||
|
"""Get non-wildcard domains using a pool."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain("www.example.com", "pool_1")
|
||||||
|
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
|
domains = db_get_domains_sharing_pool("pool_1")
|
||||||
|
assert "example.com" in domains
|
||||||
|
assert "www.example.com" in domains
|
||||||
|
assert ".example.com" not in domains
|
||||||
|
|
||||||
|
def test_update_domain_backend(self, patch_config_paths):
|
||||||
|
"""Updating a domain changes its backend."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain("example.com", "pool_5") # Update
|
||||||
|
|
||||||
|
assert db_get_domain_backend("example.com") == "pool_5"
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerOperations:
|
||||||
|
"""Tests for server data access functions."""
|
||||||
|
|
||||||
|
def test_add_and_load_server(self, patch_config_paths):
|
||||||
|
"""Add server and load config."""
|
||||||
|
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||||
|
|
||||||
|
config = db_load_servers_config()
|
||||||
|
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||||
|
assert config["example.com"]["1"]["http_port"] == 80
|
||||||
|
|
||||||
|
def test_update_server(self, patch_config_paths):
|
||||||
|
"""Update existing server slot."""
|
||||||
|
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||||
|
db_add_server("example.com", 1, "10.0.0.99", 8080)
|
||||||
|
|
||||||
|
config = db_load_servers_config()
|
||||||
|
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
|
||||||
|
assert config["example.com"]["1"]["http_port"] == 8080
|
||||||
|
|
||||||
|
def test_remove_server(self, patch_config_paths):
|
||||||
|
"""Remove a server slot."""
|
||||||
|
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||||
|
db_add_server("example.com", 2, "10.0.0.2", 80)
|
||||||
|
|
||||||
|
db_remove_server("example.com", 1)
|
||||||
|
|
||||||
|
config = db_load_servers_config()
|
||||||
|
assert "1" not in config.get("example.com", {})
|
||||||
|
assert "2" in config["example.com"]
|
||||||
|
|
||||||
|
def test_remove_domain_servers(self, patch_config_paths):
|
||||||
|
"""Remove all servers for a domain."""
|
||||||
|
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||||
|
db_add_server("example.com", 2, "10.0.0.2", 80)
|
||||||
|
db_add_server("other.com", 1, "10.0.0.3", 80)
|
||||||
|
|
||||||
|
db_remove_domain_servers("example.com")
|
||||||
|
|
||||||
|
config = db_load_servers_config()
|
||||||
|
assert config.get("example.com", {}).get("1") is None
|
||||||
|
assert "other.com" in config
|
||||||
|
|
||||||
|
def test_load_empty(self, patch_config_paths):
|
||||||
|
"""Empty database returns empty dict."""
|
||||||
|
config = db_load_servers_config()
|
||||||
|
assert config == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedDomainOperations:
|
||||||
|
"""Tests for shared domain functions."""
|
||||||
|
|
||||||
|
def test_add_and_get_shared(self, patch_config_paths):
|
||||||
|
"""Add shared domain reference."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain("www.example.com", "pool_1")
|
||||||
|
db_add_shared_domain("www.example.com", "example.com")
|
||||||
|
|
||||||
|
assert db_get_shared_domain("www.example.com") == "example.com"
|
||||||
|
|
||||||
|
def test_is_shared(self, patch_config_paths):
|
||||||
|
"""Check if domain is shared."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain("www.example.com", "pool_1")
|
||||||
|
db_add_shared_domain("www.example.com", "example.com")
|
||||||
|
|
||||||
|
assert db_is_shared_domain("www.example.com") is True
|
||||||
|
assert db_is_shared_domain("example.com") is False
|
||||||
|
|
||||||
|
def test_not_shared(self, patch_config_paths):
|
||||||
|
"""Non-shared domain returns None."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
|
||||||
|
assert db_get_shared_domain("example.com") is None
|
||||||
|
assert db_is_shared_domain("example.com") is False
|
||||||
|
|
||||||
|
def test_shared_in_load_config(self, patch_config_paths):
|
||||||
|
"""Shared domain appears in load_servers_config."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain("www.example.com", "pool_1")
|
||||||
|
db_add_shared_domain("www.example.com", "example.com")
|
||||||
|
|
||||||
|
config = db_load_servers_config()
|
||||||
|
assert config["www.example.com"]["_shares"] == "example.com"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCertificateOperations:
|
||||||
|
"""Tests for certificate data access functions."""
|
||||||
|
|
||||||
|
def test_add_and_load_cert(self, patch_config_paths):
|
||||||
|
"""Add and load certificate."""
|
||||||
|
db_add_cert("example.com")
|
||||||
|
|
||||||
|
certs = db_load_certs()
|
||||||
|
assert "example.com" in certs
|
||||||
|
|
||||||
|
def test_add_duplicate(self, patch_config_paths):
|
||||||
|
"""Adding duplicate cert is a no-op."""
|
||||||
|
db_add_cert("example.com")
|
||||||
|
db_add_cert("example.com")
|
||||||
|
|
||||||
|
certs = db_load_certs()
|
||||||
|
assert certs.count("example.com") == 1
|
||||||
|
|
||||||
|
def test_remove_cert(self, patch_config_paths):
|
||||||
|
"""Remove a certificate."""
|
||||||
|
db_add_cert("example.com")
|
||||||
|
db_add_cert("other.com")
|
||||||
|
|
||||||
|
db_remove_cert("example.com")
|
||||||
|
|
||||||
|
certs = db_load_certs()
|
||||||
|
assert "example.com" not in certs
|
||||||
|
assert "other.com" in certs
|
||||||
|
|
||||||
|
def test_load_empty(self, patch_config_paths):
|
||||||
|
"""Empty database returns empty list."""
|
||||||
|
certs = db_load_certs()
|
||||||
|
assert certs == []
|
||||||
|
|
||||||
|
def test_sorted_output(self, patch_config_paths):
|
||||||
|
"""Certificates are returned sorted."""
|
||||||
|
db_add_cert("z.com")
|
||||||
|
db_add_cert("a.com")
|
||||||
|
db_add_cert("m.com")
|
||||||
|
|
||||||
|
certs = db_load_certs()
|
||||||
|
assert certs == ["a.com", "m.com", "z.com"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncMapFiles:
|
||||||
|
"""Tests for sync_map_files function."""
|
||||||
|
|
||||||
|
def test_sync_exact_domains(self, patch_config_paths):
|
||||||
|
"""Sync writes exact domains to domains.map."""
|
||||||
|
db_add_domain("example.com", "pool_1")
|
||||||
|
db_add_domain("api.example.com", "pool_2")
|
||||||
|
|
||||||
|
sync_map_files()
|
||||||
|
|
||||||
|
with open(patch_config_paths["map_file"]) as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "example.com pool_1" in content
|
||||||
|
assert "api.example.com pool_2" in content
|
||||||
|
|
||||||
|
def test_sync_wildcards(self, patch_config_paths):
|
||||||
|
"""Sync writes wildcards to wildcards.map."""
|
||||||
|
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
|
sync_map_files()
|
||||||
|
|
||||||
|
with open(patch_config_paths["wildcards_file"]) as f:
|
||||||
|
content = f.read()
|
||||||
|
assert ".example.com pool_1" in content
|
||||||
|
|
||||||
|
def test_sync_empty(self, patch_config_paths):
|
||||||
|
"""Sync with no domains writes headers only."""
|
||||||
|
sync_map_files()
|
||||||
|
|
||||||
|
with open(patch_config_paths["map_file"]) as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "Exact Domain" in content
|
||||||
|
# No domain entries
|
||||||
|
lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")]
|
||||||
|
assert len(lines) == 0
|
||||||
|
|
||||||
|
def test_sync_sorted(self, patch_config_paths):
|
||||||
|
"""Sync output is sorted."""
|
||||||
|
db_add_domain("z.com", "pool_3")
|
||||||
|
db_add_domain("a.com", "pool_1")
|
||||||
|
|
||||||
|
sync_map_files()
|
||||||
|
|
||||||
|
with open(patch_config_paths["map_file"]) as f:
|
||||||
|
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||||
|
assert lines[0] == "a.com pool_1"
|
||||||
|
assert lines[1] == "z.com pool_3"
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Unit tests for file_ops module."""
|
"""Unit tests for file_ops module (SQLite-backed)."""
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -16,14 +15,19 @@ from haproxy_mcp.file_ops import (
|
|||||||
get_legacy_backend_name,
|
get_legacy_backend_name,
|
||||||
get_backend_and_prefix,
|
get_backend_and_prefix,
|
||||||
load_servers_config,
|
load_servers_config,
|
||||||
save_servers_config,
|
|
||||||
add_server_to_config,
|
add_server_to_config,
|
||||||
remove_server_from_config,
|
remove_server_from_config,
|
||||||
remove_domain_from_config,
|
remove_domain_from_config,
|
||||||
load_certs_config,
|
load_certs_config,
|
||||||
save_certs_config,
|
|
||||||
add_cert_to_config,
|
add_cert_to_config,
|
||||||
remove_cert_from_config,
|
remove_cert_from_config,
|
||||||
|
add_domain_to_map,
|
||||||
|
remove_domain_from_map,
|
||||||
|
find_available_pool,
|
||||||
|
add_shared_domain_to_config,
|
||||||
|
get_shared_domain,
|
||||||
|
is_shared_domain,
|
||||||
|
get_domains_sharing_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +66,7 @@ class TestAtomicWriteFile:
|
|||||||
def test_unicode_content(self, tmp_path):
|
def test_unicode_content(self, tmp_path):
|
||||||
"""Unicode content is properly written."""
|
"""Unicode content is properly written."""
|
||||||
file_path = str(tmp_path / "unicode.txt")
|
file_path = str(tmp_path / "unicode.txt")
|
||||||
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese
|
content = "Hello, \u4e16\u754c!"
|
||||||
|
|
||||||
atomic_write_file(file_path, content)
|
atomic_write_file(file_path, content)
|
||||||
|
|
||||||
@@ -81,66 +85,33 @@ class TestAtomicWriteFile:
|
|||||||
|
|
||||||
|
|
||||||
class TestGetMapContents:
|
class TestGetMapContents:
|
||||||
"""Tests for get_map_contents function."""
|
"""Tests for get_map_contents function (SQLite-backed)."""
|
||||||
|
|
||||||
def test_read_map_file(self, patch_config_paths):
|
def test_empty_db(self, patch_config_paths):
|
||||||
"""Read entries from map file."""
|
"""Empty database returns empty list."""
|
||||||
# Write test content to map file
|
entries = get_map_contents()
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
assert entries == []
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
f.write("api.example.com pool_2\n")
|
def test_read_domains(self, patch_config_paths):
|
||||||
|
"""Read entries from database."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map("api.example.com", "pool_2")
|
||||||
|
|
||||||
entries = get_map_contents()
|
entries = get_map_contents()
|
||||||
|
|
||||||
assert ("example.com", "pool_1") in entries
|
assert ("example.com", "pool_1") in entries
|
||||||
assert ("api.example.com", "pool_2") in entries
|
assert ("api.example.com", "pool_2") in entries
|
||||||
|
|
||||||
def test_read_both_map_files(self, patch_config_paths):
|
def test_read_with_wildcards(self, patch_config_paths):
|
||||||
"""Read entries from both domains.map and wildcards.map."""
|
"""Read entries including wildcards."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
|
||||||
f.write(".example.com pool_1\n")
|
|
||||||
|
|
||||||
entries = get_map_contents()
|
entries = get_map_contents()
|
||||||
|
|
||||||
assert ("example.com", "pool_1") in entries
|
assert ("example.com", "pool_1") in entries
|
||||||
assert (".example.com", "pool_1") in entries
|
assert (".example.com", "pool_1") in entries
|
||||||
|
|
||||||
def test_skip_comments(self, patch_config_paths):
|
|
||||||
"""Comments are skipped."""
|
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("# This is a comment\n")
|
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
f.write("# Another comment\n")
|
|
||||||
|
|
||||||
entries = get_map_contents()
|
|
||||||
|
|
||||||
assert len(entries) == 1
|
|
||||||
assert entries[0] == ("example.com", "pool_1")
|
|
||||||
|
|
||||||
def test_skip_empty_lines(self, patch_config_paths):
|
|
||||||
"""Empty lines are skipped."""
|
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("\n")
|
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
f.write("\n")
|
|
||||||
f.write("api.example.com pool_2\n")
|
|
||||||
|
|
||||||
entries = get_map_contents()
|
|
||||||
|
|
||||||
assert len(entries) == 2
|
|
||||||
|
|
||||||
def test_file_not_found(self, patch_config_paths):
|
|
||||||
"""Missing file returns empty list."""
|
|
||||||
os.unlink(patch_config_paths["map_file"])
|
|
||||||
os.unlink(patch_config_paths["wildcards_file"])
|
|
||||||
|
|
||||||
entries = get_map_contents()
|
|
||||||
|
|
||||||
assert entries == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestSplitDomainEntries:
|
class TestSplitDomainEntries:
|
||||||
"""Tests for split_domain_entries function."""
|
"""Tests for split_domain_entries function."""
|
||||||
@@ -182,36 +153,30 @@ class TestSplitDomainEntries:
|
|||||||
|
|
||||||
|
|
||||||
class TestSaveMapFile:
|
class TestSaveMapFile:
|
||||||
"""Tests for save_map_file function."""
|
"""Tests for save_map_file function (syncs from DB to map files)."""
|
||||||
|
|
||||||
def test_save_entries(self, patch_config_paths):
|
def test_save_entries(self, patch_config_paths):
|
||||||
"""Save entries to separate map files."""
|
"""Save entries to separate map files."""
|
||||||
entries = [
|
add_domain_to_map("example.com", "pool_1")
|
||||||
("example.com", "pool_1"),
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
(".example.com", "pool_1"),
|
|
||||||
]
|
|
||||||
|
|
||||||
save_map_file(entries)
|
save_map_file([]) # Entries param ignored, syncs from DB
|
||||||
|
|
||||||
# Check exact domains file
|
|
||||||
with open(patch_config_paths["map_file"]) as f:
|
with open(patch_config_paths["map_file"]) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
assert "example.com pool_1" in content
|
assert "example.com pool_1" in content
|
||||||
|
|
||||||
# Check wildcards file
|
|
||||||
with open(patch_config_paths["wildcards_file"]) as f:
|
with open(patch_config_paths["wildcards_file"]) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
assert ".example.com pool_1" in content
|
assert ".example.com pool_1" in content
|
||||||
|
|
||||||
def test_sorted_output(self, patch_config_paths):
|
def test_sorted_output(self, patch_config_paths):
|
||||||
"""Entries are sorted in output."""
|
"""Entries are sorted in output."""
|
||||||
entries = [
|
add_domain_to_map("z.example.com", "pool_3")
|
||||||
("z.example.com", "pool_3"),
|
add_domain_to_map("a.example.com", "pool_1")
|
||||||
("a.example.com", "pool_1"),
|
add_domain_to_map("m.example.com", "pool_2")
|
||||||
("m.example.com", "pool_2"),
|
|
||||||
]
|
|
||||||
|
|
||||||
save_map_file(entries)
|
save_map_file([])
|
||||||
|
|
||||||
with open(patch_config_paths["map_file"]) as f:
|
with open(patch_config_paths["map_file"]) as f:
|
||||||
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||||
@@ -222,12 +187,11 @@ class TestSaveMapFile:
|
|||||||
|
|
||||||
|
|
||||||
class TestGetDomainBackend:
|
class TestGetDomainBackend:
|
||||||
"""Tests for get_domain_backend function."""
|
"""Tests for get_domain_backend function (SQLite-backed)."""
|
||||||
|
|
||||||
def test_find_existing_domain(self, patch_config_paths):
|
def test_find_existing_domain(self, patch_config_paths):
|
||||||
"""Find backend for existing domain."""
|
"""Find backend for existing domain."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
backend = get_domain_backend("example.com")
|
backend = get_domain_backend("example.com")
|
||||||
|
|
||||||
@@ -235,8 +199,7 @@ class TestGetDomainBackend:
|
|||||||
|
|
||||||
def test_domain_not_found(self, patch_config_paths):
|
def test_domain_not_found(self, patch_config_paths):
|
||||||
"""Non-existent domain returns None."""
|
"""Non-existent domain returns None."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
backend = get_domain_backend("other.com")
|
backend = get_domain_backend("other.com")
|
||||||
|
|
||||||
@@ -271,8 +234,7 @@ class TestGetBackendAndPrefix:
|
|||||||
|
|
||||||
def test_pool_backend(self, patch_config_paths):
|
def test_pool_backend(self, patch_config_paths):
|
||||||
"""Pool backend returns pool-based prefix."""
|
"""Pool backend returns pool-based prefix."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_5")
|
||||||
f.write("example.com pool_5\n")
|
|
||||||
|
|
||||||
backend, prefix = get_backend_and_prefix("example.com")
|
backend, prefix = get_backend_and_prefix("example.com")
|
||||||
|
|
||||||
@@ -288,48 +250,33 @@ class TestGetBackendAndPrefix:
|
|||||||
|
|
||||||
|
|
||||||
class TestLoadServersConfig:
|
class TestLoadServersConfig:
|
||||||
"""Tests for load_servers_config function."""
|
"""Tests for load_servers_config function (SQLite-backed)."""
|
||||||
|
|
||||||
def test_load_existing_config(self, patch_config_paths, sample_servers_config):
|
def test_load_empty_config(self, patch_config_paths):
|
||||||
"""Load existing config file."""
|
"""Empty database returns empty dict."""
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
config = load_servers_config()
|
||||||
json.dump(sample_servers_config, f)
|
assert config == {}
|
||||||
|
|
||||||
|
def test_load_with_servers(self, patch_config_paths):
|
||||||
|
"""Load config with server entries."""
|
||||||
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||||
|
|
||||||
config = load_servers_config()
|
config = load_servers_config()
|
||||||
|
|
||||||
assert "example.com" in config
|
assert "example.com" in config
|
||||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||||
|
assert config["example.com"]["2"]["ip"] == "10.0.0.2"
|
||||||
|
|
||||||
def test_file_not_found(self, patch_config_paths):
|
def test_load_with_shared_domain(self, patch_config_paths):
|
||||||
"""Missing file returns empty dict."""
|
"""Load config with shared domain reference."""
|
||||||
os.unlink(patch_config_paths["servers_file"])
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map("www.example.com", "pool_1")
|
||||||
|
add_shared_domain_to_config("www.example.com", "example.com")
|
||||||
|
|
||||||
config = load_servers_config()
|
config = load_servers_config()
|
||||||
|
|
||||||
assert config == {}
|
assert config["www.example.com"]["_shares"] == "example.com"
|
||||||
|
|
||||||
def test_invalid_json(self, patch_config_paths):
|
|
||||||
"""Invalid JSON returns empty dict."""
|
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
|
||||||
f.write("not valid json {{{")
|
|
||||||
|
|
||||||
config = load_servers_config()
|
|
||||||
|
|
||||||
assert config == {}
|
|
||||||
|
|
||||||
|
|
||||||
class TestSaveServersConfig:
|
|
||||||
"""Tests for save_servers_config function."""
|
|
||||||
|
|
||||||
def test_save_config(self, patch_config_paths):
|
|
||||||
"""Save config to file."""
|
|
||||||
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
|
||||||
|
|
||||||
save_servers_config(config)
|
|
||||||
|
|
||||||
with open(patch_config_paths["servers_file"]) as f:
|
|
||||||
loaded = json.load(f)
|
|
||||||
assert loaded == config
|
|
||||||
|
|
||||||
|
|
||||||
class TestAddServerToConfig:
|
class TestAddServerToConfig:
|
||||||
@@ -373,17 +320,18 @@ class TestRemoveServerFromConfig:
|
|||||||
remove_server_from_config("example.com", 1)
|
remove_server_from_config("example.com", 1)
|
||||||
|
|
||||||
config = load_servers_config()
|
config = load_servers_config()
|
||||||
assert "1" not in config["example.com"]
|
assert "1" not in config.get("example.com", {})
|
||||||
assert "2" in config["example.com"]
|
assert "2" in config["example.com"]
|
||||||
|
|
||||||
def test_remove_last_server_removes_domain(self, patch_config_paths):
|
def test_remove_last_server_removes_domain(self, patch_config_paths):
|
||||||
"""Removing last server removes domain entry."""
|
"""Removing last server removes domain entry from servers."""
|
||||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
|
|
||||||
remove_server_from_config("example.com", 1)
|
remove_server_from_config("example.com", 1)
|
||||||
|
|
||||||
config = load_servers_config()
|
config = load_servers_config()
|
||||||
assert "example.com" not in config
|
# Domain may or may not exist (no servers = no entry)
|
||||||
|
assert config.get("example.com", {}).get("1") is None
|
||||||
|
|
||||||
def test_remove_nonexistent_server(self, patch_config_paths):
|
def test_remove_nonexistent_server(self, patch_config_paths):
|
||||||
"""Removing non-existent server is a no-op."""
|
"""Removing non-existent server is a no-op."""
|
||||||
@@ -399,14 +347,14 @@ class TestRemoveDomainFromConfig:
|
|||||||
"""Tests for remove_domain_from_config function."""
|
"""Tests for remove_domain_from_config function."""
|
||||||
|
|
||||||
def test_remove_existing_domain(self, patch_config_paths):
|
def test_remove_existing_domain(self, patch_config_paths):
|
||||||
"""Remove existing domain."""
|
"""Remove existing domain's servers."""
|
||||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
add_server_to_config("other.com", 1, "10.0.0.2", 80)
|
add_server_to_config("other.com", 1, "10.0.0.2", 80)
|
||||||
|
|
||||||
remove_domain_from_config("example.com")
|
remove_domain_from_config("example.com")
|
||||||
|
|
||||||
config = load_servers_config()
|
config = load_servers_config()
|
||||||
assert "example.com" not in config
|
assert config.get("example.com", {}).get("1") is None
|
||||||
assert "other.com" in config
|
assert "other.com" in config
|
||||||
|
|
||||||
def test_remove_nonexistent_domain(self, patch_config_paths):
|
def test_remove_nonexistent_domain(self, patch_config_paths):
|
||||||
@@ -420,40 +368,23 @@ class TestRemoveDomainFromConfig:
|
|||||||
|
|
||||||
|
|
||||||
class TestLoadCertsConfig:
|
class TestLoadCertsConfig:
|
||||||
"""Tests for load_certs_config function."""
|
"""Tests for load_certs_config function (SQLite-backed)."""
|
||||||
|
|
||||||
def test_load_existing_config(self, patch_config_paths):
|
def test_load_empty(self, patch_config_paths):
|
||||||
"""Load existing certs config."""
|
"""Empty database returns empty list."""
|
||||||
with open(patch_config_paths["certs_file"], "w") as f:
|
domains = load_certs_config()
|
||||||
json.dump({"domains": ["example.com", "other.com"]}, f)
|
assert domains == []
|
||||||
|
|
||||||
|
def test_load_with_certs(self, patch_config_paths):
|
||||||
|
"""Load certs from database."""
|
||||||
|
add_cert_to_config("example.com")
|
||||||
|
add_cert_to_config("other.com")
|
||||||
|
|
||||||
domains = load_certs_config()
|
domains = load_certs_config()
|
||||||
|
|
||||||
assert "example.com" in domains
|
assert "example.com" in domains
|
||||||
assert "other.com" in domains
|
assert "other.com" in domains
|
||||||
|
|
||||||
def test_file_not_found(self, patch_config_paths):
|
|
||||||
"""Missing file returns empty list."""
|
|
||||||
os.unlink(patch_config_paths["certs_file"])
|
|
||||||
|
|
||||||
domains = load_certs_config()
|
|
||||||
|
|
||||||
assert domains == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestSaveCertsConfig:
|
|
||||||
"""Tests for save_certs_config function."""
|
|
||||||
|
|
||||||
def test_save_domains(self, patch_config_paths):
|
|
||||||
"""Save domains to certs config."""
|
|
||||||
save_certs_config(["z.com", "a.com"])
|
|
||||||
|
|
||||||
with open(patch_config_paths["certs_file"]) as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
# Should be sorted
|
|
||||||
assert data["domains"] == ["a.com", "z.com"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestAddCertToConfig:
|
class TestAddCertToConfig:
|
||||||
"""Tests for add_cert_to_config function."""
|
"""Tests for add_cert_to_config function."""
|
||||||
@@ -496,3 +427,87 @@ class TestRemoveCertFromConfig:
|
|||||||
|
|
||||||
domains = load_certs_config()
|
domains = load_certs_config()
|
||||||
assert "example.com" in domains
|
assert "example.com" in domains
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddDomainToMap:
|
||||||
|
"""Tests for add_domain_to_map function."""
|
||||||
|
|
||||||
|
def test_add_domain(self, patch_config_paths):
|
||||||
|
"""Add a domain and verify map files are synced."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
|
||||||
|
assert get_domain_backend("example.com") == "pool_1"
|
||||||
|
|
||||||
|
with open(patch_config_paths["map_file"]) as f:
|
||||||
|
assert "example.com pool_1" in f.read()
|
||||||
|
|
||||||
|
def test_add_wildcard(self, patch_config_paths):
|
||||||
|
"""Add a wildcard domain."""
|
||||||
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
|
entries = get_map_contents()
|
||||||
|
assert (".example.com", "pool_1") in entries
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveDomainFromMap:
|
||||||
|
"""Tests for remove_domain_from_map function."""
|
||||||
|
|
||||||
|
def test_remove_domain(self, patch_config_paths):
|
||||||
|
"""Remove a domain and its wildcard."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
|
remove_domain_from_map("example.com")
|
||||||
|
|
||||||
|
assert get_domain_backend("example.com") is None
|
||||||
|
entries = get_map_contents()
|
||||||
|
assert (".example.com", "pool_1") not in entries
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindAvailablePool:
|
||||||
|
"""Tests for find_available_pool function."""
|
||||||
|
|
||||||
|
def test_first_pool_available(self, patch_config_paths):
|
||||||
|
"""When no domains exist, pool_1 is returned."""
|
||||||
|
pool = find_available_pool()
|
||||||
|
assert pool == "pool_1"
|
||||||
|
|
||||||
|
def test_skip_used_pools(self, patch_config_paths):
|
||||||
|
"""Used pools are skipped."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map("other.com", "pool_2")
|
||||||
|
|
||||||
|
pool = find_available_pool()
|
||||||
|
assert pool == "pool_3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedDomains:
|
||||||
|
"""Tests for shared domain functions."""
|
||||||
|
|
||||||
|
def test_get_shared_domain(self, patch_config_paths):
|
||||||
|
"""Get parent domain for shared domain."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map("www.example.com", "pool_1")
|
||||||
|
add_shared_domain_to_config("www.example.com", "example.com")
|
||||||
|
|
||||||
|
assert get_shared_domain("www.example.com") == "example.com"
|
||||||
|
|
||||||
|
def test_is_shared_domain(self, patch_config_paths):
|
||||||
|
"""Check if domain is shared."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map("www.example.com", "pool_1")
|
||||||
|
add_shared_domain_to_config("www.example.com", "example.com")
|
||||||
|
|
||||||
|
assert is_shared_domain("www.example.com") is True
|
||||||
|
assert is_shared_domain("example.com") is False
|
||||||
|
|
||||||
|
def test_get_domains_sharing_pool(self, patch_config_paths):
|
||||||
|
"""Get all domains using a pool."""
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
add_domain_to_map("www.example.com", "pool_1")
|
||||||
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
|
|
||||||
|
domains = get_domains_sharing_pool("pool_1")
|
||||||
|
assert "example.com" in domains
|
||||||
|
assert "www.example.com" in domains
|
||||||
|
assert ".example.com" not in domains # Wildcards excluded
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from unittest.mock import patch, MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from haproxy_mcp.file_ops import add_cert_to_config
|
||||||
|
|
||||||
|
|
||||||
class TestGetPemPaths:
|
class TestGetPemPaths:
|
||||||
"""Tests for get_pem_paths function."""
|
"""Tests for get_pem_paths function."""
|
||||||
@@ -127,8 +129,7 @@ class TestRestoreCertificates:
|
|||||||
def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
|
def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
|
||||||
"""Restore certificates successfully."""
|
"""Restore certificates successfully."""
|
||||||
# Save config
|
# Save config
|
||||||
with open(patch_config_paths["certs_file"], "w") as f:
|
add_cert_to_config("example.com")
|
||||||
json.dump({"domains": ["example.com"]}, f)
|
|
||||||
|
|
||||||
# Create PEM
|
# Create PEM
|
||||||
certs_dir = tmp_path / "certs"
|
certs_dir = tmp_path / "certs"
|
||||||
@@ -283,11 +284,17 @@ class TestHaproxyCertInfo:
|
|||||||
pem_file = tmp_path / "example.com.pem"
|
pem_file = tmp_path / "example.com.pem"
|
||||||
pem_file.write_text("cert content")
|
pem_file.write_text("cert content")
|
||||||
|
|
||||||
mock_subprocess.return_value = MagicMock(
|
def subprocess_side_effect(*args, **kwargs):
|
||||||
returncode=0,
|
cmd = args[0] if args else kwargs.get("args", [])
|
||||||
stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT",
|
if isinstance(cmd, list) and "stat" in cmd:
|
||||||
stderr=""
|
return MagicMock(returncode=0, stdout="1704067200", stderr="")
|
||||||
)
|
return MagicMock(
|
||||||
|
returncode=0,
|
||||||
|
stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT",
|
||||||
|
stderr=""
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_subprocess.side_effect = subprocess_side_effect
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show ssl cert": "/etc/haproxy/certs/example.com.pem",
|
"show ssl cert": "/etc/haproxy/certs/example.com.pem",
|
||||||
@@ -337,25 +344,33 @@ class TestHaproxyIssueCert:
|
|||||||
assert "Error" in result
|
assert "Error" in result
|
||||||
assert "Invalid domain" in result
|
assert "Invalid domain" in result
|
||||||
|
|
||||||
def test_issue_cert_no_cf_token(self, tmp_path):
|
def test_issue_cert_no_cf_token(self, tmp_path, mock_subprocess):
|
||||||
"""Fail when CF_Token is not set."""
|
"""Fail when CF_Token is not set."""
|
||||||
|
acme_sh = str(tmp_path / "acme.sh")
|
||||||
|
mock_subprocess.return_value = MagicMock(
|
||||||
|
returncode=1,
|
||||||
|
stdout="",
|
||||||
|
stderr="CF_Token is not set. Please export CF_Token environment variable.",
|
||||||
|
)
|
||||||
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)):
|
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)):
|
||||||
with patch("os.path.exists", return_value=False):
|
with patch("haproxy_mcp.tools.certificates.ACME_SH", acme_sh):
|
||||||
from haproxy_mcp.tools.certificates import register_certificate_tools
|
with patch("os.path.exists", return_value=False):
|
||||||
mcp = MagicMock()
|
from haproxy_mcp.tools.certificates import register_certificate_tools
|
||||||
registered_tools = {}
|
mcp = MagicMock()
|
||||||
|
registered_tools = {}
|
||||||
|
|
||||||
def capture_tool():
|
def capture_tool():
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
registered_tools[func.__name__] = func
|
registered_tools[func.__name__] = func
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
mcp.tool = capture_tool
|
mcp.tool = capture_tool
|
||||||
register_certificate_tools(mcp)
|
register_certificate_tools(mcp)
|
||||||
|
|
||||||
result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True)
|
result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True)
|
||||||
|
|
||||||
assert "CF_Token" in result
|
assert "CF_Token" in result
|
||||||
|
|
||||||
@@ -845,8 +860,8 @@ class TestHaproxyRenewAllCertsMultiple:
|
|||||||
def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path):
|
def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path):
|
||||||
"""Renew multiple certificates successfully."""
|
"""Renew multiple certificates successfully."""
|
||||||
# Write config with multiple domains
|
# Write config with multiple domains
|
||||||
with open(patch_config_paths["certs_file"], "w") as f:
|
add_cert_to_config("example.com")
|
||||||
json.dump({"domains": ["example.com", "example.org"]}, f)
|
add_cert_to_config("example.org")
|
||||||
|
|
||||||
# Create PEM files
|
# Create PEM files
|
||||||
certs_dir = tmp_path / "certs"
|
certs_dir = tmp_path / "certs"
|
||||||
@@ -1038,30 +1053,32 @@ class TestHaproxyDeleteCertPartialFailure:
|
|||||||
"show ssl cert": "", # Not loaded
|
"show ssl cert": "", # Not loaded
|
||||||
})
|
})
|
||||||
|
|
||||||
# Mock os.remove to fail
|
# Mock subprocess to succeed for acme.sh remove but fail for rm (PEM removal)
|
||||||
def mock_remove(path):
|
def subprocess_side_effect(*args, **kwargs):
|
||||||
if "example.com.pem" in str(path):
|
cmd = args[0] if args else kwargs.get("args", [])
|
||||||
raise PermissionError("Permission denied")
|
if isinstance(cmd, list) and cmd[0] == "rm":
|
||||||
raise FileNotFoundError()
|
return MagicMock(returncode=1, stdout="", stderr="Permission denied")
|
||||||
|
return MagicMock(returncode=0, stdout="", stderr="")
|
||||||
|
|
||||||
|
mock_subprocess.side_effect = subprocess_side_effect
|
||||||
|
|
||||||
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")):
|
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")):
|
||||||
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
|
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
|
||||||
with patch("socket.socket", return_value=mock_sock):
|
with patch("socket.socket", return_value=mock_sock):
|
||||||
with patch("os.remove", side_effect=mock_remove):
|
from haproxy_mcp.tools.certificates import register_certificate_tools
|
||||||
from haproxy_mcp.tools.certificates import register_certificate_tools
|
mcp = MagicMock()
|
||||||
mcp = MagicMock()
|
registered_tools = {}
|
||||||
registered_tools = {}
|
|
||||||
|
|
||||||
def capture_tool():
|
def capture_tool():
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
registered_tools[func.__name__] = func
|
registered_tools[func.__name__] = func
|
||||||
return func
|
return func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
mcp.tool = capture_tool
|
mcp.tool = capture_tool
|
||||||
register_certificate_tools(mcp)
|
register_certificate_tools(mcp)
|
||||||
|
|
||||||
result = registered_tools["haproxy_delete_cert"](domain="example.com")
|
result = registered_tools["haproxy_delete_cert"](domain="example.com")
|
||||||
|
|
||||||
# Should report partial success (acme.sh deleted) and error (PEM failed)
|
# Should report partial success (acme.sh deleted) and error (PEM failed)
|
||||||
assert "Deleted" in result
|
assert "Deleted" in result
|
||||||
@@ -1118,8 +1135,8 @@ class TestRestoreCertificatesFailure:
|
|||||||
def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
|
def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
|
||||||
"""Handle partial failure when restoring certificates."""
|
"""Handle partial failure when restoring certificates."""
|
||||||
# Save config with multiple domains
|
# Save config with multiple domains
|
||||||
with open(patch_config_paths["certs_file"], "w") as f:
|
add_cert_to_config("example.com")
|
||||||
json.dump({"domains": ["example.com", "missing.com"]}, f)
|
add_cert_to_config("missing.com")
|
||||||
|
|
||||||
# Create only one PEM file
|
# Create only one PEM file
|
||||||
certs_dir = tmp_path / "certs"
|
certs_dir = tmp_path / "certs"
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from unittest.mock import patch, MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from haproxy_mcp.file_ops import add_domain_to_map, add_server_to_config
|
||||||
|
|
||||||
|
|
||||||
class TestRestoreServersFromConfig:
|
class TestRestoreServersFromConfig:
|
||||||
"""Tests for restore_servers_from_config function."""
|
"""Tests for restore_servers_from_config function."""
|
||||||
@@ -19,12 +21,12 @@ class TestRestoreServersFromConfig:
|
|||||||
|
|
||||||
def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||||
"""Restore servers successfully."""
|
"""Restore servers successfully."""
|
||||||
# Write config and map
|
# Add domains and servers to database
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
json.dump(sample_servers_config, f)
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||||
f.write("example.com pool_1\n")
|
add_domain_to_map("api.example.com", "pool_2")
|
||||||
f.write("api.example.com pool_2\n")
|
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"set server": "",
|
"set server": "",
|
||||||
@@ -40,9 +42,8 @@ class TestRestoreServersFromConfig:
|
|||||||
|
|
||||||
def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths):
|
def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths):
|
||||||
"""Skip domains not in map file."""
|
"""Skip domains not in map file."""
|
||||||
config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
# Add server for unknown.com but no map entry (simulates missing domain)
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
add_server_to_config("unknown.com", 1, "10.0.0.1", 80)
|
||||||
json.dump(config, f)
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||||
|
|
||||||
@@ -55,11 +56,9 @@ class TestRestoreServersFromConfig:
|
|||||||
|
|
||||||
def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths):
|
def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths):
|
||||||
"""Skip servers with empty IP."""
|
"""Skip servers with empty IP."""
|
||||||
config = {"example.com": {"1": {"ip": "", "http_port": 80}}}
|
# Add domain to map and server with empty IP (will be skipped during restore)
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
json.dump(config, f)
|
add_server_to_config("example.com", 1, "", 80)
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||||
|
|
||||||
@@ -321,11 +320,12 @@ class TestHaproxyRestoreState:
|
|||||||
|
|
||||||
def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||||
"""Restore state successfully."""
|
"""Restore state successfully."""
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
# Add domains and servers to database
|
||||||
json.dump(sample_servers_config, f)
|
add_domain_to_map("example.com", "pool_1")
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
f.write("example.com pool_1\n")
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||||
f.write("api.example.com pool_2\n")
|
add_domain_to_map("api.example.com", "pool_2")
|
||||||
|
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||||
|
|
||||||
@@ -373,17 +373,10 @@ class TestRestoreServersFromConfigBatchFailure:
|
|||||||
|
|
||||||
def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths):
|
def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths):
|
||||||
"""Fall back to individual commands when batch fails."""
|
"""Fall back to individual commands when batch fails."""
|
||||||
# Create config with servers
|
# Add domain and servers to database
|
||||||
config = {
|
add_domain_to_map("example.com", "pool_1")
|
||||||
"example.com": {
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||||
"2": {"ip": "10.0.0.2", "http_port": 80},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
|
||||||
json.dump(config, f)
|
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
# Track call count to simulate batch failure then individual success
|
# Track call count to simulate batch failure then individual success
|
||||||
call_count = [0]
|
call_count = [0]
|
||||||
@@ -457,51 +450,51 @@ class TestRestoreServersFromConfigBatchFailure:
|
|||||||
|
|
||||||
def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths):
|
def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths):
|
||||||
"""Skip servers with invalid slot number."""
|
"""Skip servers with invalid slot number."""
|
||||||
|
# Add domain to map
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
|
||||||
|
# Mock load_servers_config to return config with invalid slot
|
||||||
config = {
|
config = {
|
||||||
"example.com": {
|
"example.com": {
|
||||||
"invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot
|
"invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot
|
||||||
"1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot
|
"1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
|
||||||
json.dump(config, f)
|
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
with patch("socket.socket", return_value=mock_sock):
|
||||||
|
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||||
|
|
||||||
with patch("socket.socket", return_value=mock_sock):
|
result = restore_servers_from_config()
|
||||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
|
||||||
|
|
||||||
result = restore_servers_from_config()
|
# Should only restore the valid server
|
||||||
|
assert result == 1
|
||||||
# Should only restore the valid server
|
|
||||||
assert result == 1
|
|
||||||
|
|
||||||
def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog):
|
def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog):
|
||||||
"""Skip servers with invalid port."""
|
"""Skip servers with invalid port."""
|
||||||
import logging
|
import logging
|
||||||
|
# Add domain to map
|
||||||
|
add_domain_to_map("example.com", "pool_1")
|
||||||
|
|
||||||
|
# Mock load_servers_config to return config with invalid port
|
||||||
config = {
|
config = {
|
||||||
"example.com": {
|
"example.com": {
|
||||||
"1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port
|
"1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port
|
||||||
"2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port
|
"2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
|
||||||
json.dump(config, f)
|
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
with patch("socket.socket", return_value=mock_sock):
|
||||||
|
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||||
|
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||||
|
|
||||||
with patch("socket.socket", return_value=mock_sock):
|
result = restore_servers_from_config()
|
||||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
|
||||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
|
||||||
|
|
||||||
result = restore_servers_from_config()
|
# Should only restore the valid server
|
||||||
|
assert result == 1
|
||||||
# Should only restore the valid server
|
|
||||||
assert result == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestStartupRestoreFailures:
|
class TestStartupRestoreFailures:
|
||||||
@@ -658,11 +651,12 @@ class TestHaproxyRestoreStateFailures:
|
|||||||
"""Handle HAProxy error when restoring state."""
|
"""Handle HAProxy error when restoring state."""
|
||||||
from haproxy_mcp.exceptions import HaproxyError
|
from haproxy_mcp.exceptions import HaproxyError
|
||||||
|
|
||||||
with open(patch_config_paths["servers_file"], "w") as f:
|
# Add domains and servers to database
|
||||||
json.dump(sample_servers_config, f)
|
add_domain_to_map("example.com", "pool_1")
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||||
f.write("example.com pool_1\n")
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||||
f.write("api.example.com pool_2\n")
|
add_domain_to_map("api.example.com", "pool_2")
|
||||||
|
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
|
||||||
|
|
||||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")):
|
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")):
|
||||||
from haproxy_mcp.tools.configuration import register_config_tools
|
from haproxy_mcp.tools.configuration import register_config_tools
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haproxy_mcp.exceptions import HaproxyError
|
from haproxy_mcp.exceptions import HaproxyError
|
||||||
|
from haproxy_mcp.file_ops import add_domain_to_map
|
||||||
|
|
||||||
|
|
||||||
class TestHaproxyListDomains:
|
class TestHaproxyListDomains:
|
||||||
@@ -38,9 +39,8 @@ class TestHaproxyListDomains:
|
|||||||
|
|
||||||
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""List domains with configured servers."""
|
"""List domains with configured servers."""
|
||||||
# Write map file
|
# Add domain to DB
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -70,10 +70,8 @@ class TestHaproxyListDomains:
|
|||||||
|
|
||||||
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""List domains excluding wildcards by default."""
|
"""List domains excluding wildcards by default."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
|
||||||
f.write(".example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([]),
|
"show servers state": response_builder.servers_state([]),
|
||||||
@@ -100,10 +98,8 @@ class TestHaproxyListDomains:
|
|||||||
|
|
||||||
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""List domains including wildcards when requested."""
|
"""List domains including wildcards when requested."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
|
||||||
f.write(".example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([]),
|
"show servers state": response_builder.servers_state([]),
|
||||||
@@ -230,8 +226,7 @@ class TestHaproxyAddDomain:
|
|||||||
|
|
||||||
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Reject adding domain that already exists."""
|
"""Reject adding domain that already exists."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
from haproxy_mcp.tools.domains import register_domain_tools
|
from haproxy_mcp.tools.domains import register_domain_tools
|
||||||
mcp = MagicMock()
|
mcp = MagicMock()
|
||||||
@@ -362,8 +357,7 @@ class TestHaproxyRemoveDomain:
|
|||||||
|
|
||||||
def test_remove_legacy_domain_rejected(self, patch_config_paths):
|
def test_remove_legacy_domain_rejected(self, patch_config_paths):
|
||||||
"""Reject removing legacy (non-pool) domain."""
|
"""Reject removing legacy (non-pool) domain."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "legacy_backend")
|
||||||
f.write("example.com legacy_backend\n")
|
|
||||||
|
|
||||||
from haproxy_mcp.tools.domains import register_domain_tools
|
from haproxy_mcp.tools.domains import register_domain_tools
|
||||||
mcp = MagicMock()
|
mcp = MagicMock()
|
||||||
@@ -385,10 +379,8 @@ class TestHaproxyRemoveDomain:
|
|||||||
|
|
||||||
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Successfully remove domain."""
|
"""Successfully remove domain."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
|
||||||
f.write(".example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"del map": "",
|
"del map": "",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haproxy_mcp.exceptions import HaproxyError
|
from haproxy_mcp.exceptions import HaproxyError
|
||||||
|
from haproxy_mcp.file_ops import add_domain_to_map
|
||||||
|
|
||||||
|
|
||||||
class TestHaproxyHealth:
|
class TestHaproxyHealth:
|
||||||
@@ -80,7 +81,7 @@ class TestHaproxyHealth:
|
|||||||
|
|
||||||
# Use paths that don't exist
|
# Use paths that don't exist
|
||||||
with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")):
|
with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")):
|
||||||
with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")):
|
with patch("haproxy_mcp.tools.health.DB_FILE", str(tmp_path / "nonexistent.db")):
|
||||||
with patch("socket.socket", return_value=mock_sock):
|
with patch("socket.socket", return_value=mock_sock):
|
||||||
from haproxy_mcp.tools.health import register_health_tools
|
from haproxy_mcp.tools.health import register_health_tools
|
||||||
mcp = MagicMock()
|
mcp = MagicMock()
|
||||||
@@ -160,8 +161,7 @@ class TestHaproxyDomainHealth:
|
|||||||
|
|
||||||
def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Domain health returns healthy when all servers are UP."""
|
"""Domain health returns healthy when all servers are UP."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -197,8 +197,7 @@ class TestHaproxyDomainHealth:
|
|||||||
|
|
||||||
def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Domain health returns degraded when some servers are DOWN."""
|
"""Domain health returns degraded when some servers are DOWN."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -234,8 +233,7 @@ class TestHaproxyDomainHealth:
|
|||||||
|
|
||||||
def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Domain health returns down when all servers are DOWN."""
|
"""Domain health returns down when all servers are DOWN."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -269,8 +267,7 @@ class TestHaproxyDomainHealth:
|
|||||||
|
|
||||||
def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Domain health returns no_servers when no servers configured."""
|
"""Domain health returns no_servers when no servers configured."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haproxy_mcp.exceptions import HaproxyError
|
from haproxy_mcp.exceptions import HaproxyError
|
||||||
|
from haproxy_mcp.file_ops import add_domain_to_map, load_servers_config
|
||||||
|
|
||||||
|
|
||||||
class TestHaproxyListServers:
|
class TestHaproxyListServers:
|
||||||
@@ -33,8 +34,7 @@ class TestHaproxyListServers:
|
|||||||
|
|
||||||
def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""List servers for domain with no servers."""
|
"""List servers for domain with no servers."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -63,8 +63,7 @@ class TestHaproxyListServers:
|
|||||||
|
|
||||||
def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""List servers with active servers."""
|
"""List servers with active servers."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -224,8 +223,7 @@ class TestHaproxyAddServer:
|
|||||||
|
|
||||||
def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Successfully add server."""
|
"""Successfully add server."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"set server": "",
|
"set server": "",
|
||||||
@@ -258,8 +256,7 @@ class TestHaproxyAddServer:
|
|||||||
|
|
||||||
def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Auto-select slot when slot=0."""
|
"""Auto-select slot when slot=0."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -413,8 +410,7 @@ class TestHaproxyAddServers:
|
|||||||
|
|
||||||
def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Successfully add multiple servers."""
|
"""Successfully add multiple servers."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"set server": "",
|
"set server": "",
|
||||||
@@ -495,8 +491,7 @@ class TestHaproxyRemoveServer:
|
|||||||
|
|
||||||
def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Successfully remove server."""
|
"""Successfully remove server."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"set server": "",
|
"set server": "",
|
||||||
@@ -689,8 +684,7 @@ class TestHaproxyAddServersRollback:
|
|||||||
|
|
||||||
def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths):
|
def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths):
|
||||||
"""Rollback only failed slots when HAProxy error occurs."""
|
"""Rollback only failed slots when HAProxy error occurs."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
# Mock configure_server_slot to fail on second slot
|
# Mock configure_server_slot to fail on second slot
|
||||||
call_count = [0]
|
call_count = [0]
|
||||||
@@ -735,18 +729,16 @@ class TestHaproxyAddServersRollback:
|
|||||||
assert "slot 2" in result # Failed
|
assert "slot 2" in result # Failed
|
||||||
|
|
||||||
# Verify servers.json only has successfully added server
|
# Verify servers.json only has successfully added server
|
||||||
with open(patch_config_paths["servers_file"], "r") as f:
|
config = load_servers_config()
|
||||||
config = json.load(f)
|
|
||||||
assert "example.com" in config
|
assert "example.com" in config
|
||||||
assert "1" in config["example.com"] # Successfully added stays
|
assert "1" in config["example.com"] # Successfully added stays
|
||||||
assert "2" not in config["example.com"] # Failed one was rolled back
|
assert "2" not in config.get("example.com", {}) # Failed one was rolled back
|
||||||
|
|
||||||
def test_add_servers_unexpected_error_rollback_only_successful(
|
def test_add_servers_unexpected_error_rollback_only_successful(
|
||||||
self, mock_socket_class, mock_select, patch_config_paths
|
self, mock_socket_class, mock_select, patch_config_paths
|
||||||
):
|
):
|
||||||
"""Rollback only successfully added servers on unexpected error."""
|
"""Rollback only successfully added servers on unexpected error."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
# Track which servers were configured
|
# Track which servers were configured
|
||||||
configured_slots = []
|
configured_slots = []
|
||||||
@@ -793,8 +785,7 @@ class TestHaproxyAddServersRollback:
|
|||||||
assert "Unexpected system error" in result
|
assert "Unexpected system error" in result
|
||||||
|
|
||||||
# Verify servers.json is empty (all rolled back)
|
# Verify servers.json is empty (all rolled back)
|
||||||
with open(patch_config_paths["servers_file"], "r") as f:
|
config = load_servers_config()
|
||||||
config = json.load(f)
|
|
||||||
assert config == {} or "example.com" not in config or config.get("example.com") == {}
|
assert config == {} or "example.com" not in config or config.get("example.com") == {}
|
||||||
|
|
||||||
def test_add_servers_rollback_failure_logged(
|
def test_add_servers_rollback_failure_logged(
|
||||||
@@ -802,8 +793,7 @@ class TestHaproxyAddServersRollback:
|
|||||||
):
|
):
|
||||||
"""Log rollback failures during error recovery."""
|
"""Log rollback failures during error recovery."""
|
||||||
import logging
|
import logging
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
||||||
if slot == 2:
|
if slot == 2:
|
||||||
@@ -858,8 +848,7 @@ class TestHaproxyAddServerAutoSlot:
|
|||||||
|
|
||||||
def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Auto-select slot fails when all slots are in use."""
|
"""Auto-select slot fails when all slots are in use."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
# Build response with all 10 slots used
|
# Build response with all 10 slots used
|
||||||
servers = []
|
servers = []
|
||||||
@@ -902,8 +891,7 @@ class TestHaproxyAddServerAutoSlot:
|
|||||||
|
|
||||||
def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Negative slot number triggers auto-selection."""
|
"""Negative slot number triggers auto-selection."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -970,8 +958,7 @@ class TestHaproxyWaitDrain:
|
|||||||
|
|
||||||
def test_wait_drain_success(self, patch_config_paths):
|
def test_wait_drain_success(self, patch_config_paths):
|
||||||
"""Successfully wait for connections to drain."""
|
"""Successfully wait for connections to drain."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
# Mock haproxy_cmd to return 0 connections
|
# Mock haproxy_cmd to return 0 connections
|
||||||
with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd:
|
with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd:
|
||||||
@@ -1000,8 +987,7 @@ class TestHaproxyWaitDrain:
|
|||||||
|
|
||||||
def test_wait_drain_timeout(self, patch_config_paths):
|
def test_wait_drain_timeout(self, patch_config_paths):
|
||||||
"""Timeout when connections don't drain."""
|
"""Timeout when connections don't drain."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing
|
time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing
|
||||||
time_iter = iter(time_values)
|
time_iter = iter(time_values)
|
||||||
@@ -1087,9 +1073,6 @@ class TestHaproxyWaitDrain:
|
|||||||
def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths):
|
def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths):
|
||||||
"""Error when domain not found in map."""
|
"""Error when domain not found in map."""
|
||||||
# Empty map file - domain not configured
|
# Empty map file - domain not configured
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("")
|
|
||||||
|
|
||||||
from haproxy_mcp.tools.servers import register_server_tools
|
from haproxy_mcp.tools.servers import register_server_tools
|
||||||
mcp = MagicMock()
|
mcp = MagicMock()
|
||||||
registered_tools = {}
|
registered_tools = {}
|
||||||
@@ -1202,8 +1185,7 @@ class TestHaproxySetDomainState:
|
|||||||
|
|
||||||
def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Set all servers of a domain to a state."""
|
"""Set all servers of a domain to a state."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([
|
"show servers state": response_builder.servers_state([
|
||||||
@@ -1283,8 +1265,7 @@ class TestHaproxySetDomainState:
|
|||||||
|
|
||||||
def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""No active servers found for domain."""
|
"""No active servers found for domain."""
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
add_domain_to_map("example.com", "pool_1")
|
||||||
f.write("example.com pool_1\n")
|
|
||||||
|
|
||||||
# All servers have 0.0.0.0 address (not configured)
|
# All servers have 0.0.0.0 address (not configured)
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
@@ -1318,9 +1299,6 @@ class TestHaproxySetDomainState:
|
|||||||
def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||||
"""Handle domain not found in map - shows no active servers."""
|
"""Handle domain not found in map - shows no active servers."""
|
||||||
# Empty map file
|
# Empty map file
|
||||||
with open(patch_config_paths["map_file"], "w") as f:
|
|
||||||
f.write("")
|
|
||||||
|
|
||||||
# Mock should show no servers for unknown domain's backend
|
# Mock should show no servers for unknown domain's backend
|
||||||
mock_sock = mock_socket_class(responses={
|
mock_sock = mock_socket_class(responses={
|
||||||
"show servers state": response_builder.servers_state([]),
|
"show servers state": response_builder.servers_state([]),
|
||||||
|
|||||||
Reference in New Issue
Block a user