Files
haproxy-mcp/tests/unit/test_db.py
kappa cf554f3f89 refactor: migrate data storage from JSON/map files to SQLite
Replace servers.json, certificates.json, and map file parsing with
SQLite (WAL mode) as single source of truth. HAProxy map files are
now generated from SQLite via sync_map_files().

Key changes:
- Add db.py with schema, connection management, and JSON migration
- Add DB_FILE config constant
- Delegate file_ops.py functions to db.py
- Refactor domains.py to use file_ops instead of direct list manipulation
- Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError)
- Add DB health check in health.py
- Init DB on startup in server.py and __main__.py
- Update all 359 tests to use SQLite-backed functions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 11:07:29 +09:00

434 lines
15 KiB
Python

"""Unit tests for db module (SQLite database operations)."""
import json
import os
import sqlite3
from unittest.mock import patch
import pytest
from haproxy_mcp.db import (
get_connection,
close_connection,
init_db,
migrate_from_json,
db_get_map_contents,
db_get_domain_backend,
db_add_domain,
db_remove_domain,
db_find_available_pool,
db_get_domains_sharing_pool,
db_load_servers_config,
db_add_server,
db_remove_server,
db_remove_domain_servers,
db_add_shared_domain,
db_get_shared_domain,
db_is_shared_domain,
db_load_certs,
db_add_cert,
db_remove_cert,
sync_map_files,
SCHEMA_VERSION,
)
class TestConnectionManagement:
"""Tests for database connection management."""
def test_get_connection(self, patch_config_paths):
"""Get a connection returns a valid SQLite connection."""
conn = get_connection()
assert conn is not None
# Verify WAL mode
result = conn.execute("PRAGMA journal_mode").fetchone()
assert result[0] == "wal"
def test_connection_is_thread_local(self, patch_config_paths):
"""Same thread gets same connection."""
conn1 = get_connection()
conn2 = get_connection()
assert conn1 is conn2
def test_close_connection(self, patch_config_paths):
"""Close connection clears thread-local."""
conn1 = get_connection()
close_connection()
conn2 = get_connection()
assert conn1 is not conn2
class TestInitDb:
"""Tests for database initialization."""
def test_creates_tables(self, patch_config_paths):
"""init_db creates all required tables."""
conn = get_connection()
# Check tables exist
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
).fetchall()
table_names = [t["name"] for t in tables]
assert "domains" in table_names
assert "servers" in table_names
assert "certificates" in table_names
assert "schema_version" in table_names
def test_schema_version_recorded(self, patch_config_paths):
"""Schema version is recorded."""
conn = get_connection()
cur = conn.execute("SELECT MAX(version) FROM schema_version")
version = cur.fetchone()[0]
assert version == SCHEMA_VERSION
def test_idempotent(self, patch_config_paths):
"""Calling init_db twice is safe."""
# init_db is already called by patch_config_paths
# Calling it again should not raise
init_db()
conn = get_connection()
cur = conn.execute("SELECT COUNT(*) FROM schema_version")
# May have 1 or 2 entries but should not fail
assert cur.fetchone()[0] >= 1
class TestMigrateFromJson:
"""Tests for JSON to SQLite migration."""
def test_migrate_map_files(self, patch_config_paths):
"""Migrate domain entries from map files."""
# Write test map files
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
migrate_from_json()
entries = db_get_map_contents()
assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries
assert (".example.com", "pool_1") in entries
def test_migrate_servers_json(self, patch_config_paths):
"""Migrate server entries from servers.json."""
config = {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": 80},
"2": {"ip": "10.0.0.2", "http_port": 8080},
}
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
migrate_from_json()
result = db_load_servers_config()
assert "example.com" in result
assert result["example.com"]["1"]["ip"] == "10.0.0.1"
assert result["example.com"]["2"]["http_port"] == 8080
def test_migrate_shared_domains(self, patch_config_paths):
"""Migrate shared domain references."""
# First add the domain to map so it exists in DB
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("www.example.com pool_1\n")
config = {
"www.example.com": {"_shares": "example.com"},
}
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(config, f)
migrate_from_json()
assert db_get_shared_domain("www.example.com") == "example.com"
def test_migrate_certificates(self, patch_config_paths):
"""Migrate certificate entries."""
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "api.example.com"]}, f)
migrate_from_json()
certs = db_load_certs()
assert "example.com" in certs
assert "api.example.com" in certs
def test_migrate_idempotent(self, patch_config_paths):
"""Migration is idempotent (INSERT OR IGNORE)."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["servers_file"], "w") as f:
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
migrate_from_json()
migrate_from_json() # Should not fail
entries = db_get_map_contents()
assert len([d for d, _ in entries if d == "example.com"]) == 1
def test_migrate_empty_files(self, patch_config_paths):
"""Migration with no existing data does nothing."""
os.unlink(patch_config_paths["map_file"])
os.unlink(patch_config_paths["wildcards_file"])
os.unlink(patch_config_paths["servers_file"])
os.unlink(patch_config_paths["certs_file"])
migrate_from_json() # Should not fail
def test_backup_files_after_migration(self, patch_config_paths):
"""Original JSON files are backed up after migration."""
with open(patch_config_paths["servers_file"], "w") as f:
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com"]}, f)
# Write map file so migration has data
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
migrate_from_json()
assert os.path.exists(f"{patch_config_paths['servers_file']}.bak")
assert os.path.exists(f"{patch_config_paths['certs_file']}.bak")
class TestDomainOperations:
"""Tests for domain data access functions."""
def test_add_and_get_domain(self, patch_config_paths):
"""Add domain and retrieve its backend."""
db_add_domain("example.com", "pool_1")
assert db_get_domain_backend("example.com") == "pool_1"
def test_get_nonexistent_domain(self, patch_config_paths):
"""Non-existent domain returns None."""
assert db_get_domain_backend("nonexistent.com") is None
def test_remove_domain(self, patch_config_paths):
"""Remove domain removes both exact and wildcard."""
db_add_domain("example.com", "pool_1")
db_add_domain(".example.com", "pool_1", is_wildcard=True)
db_remove_domain("example.com")
assert db_get_domain_backend("example.com") is None
entries = db_get_map_contents()
assert (".example.com", "pool_1") not in entries
def test_find_available_pool(self, patch_config_paths):
"""Find first available pool."""
db_add_domain("a.com", "pool_1")
db_add_domain("b.com", "pool_2")
pool = db_find_available_pool()
assert pool == "pool_3"
def test_find_available_pool_empty(self, patch_config_paths):
"""First pool available when none used."""
pool = db_find_available_pool()
assert pool == "pool_1"
def test_get_domains_sharing_pool(self, patch_config_paths):
"""Get non-wildcard domains using a pool."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_domain(".example.com", "pool_1", is_wildcard=True)
domains = db_get_domains_sharing_pool("pool_1")
assert "example.com" in domains
assert "www.example.com" in domains
assert ".example.com" not in domains
def test_update_domain_backend(self, patch_config_paths):
"""Updating a domain changes its backend."""
db_add_domain("example.com", "pool_1")
db_add_domain("example.com", "pool_5") # Update
assert db_get_domain_backend("example.com") == "pool_5"
class TestServerOperations:
"""Tests for server data access functions."""
def test_add_and_load_server(self, patch_config_paths):
"""Add server and load config."""
db_add_server("example.com", 1, "10.0.0.1", 80)
config = db_load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["1"]["http_port"] == 80
def test_update_server(self, patch_config_paths):
"""Update existing server slot."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 1, "10.0.0.99", 8080)
config = db_load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
assert config["example.com"]["1"]["http_port"] == 8080
def test_remove_server(self, patch_config_paths):
"""Remove a server slot."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 2, "10.0.0.2", 80)
db_remove_server("example.com", 1)
config = db_load_servers_config()
assert "1" not in config.get("example.com", {})
assert "2" in config["example.com"]
def test_remove_domain_servers(self, patch_config_paths):
"""Remove all servers for a domain."""
db_add_server("example.com", 1, "10.0.0.1", 80)
db_add_server("example.com", 2, "10.0.0.2", 80)
db_add_server("other.com", 1, "10.0.0.3", 80)
db_remove_domain_servers("example.com")
config = db_load_servers_config()
assert config.get("example.com", {}).get("1") is None
assert "other.com" in config
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty dict."""
config = db_load_servers_config()
assert config == {}
class TestSharedDomainOperations:
"""Tests for shared domain functions."""
def test_add_and_get_shared(self, patch_config_paths):
"""Add shared domain reference."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
assert db_get_shared_domain("www.example.com") == "example.com"
def test_is_shared(self, patch_config_paths):
"""Check if domain is shared."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
assert db_is_shared_domain("www.example.com") is True
assert db_is_shared_domain("example.com") is False
def test_not_shared(self, patch_config_paths):
"""Non-shared domain returns None."""
db_add_domain("example.com", "pool_1")
assert db_get_shared_domain("example.com") is None
assert db_is_shared_domain("example.com") is False
def test_shared_in_load_config(self, patch_config_paths):
"""Shared domain appears in load_servers_config."""
db_add_domain("example.com", "pool_1")
db_add_domain("www.example.com", "pool_1")
db_add_shared_domain("www.example.com", "example.com")
config = db_load_servers_config()
assert config["www.example.com"]["_shares"] == "example.com"
class TestCertificateOperations:
"""Tests for certificate data access functions."""
def test_add_and_load_cert(self, patch_config_paths):
"""Add and load certificate."""
db_add_cert("example.com")
certs = db_load_certs()
assert "example.com" in certs
def test_add_duplicate(self, patch_config_paths):
"""Adding duplicate cert is a no-op."""
db_add_cert("example.com")
db_add_cert("example.com")
certs = db_load_certs()
assert certs.count("example.com") == 1
def test_remove_cert(self, patch_config_paths):
"""Remove a certificate."""
db_add_cert("example.com")
db_add_cert("other.com")
db_remove_cert("example.com")
certs = db_load_certs()
assert "example.com" not in certs
assert "other.com" in certs
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty list."""
certs = db_load_certs()
assert certs == []
def test_sorted_output(self, patch_config_paths):
"""Certificates are returned sorted."""
db_add_cert("z.com")
db_add_cert("a.com")
db_add_cert("m.com")
certs = db_load_certs()
assert certs == ["a.com", "m.com", "z.com"]
class TestSyncMapFiles:
"""Tests for sync_map_files function."""
def test_sync_exact_domains(self, patch_config_paths):
"""Sync writes exact domains to domains.map."""
db_add_domain("example.com", "pool_1")
db_add_domain("api.example.com", "pool_2")
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "example.com pool_1" in content
assert "api.example.com pool_2" in content
def test_sync_wildcards(self, patch_config_paths):
"""Sync writes wildcards to wildcards.map."""
db_add_domain(".example.com", "pool_1", is_wildcard=True)
sync_map_files()
with open(patch_config_paths["wildcards_file"]) as f:
content = f.read()
assert ".example.com pool_1" in content
def test_sync_empty(self, patch_config_paths):
"""Sync with no domains writes headers only."""
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "Exact Domain" in content
# No domain entries
lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")]
assert len(lines) == 0
def test_sync_sorted(self, patch_config_paths):
"""Sync output is sorted."""
db_add_domain("z.com", "pool_3")
db_add_domain("a.com", "pool_1")
sync_map_files()
with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
assert lines[0] == "a.com pool_1"
assert lines[1] == "z.com pool_3"