Replace servers.json, certificates.json, and map file parsing with SQLite (WAL mode) as single source of truth. HAProxy map files are now generated from SQLite via sync_map_files(). Key changes: - Add db.py with schema, connection management, and JSON migration - Add DB_FILE config constant - Delegate file_ops.py functions to db.py - Refactor domains.py to use file_ops instead of direct list manipulation - Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError) - Add DB health check in health.py - Init DB on startup in server.py and __main__.py - Update all 359 tests to use SQLite-backed functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
514 lines
17 KiB
Python
514 lines
17 KiB
Python
"""Unit tests for file_ops module (SQLite-backed)."""
|
|
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from haproxy_mcp.file_ops import (
|
|
atomic_write_file,
|
|
get_map_contents,
|
|
save_map_file,
|
|
get_domain_backend,
|
|
split_domain_entries,
|
|
is_legacy_backend,
|
|
get_legacy_backend_name,
|
|
get_backend_and_prefix,
|
|
load_servers_config,
|
|
add_server_to_config,
|
|
remove_server_from_config,
|
|
remove_domain_from_config,
|
|
load_certs_config,
|
|
add_cert_to_config,
|
|
remove_cert_from_config,
|
|
add_domain_to_map,
|
|
remove_domain_from_map,
|
|
find_available_pool,
|
|
add_shared_domain_to_config,
|
|
get_shared_domain,
|
|
is_shared_domain,
|
|
get_domains_sharing_pool,
|
|
)
|
|
|
|
|
|
class TestAtomicWriteFile:
|
|
"""Tests for atomic_write_file function."""
|
|
|
|
def test_write_new_file(self, tmp_path):
|
|
"""Write to a new file."""
|
|
file_path = str(tmp_path / "test.txt")
|
|
content = "Hello, World!"
|
|
|
|
atomic_write_file(file_path, content)
|
|
|
|
assert os.path.exists(file_path)
|
|
with open(file_path) as f:
|
|
assert f.read() == content
|
|
|
|
def test_overwrite_existing_file(self, tmp_path):
|
|
"""Overwrite an existing file."""
|
|
file_path = str(tmp_path / "test.txt")
|
|
with open(file_path, "w") as f:
|
|
f.write("Old content")
|
|
|
|
atomic_write_file(file_path, "New content")
|
|
|
|
with open(file_path) as f:
|
|
assert f.read() == "New content"
|
|
|
|
def test_preserves_directory(self, tmp_path):
|
|
"""Writing does not create intermediate directories."""
|
|
file_path = str(tmp_path / "subdir" / "test.txt")
|
|
|
|
with pytest.raises(IOError):
|
|
atomic_write_file(file_path, "content")
|
|
|
|
def test_unicode_content(self, tmp_path):
|
|
"""Unicode content is properly written."""
|
|
file_path = str(tmp_path / "unicode.txt")
|
|
content = "Hello, \u4e16\u754c!"
|
|
|
|
atomic_write_file(file_path, content)
|
|
|
|
with open(file_path, encoding="utf-8") as f:
|
|
assert f.read() == content
|
|
|
|
def test_multiline_content(self, tmp_path):
|
|
"""Multi-line content is properly written."""
|
|
file_path = str(tmp_path / "multiline.txt")
|
|
content = "line1\nline2\nline3"
|
|
|
|
atomic_write_file(file_path, content)
|
|
|
|
with open(file_path) as f:
|
|
assert f.read() == content
|
|
|
|
|
|
class TestGetMapContents:
|
|
"""Tests for get_map_contents function (SQLite-backed)."""
|
|
|
|
def test_empty_db(self, patch_config_paths):
|
|
"""Empty database returns empty list."""
|
|
entries = get_map_contents()
|
|
assert entries == []
|
|
|
|
def test_read_domains(self, patch_config_paths):
|
|
"""Read entries from database."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map("api.example.com", "pool_2")
|
|
|
|
entries = get_map_contents()
|
|
|
|
assert ("example.com", "pool_1") in entries
|
|
assert ("api.example.com", "pool_2") in entries
|
|
|
|
def test_read_with_wildcards(self, patch_config_paths):
|
|
"""Read entries including wildcards."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
|
|
|
entries = get_map_contents()
|
|
|
|
assert ("example.com", "pool_1") in entries
|
|
assert (".example.com", "pool_1") in entries
|
|
|
|
|
|
class TestSplitDomainEntries:
|
|
"""Tests for split_domain_entries function."""
|
|
|
|
def test_split_entries(self):
|
|
"""Split entries into exact and wildcard."""
|
|
entries = [
|
|
("example.com", "pool_1"),
|
|
(".example.com", "pool_1"),
|
|
("api.example.com", "pool_2"),
|
|
(".api.example.com", "pool_2"),
|
|
]
|
|
|
|
exact, wildcards = split_domain_entries(entries)
|
|
|
|
assert len(exact) == 2
|
|
assert len(wildcards) == 2
|
|
assert ("example.com", "pool_1") in exact
|
|
assert (".example.com", "pool_1") in wildcards
|
|
|
|
def test_empty_entries(self):
|
|
"""Empty entries returns empty lists."""
|
|
exact, wildcards = split_domain_entries([])
|
|
|
|
assert exact == []
|
|
assert wildcards == []
|
|
|
|
def test_all_exact(self):
|
|
"""All exact entries."""
|
|
entries = [
|
|
("example.com", "pool_1"),
|
|
("api.example.com", "pool_2"),
|
|
]
|
|
|
|
exact, wildcards = split_domain_entries(entries)
|
|
|
|
assert len(exact) == 2
|
|
assert len(wildcards) == 0
|
|
|
|
|
|
class TestSaveMapFile:
|
|
"""Tests for save_map_file function (syncs from DB to map files)."""
|
|
|
|
def test_save_entries(self, patch_config_paths):
|
|
"""Save entries to separate map files."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
|
|
|
save_map_file([]) # Entries param ignored, syncs from DB
|
|
|
|
with open(patch_config_paths["map_file"]) as f:
|
|
content = f.read()
|
|
assert "example.com pool_1" in content
|
|
|
|
with open(patch_config_paths["wildcards_file"]) as f:
|
|
content = f.read()
|
|
assert ".example.com pool_1" in content
|
|
|
|
def test_sorted_output(self, patch_config_paths):
|
|
"""Entries are sorted in output."""
|
|
add_domain_to_map("z.example.com", "pool_3")
|
|
add_domain_to_map("a.example.com", "pool_1")
|
|
add_domain_to_map("m.example.com", "pool_2")
|
|
|
|
save_map_file([])
|
|
|
|
with open(patch_config_paths["map_file"]) as f:
|
|
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
|
|
|
assert lines[0] == "a.example.com pool_1"
|
|
assert lines[1] == "m.example.com pool_2"
|
|
assert lines[2] == "z.example.com pool_3"
|
|
|
|
|
|
class TestGetDomainBackend:
|
|
"""Tests for get_domain_backend function (SQLite-backed)."""
|
|
|
|
def test_find_existing_domain(self, patch_config_paths):
|
|
"""Find backend for existing domain."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
|
|
backend = get_domain_backend("example.com")
|
|
|
|
assert backend == "pool_1"
|
|
|
|
def test_domain_not_found(self, patch_config_paths):
|
|
"""Non-existent domain returns None."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
|
|
backend = get_domain_backend("other.com")
|
|
|
|
assert backend is None
|
|
|
|
|
|
class TestIsLegacyBackend:
|
|
"""Tests for is_legacy_backend function."""
|
|
|
|
def test_pool_backend(self):
|
|
"""Pool backend is not legacy."""
|
|
assert is_legacy_backend("pool_1") is False
|
|
assert is_legacy_backend("pool_100") is False
|
|
|
|
def test_legacy_backend(self):
|
|
"""Non-pool backend is legacy."""
|
|
assert is_legacy_backend("api_example_com_backend") is True
|
|
assert is_legacy_backend("static_backend") is True
|
|
|
|
|
|
class TestGetLegacyBackendName:
|
|
"""Tests for get_legacy_backend_name function."""
|
|
|
|
def test_convert_domain(self):
|
|
"""Convert domain to legacy backend name."""
|
|
result = get_legacy_backend_name("api.example.com")
|
|
assert result == "api_example_com_backend"
|
|
|
|
|
|
class TestGetBackendAndPrefix:
|
|
"""Tests for get_backend_and_prefix function."""
|
|
|
|
def test_pool_backend(self, patch_config_paths):
|
|
"""Pool backend returns pool-based prefix."""
|
|
add_domain_to_map("example.com", "pool_5")
|
|
|
|
backend, prefix = get_backend_and_prefix("example.com")
|
|
|
|
assert backend == "pool_5"
|
|
assert prefix == "pool_5"
|
|
|
|
def test_unknown_domain_uses_legacy(self, patch_config_paths):
|
|
"""Unknown domain uses legacy backend naming."""
|
|
backend, prefix = get_backend_and_prefix("unknown.com")
|
|
|
|
assert backend == "unknown_com_backend"
|
|
assert prefix == "unknown_com"
|
|
|
|
|
|
class TestLoadServersConfig:
|
|
"""Tests for load_servers_config function (SQLite-backed)."""
|
|
|
|
def test_load_empty_config(self, patch_config_paths):
|
|
"""Empty database returns empty dict."""
|
|
config = load_servers_config()
|
|
assert config == {}
|
|
|
|
def test_load_with_servers(self, patch_config_paths):
|
|
"""Load config with server entries."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
|
|
|
config = load_servers_config()
|
|
|
|
assert "example.com" in config
|
|
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
|
assert config["example.com"]["2"]["ip"] == "10.0.0.2"
|
|
|
|
def test_load_with_shared_domain(self, patch_config_paths):
|
|
"""Load config with shared domain reference."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map("www.example.com", "pool_1")
|
|
add_shared_domain_to_config("www.example.com", "example.com")
|
|
|
|
config = load_servers_config()
|
|
|
|
assert config["www.example.com"]["_shares"] == "example.com"
|
|
|
|
|
|
class TestAddServerToConfig:
|
|
"""Tests for add_server_to_config function."""
|
|
|
|
def test_add_to_empty_config(self, patch_config_paths):
|
|
"""Add server to empty config."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
|
|
config = load_servers_config()
|
|
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
|
assert config["example.com"]["1"]["http_port"] == 80
|
|
|
|
def test_add_to_existing_domain(self, patch_config_paths):
|
|
"""Add server to domain with existing servers."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
|
|
|
config = load_servers_config()
|
|
assert "1" in config["example.com"]
|
|
assert "2" in config["example.com"]
|
|
|
|
def test_overwrite_existing_slot(self, patch_config_paths):
|
|
"""Overwrite existing slot."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
add_server_to_config("example.com", 1, "10.0.0.99", 8080)
|
|
|
|
config = load_servers_config()
|
|
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
|
|
assert config["example.com"]["1"]["http_port"] == 8080
|
|
|
|
|
|
class TestRemoveServerFromConfig:
|
|
"""Tests for remove_server_from_config function."""
|
|
|
|
def test_remove_existing_server(self, patch_config_paths):
|
|
"""Remove existing server."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
|
|
|
remove_server_from_config("example.com", 1)
|
|
|
|
config = load_servers_config()
|
|
assert "1" not in config.get("example.com", {})
|
|
assert "2" in config["example.com"]
|
|
|
|
def test_remove_last_server_removes_domain(self, patch_config_paths):
|
|
"""Removing last server removes domain entry from servers."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
|
|
remove_server_from_config("example.com", 1)
|
|
|
|
config = load_servers_config()
|
|
# Domain may or may not exist (no servers = no entry)
|
|
assert config.get("example.com", {}).get("1") is None
|
|
|
|
def test_remove_nonexistent_server(self, patch_config_paths):
|
|
"""Removing non-existent server is a no-op."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
|
|
remove_server_from_config("example.com", 99) # Non-existent slot
|
|
|
|
config = load_servers_config()
|
|
assert "1" in config["example.com"]
|
|
|
|
|
|
class TestRemoveDomainFromConfig:
|
|
"""Tests for remove_domain_from_config function."""
|
|
|
|
def test_remove_existing_domain(self, patch_config_paths):
|
|
"""Remove existing domain's servers."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
add_server_to_config("other.com", 1, "10.0.0.2", 80)
|
|
|
|
remove_domain_from_config("example.com")
|
|
|
|
config = load_servers_config()
|
|
assert config.get("example.com", {}).get("1") is None
|
|
assert "other.com" in config
|
|
|
|
def test_remove_nonexistent_domain(self, patch_config_paths):
|
|
"""Removing non-existent domain is a no-op."""
|
|
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
|
|
|
remove_domain_from_config("other.com") # Non-existent
|
|
|
|
config = load_servers_config()
|
|
assert "example.com" in config
|
|
|
|
|
|
class TestLoadCertsConfig:
|
|
"""Tests for load_certs_config function (SQLite-backed)."""
|
|
|
|
def test_load_empty(self, patch_config_paths):
|
|
"""Empty database returns empty list."""
|
|
domains = load_certs_config()
|
|
assert domains == []
|
|
|
|
def test_load_with_certs(self, patch_config_paths):
|
|
"""Load certs from database."""
|
|
add_cert_to_config("example.com")
|
|
add_cert_to_config("other.com")
|
|
|
|
domains = load_certs_config()
|
|
|
|
assert "example.com" in domains
|
|
assert "other.com" in domains
|
|
|
|
|
|
class TestAddCertToConfig:
|
|
"""Tests for add_cert_to_config function."""
|
|
|
|
def test_add_new_cert(self, patch_config_paths):
|
|
"""Add new certificate domain."""
|
|
add_cert_to_config("example.com")
|
|
|
|
domains = load_certs_config()
|
|
assert "example.com" in domains
|
|
|
|
def test_add_duplicate_cert(self, patch_config_paths):
|
|
"""Adding duplicate cert is a no-op."""
|
|
add_cert_to_config("example.com")
|
|
add_cert_to_config("example.com")
|
|
|
|
domains = load_certs_config()
|
|
assert domains.count("example.com") == 1
|
|
|
|
|
|
class TestRemoveCertFromConfig:
|
|
"""Tests for remove_cert_from_config function."""
|
|
|
|
def test_remove_existing_cert(self, patch_config_paths):
|
|
"""Remove existing certificate domain."""
|
|
add_cert_to_config("example.com")
|
|
add_cert_to_config("other.com")
|
|
|
|
remove_cert_from_config("example.com")
|
|
|
|
domains = load_certs_config()
|
|
assert "example.com" not in domains
|
|
assert "other.com" in domains
|
|
|
|
def test_remove_nonexistent_cert(self, patch_config_paths):
|
|
"""Removing non-existent cert is a no-op."""
|
|
add_cert_to_config("example.com")
|
|
|
|
remove_cert_from_config("other.com") # Non-existent
|
|
|
|
domains = load_certs_config()
|
|
assert "example.com" in domains
|
|
|
|
|
|
class TestAddDomainToMap:
|
|
"""Tests for add_domain_to_map function."""
|
|
|
|
def test_add_domain(self, patch_config_paths):
|
|
"""Add a domain and verify map files are synced."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
|
|
assert get_domain_backend("example.com") == "pool_1"
|
|
|
|
with open(patch_config_paths["map_file"]) as f:
|
|
assert "example.com pool_1" in f.read()
|
|
|
|
def test_add_wildcard(self, patch_config_paths):
|
|
"""Add a wildcard domain."""
|
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
|
|
|
entries = get_map_contents()
|
|
assert (".example.com", "pool_1") in entries
|
|
|
|
|
|
class TestRemoveDomainFromMap:
|
|
"""Tests for remove_domain_from_map function."""
|
|
|
|
def test_remove_domain(self, patch_config_paths):
|
|
"""Remove a domain and its wildcard."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
|
|
|
remove_domain_from_map("example.com")
|
|
|
|
assert get_domain_backend("example.com") is None
|
|
entries = get_map_contents()
|
|
assert (".example.com", "pool_1") not in entries
|
|
|
|
|
|
class TestFindAvailablePool:
|
|
"""Tests for find_available_pool function."""
|
|
|
|
def test_first_pool_available(self, patch_config_paths):
|
|
"""When no domains exist, pool_1 is returned."""
|
|
pool = find_available_pool()
|
|
assert pool == "pool_1"
|
|
|
|
def test_skip_used_pools(self, patch_config_paths):
|
|
"""Used pools are skipped."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map("other.com", "pool_2")
|
|
|
|
pool = find_available_pool()
|
|
assert pool == "pool_3"
|
|
|
|
|
|
class TestSharedDomains:
|
|
"""Tests for shared domain functions."""
|
|
|
|
def test_get_shared_domain(self, patch_config_paths):
|
|
"""Get parent domain for shared domain."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map("www.example.com", "pool_1")
|
|
add_shared_domain_to_config("www.example.com", "example.com")
|
|
|
|
assert get_shared_domain("www.example.com") == "example.com"
|
|
|
|
def test_is_shared_domain(self, patch_config_paths):
|
|
"""Check if domain is shared."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map("www.example.com", "pool_1")
|
|
add_shared_domain_to_config("www.example.com", "example.com")
|
|
|
|
assert is_shared_domain("www.example.com") is True
|
|
assert is_shared_domain("example.com") is False
|
|
|
|
def test_get_domains_sharing_pool(self, patch_config_paths):
|
|
"""Get all domains using a pool."""
|
|
add_domain_to_map("example.com", "pool_1")
|
|
add_domain_to_map("www.example.com", "pool_1")
|
|
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
|
|
|
domains = get_domains_sharing_pool("pool_1")
|
|
assert "example.com" in domains
|
|
assert "www.example.com" in domains
|
|
assert ".example.com" not in domains # Wildcards excluded
|