"""Unit tests for db module (SQLite database operations).""" import json import os import sqlite3 from unittest.mock import patch import pytest from haproxy_mcp.db import ( get_connection, close_connection, init_db, migrate_from_json, db_get_map_contents, db_get_domain_backend, db_add_domain, db_remove_domain, db_find_available_pool, db_get_domains_sharing_pool, db_load_servers_config, db_add_server, db_remove_server, db_remove_domain_servers, db_add_shared_domain, db_get_shared_domain, db_is_shared_domain, db_load_certs, db_add_cert, db_remove_cert, sync_map_files, SCHEMA_VERSION, ) class TestConnectionManagement: """Tests for database connection management.""" def test_get_connection(self, patch_config_paths): """Get a connection returns a valid SQLite connection.""" conn = get_connection() assert conn is not None # Verify WAL mode result = conn.execute("PRAGMA journal_mode").fetchone() assert result[0] == "wal" def test_connection_is_thread_local(self, patch_config_paths): """Same thread gets same connection.""" conn1 = get_connection() conn2 = get_connection() assert conn1 is conn2 def test_close_connection(self, patch_config_paths): """Close connection clears thread-local.""" conn1 = get_connection() close_connection() conn2 = get_connection() assert conn1 is not conn2 class TestInitDb: """Tests for database initialization.""" def test_creates_tables(self, patch_config_paths): """init_db creates all required tables.""" conn = get_connection() # Check tables exist tables = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" ).fetchall() table_names = [t["name"] for t in tables] assert "domains" in table_names assert "servers" in table_names assert "certificates" in table_names assert "schema_version" in table_names def test_schema_version_recorded(self, patch_config_paths): """Schema version is recorded.""" conn = get_connection() cur = conn.execute("SELECT MAX(version) FROM schema_version") version = cur.fetchone()[0] assert version == SCHEMA_VERSION def test_idempotent(self, patch_config_paths): """Calling init_db twice is safe.""" # init_db is already called by patch_config_paths # Calling it again should not raise init_db() conn = get_connection() cur = conn.execute("SELECT COUNT(*) FROM schema_version") # May have 1 or 2 entries but should not fail assert cur.fetchone()[0] >= 1 class TestMigrateFromJson: """Tests for JSON to SQLite migration.""" def test_migrate_map_files(self, patch_config_paths): """Migrate domain entries from map files.""" # Write test map files with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") f.write("api.example.com pool_2\n") with open(patch_config_paths["wildcards_file"], "w") as f: f.write(".example.com pool_1\n") migrate_from_json() entries = db_get_map_contents() assert ("example.com", "pool_1") in entries assert ("api.example.com", "pool_2") in entries assert (".example.com", "pool_1") in entries def test_migrate_servers_json(self, patch_config_paths): """Migrate server entries from servers.json.""" config = { "example.com": { "1": {"ip": "10.0.0.1", "http_port": 80}, "2": {"ip": "10.0.0.2", "http_port": 8080}, } } with open(patch_config_paths["servers_file"], "w") as f: json.dump(config, f) migrate_from_json() result = db_load_servers_config() assert "example.com" in result assert result["example.com"]["1"]["ip"] == "10.0.0.1" assert result["example.com"]["2"]["http_port"] == 8080 def test_migrate_shared_domains(self, patch_config_paths): """Migrate shared domain references.""" # First add the domain to map so it exists in DB with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") f.write("www.example.com pool_1\n") config = { "www.example.com": {"_shares": "example.com"}, } with open(patch_config_paths["servers_file"], "w") as f: json.dump(config, f) migrate_from_json() assert db_get_shared_domain("www.example.com") == "example.com" def test_migrate_certificates(self, patch_config_paths): """Migrate certificate entries.""" with open(patch_config_paths["certs_file"], "w") as f: json.dump({"domains": ["example.com", "api.example.com"]}, f) migrate_from_json() certs = db_load_certs() assert "example.com" in certs assert "api.example.com" in certs def test_migrate_idempotent(self, patch_config_paths): """Migration is idempotent (INSERT OR IGNORE).""" with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") with open(patch_config_paths["servers_file"], "w") as f: json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f) migrate_from_json() migrate_from_json() # Should not fail entries = db_get_map_contents() assert len([d for d, _ in entries if d == "example.com"]) == 1 def test_migrate_empty_files(self, patch_config_paths): """Migration with no existing data does nothing.""" os.unlink(patch_config_paths["map_file"]) os.unlink(patch_config_paths["wildcards_file"]) os.unlink(patch_config_paths["servers_file"]) os.unlink(patch_config_paths["certs_file"]) migrate_from_json() # Should not fail def test_backup_files_after_migration(self, patch_config_paths): """Original JSON files are backed up after migration.""" with open(patch_config_paths["servers_file"], "w") as f: json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f) with open(patch_config_paths["certs_file"], "w") as f: json.dump({"domains": ["example.com"]}, f) # Write map file so migration has data with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") migrate_from_json() assert os.path.exists(f"{patch_config_paths['servers_file']}.bak") assert os.path.exists(f"{patch_config_paths['certs_file']}.bak") class TestDomainOperations: """Tests for domain data access functions.""" def test_add_and_get_domain(self, patch_config_paths): """Add domain and retrieve its backend.""" db_add_domain("example.com", "pool_1") assert db_get_domain_backend("example.com") == "pool_1" def test_get_nonexistent_domain(self, patch_config_paths): """Non-existent domain returns None.""" assert db_get_domain_backend("nonexistent.com") is None def test_remove_domain(self, patch_config_paths): """Remove domain removes both exact and wildcard.""" db_add_domain("example.com", "pool_1") db_add_domain(".example.com", "pool_1", is_wildcard=True) db_remove_domain("example.com") assert db_get_domain_backend("example.com") is None entries = db_get_map_contents() assert (".example.com", "pool_1") not in entries def test_find_available_pool(self, patch_config_paths): """Find first available pool.""" db_add_domain("a.com", "pool_1") db_add_domain("b.com", "pool_2") pool = db_find_available_pool() assert pool == "pool_3" def test_find_available_pool_empty(self, patch_config_paths): """First pool available when none used.""" pool = db_find_available_pool() assert pool == "pool_1" def test_get_domains_sharing_pool(self, patch_config_paths): """Get non-wildcard domains using a pool.""" db_add_domain("example.com", "pool_1") db_add_domain("www.example.com", "pool_1") db_add_domain(".example.com", "pool_1", is_wildcard=True) domains = db_get_domains_sharing_pool("pool_1") assert "example.com" in domains assert "www.example.com" in domains assert ".example.com" not in domains def test_update_domain_backend(self, patch_config_paths): """Updating a domain changes its backend.""" db_add_domain("example.com", "pool_1") db_add_domain("example.com", "pool_5") # Update assert db_get_domain_backend("example.com") == "pool_5" class TestServerOperations: """Tests for server data access functions.""" def test_add_and_load_server(self, patch_config_paths): """Add server and load config.""" db_add_server("example.com", 1, "10.0.0.1", 80) config = db_load_servers_config() assert config["example.com"]["1"]["ip"] == "10.0.0.1" assert config["example.com"]["1"]["http_port"] == 80 def test_update_server(self, patch_config_paths): """Update existing server slot.""" db_add_server("example.com", 1, "10.0.0.1", 80) db_add_server("example.com", 1, "10.0.0.99", 8080) config = db_load_servers_config() assert config["example.com"]["1"]["ip"] == "10.0.0.99" assert config["example.com"]["1"]["http_port"] == 8080 def test_remove_server(self, patch_config_paths): """Remove a server slot.""" db_add_server("example.com", 1, "10.0.0.1", 80) db_add_server("example.com", 2, "10.0.0.2", 80) db_remove_server("example.com", 1) config = db_load_servers_config() assert "1" not in config.get("example.com", {}) assert "2" in config["example.com"] def test_remove_domain_servers(self, patch_config_paths): """Remove all servers for a domain.""" db_add_server("example.com", 1, "10.0.0.1", 80) db_add_server("example.com", 2, "10.0.0.2", 80) db_add_server("other.com", 1, "10.0.0.3", 80) db_remove_domain_servers("example.com") config = db_load_servers_config() assert config.get("example.com", {}).get("1") is None assert "other.com" in config def test_load_empty(self, patch_config_paths): """Empty database returns empty dict.""" config = db_load_servers_config() assert config == {} class TestSharedDomainOperations: """Tests for shared domain functions.""" def test_add_and_get_shared(self, patch_config_paths): """Add shared domain reference.""" db_add_domain("example.com", "pool_1") db_add_domain("www.example.com", "pool_1") db_add_shared_domain("www.example.com", "example.com") assert db_get_shared_domain("www.example.com") == "example.com" def test_is_shared(self, patch_config_paths): """Check if domain is shared.""" db_add_domain("example.com", "pool_1") db_add_domain("www.example.com", "pool_1") db_add_shared_domain("www.example.com", "example.com") assert db_is_shared_domain("www.example.com") is True assert db_is_shared_domain("example.com") is False def test_not_shared(self, patch_config_paths): """Non-shared domain returns None.""" db_add_domain("example.com", "pool_1") assert db_get_shared_domain("example.com") is None assert db_is_shared_domain("example.com") is False def test_shared_in_load_config(self, patch_config_paths): """Shared domain appears in load_servers_config.""" db_add_domain("example.com", "pool_1") db_add_domain("www.example.com", "pool_1") db_add_shared_domain("www.example.com", "example.com") config = db_load_servers_config() assert config["www.example.com"]["_shares"] == "example.com" class TestCertificateOperations: """Tests for certificate data access functions.""" def test_add_and_load_cert(self, patch_config_paths): """Add and load certificate.""" db_add_cert("example.com") certs = db_load_certs() assert "example.com" in certs def test_add_duplicate(self, patch_config_paths): """Adding duplicate cert is a no-op.""" db_add_cert("example.com") db_add_cert("example.com") certs = db_load_certs() assert certs.count("example.com") == 1 def test_remove_cert(self, patch_config_paths): """Remove a certificate.""" db_add_cert("example.com") db_add_cert("other.com") db_remove_cert("example.com") certs = db_load_certs() assert "example.com" not in certs assert "other.com" in certs def test_load_empty(self, patch_config_paths): """Empty database returns empty list.""" certs = db_load_certs() assert certs == [] def test_sorted_output(self, patch_config_paths): """Certificates are returned sorted.""" db_add_cert("z.com") db_add_cert("a.com") db_add_cert("m.com") certs = db_load_certs() assert certs == ["a.com", "m.com", "z.com"] class TestSyncMapFiles: """Tests for sync_map_files function.""" def test_sync_exact_domains(self, patch_config_paths): """Sync writes exact domains to domains.map.""" db_add_domain("example.com", "pool_1") db_add_domain("api.example.com", "pool_2") sync_map_files() with open(patch_config_paths["map_file"]) as f: content = f.read() assert "example.com pool_1" in content assert "api.example.com pool_2" in content def test_sync_wildcards(self, patch_config_paths): """Sync writes wildcards to wildcards.map.""" db_add_domain(".example.com", "pool_1", is_wildcard=True) sync_map_files() with open(patch_config_paths["wildcards_file"]) as f: content = f.read() assert ".example.com pool_1" in content def test_sync_empty(self, patch_config_paths): """Sync with no domains writes headers only.""" sync_map_files() with open(patch_config_paths["map_file"]) as f: content = f.read() assert "Exact Domain" in content # No domain entries lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")] assert len(lines) == 0 def test_sync_sorted(self, patch_config_paths): """Sync output is sorted.""" db_add_domain("z.com", "pool_3") db_add_domain("a.com", "pool_1") sync_map_files() with open(patch_config_paths["map_file"]) as f: lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] assert lines[0] == "a.com pool_1" assert lines[1] == "z.com pool_3"