refactor: migrate data storage from JSON/map files to SQLite

Replace servers.json, certificates.json, and map file parsing with
SQLite (WAL mode) as single source of truth. HAProxy map files are
now generated from SQLite via sync_map_files().

Key changes:
- Add db.py with schema, connection management, and JSON migration
- Add DB_FILE config constant
- Delegate file_ops.py functions to db.py
- Refactor domains.py to use file_ops instead of direct list manipulation
- Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError)
- Add DB health check in health.py
- Init DB on startup in server.py and __main__.py
- Update all 359 tests to use SQLite-backed functions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kappa
2026-02-08 11:07:29 +09:00
parent 05bff61b85
commit cf554f3f89
19 changed files with 1525 additions and 564 deletions

View File

@@ -1,8 +1,10 @@
"""Entry point for running haproxy_mcp as a module."""
from .db import init_db
from .server import mcp
from .tools.configuration import startup_restore
if __name__ == "__main__":
init_db()
startup_restore()
mcp.run(transport="streamable-http")

View File

@@ -31,6 +31,7 @@ WILDCARDS_MAP_FILE: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE", "/opt/haproxy/
WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map")
SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json")
CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json")
DB_FILE: str = os.getenv("HAPROXY_DB_FILE", "/opt/haproxy/conf/haproxy_mcp.db")
# Certificate paths
CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs")

577
haproxy_mcp/db.py Normal file
View File

@@ -0,0 +1,577 @@
"""SQLite database operations for HAProxy MCP Server.
Single source of truth for all persistent configuration data.
Replaces JSON files (servers.json, certificates.json) and map files
(domains.map, wildcards.map) with a single SQLite database.
Map files are generated from the database via sync_map_files() for
HAProxy to consume directly.
"""
import json
import os
import sqlite3
import threading
from typing import Any, Optional
from .config import (
DB_FILE,
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
POOL_COUNT,
REMOTE_MODE,
logger,
)
SCHEMA_VERSION = 1
_local = threading.local()
def get_connection() -> sqlite3.Connection:
"""Get a thread-local SQLite connection with WAL mode.
Returns:
sqlite3.Connection configured for concurrent access.
"""
conn = getattr(_local, "connection", None)
if conn is not None:
return conn
conn = sqlite3.connect(DB_FILE, timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
_local.connection = conn
return conn
def close_connection() -> None:
"""Close the thread-local SQLite connection."""
conn = getattr(_local, "connection", None)
if conn is not None:
conn.close()
_local.connection = None
def init_db() -> None:
"""Initialize database schema and run migration if needed.
Creates tables if they don't exist, then checks for existing
JSON/map files to migrate data from.
"""
conn = get_connection()
cur = conn.cursor()
cur.executescript("""
CREATE TABLE IF NOT EXISTS domains (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
backend TEXT NOT NULL,
is_wildcard INTEGER NOT NULL DEFAULT 0,
shares_with TEXT DEFAULT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend);
CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with);
CREATE TABLE IF NOT EXISTS servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL,
slot INTEGER NOT NULL,
ip TEXT NOT NULL,
http_port INTEGER NOT NULL DEFAULT 80,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
UNIQUE(domain, slot)
);
CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain);
CREATE TABLE IF NOT EXISTS certificates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL,
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
""")
conn.commit()
# Check schema version
cur.execute("SELECT MAX(version) FROM schema_version")
row = cur.fetchone()
current_version = row[0] if row[0] is not None else 0
if current_version < SCHEMA_VERSION:
# First-time setup: try migrating from JSON files
if current_version == 0:
migrate_from_json()
cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,))
conn.commit()
logger.info("Database initialized (schema v%d)", SCHEMA_VERSION)
def migrate_from_json() -> None:
"""Migrate data from JSON/map files to SQLite.
Reads domains.map, wildcards.map, servers.json, and certificates.json,
imports their data into the database within a single transaction,
then renames the original files to .bak.
This function is idempotent (uses INSERT OR IGNORE).
"""
conn = get_connection()
# Collect data from existing files
map_entries: list[tuple[str, str]] = []
servers_config: dict[str, Any] = {}
cert_domains: list[str] = []
# Read map files
for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]:
try:
content = _read_file_for_migration(map_path)
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
map_entries.append((parts[0], parts[1]))
except FileNotFoundError:
logger.debug("Map file not found for migration: %s", map_path)
# Read servers.json
try:
content = _read_file_for_migration(SERVERS_FILE)
servers_config = json.loads(content)
except FileNotFoundError:
logger.debug("Servers config not found for migration: %s", SERVERS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt servers config during migration: %s", e)
# Read certificates.json
try:
content = _read_file_for_migration(CERTS_FILE)
data = json.loads(content)
cert_domains = data.get("domains", [])
except FileNotFoundError:
logger.debug("Certs config not found for migration: %s", CERTS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt certs config during migration: %s", e)
if not map_entries and not servers_config and not cert_domains:
logger.info("No existing data to migrate")
return
# Import into database in a single transaction
try:
with conn:
# Import domains from map files
for domain, backend in map_entries:
is_wildcard = 1 if domain.startswith(".") else 0
conn.execute(
"INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)",
(domain, backend, is_wildcard),
)
# Import servers and shares_with from servers.json
for domain, slots in servers_config.items():
shares_with = slots.get("_shares")
if shares_with:
# Update domain's shares_with field
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
continue
for slot_str, server_info in slots.items():
if slot_str.startswith("_"):
continue
try:
slot = int(slot_str)
ip = server_info.get("ip", "")
http_port = int(server_info.get("http_port", 80))
if ip:
conn.execute(
"INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)",
(domain, slot, ip, http_port),
)
except (ValueError, TypeError) as e:
logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e)
# Import certificates
for cert_domain in cert_domains:
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(cert_domain,),
)
migrated_domains = len(set(d for d, _ in map_entries))
migrated_servers = sum(
1 for d, slots in servers_config.items()
if "_shares" not in slots
for s in slots
if not s.startswith("_")
)
migrated_certs = len(cert_domains)
logger.info(
"Migrated from JSON: %d domain entries, %d servers, %d certificates",
migrated_domains, migrated_servers, migrated_certs,
)
# Rename original files to .bak
for path in [SERVERS_FILE, CERTS_FILE]:
_backup_file(path)
except sqlite3.Error as e:
logger.error("Migration failed: %s", e)
raise
def _read_file_for_migration(file_path: str) -> str:
"""Read a file for migration purposes (local only, no locking needed)."""
if REMOTE_MODE:
from .ssh_ops import remote_read_file
return remote_read_file(file_path)
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def _backup_file(file_path: str) -> None:
"""Rename a file to .bak if it exists."""
if REMOTE_MODE:
return # Don't rename remote files
try:
if os.path.exists(file_path):
bak_path = f"{file_path}.bak"
os.rename(file_path, bak_path)
logger.info("Backed up %s -> %s", file_path, bak_path)
except OSError as e:
logger.warning("Failed to backup %s: %s", file_path, e)
# --- Domain data access ---
def db_get_map_contents() -> list[tuple[str, str]]:
"""Get all domain-to-backend mappings.
Returns:
List of (domain, backend) tuples including wildcards.
"""
conn = get_connection()
cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain")
return [(row["domain"], row["backend"]) for row in cur.fetchall()]
def db_get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a specific domain.
Args:
domain: Domain name to look up.
Returns:
Backend name if found, None otherwise.
"""
conn = get_connection()
cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,))
row = cur.fetchone()
return row["backend"] if row else None
def db_add_domain(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to the database.
Args:
domain: Domain name (e.g., "example.com" or ".example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO domains (domain, backend, is_wildcard, shares_with)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain) DO UPDATE SET
backend = excluded.backend,
shares_with = excluded.shares_with,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, backend, 1 if is_wildcard else 0, shares_with),
)
conn.commit()
def db_remove_domain(domain: str) -> None:
"""Remove a domain (exact + wildcard) from the database.
Args:
domain: Base domain name (without leading dot).
"""
conn = get_connection()
conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}"))
conn.commit()
def db_find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
conn = get_connection()
cur = conn.execute(
"SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0"
)
used_pools = {row["backend"] for row in cur.fetchall()}
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def db_get_domains_sharing_pool(pool: str) -> list[str]:
"""Get all non-wildcard domains using a specific pool.
Args:
pool: Pool name (e.g., "pool_5").
Returns:
List of domain names.
"""
conn = get_connection()
cur = conn.execute(
"SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0",
(pool,),
)
return [row["domain"] for row in cur.fetchall()]
# --- Server data access ---
def db_load_servers_config() -> dict[str, Any]:
"""Load servers configuration in the legacy dict format.
Returns:
Dictionary compatible with the old servers.json format:
{"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}}
"""
conn = get_connection()
config: dict[str, Any] = {}
# Load server entries
cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot")
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain][str(row["slot"])] = {
"ip": row["ip"],
"http_port": row["http_port"],
}
# Load shared domain references
cur = conn.execute(
"SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0"
)
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain]["_shares"] = row["shares_with"]
return config
def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add or update a server entry.
Args:
domain: Domain name.
slot: Server slot number.
ip: Server IP address.
http_port: HTTP port.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO servers (domain, slot, ip, http_port)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain, slot) DO UPDATE SET
ip = excluded.ip,
http_port = excluded.http_port,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, slot, ip, http_port),
)
conn.commit()
def db_remove_server(domain: str, slot: int) -> None:
"""Remove a server entry.
Args:
domain: Domain name.
slot: Server slot number.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot))
conn.commit()
def db_remove_domain_servers(domain: str) -> None:
"""Remove all server entries for a domain.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ?", (domain,))
conn.commit()
def db_add_shared_domain(domain: str, shares_with: str) -> None:
"""Mark a domain as sharing a pool with another domain.
Updates the shares_with field in the domains table.
Args:
domain: The new domain.
shares_with: The existing domain to share with.
"""
conn = get_connection()
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
conn.commit()
def db_get_shared_domain(domain: str) -> Optional[str]:
"""Get the parent domain that this domain shares a pool with.
Args:
domain: Domain name.
Returns:
Parent domain name if sharing, None otherwise.
"""
conn = get_connection()
cur = conn.execute(
"SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0",
(domain,),
)
row = cur.fetchone()
if row and row["shares_with"]:
return row["shares_with"]
return None
def db_is_shared_domain(domain: str) -> bool:
"""Check if a domain is sharing another domain's pool.
Args:
domain: Domain name.
Returns:
True if domain has shares_with reference.
"""
return db_get_shared_domain(domain) is not None
# --- Certificate data access ---
def db_load_certs() -> list[str]:
"""Load certificate domain list.
Returns:
Sorted list of domain names with certificates.
"""
conn = get_connection()
cur = conn.execute("SELECT domain FROM certificates ORDER BY domain")
return [row["domain"] for row in cur.fetchall()]
def db_add_cert(domain: str) -> None:
"""Add a domain to the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(domain,),
)
conn.commit()
def db_remove_cert(domain: str) -> None:
"""Remove a domain from the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,))
conn.commit()
# --- Map file synchronization ---
def sync_map_files() -> None:
"""Regenerate HAProxy map files from database.
Writes domains.map (exact matches) and wildcards.map (wildcard entries)
from the current database state. Uses atomic writes.
"""
from .file_ops import atomic_write_file
conn = get_connection()
# Fetch exact domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain"
)
exact_entries = cur.fetchall()
# Fetch wildcard domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain"
)
wildcard_entries = cur.fetchall()
# Write exact domains map
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for row in exact_entries:
exact_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Write wildcards map
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for row in wildcard_entries:
wildcard_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
logger.debug("Map files synced: %d exact, %d wildcard entries",
len(exact_entries), len(wildcard_entries))

View File

@@ -1,7 +1,11 @@
"""File I/O operations for HAProxy MCP Server."""
"""File I/O operations for HAProxy MCP Server.
Most data access is now delegated to db.py (SQLite).
This module retains atomic file writes, map file I/O for HAProxy,
and provides backward-compatible function signatures.
"""
import fcntl
import json
import os
import tempfile
from contextlib import contextmanager
@@ -10,8 +14,6 @@ from typing import Any, Generator, Optional
from .config import (
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
REMOTE_MODE,
logger,
)
@@ -138,16 +140,13 @@ def _read_file(file_path: str) -> str:
def get_map_contents() -> list[tuple[str, str]]:
"""Read both domains.map and wildcards.map and return combined entries.
"""Get all domain-to-backend mappings from SQLite.
Returns:
List of (domain, backend) tuples from both map files
List of (domain, backend) tuples including wildcards.
"""
# Read exact domains
entries = _read_map_file(MAP_FILE)
# Read wildcards and append
entries.extend(_read_map_file(WILDCARDS_MAP_FILE))
return entries
from .db import db_get_map_contents
return db_get_map_contents()
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
@@ -170,44 +169,21 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str
def save_map_file(entries: list[tuple[str, str]]) -> None:
"""Save domain-to-backend entries to map files.
"""Sync map files from the database.
Splits entries into two files for 2-stage routing:
- domains.map: Exact matches (map_str, O(log n))
- wildcards.map: Wildcard entries starting with "." (map_dom, O(n))
Args:
entries: List of (domain, backend) tuples.
Regenerates domains.map and wildcards.map from the current
database state. The entries parameter is ignored (kept for
backward compatibility during transition).
Raises:
IOError: If map files cannot be written.
"""
# Split into exact and wildcard entries
exact_entries, wildcard_entries = split_domain_entries(entries)
# Save exact domains (for map_str - fast O(log n) lookup)
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for domain, backend in sorted(exact_entries):
exact_lines.append(f"{domain} {backend}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Save wildcards (for map_dom - O(n) but small set)
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for domain, backend in sorted(wildcard_entries):
wildcard_lines.append(f"{domain} {backend}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
from .db import sync_map_files
sync_map_files()
def get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a domain from domains.map.
"""Look up the backend for a domain from SQLite (O(1)).
Args:
domain: The domain to look up
@@ -215,10 +191,8 @@ def get_domain_backend(domain: str) -> Optional[str]:
Returns:
Backend name if found, None otherwise
"""
for map_domain, backend in get_map_contents():
if map_domain == domain:
return backend
return None
from .db import db_get_domain_backend
return db_get_domain_backend(domain)
def is_legacy_backend(backend: str) -> bool:
@@ -273,34 +247,17 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]:
def load_servers_config() -> dict[str, Any]:
"""Load servers configuration from JSON file.
"""Load servers configuration from SQLite.
Returns:
Dictionary with server configurations
Dictionary with server configurations (legacy format compatible).
"""
try:
content = _read_file(SERVERS_FILE)
return json.loads(content)
except FileNotFoundError:
return {}
except json.JSONDecodeError as e:
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
return {}
def save_servers_config(config: dict[str, Any]) -> None:
"""Save servers configuration to JSON file atomically.
Uses temp file + rename for atomic write to prevent race conditions.
Args:
config: Dictionary with server configurations
"""
atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2))
from .db import db_load_servers_config
return db_load_servers_config()
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add server configuration to persistent storage with file locking.
"""Add server configuration to persistent storage.
Args:
domain: Domain name
@@ -308,41 +265,29 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non
ip: Server IP address
http_port: HTTP port
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain not in config:
config[domain] = {}
config[domain][str(slot)] = {"ip": ip, "http_port": http_port}
save_servers_config(config)
from .db import db_add_server
db_add_server(domain, slot, ip, http_port)
def remove_server_from_config(domain: str, slot: int) -> None:
"""Remove server configuration from persistent storage with file locking.
"""Remove server configuration from persistent storage.
Args:
domain: Domain name
slot: Server slot to remove
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain in config and str(slot) in config[domain]:
del config[domain][str(slot)]
if not config[domain]:
del config[domain]
save_servers_config(config)
from .db import db_remove_server
db_remove_server(domain, slot)
def remove_domain_from_config(domain: str) -> None:
"""Remove domain from persistent config with file locking.
"""Remove domain from persistent config (servers + domain entry).
Args:
domain: Domain name to remove
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
if domain in config:
del config[domain]
save_servers_config(config)
from .db import db_remove_domain_servers
db_remove_domain_servers(domain)
def get_shared_domain(domain: str) -> Optional[str]:
@@ -354,9 +299,8 @@ def get_shared_domain(domain: str) -> Optional[str]:
Returns:
Parent domain name if sharing, None otherwise
"""
config = load_servers_config()
domain_config = config.get(domain, {})
return domain_config.get("_shares")
from .db import db_get_shared_domain
return db_get_shared_domain(domain)
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
@@ -366,10 +310,8 @@ def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
domain: New domain name
shares_with: Existing domain to share pool with
"""
with file_lock(f"{SERVERS_FILE}.lock"):
config = load_servers_config()
config[domain] = {"_shares": shares_with}
save_servers_config(config)
from .db import db_add_shared_domain
db_add_shared_domain(domain, shares_with)
def get_domains_sharing_pool(pool: str) -> list[str]:
@@ -381,11 +323,8 @@ def get_domains_sharing_pool(pool: str) -> list[str]:
Returns:
List of domain names using this pool
"""
domains = []
for domain, backend in get_map_contents():
if backend == pool and not domain.startswith("."):
domains.append(domain)
return domains
from .db import db_get_domains_sharing_pool
return db_get_domains_sharing_pool(pool)
def is_shared_domain(domain: str) -> bool:
@@ -397,37 +336,20 @@ def is_shared_domain(domain: str) -> bool:
Returns:
True if domain has _shares reference, False otherwise
"""
config = load_servers_config()
domain_config = config.get(domain, {})
return "_shares" in domain_config
from .db import db_is_shared_domain
return db_is_shared_domain(domain)
# Certificate configuration functions
def load_certs_config() -> list[str]:
"""Load certificate domain list from JSON file.
"""Load certificate domain list from SQLite.
Returns:
List of domain names
Sorted list of domain names.
"""
try:
content = _read_file(CERTS_FILE)
data = json.loads(content)
return data.get("domains", [])
except FileNotFoundError:
return []
except json.JSONDecodeError as e:
logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e)
return []
def save_certs_config(domains: list[str]) -> None:
"""Save certificate domain list to JSON file atomically.
Args:
domains: List of domain names
"""
atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2))
from .db import db_load_certs
return db_load_certs()
def add_cert_to_config(domain: str) -> None:
@@ -436,11 +358,8 @@ def add_cert_to_config(domain: str) -> None:
Args:
domain: Domain name to add
"""
with file_lock(f"{CERTS_FILE}.lock"):
domains = load_certs_config()
if domain not in domains:
domains.append(domain)
save_certs_config(domains)
from .db import db_add_cert
db_add_cert(domain)
def remove_cert_from_config(domain: str) -> None:
@@ -449,8 +368,45 @@ def remove_cert_from_config(domain: str) -> None:
Args:
domain: Domain name to remove
"""
with file_lock(f"{CERTS_FILE}.lock"):
domains = load_certs_config()
if domain in domains:
domains.remove(domain)
save_certs_config(domains)
from .db import db_remove_cert
db_remove_cert(domain)
# Domain map helper functions (used by domains.py)
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to SQLite and sync map files.
Args:
domain: Domain name (e.g., "example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
from .db import db_add_domain, sync_map_files
db_add_domain(domain, backend, is_wildcard, shares_with)
sync_map_files()
def remove_domain_from_map(domain: str) -> None:
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
Args:
domain: Base domain name (without leading dot).
"""
from .db import db_remove_domain, sync_map_files
db_remove_domain(domain)
sync_map_files()
def find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
from .db import db_find_available_pool
return db_find_available_pool()

View File

@@ -2,6 +2,7 @@
import socket
import select
import subprocess
import time
from .config import (
@@ -161,7 +162,7 @@ def reload_haproxy() -> tuple[bool, str]:
if result.returncode != 0:
return False, f"Reload failed: {result.stderr}"
return True, "OK"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds"
except FileNotFoundError:
return False, "ssh/podman command not found"

View File

@@ -21,6 +21,7 @@ Environment Variables:
from mcp.server.fastmcp import FastMCP
from .config import MCP_HOST, MCP_PORT
from .db import init_db
from .tools import register_all_tools
from .tools.configuration import startup_restore
@@ -32,5 +33,6 @@ register_all_tools(mcp)
if __name__ == "__main__":
init_db()
startup_restore()
mcp.run(transport="streamable-http")

View File

@@ -1,6 +1,7 @@
"""Certificate management tools for HAProxy MCP Server."""
import os
import subprocess
from datetime import datetime
from typing import Annotated
@@ -152,7 +153,7 @@ def _haproxy_list_certs_impl() -> str:
certs.append(f"{domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
return "\n\n".join(certs) if certs else "No certificates found"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Command timed out"
except FileNotFoundError:
return "Error: acme.sh not found"
@@ -203,7 +204,7 @@ def _haproxy_cert_info_impl(domain: str) -> str:
result.stdout.strip()
]
return "\n".join(info)
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Command timed out"
except OSError as e:
logger.error("Error getting certificate info for %s: %s", domain, e)
@@ -250,7 +251,7 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
else:
return f"Certificate issued but PEM file not created. Check {host_path}"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
except OSError as e:
logger.error("Error issuing certificate for %s: %s", domain, e)
@@ -289,7 +290,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
else:
return f"Error renewing certificate:\n{output}"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
except FileNotFoundError:
return "Error: acme.sh not found"
@@ -323,7 +324,7 @@ def _haproxy_renew_all_certs_impl() -> str:
else:
return "Renewal check completed"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return "Error: Renewal cron timed out"
except FileNotFoundError:
return "Error: acme.sh not found"

View File

@@ -1,5 +1,6 @@
"""Configuration management tools for HAProxy MCP Server."""
import subprocess
import time
from ..config import (
@@ -177,7 +178,7 @@ def register_config_tools(mcp):
if result.returncode == 0:
return "Configuration is valid"
return f"Configuration errors:\n{result.stderr}"
except TimeoutError:
except (TimeoutError, subprocess.TimeoutExpired):
return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds"
except FileNotFoundError:
return "Error: ssh/podman command not found"

View File

@@ -6,7 +6,6 @@ from typing import Annotated, Optional
from pydantic import Field
from ..config import (
MAP_FILE,
MAP_FILE_CONTAINER,
WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT,
@@ -22,7 +21,6 @@ from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd
from ..file_ops import (
get_map_contents,
save_map_file,
get_domain_backend,
is_legacy_backend,
add_server_to_config,
@@ -31,30 +29,13 @@ from ..file_ops import (
add_shared_domain_to_config,
get_domains_sharing_pool,
is_shared_domain,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
)
from ..utils import parse_servers_state, disable_server_slot
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
"""Find an available pool backend from the pool list.
Iterates through pool_1 to pool_N and returns the first pool
that is not currently in use.
Args:
entries: List of (domain, backend) tuples from the map file.
used_pools: Set of pool names already in use.
Returns:
Available pool name (e.g., "pool_5") or None if all pools are in use.
"""
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
"""Check if a domain is a subdomain of an existing registered domain.
@@ -95,24 +76,19 @@ def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
def _rollback_domain_addition(
domain: str,
entries: list[tuple[str, str]]
) -> None:
"""Rollback a failed domain addition by removing entries from map file.
def _rollback_domain_addition(domain: str) -> None:
"""Rollback a failed domain addition by removing from SQLite + map files.
Called when HAProxy Runtime API update fails after the map file
has already been saved.
Called when HAProxy Runtime API update fails after the domain
has already been saved to the database.
Args:
domain: Domain name that was added.
entries: Current list of map entries to rollback from.
"""
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
try:
save_map_file(rollback_entries)
except IOError:
logger.error("Failed to rollback map file after HAProxy error")
remove_domain_from_map(domain)
except (IOError, Exception):
logger.error("Failed to rollback domain %s after HAProxy error", domain)
def _file_exists(path: str) -> bool:
@@ -242,93 +218,86 @@ def register_domain_tools(mcp):
if share_with and ip:
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
# Use file locking for the entire pool allocation operation
from ..file_ops import file_lock
with file_lock(f"{MAP_FILE}.lock"):
# Read map contents once for both existence check and pool lookup
entries = get_map_contents()
# Read current entries for existence check and subdomain detection
entries = get_map_contents()
# Check if domain already exists (using cached entries)
for domain_entry, backend in entries:
if domain_entry == domain:
return f"Error: Domain {domain} already exists (mapped to {backend})"
# Check if domain already exists
for domain_entry, backend in entries:
if domain_entry == domain:
return f"Error: Domain {domain} already exists (mapped to {backend})"
# Build used pools and registered domains sets
used_pools: set[str] = set()
registered_domains: set[str] = set()
for entry_domain, backend in entries:
if backend.startswith("pool_"):
used_pools.add(backend)
if not entry_domain.startswith("."):
registered_domains.add(entry_domain)
# Build registered domains set for subdomain check
registered_domains: set[str] = set()
for entry_domain, _ in entries:
if not entry_domain.startswith("."):
registered_domains.add(entry_domain)
# Handle share_with: reuse existing domain's pool
if share_with:
share_backend = get_domain_backend(share_with)
if not share_backend:
return f"Error: Domain {share_with} not found"
if not share_backend.startswith("pool_"):
return f"Error: Cannot share with legacy backend {share_backend}"
pool = share_backend
else:
# Find available pool
pool = _find_available_pool(entries, used_pools)
if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use"
# Handle share_with: reuse existing domain's pool
if share_with:
share_backend = get_domain_backend(share_with)
if not share_backend:
return f"Error: Domain {share_with} not found"
if not share_backend.startswith("pool_"):
return f"Error: Cannot share with legacy backend {share_backend}"
pool = share_backend
else:
# Find available pool (SQLite query, O(1))
pool = find_available_pool()
if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use"
# Check if this is a subdomain of an existing domain
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
# Check if this is a subdomain of an existing domain
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
try:
# Save to SQLite + sync map files (atomic via SQLite transaction)
try:
# Save to disk first (atomic write for persistence)
entries.append((domain, pool))
add_domain_to_map(domain, pool)
if not is_subdomain:
entries.append((f".{domain}", pool))
try:
save_map_file(entries)
except IOError as e:
return f"Error: Failed to save map file: {e}"
# Update HAProxy maps via Runtime API
try:
_update_haproxy_maps(domain, pool, is_subdomain)
except HaproxyError as e:
_rollback_domain_addition(domain, entries)
return f"Error: Failed to update HAProxy map: {e}"
# Handle server configuration based on mode
if share_with:
# Save shared domain reference
add_shared_domain_to_config(domain, share_with)
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
elif ip:
# Add server to slot 1
add_server_to_config(domain, 1, ip, http_port)
try:
server = f"{pool}_1"
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
haproxy_cmd(f"set server {pool}/{server} state ready")
except HaproxyError as e:
remove_server_from_config(domain, 1)
return f"Domain {domain} added to {pool} but server config failed: {e}"
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
else:
result = f"Domain {domain} added to {pool} (no servers configured)"
if is_subdomain:
result += f" (subdomain of {parent_domain}, no wildcard)"
# Check certificate coverage
cert_covered, cert_info = check_certificate_coverage(domain)
if cert_covered:
result += f"\nSSL: Using certificate {cert_info}"
else:
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
return result
add_domain_to_map(f".{domain}", pool, is_wildcard=True)
except (IOError, Exception) as e:
return f"Error: Failed to save domain: {e}"
# Update HAProxy maps via Runtime API
try:
_update_haproxy_maps(domain, pool, is_subdomain)
except HaproxyError as e:
return f"Error: {e}"
_rollback_domain_addition(domain)
return f"Error: Failed to update HAProxy map: {e}"
# Handle server configuration based on mode
if share_with:
# Save shared domain reference
add_shared_domain_to_config(domain, share_with)
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
elif ip:
# Add server to slot 1
add_server_to_config(domain, 1, ip, http_port)
try:
server = f"{pool}_1"
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
haproxy_cmd(f"set server {pool}/{server} state ready")
except HaproxyError as e:
remove_server_from_config(domain, 1)
return f"Domain {domain} added to {pool} but server config failed: {e}"
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
else:
result = f"Domain {domain} added to {pool} (no servers configured)"
if is_subdomain:
result += f" (subdomain of {parent_domain}, no wildcard)"
# Check certificate coverage
cert_covered, cert_info = check_certificate_coverage(domain)
if cert_covered:
result += f"\nSSL: Using certificate {cert_info}"
else:
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
return result
except HaproxyError as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_remove_domain(
@@ -355,10 +324,8 @@ def register_domain_tools(mcp):
domains_using_pool = get_domains_sharing_pool(backend)
other_domains = [d for d in domains_using_pool if d != domain]
# Save to disk first (atomic write for persistence)
entries = get_map_contents()
new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
save_map_file(new_entries)
# Remove from SQLite + sync map files
remove_domain_from_map(domain)
# Remove from persistent server config
remove_domain_from_config(domain)

View File

@@ -10,6 +10,7 @@ from pydantic import Field
from ..config import (
MAP_FILE,
SERVERS_FILE,
DB_FILE,
HAPROXY_CONTAINER,
)
from ..exceptions import HaproxyError
@@ -88,7 +89,7 @@ def register_health_tools(mcp):
# Check configuration files
files_ok = True
file_status: dict[str, str] = {}
for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]:
for name, path in [("map_file", MAP_FILE), ("db_file", DB_FILE)]:
exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path)
if exists:
file_status[name] = "ok"