refactor: migrate data storage from JSON/map files to SQLite
Replace servers.json, certificates.json, and map file parsing with SQLite (WAL mode) as single source of truth. HAProxy map files are now generated from SQLite via sync_map_files(). Key changes: - Add db.py with schema, connection management, and JSON migration - Add DB_FILE config constant - Delegate file_ops.py functions to db.py - Refactor domains.py to use file_ops instead of direct list manipulation - Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError) - Add DB health check in health.py - Init DB on startup in server.py and __main__.py - Update all 359 tests to use SQLite-backed functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
433
tests/unit/test_db.py
Normal file
433
tests/unit/test_db.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Unit tests for db module (SQLite database operations)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.db import (
|
||||
get_connection,
|
||||
close_connection,
|
||||
init_db,
|
||||
migrate_from_json,
|
||||
db_get_map_contents,
|
||||
db_get_domain_backend,
|
||||
db_add_domain,
|
||||
db_remove_domain,
|
||||
db_find_available_pool,
|
||||
db_get_domains_sharing_pool,
|
||||
db_load_servers_config,
|
||||
db_add_server,
|
||||
db_remove_server,
|
||||
db_remove_domain_servers,
|
||||
db_add_shared_domain,
|
||||
db_get_shared_domain,
|
||||
db_is_shared_domain,
|
||||
db_load_certs,
|
||||
db_add_cert,
|
||||
db_remove_cert,
|
||||
sync_map_files,
|
||||
SCHEMA_VERSION,
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionManagement:
|
||||
"""Tests for database connection management."""
|
||||
|
||||
def test_get_connection(self, patch_config_paths):
|
||||
"""Get a connection returns a valid SQLite connection."""
|
||||
conn = get_connection()
|
||||
assert conn is not None
|
||||
# Verify WAL mode
|
||||
result = conn.execute("PRAGMA journal_mode").fetchone()
|
||||
assert result[0] == "wal"
|
||||
|
||||
def test_connection_is_thread_local(self, patch_config_paths):
|
||||
"""Same thread gets same connection."""
|
||||
conn1 = get_connection()
|
||||
conn2 = get_connection()
|
||||
assert conn1 is conn2
|
||||
|
||||
def test_close_connection(self, patch_config_paths):
|
||||
"""Close connection clears thread-local."""
|
||||
conn1 = get_connection()
|
||||
close_connection()
|
||||
conn2 = get_connection()
|
||||
assert conn1 is not conn2
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for database initialization."""
|
||||
|
||||
def test_creates_tables(self, patch_config_paths):
|
||||
"""init_db creates all required tables."""
|
||||
conn = get_connection()
|
||||
|
||||
# Check tables exist
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||
).fetchall()
|
||||
table_names = [t["name"] for t in tables]
|
||||
|
||||
assert "domains" in table_names
|
||||
assert "servers" in table_names
|
||||
assert "certificates" in table_names
|
||||
assert "schema_version" in table_names
|
||||
|
||||
def test_schema_version_recorded(self, patch_config_paths):
|
||||
"""Schema version is recorded."""
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT MAX(version) FROM schema_version")
|
||||
version = cur.fetchone()[0]
|
||||
assert version == SCHEMA_VERSION
|
||||
|
||||
def test_idempotent(self, patch_config_paths):
|
||||
"""Calling init_db twice is safe."""
|
||||
# init_db is already called by patch_config_paths
|
||||
# Calling it again should not raise
|
||||
init_db()
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT COUNT(*) FROM schema_version")
|
||||
# May have 1 or 2 entries but should not fail
|
||||
assert cur.fetchone()[0] >= 1
|
||||
|
||||
|
||||
class TestMigrateFromJson:
|
||||
"""Tests for JSON to SQLite migration."""
|
||||
|
||||
def test_migrate_map_files(self, patch_config_paths):
|
||||
"""Migrate domain entries from map files."""
|
||||
# Write test map files
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
entries = db_get_map_contents()
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert ("api.example.com", "pool_2") in entries
|
||||
assert (".example.com", "pool_1") in entries
|
||||
|
||||
def test_migrate_servers_json(self, patch_config_paths):
|
||||
"""Migrate server entries from servers.json."""
|
||||
config = {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||
"2": {"ip": "10.0.0.2", "http_port": 8080},
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
result = db_load_servers_config()
|
||||
assert "example.com" in result
|
||||
assert result["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert result["example.com"]["2"]["http_port"] == 8080
|
||||
|
||||
def test_migrate_shared_domains(self, patch_config_paths):
|
||||
"""Migrate shared domain references."""
|
||||
# First add the domain to map so it exists in DB
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("www.example.com pool_1\n")
|
||||
|
||||
config = {
|
||||
"www.example.com": {"_shares": "example.com"},
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
assert db_get_shared_domain("www.example.com") == "example.com"
|
||||
|
||||
def test_migrate_certificates(self, patch_config_paths):
|
||||
"""Migrate certificate entries."""
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "api.example.com"]}, f)
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
certs = db_load_certs()
|
||||
assert "example.com" in certs
|
||||
assert "api.example.com" in certs
|
||||
|
||||
def test_migrate_idempotent(self, patch_config_paths):
|
||||
"""Migration is idempotent (INSERT OR IGNORE)."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
|
||||
|
||||
migrate_from_json()
|
||||
migrate_from_json() # Should not fail
|
||||
|
||||
entries = db_get_map_contents()
|
||||
assert len([d for d, _ in entries if d == "example.com"]) == 1
|
||||
|
||||
def test_migrate_empty_files(self, patch_config_paths):
|
||||
"""Migration with no existing data does nothing."""
|
||||
os.unlink(patch_config_paths["map_file"])
|
||||
os.unlink(patch_config_paths["wildcards_file"])
|
||||
os.unlink(patch_config_paths["servers_file"])
|
||||
os.unlink(patch_config_paths["certs_file"])
|
||||
|
||||
migrate_from_json() # Should not fail
|
||||
|
||||
def test_backup_files_after_migration(self, patch_config_paths):
|
||||
"""Original JSON files are backed up after migration."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com"]}, f)
|
||||
|
||||
# Write map file so migration has data
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
assert os.path.exists(f"{patch_config_paths['servers_file']}.bak")
|
||||
assert os.path.exists(f"{patch_config_paths['certs_file']}.bak")
|
||||
|
||||
|
||||
class TestDomainOperations:
|
||||
"""Tests for domain data access functions."""
|
||||
|
||||
def test_add_and_get_domain(self, patch_config_paths):
|
||||
"""Add domain and retrieve its backend."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
|
||||
assert db_get_domain_backend("example.com") == "pool_1"
|
||||
|
||||
def test_get_nonexistent_domain(self, patch_config_paths):
|
||||
"""Non-existent domain returns None."""
|
||||
assert db_get_domain_backend("nonexistent.com") is None
|
||||
|
||||
def test_remove_domain(self, patch_config_paths):
|
||||
"""Remove domain removes both exact and wildcard."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
db_remove_domain("example.com")
|
||||
|
||||
assert db_get_domain_backend("example.com") is None
|
||||
entries = db_get_map_contents()
|
||||
assert (".example.com", "pool_1") not in entries
|
||||
|
||||
def test_find_available_pool(self, patch_config_paths):
|
||||
"""Find first available pool."""
|
||||
db_add_domain("a.com", "pool_1")
|
||||
db_add_domain("b.com", "pool_2")
|
||||
|
||||
pool = db_find_available_pool()
|
||||
assert pool == "pool_3"
|
||||
|
||||
def test_find_available_pool_empty(self, patch_config_paths):
|
||||
"""First pool available when none used."""
|
||||
pool = db_find_available_pool()
|
||||
assert pool == "pool_1"
|
||||
|
||||
def test_get_domains_sharing_pool(self, patch_config_paths):
|
||||
"""Get non-wildcard domains using a pool."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
domains = db_get_domains_sharing_pool("pool_1")
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
assert ".example.com" not in domains
|
||||
|
||||
def test_update_domain_backend(self, patch_config_paths):
|
||||
"""Updating a domain changes its backend."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("example.com", "pool_5") # Update
|
||||
|
||||
assert db_get_domain_backend("example.com") == "pool_5"
|
||||
|
||||
|
||||
class TestServerOperations:
|
||||
"""Tests for server data access functions."""
|
||||
|
||||
def test_add_and_load_server(self, patch_config_paths):
|
||||
"""Add server and load config."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert config["example.com"]["1"]["http_port"] == 80
|
||||
|
||||
def test_update_server(self, patch_config_paths):
|
||||
"""Update existing server slot."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
db_add_server("example.com", 1, "10.0.0.99", 8080)
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
|
||||
assert config["example.com"]["1"]["http_port"] == 8080
|
||||
|
||||
def test_remove_server(self, patch_config_paths):
|
||||
"""Remove a server slot."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
db_add_server("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
db_remove_server("example.com", 1)
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert "1" not in config.get("example.com", {})
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_remove_domain_servers(self, patch_config_paths):
|
||||
"""Remove all servers for a domain."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
db_add_server("example.com", 2, "10.0.0.2", 80)
|
||||
db_add_server("other.com", 1, "10.0.0.3", 80)
|
||||
|
||||
db_remove_domain_servers("example.com")
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config.get("example.com", {}).get("1") is None
|
||||
assert "other.com" in config
|
||||
|
||||
def test_load_empty(self, patch_config_paths):
|
||||
"""Empty database returns empty dict."""
|
||||
config = db_load_servers_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestSharedDomainOperations:
|
||||
"""Tests for shared domain functions."""
|
||||
|
||||
def test_add_and_get_shared(self, patch_config_paths):
|
||||
"""Add shared domain reference."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_shared_domain("www.example.com", "example.com")
|
||||
|
||||
assert db_get_shared_domain("www.example.com") == "example.com"
|
||||
|
||||
def test_is_shared(self, patch_config_paths):
|
||||
"""Check if domain is shared."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_shared_domain("www.example.com", "example.com")
|
||||
|
||||
assert db_is_shared_domain("www.example.com") is True
|
||||
assert db_is_shared_domain("example.com") is False
|
||||
|
||||
def test_not_shared(self, patch_config_paths):
|
||||
"""Non-shared domain returns None."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
|
||||
assert db_get_shared_domain("example.com") is None
|
||||
assert db_is_shared_domain("example.com") is False
|
||||
|
||||
def test_shared_in_load_config(self, patch_config_paths):
|
||||
"""Shared domain appears in load_servers_config."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_shared_domain("www.example.com", "example.com")
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config["www.example.com"]["_shares"] == "example.com"
|
||||
|
||||
|
||||
class TestCertificateOperations:
|
||||
"""Tests for certificate data access functions."""
|
||||
|
||||
def test_add_and_load_cert(self, patch_config_paths):
|
||||
"""Add and load certificate."""
|
||||
db_add_cert("example.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert "example.com" in certs
|
||||
|
||||
def test_add_duplicate(self, patch_config_paths):
|
||||
"""Adding duplicate cert is a no-op."""
|
||||
db_add_cert("example.com")
|
||||
db_add_cert("example.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert certs.count("example.com") == 1
|
||||
|
||||
def test_remove_cert(self, patch_config_paths):
|
||||
"""Remove a certificate."""
|
||||
db_add_cert("example.com")
|
||||
db_add_cert("other.com")
|
||||
|
||||
db_remove_cert("example.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert "example.com" not in certs
|
||||
assert "other.com" in certs
|
||||
|
||||
def test_load_empty(self, patch_config_paths):
|
||||
"""Empty database returns empty list."""
|
||||
certs = db_load_certs()
|
||||
assert certs == []
|
||||
|
||||
def test_sorted_output(self, patch_config_paths):
|
||||
"""Certificates are returned sorted."""
|
||||
db_add_cert("z.com")
|
||||
db_add_cert("a.com")
|
||||
db_add_cert("m.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert certs == ["a.com", "m.com", "z.com"]
|
||||
|
||||
|
||||
class TestSyncMapFiles:
|
||||
"""Tests for sync_map_files function."""
|
||||
|
||||
def test_sync_exact_domains(self, patch_config_paths):
|
||||
"""Sync writes exact domains to domains.map."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("api.example.com", "pool_2")
|
||||
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "example.com pool_1" in content
|
||||
assert "api.example.com pool_2" in content
|
||||
|
||||
def test_sync_wildcards(self, patch_config_paths):
|
||||
"""Sync writes wildcards to wildcards.map."""
|
||||
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["wildcards_file"]) as f:
|
||||
content = f.read()
|
||||
assert ".example.com pool_1" in content
|
||||
|
||||
def test_sync_empty(self, patch_config_paths):
|
||||
"""Sync with no domains writes headers only."""
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "Exact Domain" in content
|
||||
# No domain entries
|
||||
lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")]
|
||||
assert len(lines) == 0
|
||||
|
||||
def test_sync_sorted(self, patch_config_paths):
|
||||
"""Sync output is sorted."""
|
||||
db_add_domain("z.com", "pool_3")
|
||||
db_add_domain("a.com", "pool_1")
|
||||
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||
assert lines[0] == "a.com pool_1"
|
||||
assert lines[1] == "z.com pool_3"
|
||||
Reference in New Issue
Block a user