"""SQLite database operations for HAProxy MCP Server. Single source of truth for all persistent configuration data. Replaces JSON files (servers.json, certificates.json) and map files (domains.map, wildcards.map) with a single SQLite database. Map files are generated from the database via sync_map_files() for HAProxy to consume directly. """ import json import os import sqlite3 import threading from typing import Any, Optional from .config import ( DB_FILE, MAP_FILE, WILDCARDS_MAP_FILE, SERVERS_FILE, CERTS_FILE, POOL_COUNT, REMOTE_MODE, logger, ) SCHEMA_VERSION = 1 _local = threading.local() def get_connection() -> sqlite3.Connection: """Get a thread-local SQLite connection with WAL mode. Returns: sqlite3.Connection configured for concurrent access. """ conn = getattr(_local, "connection", None) if conn is not None: return conn conn = sqlite3.connect(DB_FILE, timeout=10) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA busy_timeout=5000") conn.execute("PRAGMA foreign_keys=ON") conn.row_factory = sqlite3.Row _local.connection = conn return conn def close_connection() -> None: """Close the thread-local SQLite connection.""" conn = getattr(_local, "connection", None) if conn is not None: conn.close() _local.connection = None def init_db() -> None: """Initialize database schema and run migration if needed. Creates tables if they don't exist, then checks for existing JSON/map files to migrate data from. """ # Ensure parent directory exists for the database file db_dir = os.path.dirname(DB_FILE) if db_dir: os.makedirs(db_dir, exist_ok=True) conn = get_connection() cur = conn.cursor() cur.executescript(""" CREATE TABLE IF NOT EXISTS domains ( id INTEGER PRIMARY KEY AUTOINCREMENT, domain TEXT NOT NULL UNIQUE, backend TEXT NOT NULL, is_wildcard INTEGER NOT NULL DEFAULT 0, shares_with TEXT DEFAULT NULL, created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) ); CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend); CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with); CREATE TABLE IF NOT EXISTS servers ( id INTEGER PRIMARY KEY AUTOINCREMENT, domain TEXT NOT NULL, slot INTEGER NOT NULL, ip TEXT NOT NULL, http_port INTEGER NOT NULL DEFAULT 80, created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), UNIQUE(domain, slot) ); CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain); CREATE TABLE IF NOT EXISTS certificates ( id INTEGER PRIMARY KEY AUTOINCREMENT, domain TEXT NOT NULL UNIQUE, created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) ); CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER NOT NULL, applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) ); """) conn.commit() # Check schema version cur.execute("SELECT MAX(version) FROM schema_version") row = cur.fetchone() current_version = row[0] if row[0] is not None else 0 if current_version < SCHEMA_VERSION: # First-time setup: try migrating from JSON files if current_version == 0: migrate_from_json() cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,)) conn.commit() logger.info("Database initialized (schema v%d)", SCHEMA_VERSION) def migrate_from_json() -> None: """Migrate data from JSON/map files to SQLite. Reads domains.map, wildcards.map, servers.json, and certificates.json, imports their data into the database within a single transaction, then renames the original files to .bak. This function is idempotent (uses INSERT OR IGNORE). """ conn = get_connection() # Collect data from existing files map_entries: list[tuple[str, str]] = [] servers_config: dict[str, Any] = {} cert_domains: list[str] = [] # Read map files for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]: try: content = _read_file_for_migration(map_path) for line in content.splitlines(): line = line.strip() if not line or line.startswith("#"): continue parts = line.split() if len(parts) >= 2: map_entries.append((parts[0], parts[1])) except FileNotFoundError: logger.debug("Map file not found for migration: %s", map_path) # Read servers.json try: content = _read_file_for_migration(SERVERS_FILE) servers_config = json.loads(content) except FileNotFoundError: logger.debug("Servers config not found for migration: %s", SERVERS_FILE) except json.JSONDecodeError as e: logger.warning("Corrupt servers config during migration: %s", e) # Read certificates.json try: content = _read_file_for_migration(CERTS_FILE) data = json.loads(content) cert_domains = data.get("domains", []) except FileNotFoundError: logger.debug("Certs config not found for migration: %s", CERTS_FILE) except json.JSONDecodeError as e: logger.warning("Corrupt certs config during migration: %s", e) if not map_entries and not servers_config and not cert_domains: logger.info("No existing data to migrate") return # Import into database in a single transaction try: with conn: # Import domains from map files for domain, backend in map_entries: is_wildcard = 1 if domain.startswith(".") else 0 conn.execute( "INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)", (domain, backend, is_wildcard), ) # Import servers and shares_with from servers.json for domain, slots in servers_config.items(): shares_with = slots.get("_shares") if shares_with: # Update domain's shares_with field conn.execute( "UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0", (shares_with, domain), ) continue for slot_str, server_info in slots.items(): if slot_str.startswith("_"): continue try: slot = int(slot_str) ip = server_info.get("ip", "") http_port = int(server_info.get("http_port", 80)) if ip: conn.execute( "INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)", (domain, slot, ip, http_port), ) except (ValueError, TypeError) as e: logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e) # Import certificates for cert_domain in cert_domains: conn.execute( "INSERT OR IGNORE INTO certificates (domain) VALUES (?)", (cert_domain,), ) migrated_domains = len(set(d for d, _ in map_entries)) migrated_servers = sum( 1 for d, slots in servers_config.items() if "_shares" not in slots for s in slots if not s.startswith("_") ) migrated_certs = len(cert_domains) logger.info( "Migrated from JSON: %d domain entries, %d servers, %d certificates", migrated_domains, migrated_servers, migrated_certs, ) # Rename original files to .bak for path in [SERVERS_FILE, CERTS_FILE]: _backup_file(path) except sqlite3.Error as e: logger.error("Migration failed: %s", e) raise def _read_file_for_migration(file_path: str) -> str: """Read a file for migration purposes (local only, no locking needed).""" if REMOTE_MODE: from .ssh_ops import remote_read_file return remote_read_file(file_path) with open(file_path, "r", encoding="utf-8") as f: return f.read() def _backup_file(file_path: str) -> None: """Rename a file to .bak if it exists.""" if REMOTE_MODE: return # Don't rename remote files try: if os.path.exists(file_path): bak_path = f"{file_path}.bak" os.rename(file_path, bak_path) logger.info("Backed up %s -> %s", file_path, bak_path) except OSError as e: logger.warning("Failed to backup %s: %s", file_path, e) # --- Domain data access --- def db_get_map_contents() -> list[tuple[str, str]]: """Get all domain-to-backend mappings. Returns: List of (domain, backend) tuples including wildcards. """ conn = get_connection() cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain") return [(row["domain"], row["backend"]) for row in cur.fetchall()] def db_get_domain_backend(domain: str) -> Optional[str]: """Look up the backend for a specific domain. Args: domain: Domain name to look up. Returns: Backend name if found, None otherwise. """ conn = get_connection() cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,)) row = cur.fetchone() return row["backend"] if row else None def db_add_domain(domain: str, backend: str, is_wildcard: bool = False, shares_with: Optional[str] = None) -> None: """Add a domain to the database. Args: domain: Domain name (e.g., "example.com" or ".example.com"). backend: Backend pool name (e.g., "pool_5"). is_wildcard: Whether this is a wildcard entry. shares_with: Parent domain if sharing a pool. """ conn = get_connection() conn.execute( """INSERT INTO domains (domain, backend, is_wildcard, shares_with) VALUES (?, ?, ?, ?) ON CONFLICT(domain) DO UPDATE SET backend = excluded.backend, shares_with = excluded.shares_with, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""", (domain, backend, 1 if is_wildcard else 0, shares_with), ) conn.commit() def db_remove_domain(domain: str) -> None: """Remove a domain (exact + wildcard) from the database. Args: domain: Base domain name (without leading dot). """ conn = get_connection() conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}")) conn.commit() def db_find_available_pool() -> Optional[str]: """Find the first available pool not assigned to any domain. Returns: Pool name (e.g., "pool_5") or None if all pools are in use. """ conn = get_connection() cur = conn.execute( "SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0" ) used_pools = {row["backend"] for row in cur.fetchall()} for i in range(1, POOL_COUNT + 1): pool_name = f"pool_{i}" if pool_name not in used_pools: return pool_name return None def db_get_domains_sharing_pool(pool: str) -> list[str]: """Get all non-wildcard domains using a specific pool. Args: pool: Pool name (e.g., "pool_5"). Returns: List of domain names. """ conn = get_connection() cur = conn.execute( "SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0", (pool,), ) return [row["domain"] for row in cur.fetchall()] # --- Server data access --- def db_load_servers_config() -> dict[str, Any]: """Load servers configuration in the legacy dict format. Returns: Dictionary compatible with the old servers.json format: {"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}} """ conn = get_connection() config: dict[str, Any] = {} # Load server entries cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot") for row in cur.fetchall(): domain = row["domain"] if domain not in config: config[domain] = {} config[domain][str(row["slot"])] = { "ip": row["ip"], "http_port": row["http_port"], } # Load shared domain references cur = conn.execute( "SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0" ) for row in cur.fetchall(): domain = row["domain"] if domain not in config: config[domain] = {} config[domain]["_shares"] = row["shares_with"] return config def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None: """Add or update a server entry. Args: domain: Domain name. slot: Server slot number. ip: Server IP address. http_port: HTTP port. """ conn = get_connection() conn.execute( """INSERT INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?) ON CONFLICT(domain, slot) DO UPDATE SET ip = excluded.ip, http_port = excluded.http_port, updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""", (domain, slot, ip, http_port), ) conn.commit() def db_remove_server(domain: str, slot: int) -> None: """Remove a server entry. Args: domain: Domain name. slot: Server slot number. """ conn = get_connection() conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot)) conn.commit() def db_remove_domain_servers(domain: str) -> None: """Remove all server entries for a domain. Args: domain: Domain name. """ conn = get_connection() conn.execute("DELETE FROM servers WHERE domain = ?", (domain,)) conn.commit() def db_add_shared_domain(domain: str, shares_with: str) -> None: """Mark a domain as sharing a pool with another domain. Updates the shares_with field in the domains table. Args: domain: The new domain. shares_with: The existing domain to share with. """ conn = get_connection() conn.execute( "UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0", (shares_with, domain), ) conn.commit() def db_get_shared_domain(domain: str) -> Optional[str]: """Get the parent domain that this domain shares a pool with. Args: domain: Domain name. Returns: Parent domain name if sharing, None otherwise. """ conn = get_connection() cur = conn.execute( "SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0", (domain,), ) row = cur.fetchone() if row and row["shares_with"]: return row["shares_with"] return None def db_is_shared_domain(domain: str) -> bool: """Check if a domain is sharing another domain's pool. Args: domain: Domain name. Returns: True if domain has shares_with reference. """ return db_get_shared_domain(domain) is not None # --- Certificate data access --- def db_load_certs() -> list[str]: """Load certificate domain list. Returns: Sorted list of domain names with certificates. """ conn = get_connection() cur = conn.execute("SELECT domain FROM certificates ORDER BY domain") return [row["domain"] for row in cur.fetchall()] def db_add_cert(domain: str) -> None: """Add a domain to the certificate list. Args: domain: Domain name. """ conn = get_connection() conn.execute( "INSERT OR IGNORE INTO certificates (domain) VALUES (?)", (domain,), ) conn.commit() def db_remove_cert(domain: str) -> None: """Remove a domain from the certificate list. Args: domain: Domain name. """ conn = get_connection() conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,)) conn.commit() # --- Map file synchronization --- def sync_map_files() -> None: """Regenerate HAProxy map files from database. Writes domains.map (exact matches) and wildcards.map (wildcard entries) from the current database state. Uses atomic writes. """ from .file_ops import atomic_write_file conn = get_connection() # Fetch exact domains cur = conn.execute( "SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain" ) exact_entries = cur.fetchall() # Fetch wildcard domains cur = conn.execute( "SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain" ) wildcard_entries = cur.fetchall() # Write exact domains map exact_lines = [ "# Exact Domain to Backend mapping (for map_str)\n", "# Format: domain backend_name\n", "# Uses ebtree for O(log n) lookup performance\n\n", ] for row in exact_entries: exact_lines.append(f"{row['domain']} {row['backend']}\n") atomic_write_file(MAP_FILE, "".join(exact_lines)) # Write wildcards map wildcard_lines = [ "# Wildcard Domain to Backend mapping (for map_dom)\n", "# Format: .domain.com backend_name (matches *.domain.com)\n", "# Uses map_dom for suffix matching\n\n", ] for row in wildcard_entries: wildcard_lines.append(f"{row['domain']} {row['backend']}\n") atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines)) logger.debug("Map files synced: %d exact, %d wildcard entries", len(exact_entries), len(wildcard_entries))