Files
kappa 2e22a5d5a8 Add haproxy_cleanup_wildcards tool to remove subdomain wildcard entries
New db_remove_wildcard() for surgical wildcard-only deletion.
New haproxy_cleanup_wildcards tool scans all wildcard entries and removes
those belonging to subdomains (e.g., *.nocodb.inouter.com) while keeping
base domain wildcards (e.g., *.anvil.it.com).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 20:36:59 +09:00

645 lines
20 KiB
Python

"""SQLite database operations for HAProxy MCP Server.
Single source of truth for all persistent configuration data.
Replaces JSON files (servers.json, certificates.json) and map files
(domains.map, wildcards.map) with a single SQLite database.
Map files are generated from the database via sync_map_files() for
HAProxy to consume directly.
"""
import json
import os
import sqlite3
import threading
from typing import Any, Optional
from .config import (
DB_FILE,
REMOTE_DB_FILE,
MAP_FILE,
WILDCARDS_MAP_FILE,
SERVERS_FILE,
CERTS_FILE,
POOL_COUNT,
REMOTE_MODE,
logger,
)
SCHEMA_VERSION = 1
_local = threading.local()
def get_connection() -> sqlite3.Connection:
"""Get a thread-local SQLite connection with WAL mode.
Returns:
sqlite3.Connection configured for concurrent access.
"""
conn = getattr(_local, "connection", None)
if conn is not None:
return conn
conn = sqlite3.connect(DB_FILE, timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
_local.connection = conn
return conn
def close_connection() -> None:
"""Close the thread-local SQLite connection."""
conn = getattr(_local, "connection", None)
if conn is not None:
conn.close()
_local.connection = None
def init_db() -> None:
"""Initialize database schema and run migration if needed.
In REMOTE_MODE, tries to download existing DB from the remote host first.
If no remote DB exists, creates a new one and migrates from JSON files.
"""
# Ensure parent directory exists for the database file
db_dir = os.path.dirname(DB_FILE)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
# In REMOTE_MODE, try to restore DB from remote host
if REMOTE_MODE:
_try_download_remote_db()
conn = get_connection()
cur = conn.cursor()
cur.executescript("""
CREATE TABLE IF NOT EXISTS domains (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
backend TEXT NOT NULL,
is_wildcard INTEGER NOT NULL DEFAULT 0,
shares_with TEXT DEFAULT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE INDEX IF NOT EXISTS idx_domains_backend ON domains(backend);
CREATE INDEX IF NOT EXISTS idx_domains_shares_with ON domains(shares_with);
CREATE TABLE IF NOT EXISTS servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL,
slot INTEGER NOT NULL,
ip TEXT NOT NULL,
http_port INTEGER NOT NULL DEFAULT 80,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
UNIQUE(domain, slot)
);
CREATE INDEX IF NOT EXISTS idx_servers_domain ON servers(domain);
CREATE TABLE IF NOT EXISTS certificates (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL,
applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
""")
conn.commit()
# Check schema version
cur.execute("SELECT MAX(version) FROM schema_version")
row = cur.fetchone()
current_version = row[0] if row[0] is not None else 0
if current_version < SCHEMA_VERSION:
# First-time setup: try migrating from JSON files
if current_version == 0:
migrate_from_json()
cur.execute("INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,))
conn.commit()
# Upload newly created DB to remote for persistence
if REMOTE_MODE:
sync_db_to_remote()
logger.info("Database initialized (schema v%d)", SCHEMA_VERSION)
def _try_download_remote_db() -> None:
"""Try to download existing DB from remote host.
If the remote DB exists, downloads it to the local DB_FILE path.
If not, does nothing (init_db will create a fresh DB).
"""
from .ssh_ops import remote_download_file, remote_file_exists
if remote_file_exists(REMOTE_DB_FILE):
if remote_download_file(REMOTE_DB_FILE, DB_FILE):
logger.info("Downloaded remote DB from %s", REMOTE_DB_FILE)
else:
logger.warning("Failed to download remote DB, will create new")
else:
logger.info("No remote DB found at %s, will create new", REMOTE_DB_FILE)
def migrate_from_json() -> None:
"""Migrate data from JSON/map files to SQLite.
Reads domains.map, wildcards.map, servers.json, and certificates.json,
imports their data into the database within a single transaction,
then renames the original files to .bak.
This function is idempotent (uses INSERT OR IGNORE).
"""
conn = get_connection()
# Collect data from existing files
map_entries: list[tuple[str, str]] = []
servers_config: dict[str, Any] = {}
cert_domains: list[str] = []
# Read map files
for map_path in [MAP_FILE, WILDCARDS_MAP_FILE]:
try:
content = _read_file_for_migration(map_path)
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) >= 2:
map_entries.append((parts[0], parts[1]))
except FileNotFoundError:
logger.debug("Map file not found for migration: %s", map_path)
# Read servers.json
try:
content = _read_file_for_migration(SERVERS_FILE)
servers_config = json.loads(content)
except FileNotFoundError:
logger.debug("Servers config not found for migration: %s", SERVERS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt servers config during migration: %s", e)
# Read certificates.json
try:
content = _read_file_for_migration(CERTS_FILE)
data = json.loads(content)
cert_domains = data.get("domains", [])
except FileNotFoundError:
logger.debug("Certs config not found for migration: %s", CERTS_FILE)
except json.JSONDecodeError as e:
logger.warning("Corrupt certs config during migration: %s", e)
if not map_entries and not servers_config and not cert_domains:
logger.info("No existing data to migrate")
return
# Import into database in a single transaction
try:
with conn:
# Import domains from map files
for domain, backend in map_entries:
is_wildcard = 1 if domain.startswith(".") else 0
conn.execute(
"INSERT OR IGNORE INTO domains (domain, backend, is_wildcard) VALUES (?, ?, ?)",
(domain, backend, is_wildcard),
)
# Import servers and shares_with from servers.json
for domain, slots in servers_config.items():
shares_with = slots.get("_shares")
if shares_with:
# Update domain's shares_with field
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
continue
for slot_str, server_info in slots.items():
if slot_str.startswith("_"):
continue
try:
slot = int(slot_str)
ip = server_info.get("ip", "")
http_port = int(server_info.get("http_port", 80))
if ip:
conn.execute(
"INSERT OR IGNORE INTO servers (domain, slot, ip, http_port) VALUES (?, ?, ?, ?)",
(domain, slot, ip, http_port),
)
except (ValueError, TypeError) as e:
logger.warning("Skipping invalid server entry %s/%s: %s", domain, slot_str, e)
# Import certificates
for cert_domain in cert_domains:
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(cert_domain,),
)
migrated_domains = len(set(d for d, _ in map_entries))
migrated_servers = sum(
1 for d, slots in servers_config.items()
if "_shares" not in slots
for s in slots
if not s.startswith("_")
)
migrated_certs = len(cert_domains)
logger.info(
"Migrated from JSON: %d domain entries, %d servers, %d certificates",
migrated_domains, migrated_servers, migrated_certs,
)
# Rename original files to .bak
for path in [SERVERS_FILE, CERTS_FILE]:
_backup_file(path)
except sqlite3.Error as e:
logger.error("Migration failed: %s", e)
raise
def _read_file_for_migration(file_path: str) -> str:
"""Read a file for migration purposes (local only, no locking needed)."""
if REMOTE_MODE:
from .ssh_ops import remote_read_file
return remote_read_file(file_path)
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def _backup_file(file_path: str) -> None:
"""Rename a file to .bak if it exists."""
if REMOTE_MODE:
return # Don't rename remote files
try:
if os.path.exists(file_path):
bak_path = f"{file_path}.bak"
os.rename(file_path, bak_path)
logger.info("Backed up %s -> %s", file_path, bak_path)
except OSError as e:
logger.warning("Failed to backup %s: %s", file_path, e)
# --- Domain data access ---
def db_get_map_contents() -> list[tuple[str, str]]:
"""Get all domain-to-backend mappings.
Returns:
List of (domain, backend) tuples including wildcards.
"""
conn = get_connection()
cur = conn.execute("SELECT domain, backend FROM domains ORDER BY domain")
return [(row["domain"], row["backend"]) for row in cur.fetchall()]
def db_get_domain_backend(domain: str) -> Optional[str]:
"""Look up the backend for a specific domain.
Args:
domain: Domain name to look up.
Returns:
Backend name if found, None otherwise.
"""
conn = get_connection()
cur = conn.execute("SELECT backend FROM domains WHERE domain = ?", (domain,))
row = cur.fetchone()
return row["backend"] if row else None
def db_add_domain(domain: str, backend: str, is_wildcard: bool = False,
shares_with: Optional[str] = None) -> None:
"""Add a domain to the database.
Args:
domain: Domain name (e.g., "example.com" or ".example.com").
backend: Backend pool name (e.g., "pool_5").
is_wildcard: Whether this is a wildcard entry.
shares_with: Parent domain if sharing a pool.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO domains (domain, backend, is_wildcard, shares_with)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain) DO UPDATE SET
backend = excluded.backend,
shares_with = excluded.shares_with,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, backend, 1 if is_wildcard else 0, shares_with),
)
conn.commit()
def db_remove_domain(domain: str) -> None:
"""Remove a domain (exact + wildcard) from the database.
Args:
domain: Base domain name (without leading dot).
"""
conn = get_connection()
conn.execute("DELETE FROM domains WHERE domain = ? OR domain = ?", (domain, f".{domain}"))
conn.commit()
def db_remove_wildcard(domain: str) -> bool:
"""Remove only the wildcard entry for a domain from the database.
Args:
domain: Base domain name (without leading dot, e.g., "nocodb.inouter.com").
Returns:
True if a wildcard entry was deleted, False if not found.
"""
conn = get_connection()
cur = conn.execute("DELETE FROM domains WHERE domain = ? AND is_wildcard = 1", (f".{domain}",))
conn.commit()
return cur.rowcount > 0
def db_find_available_pool() -> Optional[str]:
"""Find the first available pool not assigned to any domain.
Returns:
Pool name (e.g., "pool_5") or None if all pools are in use.
"""
conn = get_connection()
cur = conn.execute(
"SELECT DISTINCT backend FROM domains WHERE backend LIKE 'pool_%' AND is_wildcard = 0"
)
used_pools = {row["backend"] for row in cur.fetchall()}
for i in range(1, POOL_COUNT + 1):
pool_name = f"pool_{i}"
if pool_name not in used_pools:
return pool_name
return None
def db_get_domains_sharing_pool(pool: str) -> list[str]:
"""Get all non-wildcard domains using a specific pool.
Args:
pool: Pool name (e.g., "pool_5").
Returns:
List of domain names.
"""
conn = get_connection()
cur = conn.execute(
"SELECT domain FROM domains WHERE backend = ? AND is_wildcard = 0",
(pool,),
)
return [row["domain"] for row in cur.fetchall()]
# --- Server data access ---
def db_load_servers_config() -> dict[str, Any]:
"""Load servers configuration in the legacy dict format.
Returns:
Dictionary compatible with the old servers.json format:
{"domain": {"slot": {"ip": "...", "http_port": N}, "_shares": "parent"}}
"""
conn = get_connection()
config: dict[str, Any] = {}
# Load server entries
cur = conn.execute("SELECT domain, slot, ip, http_port FROM servers ORDER BY domain, slot")
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain][str(row["slot"])] = {
"ip": row["ip"],
"http_port": row["http_port"],
}
# Load shared domain references
cur = conn.execute(
"SELECT domain, shares_with FROM domains WHERE shares_with IS NOT NULL AND is_wildcard = 0"
)
for row in cur.fetchall():
domain = row["domain"]
if domain not in config:
config[domain] = {}
config[domain]["_shares"] = row["shares_with"]
return config
def db_add_server(domain: str, slot: int, ip: str, http_port: int) -> None:
"""Add or update a server entry.
Args:
domain: Domain name.
slot: Server slot number.
ip: Server IP address.
http_port: HTTP port.
"""
conn = get_connection()
conn.execute(
"""INSERT INTO servers (domain, slot, ip, http_port)
VALUES (?, ?, ?, ?)
ON CONFLICT(domain, slot) DO UPDATE SET
ip = excluded.ip,
http_port = excluded.http_port,
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')""",
(domain, slot, ip, http_port),
)
conn.commit()
def db_remove_server(domain: str, slot: int) -> None:
"""Remove a server entry.
Args:
domain: Domain name.
slot: Server slot number.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ? AND slot = ?", (domain, slot))
conn.commit()
def db_remove_domain_servers(domain: str) -> None:
"""Remove all server entries for a domain.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM servers WHERE domain = ?", (domain,))
conn.commit()
def db_add_shared_domain(domain: str, shares_with: str) -> None:
"""Mark a domain as sharing a pool with another domain.
Updates the shares_with field in the domains table.
Args:
domain: The new domain.
shares_with: The existing domain to share with.
"""
conn = get_connection()
conn.execute(
"UPDATE domains SET shares_with = ? WHERE domain = ? AND is_wildcard = 0",
(shares_with, domain),
)
conn.commit()
def db_get_shared_domain(domain: str) -> Optional[str]:
"""Get the parent domain that this domain shares a pool with.
Args:
domain: Domain name.
Returns:
Parent domain name if sharing, None otherwise.
"""
conn = get_connection()
cur = conn.execute(
"SELECT shares_with FROM domains WHERE domain = ? AND is_wildcard = 0",
(domain,),
)
row = cur.fetchone()
if row and row["shares_with"]:
return row["shares_with"]
return None
def db_is_shared_domain(domain: str) -> bool:
"""Check if a domain is sharing another domain's pool.
Args:
domain: Domain name.
Returns:
True if domain has shares_with reference.
"""
return db_get_shared_domain(domain) is not None
# --- Certificate data access ---
def db_load_certs() -> list[str]:
"""Load certificate domain list.
Returns:
Sorted list of domain names with certificates.
"""
conn = get_connection()
cur = conn.execute("SELECT domain FROM certificates ORDER BY domain")
return [row["domain"] for row in cur.fetchall()]
def db_add_cert(domain: str) -> None:
"""Add a domain to the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute(
"INSERT OR IGNORE INTO certificates (domain) VALUES (?)",
(domain,),
)
conn.commit()
def db_remove_cert(domain: str) -> None:
"""Remove a domain from the certificate list.
Args:
domain: Domain name.
"""
conn = get_connection()
conn.execute("DELETE FROM certificates WHERE domain = ?", (domain,))
conn.commit()
# --- Map file synchronization ---
def sync_map_files() -> None:
"""Regenerate HAProxy map files from database.
Writes domains.map (exact matches) and wildcards.map (wildcard entries)
from the current database state. Uses atomic writes.
"""
from .file_ops import atomic_write_file
conn = get_connection()
# Fetch exact domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 0 ORDER BY domain"
)
exact_entries = cur.fetchall()
# Fetch wildcard domains
cur = conn.execute(
"SELECT domain, backend FROM domains WHERE is_wildcard = 1 ORDER BY domain"
)
wildcard_entries = cur.fetchall()
# Write exact domains map
exact_lines = [
"# Exact Domain to Backend mapping (for map_str)\n",
"# Format: domain backend_name\n",
"# Uses ebtree for O(log n) lookup performance\n\n",
]
for row in exact_entries:
exact_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(MAP_FILE, "".join(exact_lines))
# Write wildcards map
wildcard_lines = [
"# Wildcard Domain to Backend mapping (for map_dom)\n",
"# Format: .domain.com backend_name (matches *.domain.com)\n",
"# Uses map_dom for suffix matching\n\n",
]
for row in wildcard_entries:
wildcard_lines.append(f"{row['domain']} {row['backend']}\n")
atomic_write_file(WILDCARDS_MAP_FILE, "".join(wildcard_lines))
logger.debug("Map files synced: %d exact, %d wildcard entries",
len(exact_entries), len(wildcard_entries))
def sync_db_to_remote() -> None:
"""Upload local SQLite DB to remote host for persistence.
Checkpoints WAL first to merge all changes into the main DB file,
then uploads via SCP. No-op in local (non-remote) mode.
"""
if not REMOTE_MODE:
return
from .ssh_ops import remote_upload_file
try:
# Merge WAL into main DB file before upload
conn = get_connection()
conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
remote_upload_file(DB_FILE, REMOTE_DB_FILE)
logger.debug("Synced DB to remote: %s", REMOTE_DB_FILE)
except (IOError, OSError) as e:
logger.warning("Failed to sync DB to remote: %s", e)