refactor: migrate data storage from JSON/map files to SQLite
Replace servers.json, certificates.json, and map file parsing with SQLite (WAL mode) as single source of truth. HAProxy map files are now generated from SQLite via sync_map_files(). Key changes: - Add db.py with schema, connection management, and JSON migration - Add DB_FILE config constant - Delegate file_ops.py functions to db.py - Refactor domains.py to use file_ops instead of direct list manipulation - Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError) - Add DB health check in health.py - Init DB on startup in server.py and __main__.py - Update all 359 tests to use SQLite-backed functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
"""Entry point for running haproxy_mcp as a module."""
|
||||
|
||||
from .db import init_db
|
||||
from .server import mcp
|
||||
from .tools.configuration import startup_restore
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db()
|
||||
startup_restore()
|
||||
mcp.run(transport="streamable-http")
|
||||
|
||||
@@ -31,6 +31,7 @@ WILDCARDS_MAP_FILE: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE", "/opt/haproxy/
|
||||
WILDCARDS_MAP_FILE_CONTAINER: str = os.getenv("HAPROXY_WILDCARDS_MAP_FILE_CONTAINER", "/usr/local/etc/haproxy/wildcards.map")
|
||||
SERVERS_FILE: str = os.getenv("HAPROXY_SERVERS_FILE", "/opt/haproxy/conf/servers.json")
|
||||
CERTS_FILE: str = os.getenv("HAPROXY_CERTS_FILE", "/opt/haproxy/conf/certificates.json")
|
||||
DB_FILE: str = os.getenv("HAPROXY_DB_FILE", "/opt/haproxy/conf/haproxy_mcp.db")
|
||||
|
||||
# Certificate paths
|
||||
CERTS_DIR: str = os.getenv("HAPROXY_CERTS_DIR", "/opt/haproxy/certs")
|
||||
|
||||
577
haproxy_mcp/db.py
Normal file
577
haproxy_mcp/db.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""SQLite database operations for HAProxy MCP Server.
|
||||
|
||||
Single source of truth for all persistent configuration data.
|
||||
Replaces JSON files (servers.json, certificates.json) and map files
|
||||
(domains.map, wildcards.map) with a single SQLite database.
|
||||
|
||||
Map files are generated from the database via sync_map_files() for
|
||||
HAProxy to consume directly.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
from .config import (
|
||||
DB_FILE,
|
||||
MAP_FILE,
|
||||
WILDCARDS_MAP_FILE,
|
||||
SERVERS_FILE,
|
||||
CERTS_FILE,
|
||||
POOL_COUNT,
|
||||
REMOTE_MODE,
|
||||
logger,
|
||||
)
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
|
||||
def get_connection() -> sqlite3.Connection:
|
||||
"""Get a thread-local SQLite connection with WAL mode.
|
||||
|
||||
Returns:
|
||||
sqlite3.Connection configured for concurrent access.
|
||||
"""
|
||||
conn = getattr(_local, "connection", None)
|
||||
if conn is not None:
|
||||
return conn
|
||||
|
||||
conn = sqlite3.connect(DB_FILE, timeout=10)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
_local.connection = conn
|
||||
return conn
|
||||
|
||||
|
||||
def close_connection() -> None:
|
||||
"""Close the thread-local SQLite connection."""
|
||||
conn = getattr(_local, "connection", None)
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
_local.connection = None
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize database schema and run migration if needed.
|
||||
|
||||
Creates tables if they don't exist, then checks for existing
|
||||
JSON/map files to migrate data from.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS domains (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain TEXT NOT NULL UNIQUE,
|
||||
backend TEXT NOT NULL,
|
||||
is_wildcard INTEGER NOT NULL DEFAULT 0,
|
||||
shares_with TEXT DEFAULT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend);
|
||||
CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS servers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain TEXT NOT NULL,
|
||||
slot INTEGER NOT NULL,
|
||||
ip TEXT NOT NULL,
|
||||
http_port INTEGER NOT NULL DEFAULT 80,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
UNIQUE(domain, slot)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS certificates (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
# Check schema version
|
||||
cur.execute("SELECT MAX(version) FROM schema_version")
|
||||
row = cur.fetchone()
|
||||
current_version = row[0] if row[0] is not None else 0
|
||||
|
||||
if current_version < SCHEMA_VERSION:
|
||||
# First-time setup: try migrating from JSON files
|
||||
if current_version == 0:
|
||||
migrate_from_json()
|
||||
cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,))
|
||||
conn.commit()
|
||||
|
||||
logger.info("Database initialized (schema v%d)", SCHEMA_VERSION)
|
||||
|
||||
|
||||
def migrate_from_json() -> None:
|
||||
"""Migrate data from JSON/map files to SQLite.
|
||||
|
||||
Reads domains.map, wildcards.map, servers.json, and certificates.json,
|
||||
imports their data into the database within a single transaction,
|
||||
then renames the original files to .bak.
|
||||
|
||||
This function is idempotent (uses INSERT OR IGNORE).
|
||||
"""
|
||||
conn = get_connection()
|
||||
|
||||
# Collect data from existing files
|
||||
map_entries: list[tuple[str, str]] = []
|
||||
servers_config: dict[str, Any] = {}
|
||||
cert_domains: list[str] = []
|
||||
|
||||
# Read map files
|
||||
for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]:
|
||||
try:
|
||||
content = _read_file_for_migration(map_path)
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
map_entries.append((parts[0], parts[1]))
|
||||
except FileNotFoundError:
|
||||
logger.debug("Map file not found for migration: %s", map_path)
|
||||
|
||||
# Read servers.json
|
||||
try:
|
||||
content = _read_file_for_migration(SERVERS_FILE)
|
||||
servers_config = json.loads(content)
|
||||
except FileNotFoundError:
|
||||
logger.debug("Servers config not found for migration: %s", SERVERS_FILE)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Corrupt servers config during migration: %s", e)
|
||||
|
||||
# Read certificates.json
|
||||
try:
|
||||
content = _read_file_for_migration(CERTS_FILE)
|
||||
data = json.loads(content)
|
||||
cert_domains = data.get("domains", [])
|
||||
except FileNotFoundError:
|
||||
logger.debug("Certs config not found for migration: %s", CERTS_FILE)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Corrupt certs config during migration: %s", e)
|
||||
|
||||
if not map_entries and not servers_config and not cert_domains:
|
||||
logger.info("No existing data to migrate")
|
||||
return
|
||||
|
||||
# Import into database in a single transaction
|
||||
try:
|
||||
with conn:
|
||||
# Import domains from map files
|
||||
for domain, backend in map_entries:
|
||||
is_wildcard = 1 if domain.startswith(".") else 0
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)",
|
||||
(domain, backend, is_wildcard),
|
||||
)
|
||||
|
||||
# Import servers and shares_with from servers.json
|
||||
for domain, slots in servers_config.items():
|
||||
shares_with = slots.get("_shares")
|
||||
if shares_with:
|
||||
# Update domain's shares_with field
|
||||
conn.execute(
|
||||
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
|
||||
(shares_with, domain),
|
||||
)
|
||||
continue
|
||||
|
||||
for slot_str, server_info in slots.items():
|
||||
if slot_str.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
slot = int(slot_str)
|
||||
ip = server_info.get("ip", "")
|
||||
http_port = int(server_info.get("http_port", 80))
|
||||
if ip:
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)",
|
||||
(domain, slot, ip, http_port),
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e)
|
||||
|
||||
# Import certificates
|
||||
for cert_domain in cert_domains:
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
|
||||
(cert_domain,),
|
||||
)
|
||||
|
||||
migrated_domains = len(set(d for d, _ in map_entries))
|
||||
migrated_servers = sum(
|
||||
1 for d, slots in servers_config.items()
|
||||
if "_shares" not in slots
|
||||
for s in slots
|
||||
if not s.startswith("_")
|
||||
)
|
||||
migrated_certs = len(cert_domains)
|
||||
logger.info(
|
||||
"Migrated from JSON: %d domain entries, %d servers, %d certificates",
|
||||
migrated_domains, migrated_servers, migrated_certs,
|
||||
)
|
||||
|
||||
# Rename original files to .bak
|
||||
for path in [SERVERS_FILE, CERTS_FILE]:
|
||||
_backup_file(path)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Migration failed: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _read_file_for_migration(file_path: str) -> str:
|
||||
"""Read a file for migration purposes (local only, no locking needed)."""
|
||||
if REMOTE_MODE:
|
||||
from .ssh_ops import remote_read_file
|
||||
return remote_read_file(file_path)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _backup_file(file_path: str) -> None:
|
||||
"""Rename a file to .bak if it exists."""
|
||||
if REMOTE_MODE:
|
||||
return # Don't rename remote files
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
bak_path = f"{file_path}.bak"
|
||||
os.rename(file_path, bak_path)
|
||||
logger.info("Backed up %s -> %s", file_path, bak_path)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to backup %s: %s", file_path, e)
|
||||
|
||||
|
||||
# --- Domain data access ---
|
||||
|
||||
def db_get_map_contents() -> list[tuple[str, str]]:
|
||||
"""Get all domain-to-backend mappings.
|
||||
|
||||
Returns:
|
||||
List of (domain, backend) tuples including wildcards.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain")
|
||||
return [(row["domain"], row["backend"]) for row in cur.fetchall()]
|
||||
|
||||
|
||||
def db_get_domain_backend(domain: str) -> Optional[str]:
|
||||
"""Look up the backend for a specific domain.
|
||||
|
||||
Args:
|
||||
domain: Domain name to look up.
|
||||
|
||||
Returns:
|
||||
Backend name if found, None otherwise.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,))
|
||||
row = cur.fetchone()
|
||||
return row["backend"] if row else None
|
||||
|
||||
|
||||
def db_add_domain(domain: str, backend: str, is_wildcard: bool = False,
|
||||
shares_with: Optional[str] = None) -> None:
|
||||
"""Add a domain to the database.
|
||||
|
||||
Args:
|
||||
domain: Domain name (e.g., "example.com" or ".example.com").
|
||||
backend: Backend pool name (e.g., "pool_5").
|
||||
is_wildcard: Whether this is a wildcard entry.
|
||||
shares_with: Parent domain if sharing a pool.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"""INSERT INTO domains (domain, backend, is_wildcard, shares_with)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(domain) DO UPDATE SET
|
||||
backend = excluded.backend,
|
||||
shares_with = excluded.shares_with,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
|
||||
(domain, backend, 1 if is_wildcard else 0, shares_with),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_remove_domain(domain: str) -> None:
|
||||
"""Remove a domain (exact + wildcard) from the database.
|
||||
|
||||
Args:
|
||||
domain: Base domain name (without leading dot).
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_find_available_pool() -> Optional[str]:
|
||||
"""Find the first available pool not assigned to any domain.
|
||||
|
||||
Returns:
|
||||
Pool name (e.g., "pool_5") or None if all pools are in use.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.execute(
|
||||
"SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0"
|
||||
)
|
||||
used_pools = {row["backend"] for row in cur.fetchall()}
|
||||
|
||||
for i in range(1, POOL_COUNT + 1):
|
||||
pool_name = f"pool_{i}"
|
||||
if pool_name not in used_pools:
|
||||
return pool_name
|
||||
return None
|
||||
|
||||
|
||||
def db_get_domains_sharing_pool(pool: str) -> list[str]:
|
||||
"""Get all non-wildcard domains using a specific pool.
|
||||
|
||||
Args:
|
||||
pool: Pool name (e.g., "pool_5").
|
||||
|
||||
Returns:
|
||||
List of domain names.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.execute(
|
||||
"SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0",
|
||||
(pool,),
|
||||
)
|
||||
return [row["domain"] for row in cur.fetchall()]
|
||||
|
||||
|
||||
# --- Server data access ---
|
||||
|
||||
def db_load_servers_config() -> dict[str, Any]:
|
||||
"""Load servers configuration in the legacy dict format.
|
||||
|
||||
Returns:
|
||||
Dictionary compatible with the old servers.json format:
|
||||
{"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}}
|
||||
"""
|
||||
conn = get_connection()
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
# Load server entries
|
||||
cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot")
|
||||
for row in cur.fetchall():
|
||||
domain = row["domain"]
|
||||
if domain not in config:
|
||||
config[domain] = {}
|
||||
config[domain][str(row["slot"])] = {
|
||||
"ip": row["ip"],
|
||||
"http_port": row["http_port"],
|
||||
}
|
||||
|
||||
# Load shared domain references
|
||||
cur = conn.execute(
|
||||
"SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0"
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
domain = row["domain"]
|
||||
if domain not in config:
|
||||
config[domain] = {}
|
||||
config[domain]["_shares"] = row["shares_with"]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None:
|
||||
"""Add or update a server entry.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
slot: Server slot number.
|
||||
ip: Server IP address.
|
||||
http_port: HTTP port.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"""INSERT INTO servers (domain, slot, ip, http_port)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(domain, slot) DO UPDATE SET
|
||||
ip = excluded.ip,
|
||||
http_port = excluded.http_port,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
|
||||
(domain, slot, ip, http_port),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_remove_server(domain: str, slot: int) -> None:
|
||||
"""Remove a server entry.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
slot: Server slot number.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_remove_domain_servers(domain: str) -> None:
|
||||
"""Remove all server entries for a domain.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute("DELETE FROM servers WHERE domain = ?", (domain,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_add_shared_domain(domain: str, shares_with: str) -> None:
|
||||
"""Mark a domain as sharing a pool with another domain.
|
||||
|
||||
Updates the shares_with field in the domains table.
|
||||
|
||||
Args:
|
||||
domain: The new domain.
|
||||
shares_with: The existing domain to share with.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
|
||||
(shares_with, domain),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_get_shared_domain(domain: str) -> Optional[str]:
|
||||
"""Get the parent domain that this domain shares a pool with.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
|
||||
Returns:
|
||||
Parent domain name if sharing, None otherwise.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.execute(
|
||||
"SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0",
|
||||
(domain,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row["shares_with"]:
|
||||
return row["shares_with"]
|
||||
return None
|
||||
|
||||
|
||||
def db_is_shared_domain(domain: str) -> bool:
|
||||
"""Check if a domain is sharing another domain's pool.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
|
||||
Returns:
|
||||
True if domain has shares_with reference.
|
||||
"""
|
||||
return db_get_shared_domain(domain) is not None
|
||||
|
||||
|
||||
# --- Certificate data access ---
|
||||
|
||||
def db_load_certs() -> list[str]:
|
||||
"""Load certificate domain list.
|
||||
|
||||
Returns:
|
||||
Sorted list of domain names with certificates.
|
||||
"""
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT domain FROM certificates ORDER BY domain")
|
||||
return [row["domain"] for row in cur.fetchall()]
|
||||
|
||||
|
||||
def db_add_cert(domain: str) -> None:
|
||||
"""Add a domain to the certificate list.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
|
||||
(domain,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def db_remove_cert(domain: str) -> None:
|
||||
"""Remove a domain from the certificate list.
|
||||
|
||||
Args:
|
||||
domain: Domain name.
|
||||
"""
|
||||
conn = get_connection()
|
||||
conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# --- Map file synchronization ---
|
||||
|
||||
def sync_map_files() -> None:
|
||||
"""Regenerate HAProxy map files from database.
|
||||
|
||||
Writes domains.map (exact matches) and wildcards.map (wildcard entries)
|
||||
from the current database state. Uses atomic writes.
|
||||
"""
|
||||
from .file_ops import atomic_write_file
|
||||
|
||||
conn = get_connection()
|
||||
|
||||
# Fetch exact domains
|
||||
cur = conn.execute(
|
||||
"SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain"
|
||||
)
|
||||
exact_entries = cur.fetchall()
|
||||
|
||||
# Fetch wildcard domains
|
||||
cur = conn.execute(
|
||||
"SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain"
|
||||
)
|
||||
wildcard_entries = cur.fetchall()
|
||||
|
||||
# Write exact domains map
|
||||
exact_lines = [
|
||||
"# Exact Domain to Backend mapping (for map_str)\n",
|
||||
"# Format: domain backend_name\n",
|
||||
"# Uses ebtree for O(log n) lookup performance\n\n",
|
||||
]
|
||||
for row in exact_entries:
|
||||
exact_lines.append(f"{row['domain']} {row['backend']}\n")
|
||||
atomic_write_file(MAP_FILE, "".join(exact_lines))
|
||||
|
||||
# Write wildcards map
|
||||
wildcard_lines = [
|
||||
"# Wildcard Domain to Backend mapping (for map_dom)\n",
|
||||
"# Format: .domain.com backend_name (matches *.domain.com)\n",
|
||||
"# Uses map_dom for suffix matching\n\n",
|
||||
]
|
||||
for row in wildcard_entries:
|
||||
wildcard_lines.append(f"{row['domain']} {row['backend']}\n")
|
||||
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
|
||||
|
||||
logger.debug("Map files synced: %d exact, %d wildcard entries",
|
||||
len(exact_entries), len(wildcard_entries))
|
||||
@@ -1,7 +1,11 @@
|
||||
"""File I/O operations for HAProxy MCP Server."""
|
||||
"""File I/O operations for HAProxy MCP Server.
|
||||
|
||||
Most data access is now delegated to db.py (SQLite).
|
||||
This module retains atomic file writes, map file I/O for HAProxy,
|
||||
and provides backward-compatible function signatures.
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
@@ -10,8 +14,6 @@ from typing import Any, Generator, Optional
|
||||
from .config import (
|
||||
MAP_FILE,
|
||||
WILDCARDS_MAP_FILE,
|
||||
SERVERS_FILE,
|
||||
CERTS_FILE,
|
||||
REMOTE_MODE,
|
||||
logger,
|
||||
)
|
||||
@@ -138,16 +140,13 @@ def _read_file(file_path: str) -> str:
|
||||
|
||||
|
||||
def get_map_contents() -> list[tuple[str, str]]:
|
||||
"""Read both domains.map and wildcards.map and return combined entries.
|
||||
"""Get all domain-to-backend mappings from SQLite.
|
||||
|
||||
Returns:
|
||||
List of (domain, backend) tuples from both map files
|
||||
List of (domain, backend) tuples including wildcards.
|
||||
"""
|
||||
# Read exact domains
|
||||
entries = _read_map_file(MAP_FILE)
|
||||
# Read wildcards and append
|
||||
entries.extend(_read_map_file(WILDCARDS_MAP_FILE))
|
||||
return entries
|
||||
from .db import db_get_map_contents
|
||||
return db_get_map_contents()
|
||||
|
||||
|
||||
def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
||||
@@ -170,44 +169,21 @@ def split_domain_entries(entries: list[tuple[str, str]]) -> tuple[list[tuple[str
|
||||
|
||||
|
||||
def save_map_file(entries: list[tuple[str, str]]) -> None:
|
||||
"""Save domain-to-backend entries to map files.
|
||||
"""Sync map files from the database.
|
||||
|
||||
Splits entries into two files for 2-stage routing:
|
||||
- domains.map: Exact matches (map_str, O(log n))
|
||||
- wildcards.map: Wildcard entries starting with "." (map_dom, O(n))
|
||||
|
||||
Args:
|
||||
entries: List of (domain, backend) tuples.
|
||||
Regenerates domains.map and wildcards.map from the current
|
||||
database state. The entries parameter is ignored (kept for
|
||||
backward compatibility during transition).
|
||||
|
||||
Raises:
|
||||
IOError: If map files cannot be written.
|
||||
"""
|
||||
# Split into exact and wildcard entries
|
||||
exact_entries, wildcard_entries = split_domain_entries(entries)
|
||||
|
||||
# Save exact domains (for map_str - fast O(log n) lookup)
|
||||
exact_lines = [
|
||||
"# Exact Domain to Backend mapping (for map_str)\n",
|
||||
"# Format: domain backend_name\n",
|
||||
"# Uses ebtree for O(log n) lookup performance\n\n",
|
||||
]
|
||||
for domain, backend in sorted(exact_entries):
|
||||
exact_lines.append(f"{domain} {backend}\n")
|
||||
atomic_write_file(MAP_FILE, "".join(exact_lines))
|
||||
|
||||
# Save wildcards (for map_dom - O(n) but small set)
|
||||
wildcard_lines = [
|
||||
"# Wildcard Domain to Backend mapping (for map_dom)\n",
|
||||
"# Format: .domain.com backend_name (matches *.domain.com)\n",
|
||||
"# Uses map_dom for suffix matching\n\n",
|
||||
]
|
||||
for domain, backend in sorted(wildcard_entries):
|
||||
wildcard_lines.append(f"{domain} {backend}\n")
|
||||
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
|
||||
from .db import sync_map_files
|
||||
sync_map_files()
|
||||
|
||||
|
||||
def get_domain_backend(domain: str) -> Optional[str]:
|
||||
"""Look up the backend for a domain from domains.map.
|
||||
"""Look up the backend for a domain from SQLite (O(1)).
|
||||
|
||||
Args:
|
||||
domain: The domain to look up
|
||||
@@ -215,10 +191,8 @@ def get_domain_backend(domain: str) -> Optional[str]:
|
||||
Returns:
|
||||
Backend name if found, None otherwise
|
||||
"""
|
||||
for map_domain, backend in get_map_contents():
|
||||
if map_domain == domain:
|
||||
return backend
|
||||
return None
|
||||
from .db import db_get_domain_backend
|
||||
return db_get_domain_backend(domain)
|
||||
|
||||
|
||||
def is_legacy_backend(backend: str) -> bool:
|
||||
@@ -273,34 +247,17 @@ def get_backend_and_prefix(domain: str) -> tuple[str, str]:
|
||||
|
||||
|
||||
def load_servers_config() -> dict[str, Any]:
|
||||
"""Load servers configuration from JSON file.
|
||||
"""Load servers configuration from SQLite.
|
||||
|
||||
Returns:
|
||||
Dictionary with server configurations
|
||||
Dictionary with server configurations (legacy format compatible).
|
||||
"""
|
||||
try:
|
||||
content = _read_file(SERVERS_FILE)
|
||||
return json.loads(content)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Corrupt config file %s: %s", SERVERS_FILE, e)
|
||||
return {}
|
||||
|
||||
|
||||
def save_servers_config(config: dict[str, Any]) -> None:
|
||||
"""Save servers configuration to JSON file atomically.
|
||||
|
||||
Uses temp file + rename for atomic write to prevent race conditions.
|
||||
|
||||
Args:
|
||||
config: Dictionary with server configurations
|
||||
"""
|
||||
atomic_write_file(SERVERS_FILE, json.dumps(config, indent=2))
|
||||
from .db import db_load_servers_config
|
||||
return db_load_servers_config()
|
||||
|
||||
|
||||
def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> None:
|
||||
"""Add server configuration to persistent storage with file locking.
|
||||
"""Add server configuration to persistent storage.
|
||||
|
||||
Args:
|
||||
domain: Domain name
|
||||
@@ -308,41 +265,29 @@ def add_server_to_config(domain: str, slot: int, ip: str, http_port: int) -> Non
|
||||
ip: Server IP address
|
||||
http_port: HTTP port
|
||||
"""
|
||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
||||
config = load_servers_config()
|
||||
if domain not in config:
|
||||
config[domain] = {}
|
||||
config[domain][str(slot)] = {"ip": ip, "http_port": http_port}
|
||||
save_servers_config(config)
|
||||
from .db import db_add_server
|
||||
db_add_server(domain, slot, ip, http_port)
|
||||
|
||||
|
||||
def remove_server_from_config(domain: str, slot: int) -> None:
|
||||
"""Remove server configuration from persistent storage with file locking.
|
||||
"""Remove server configuration from persistent storage.
|
||||
|
||||
Args:
|
||||
domain: Domain name
|
||||
slot: Server slot to remove
|
||||
"""
|
||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
||||
config = load_servers_config()
|
||||
if domain in config and str(slot) in config[domain]:
|
||||
del config[domain][str(slot)]
|
||||
if not config[domain]:
|
||||
del config[domain]
|
||||
save_servers_config(config)
|
||||
from .db import db_remove_server
|
||||
db_remove_server(domain, slot)
|
||||
|
||||
|
||||
def remove_domain_from_config(domain: str) -> None:
|
||||
"""Remove domain from persistent config with file locking.
|
||||
"""Remove domain from persistent config (servers + domain entry).
|
||||
|
||||
Args:
|
||||
domain: Domain name to remove
|
||||
"""
|
||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
||||
config = load_servers_config()
|
||||
if domain in config:
|
||||
del config[domain]
|
||||
save_servers_config(config)
|
||||
from .db import db_remove_domain_servers
|
||||
db_remove_domain_servers(domain)
|
||||
|
||||
|
||||
def get_shared_domain(domain: str) -> Optional[str]:
|
||||
@@ -354,9 +299,8 @@ def get_shared_domain(domain: str) -> Optional[str]:
|
||||
Returns:
|
||||
Parent domain name if sharing, None otherwise
|
||||
"""
|
||||
config = load_servers_config()
|
||||
domain_config = config.get(domain, {})
|
||||
return domain_config.get("_shares")
|
||||
from .db import db_get_shared_domain
|
||||
return db_get_shared_domain(domain)
|
||||
|
||||
|
||||
def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
|
||||
@@ -366,10 +310,8 @@ def add_shared_domain_to_config(domain: str, shares_with: str) -> None:
|
||||
domain: New domain name
|
||||
shares_with: Existing domain to share pool with
|
||||
"""
|
||||
with file_lock(f"{SERVERS_FILE}.lock"):
|
||||
config = load_servers_config()
|
||||
config[domain] = {"_shares": shares_with}
|
||||
save_servers_config(config)
|
||||
from .db import db_add_shared_domain
|
||||
db_add_shared_domain(domain, shares_with)
|
||||
|
||||
|
||||
def get_domains_sharing_pool(pool: str) -> list[str]:
|
||||
@@ -381,11 +323,8 @@ def get_domains_sharing_pool(pool: str) -> list[str]:
|
||||
Returns:
|
||||
List of domain names using this pool
|
||||
"""
|
||||
domains = []
|
||||
for domain, backend in get_map_contents():
|
||||
if backend == pool and not domain.startswith("."):
|
||||
domains.append(domain)
|
||||
return domains
|
||||
from .db import db_get_domains_sharing_pool
|
||||
return db_get_domains_sharing_pool(pool)
|
||||
|
||||
|
||||
def is_shared_domain(domain: str) -> bool:
|
||||
@@ -397,37 +336,20 @@ def is_shared_domain(domain: str) -> bool:
|
||||
Returns:
|
||||
True if domain has _shares reference, False otherwise
|
||||
"""
|
||||
config = load_servers_config()
|
||||
domain_config = config.get(domain, {})
|
||||
return "_shares" in domain_config
|
||||
from .db import db_is_shared_domain
|
||||
return db_is_shared_domain(domain)
|
||||
|
||||
|
||||
# Certificate configuration functions
|
||||
|
||||
def load_certs_config() -> list[str]:
|
||||
"""Load certificate domain list from JSON file.
|
||||
"""Load certificate domain list from SQLite.
|
||||
|
||||
Returns:
|
||||
List of domain names
|
||||
Sorted list of domain names.
|
||||
"""
|
||||
try:
|
||||
content = _read_file(CERTS_FILE)
|
||||
data = json.loads(content)
|
||||
return data.get("domains", [])
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Corrupt certificates config %s: %s", CERTS_FILE, e)
|
||||
return []
|
||||
|
||||
|
||||
def save_certs_config(domains: list[str]) -> None:
|
||||
"""Save certificate domain list to JSON file atomically.
|
||||
|
||||
Args:
|
||||
domains: List of domain names
|
||||
"""
|
||||
atomic_write_file(CERTS_FILE, json.dumps({"domains": sorted(domains)}, indent=2))
|
||||
from .db import db_load_certs
|
||||
return db_load_certs()
|
||||
|
||||
|
||||
def add_cert_to_config(domain: str) -> None:
|
||||
@@ -436,11 +358,8 @@ def add_cert_to_config(domain: str) -> None:
|
||||
Args:
|
||||
domain: Domain name to add
|
||||
"""
|
||||
with file_lock(f"{CERTS_FILE}.lock"):
|
||||
domains = load_certs_config()
|
||||
if domain not in domains:
|
||||
domains.append(domain)
|
||||
save_certs_config(domains)
|
||||
from .db import db_add_cert
|
||||
db_add_cert(domain)
|
||||
|
||||
|
||||
def remove_cert_from_config(domain: str) -> None:
|
||||
@@ -449,8 +368,45 @@ def remove_cert_from_config(domain: str) -> None:
|
||||
Args:
|
||||
domain: Domain name to remove
|
||||
"""
|
||||
with file_lock(f"{CERTS_FILE}.lock"):
|
||||
domains = load_certs_config()
|
||||
if domain in domains:
|
||||
domains.remove(domain)
|
||||
save_certs_config(domains)
|
||||
from .db import db_remove_cert
|
||||
db_remove_cert(domain)
|
||||
|
||||
|
||||
# Domain map helper functions (used by domains.py)
|
||||
|
||||
def add_domain_to_map(domain: str, backend: str, is_wildcard: bool = False,
|
||||
shares_with: Optional[str] = None) -> None:
|
||||
"""Add a domain to SQLite and sync map files.
|
||||
|
||||
Args:
|
||||
domain: Domain name (e.g., "example.com").
|
||||
backend: Backend pool name (e.g., "pool_5").
|
||||
is_wildcard: Whether this is a wildcard entry.
|
||||
shares_with: Parent domain if sharing a pool.
|
||||
"""
|
||||
from .db import db_add_domain, sync_map_files
|
||||
db_add_domain(domain, backend, is_wildcard, shares_with)
|
||||
sync_map_files()
|
||||
|
||||
|
||||
def remove_domain_from_map(domain: str) -> None:
|
||||
"""Remove a domain (exact + wildcard) from SQLite and sync map files.
|
||||
|
||||
Args:
|
||||
domain: Base domain name (without leading dot).
|
||||
"""
|
||||
from .db import db_remove_domain, sync_map_files
|
||||
db_remove_domain(domain)
|
||||
sync_map_files()
|
||||
|
||||
|
||||
def find_available_pool() -> Optional[str]:
|
||||
"""Find the first available pool not assigned to any domain.
|
||||
|
||||
Uses SQLite query for O(1) lookup vs previous O(n) list scan.
|
||||
|
||||
Returns:
|
||||
Pool name (e.g., "pool_5") or None if all pools are in use.
|
||||
"""
|
||||
from .db import db_find_available_pool
|
||||
return db_find_available_pool()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import socket
|
||||
import select
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from .config import (
|
||||
@@ -161,7 +162,7 @@ def reload_haproxy() -> tuple[bool, str]:
|
||||
if result.returncode != 0:
|
||||
return False, f"Reload failed: {result.stderr}"
|
||||
return True, "OK"
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return False, f"Command timed out after {SUBPROCESS_TIMEOUT} seconds"
|
||||
except FileNotFoundError:
|
||||
return False, "ssh/podman command not found"
|
||||
|
||||
@@ -21,6 +21,7 @@ Environment Variables:
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .config import MCP_HOST, MCP_PORT
|
||||
from .db import init_db
|
||||
from .tools import register_all_tools
|
||||
from .tools.configuration import startup_restore
|
||||
|
||||
@@ -32,5 +33,6 @@ register_all_tools(mcp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db()
|
||||
startup_restore()
|
||||
mcp.run(transport="streamable-http")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Certificate management tools for HAProxy MCP Server."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
@@ -152,7 +153,7 @@ def _haproxy_list_certs_impl() -> str:
|
||||
certs.append(f"• {domain} ({ca})\n Created: {created}\n Renew: {renew}\n Status: {status}")
|
||||
|
||||
return "\n\n".join(certs) if certs else "No certificates found"
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return "Error: Command timed out"
|
||||
except FileNotFoundError:
|
||||
return "Error: acme.sh not found"
|
||||
@@ -203,7 +204,7 @@ def _haproxy_cert_info_impl(domain: str) -> str:
|
||||
result.stdout.strip()
|
||||
]
|
||||
return "\n".join(info)
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return "Error: Command timed out"
|
||||
except OSError as e:
|
||||
logger.error("Error getting certificate info for %s: %s", domain, e)
|
||||
@@ -250,7 +251,7 @@ def _haproxy_issue_cert_impl(domain: str, wildcard: bool) -> str:
|
||||
else:
|
||||
return f"Certificate issued but PEM file not created. Check {host_path}"
|
||||
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return f"Error: Certificate issuance timed out after {CERT_TIMEOUT}s"
|
||||
except OSError as e:
|
||||
logger.error("Error issuing certificate for %s: %s", domain, e)
|
||||
@@ -289,7 +290,7 @@ def _haproxy_renew_cert_impl(domain: str, force: bool) -> str:
|
||||
else:
|
||||
return f"Error renewing certificate:\n{output}"
|
||||
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return f"Error: Certificate renewal timed out after {CERT_TIMEOUT}s"
|
||||
except FileNotFoundError:
|
||||
return "Error: acme.sh not found"
|
||||
@@ -323,7 +324,7 @@ def _haproxy_renew_all_certs_impl() -> str:
|
||||
else:
|
||||
return "Renewal check completed"
|
||||
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return "Error: Renewal cron timed out"
|
||||
except FileNotFoundError:
|
||||
return "Error: acme.sh not found"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Configuration management tools for HAProxy MCP Server."""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from ..config import (
|
||||
@@ -177,7 +178,7 @@ def register_config_tools(mcp):
|
||||
if result.returncode == 0:
|
||||
return "Configuration is valid"
|
||||
return f"Configuration errors:\n{result.stderr}"
|
||||
except TimeoutError:
|
||||
except (TimeoutError, subprocess.TimeoutExpired):
|
||||
return f"Error: Command timed out after {SUBPROCESS_TIMEOUT} seconds"
|
||||
except FileNotFoundError:
|
||||
return "Error: ssh/podman command not found"
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Annotated, Optional
|
||||
from pydantic import Field
|
||||
|
||||
from ..config import (
|
||||
MAP_FILE,
|
||||
MAP_FILE_CONTAINER,
|
||||
WILDCARDS_MAP_FILE_CONTAINER,
|
||||
POOL_COUNT,
|
||||
@@ -22,7 +21,6 @@ from ..validation import validate_domain, validate_ip, validate_port_int
|
||||
from ..haproxy_client import haproxy_cmd
|
||||
from ..file_ops import (
|
||||
get_map_contents,
|
||||
save_map_file,
|
||||
get_domain_backend,
|
||||
is_legacy_backend,
|
||||
add_server_to_config,
|
||||
@@ -31,30 +29,13 @@ from ..file_ops import (
|
||||
add_shared_domain_to_config,
|
||||
get_domains_sharing_pool,
|
||||
is_shared_domain,
|
||||
add_domain_to_map,
|
||||
remove_domain_from_map,
|
||||
find_available_pool,
|
||||
)
|
||||
from ..utils import parse_servers_state, disable_server_slot
|
||||
|
||||
|
||||
def _find_available_pool(entries: list[tuple[str, str]], used_pools: set[str]) -> Optional[str]:
|
||||
"""Find an available pool backend from the pool list.
|
||||
|
||||
Iterates through pool_1 to pool_N and returns the first pool
|
||||
that is not currently in use.
|
||||
|
||||
Args:
|
||||
entries: List of (domain, backend) tuples from the map file.
|
||||
used_pools: Set of pool names already in use.
|
||||
|
||||
Returns:
|
||||
Available pool name (e.g., "pool_5") or None if all pools are in use.
|
||||
"""
|
||||
for i in range(1, POOL_COUNT + 1):
|
||||
pool_name = f"pool_{i}"
|
||||
if pool_name not in used_pools:
|
||||
return pool_name
|
||||
return None
|
||||
|
||||
|
||||
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
|
||||
"""Check if a domain is a subdomain of an existing registered domain.
|
||||
|
||||
@@ -95,24 +76,19 @@ def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
|
||||
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
|
||||
|
||||
|
||||
def _rollback_domain_addition(
|
||||
domain: str,
|
||||
entries: list[tuple[str, str]]
|
||||
) -> None:
|
||||
"""Rollback a failed domain addition by removing entries from map file.
|
||||
def _rollback_domain_addition(domain: str) -> None:
|
||||
"""Rollback a failed domain addition by removing from SQLite + map files.
|
||||
|
||||
Called when HAProxy Runtime API update fails after the map file
|
||||
has already been saved.
|
||||
Called when HAProxy Runtime API update fails after the domain
|
||||
has already been saved to the database.
|
||||
|
||||
Args:
|
||||
domain: Domain name that was added.
|
||||
entries: Current list of map entries to rollback from.
|
||||
"""
|
||||
rollback_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
||||
try:
|
||||
save_map_file(rollback_entries)
|
||||
except IOError:
|
||||
logger.error("Failed to rollback map file after HAProxy error")
|
||||
remove_domain_from_map(domain)
|
||||
except (IOError, Exception):
|
||||
logger.error("Failed to rollback domain %s after HAProxy error", domain)
|
||||
|
||||
|
||||
def _file_exists(path: str) -> bool:
|
||||
@@ -242,93 +218,86 @@ def register_domain_tools(mcp):
|
||||
if share_with and ip:
|
||||
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
|
||||
|
||||
# Use file locking for the entire pool allocation operation
|
||||
from ..file_ops import file_lock
|
||||
with file_lock(f"{MAP_FILE}.lock"):
|
||||
# Read map contents once for both existence check and pool lookup
|
||||
entries = get_map_contents()
|
||||
# Read current entries for existence check and subdomain detection
|
||||
entries = get_map_contents()
|
||||
|
||||
# Check if domain already exists (using cached entries)
|
||||
for domain_entry, backend in entries:
|
||||
if domain_entry == domain:
|
||||
return f"Error: Domain {domain} already exists (mapped to {backend})"
|
||||
# Check if domain already exists
|
||||
for domain_entry, backend in entries:
|
||||
if domain_entry == domain:
|
||||
return f"Error: Domain {domain} already exists (mapped to {backend})"
|
||||
|
||||
# Build used pools and registered domains sets
|
||||
used_pools: set[str] = set()
|
||||
registered_domains: set[str] = set()
|
||||
for entry_domain, backend in entries:
|
||||
if backend.startswith("pool_"):
|
||||
used_pools.add(backend)
|
||||
if not entry_domain.startswith("."):
|
||||
registered_domains.add(entry_domain)
|
||||
# Build registered domains set for subdomain check
|
||||
registered_domains: set[str] = set()
|
||||
for entry_domain, _ in entries:
|
||||
if not entry_domain.startswith("."):
|
||||
registered_domains.add(entry_domain)
|
||||
|
||||
# Handle share_with: reuse existing domain's pool
|
||||
if share_with:
|
||||
share_backend = get_domain_backend(share_with)
|
||||
if not share_backend:
|
||||
return f"Error: Domain {share_with} not found"
|
||||
if not share_backend.startswith("pool_"):
|
||||
return f"Error: Cannot share with legacy backend {share_backend}"
|
||||
pool = share_backend
|
||||
else:
|
||||
# Find available pool
|
||||
pool = _find_available_pool(entries, used_pools)
|
||||
if not pool:
|
||||
return f"Error: All {POOL_COUNT} pool backends are in use"
|
||||
# Handle share_with: reuse existing domain's pool
|
||||
if share_with:
|
||||
share_backend = get_domain_backend(share_with)
|
||||
if not share_backend:
|
||||
return f"Error: Domain {share_with} not found"
|
||||
if not share_backend.startswith("pool_"):
|
||||
return f"Error: Cannot share with legacy backend {share_backend}"
|
||||
pool = share_backend
|
||||
else:
|
||||
# Find available pool (SQLite query, O(1))
|
||||
pool = find_available_pool()
|
||||
if not pool:
|
||||
return f"Error: All {POOL_COUNT} pool backends are in use"
|
||||
|
||||
# Check if this is a subdomain of an existing domain
|
||||
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
|
||||
# Check if this is a subdomain of an existing domain
|
||||
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
|
||||
|
||||
try:
|
||||
# Save to SQLite + sync map files (atomic via SQLite transaction)
|
||||
try:
|
||||
# Save to disk first (atomic write for persistence)
|
||||
entries.append((domain, pool))
|
||||
add_domain_to_map(domain, pool)
|
||||
if not is_subdomain:
|
||||
entries.append((f".{domain}", pool))
|
||||
try:
|
||||
save_map_file(entries)
|
||||
except IOError as e:
|
||||
return f"Error: Failed to save map file: {e}"
|
||||
|
||||
# Update HAProxy maps via Runtime API
|
||||
try:
|
||||
_update_haproxy_maps(domain, pool, is_subdomain)
|
||||
except HaproxyError as e:
|
||||
_rollback_domain_addition(domain, entries)
|
||||
return f"Error: Failed to update HAProxy map: {e}"
|
||||
|
||||
# Handle server configuration based on mode
|
||||
if share_with:
|
||||
# Save shared domain reference
|
||||
add_shared_domain_to_config(domain, share_with)
|
||||
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
|
||||
elif ip:
|
||||
# Add server to slot 1
|
||||
add_server_to_config(domain, 1, ip, http_port)
|
||||
try:
|
||||
server = f"{pool}_1"
|
||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
||||
haproxy_cmd(f"set server {pool}/{server} state ready")
|
||||
except HaproxyError as e:
|
||||
remove_server_from_config(domain, 1)
|
||||
return f"Domain {domain} added to {pool} but server config failed: {e}"
|
||||
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
|
||||
else:
|
||||
result = f"Domain {domain} added to {pool} (no servers configured)"
|
||||
|
||||
if is_subdomain:
|
||||
result += f" (subdomain of {parent_domain}, no wildcard)"
|
||||
|
||||
# Check certificate coverage
|
||||
cert_covered, cert_info = check_certificate_coverage(domain)
|
||||
if cert_covered:
|
||||
result += f"\nSSL: Using certificate {cert_info}"
|
||||
else:
|
||||
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
|
||||
|
||||
return result
|
||||
add_domain_to_map(f".{domain}", pool, is_wildcard=True)
|
||||
except (IOError, Exception) as e:
|
||||
return f"Error: Failed to save domain: {e}"
|
||||
|
||||
# Update HAProxy maps via Runtime API
|
||||
try:
|
||||
_update_haproxy_maps(domain, pool, is_subdomain)
|
||||
except HaproxyError as e:
|
||||
return f"Error: {e}"
|
||||
_rollback_domain_addition(domain)
|
||||
return f"Error: Failed to update HAProxy map: {e}"
|
||||
|
||||
# Handle server configuration based on mode
|
||||
if share_with:
|
||||
# Save shared domain reference
|
||||
add_shared_domain_to_config(domain, share_with)
|
||||
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
|
||||
elif ip:
|
||||
# Add server to slot 1
|
||||
add_server_to_config(domain, 1, ip, http_port)
|
||||
try:
|
||||
server = f"{pool}_1"
|
||||
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
|
||||
haproxy_cmd(f"set server {pool}/{server} state ready")
|
||||
except HaproxyError as e:
|
||||
remove_server_from_config(domain, 1)
|
||||
return f"Domain {domain} added to {pool} but server config failed: {e}"
|
||||
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
|
||||
else:
|
||||
result = f"Domain {domain} added to {pool} (no servers configured)"
|
||||
|
||||
if is_subdomain:
|
||||
result += f" (subdomain of {parent_domain}, no wildcard)"
|
||||
|
||||
# Check certificate coverage
|
||||
cert_covered, cert_info = check_certificate_coverage(domain)
|
||||
if cert_covered:
|
||||
result += f"\nSSL: Using certificate {cert_info}"
|
||||
else:
|
||||
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
|
||||
|
||||
return result
|
||||
|
||||
except HaproxyError as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
@mcp.tool()
|
||||
def haproxy_remove_domain(
|
||||
@@ -355,10 +324,8 @@ def register_domain_tools(mcp):
|
||||
domains_using_pool = get_domains_sharing_pool(backend)
|
||||
other_domains = [d for d in domains_using_pool if d != domain]
|
||||
|
||||
# Save to disk first (atomic write for persistence)
|
||||
entries = get_map_contents()
|
||||
new_entries = [(d, b) for d, b in entries if d != domain and d != f".{domain}"]
|
||||
save_map_file(new_entries)
|
||||
# Remove from SQLite + sync map files
|
||||
remove_domain_from_map(domain)
|
||||
|
||||
# Remove from persistent server config
|
||||
remove_domain_from_config(domain)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import Field
|
||||
from ..config import (
|
||||
MAP_FILE,
|
||||
SERVERS_FILE,
|
||||
DB_FILE,
|
||||
HAPROXY_CONTAINER,
|
||||
)
|
||||
from ..exceptions import HaproxyError
|
||||
@@ -88,7 +89,7 @@ def register_health_tools(mcp):
|
||||
# Check configuration files
|
||||
files_ok = True
|
||||
file_status: dict[str, str] = {}
|
||||
for name, path in [("map_file", MAP_FILE), ("servers_file", SERVERS_FILE)]:
|
||||
for name, path in [("map_file", MAP_FILE), ("db_file", DB_FILE)]:
|
||||
exists = remote_file_exists(path) if REMOTE_MODE else __import__('os').path.exists(path)
|
||||
if exists:
|
||||
file_status[name] = "ok"
|
||||
|
||||
Reference in New Issue
Block a user