Files
haproxy-mcp/tests/unit/test_file_ops.py
kappa cf554f3f89 refactor: migrate data storage from JSON/map files to SQLite
Replace servers.json, certificates.json, and map file parsing with
SQLite (WAL mode) as single source of truth. HAProxy map files are
now generated from SQLite via sync_map_files().

Key changes:
- Add db.py with schema, connection management, and JSON migration
- Add DB_FILE config constant
- Delegate file_ops.py functions to db.py
- Refactor domains.py to use file_ops instead of direct list manipulation
- Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError)
- Add DB health check in health.py
- Init DB on startup in server.py and __main__.py
- Update all 359 tests to use SQLite-backed functions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 11:07:29 +09:00

514 lines
17 KiB
Python

"""Unit tests for file_ops module (SQLite-backed)."""
import os
from unittest.mock import patch
import pytest
from haproxy_mcp.file_ops import (
atomic_write_file,
get_map_contents,
save_map_file,
get_domain_backend,
split_domain_entries,
is_legacy_backend,
get_legacy_backend_name,
get_backend_and_prefix,
load_servers_config,
add_server_to_config,
remove_server_from_config,
remove_domain_from_config,
load_certs_config,
add_cert_to_config,
remove_cert_from_config,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
add_shared_domain_to_config,
get_shared_domain,
is_shared_domain,
get_domains_sharing_pool,
)
class TestAtomicWriteFile:
"""Tests for atomic_write_file function."""
def test_write_new_file(self, tmp_path):
"""Write to a new file."""
file_path = str(tmp_path / "test.txt")
content = "Hello, World!"
atomic_write_file(file_path, content)
assert os.path.exists(file_path)
with open(file_path) as f:
assert f.read() == content
def test_overwrite_existing_file(self, tmp_path):
"""Overwrite an existing file."""
file_path = str(tmp_path / "test.txt")
with open(file_path, "w") as f:
f.write("Old content")
atomic_write_file(file_path, "New content")
with open(file_path) as f:
assert f.read() == "New content"
def test_preserves_directory(self, tmp_path):
"""Writing does not create intermediate directories."""
file_path = str(tmp_path / "subdir" / "test.txt")
with pytest.raises(IOError):
atomic_write_file(file_path, "content")
def test_unicode_content(self, tmp_path):
"""Unicode content is properly written."""
file_path = str(tmp_path / "unicode.txt")
content = "Hello, \u4e16\u754c!"
atomic_write_file(file_path, content)
with open(file_path, encoding="utf-8") as f:
assert f.read() == content
def test_multiline_content(self, tmp_path):
"""Multi-line content is properly written."""
file_path = str(tmp_path / "multiline.txt")
content = "line1\nline2\nline3"
atomic_write_file(file_path, content)
with open(file_path) as f:
assert f.read() == content
class TestGetMapContents:
"""Tests for get_map_contents function (SQLite-backed)."""
def test_empty_db(self, patch_config_paths):
"""Empty database returns empty list."""
entries = get_map_contents()
assert entries == []
def test_read_domains(self, patch_config_paths):
"""Read entries from database."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("api.example.com", "pool_2")
entries = get_map_contents()
assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries
def test_read_with_wildcards(self, patch_config_paths):
"""Read entries including wildcards."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
entries = get_map_contents()
assert ("example.com", "pool_1") in entries
assert (".example.com", "pool_1") in entries
class TestSplitDomainEntries:
"""Tests for split_domain_entries function."""
def test_split_entries(self):
"""Split entries into exact and wildcard."""
entries = [
("example.com", "pool_1"),
(".example.com", "pool_1"),
("api.example.com", "pool_2"),
(".api.example.com", "pool_2"),
]
exact, wildcards = split_domain_entries(entries)
assert len(exact) == 2
assert len(wildcards) == 2
assert ("example.com", "pool_1") in exact
assert (".example.com", "pool_1") in wildcards
def test_empty_entries(self):
"""Empty entries returns empty lists."""
exact, wildcards = split_domain_entries([])
assert exact == []
assert wildcards == []
def test_all_exact(self):
"""All exact entries."""
entries = [
("example.com", "pool_1"),
("api.example.com", "pool_2"),
]
exact, wildcards = split_domain_entries(entries)
assert len(exact) == 2
assert len(wildcards) == 0
class TestSaveMapFile:
"""Tests for save_map_file function (syncs from DB to map files)."""
def test_save_entries(self, patch_config_paths):
"""Save entries to separate map files."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
save_map_file([]) # Entries param ignored, syncs from DB
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "example.com pool_1" in content
with open(patch_config_paths["wildcards_file"]) as f:
content = f.read()
assert ".example.com pool_1" in content
def test_sorted_output(self, patch_config_paths):
"""Entries are sorted in output."""
add_domain_to_map("z.example.com", "pool_3")
add_domain_to_map("a.example.com", "pool_1")
add_domain_to_map("m.example.com", "pool_2")
save_map_file([])
with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
assert lines[0] == "a.example.com pool_1"
assert lines[1] == "m.example.com pool_2"
assert lines[2] == "z.example.com pool_3"
class TestGetDomainBackend:
"""Tests for get_domain_backend function (SQLite-backed)."""
def test_find_existing_domain(self, patch_config_paths):
"""Find backend for existing domain."""
add_domain_to_map("example.com", "pool_1")
backend = get_domain_backend("example.com")
assert backend == "pool_1"
def test_domain_not_found(self, patch_config_paths):
"""Non-existent domain returns None."""
add_domain_to_map("example.com", "pool_1")
backend = get_domain_backend("other.com")
assert backend is None
class TestIsLegacyBackend:
"""Tests for is_legacy_backend function."""
def test_pool_backend(self):
"""Pool backend is not legacy."""
assert is_legacy_backend("pool_1") is False
assert is_legacy_backend("pool_100") is False
def test_legacy_backend(self):
"""Non-pool backend is legacy."""
assert is_legacy_backend("api_example_com_backend") is True
assert is_legacy_backend("static_backend") is True
class TestGetLegacyBackendName:
"""Tests for get_legacy_backend_name function."""
def test_convert_domain(self):
"""Convert domain to legacy backend name."""
result = get_legacy_backend_name("api.example.com")
assert result == "api_example_com_backend"
class TestGetBackendAndPrefix:
"""Tests for get_backend_and_prefix function."""
def test_pool_backend(self, patch_config_paths):
"""Pool backend returns pool-based prefix."""
add_domain_to_map("example.com", "pool_5")
backend, prefix = get_backend_and_prefix("example.com")
assert backend == "pool_5"
assert prefix == "pool_5"
def test_unknown_domain_uses_legacy(self, patch_config_paths):
"""Unknown domain uses legacy backend naming."""
backend, prefix = get_backend_and_prefix("unknown.com")
assert backend == "unknown_com_backend"
assert prefix == "unknown_com"
class TestLoadServersConfig:
"""Tests for load_servers_config function (SQLite-backed)."""
def test_load_empty_config(self, patch_config_paths):
"""Empty database returns empty dict."""
config = load_servers_config()
assert config == {}
def test_load_with_servers(self, patch_config_paths):
"""Load config with server entries."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
config = load_servers_config()
assert "example.com" in config
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["2"]["ip"] == "10.0.0.2"
def test_load_with_shared_domain(self, patch_config_paths):
"""Load config with shared domain reference."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
config = load_servers_config()
assert config["www.example.com"]["_shares"] == "example.com"
class TestAddServerToConfig:
"""Tests for add_server_to_config function."""
def test_add_to_empty_config(self, patch_config_paths):
"""Add server to empty config."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
config = load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["1"]["http_port"] == 80
def test_add_to_existing_domain(self, patch_config_paths):
"""Add server to domain with existing servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
config = load_servers_config()
assert "1" in config["example.com"]
assert "2" in config["example.com"]
def test_overwrite_existing_slot(self, patch_config_paths):
"""Overwrite existing slot."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 1, "10.0.0.99", 8080)
config = load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
assert config["example.com"]["1"]["http_port"] == 8080
class TestRemoveServerFromConfig:
"""Tests for remove_server_from_config function."""
def test_remove_existing_server(self, patch_config_paths):
"""Remove existing server."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
remove_server_from_config("example.com", 1)
config = load_servers_config()
assert "1" not in config.get("example.com", {})
assert "2" in config["example.com"]
def test_remove_last_server_removes_domain(self, patch_config_paths):
"""Removing last server removes domain entry from servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_server_from_config("example.com", 1)
config = load_servers_config()
# Domain may or may not exist (no servers = no entry)
assert config.get("example.com", {}).get("1") is None
def test_remove_nonexistent_server(self, patch_config_paths):
"""Removing non-existent server is a no-op."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_server_from_config("example.com", 99) # Non-existent slot
config = load_servers_config()
assert "1" in config["example.com"]
class TestRemoveDomainFromConfig:
"""Tests for remove_domain_from_config function."""
def test_remove_existing_domain(self, patch_config_paths):
"""Remove existing domain's servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("other.com", 1, "10.0.0.2", 80)
remove_domain_from_config("example.com")
config = load_servers_config()
assert config.get("example.com", {}).get("1") is None
assert "other.com" in config
def test_remove_nonexistent_domain(self, patch_config_paths):
"""Removing non-existent domain is a no-op."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_domain_from_config("other.com") # Non-existent
config = load_servers_config()
assert "example.com" in config
class TestLoadCertsConfig:
"""Tests for load_certs_config function (SQLite-backed)."""
def test_load_empty(self, patch_config_paths):
"""Empty database returns empty list."""
domains = load_certs_config()
assert domains == []
def test_load_with_certs(self, patch_config_paths):
"""Load certs from database."""
add_cert_to_config("example.com")
add_cert_to_config("other.com")
domains = load_certs_config()
assert "example.com" in domains
assert "other.com" in domains
class TestAddCertToConfig:
"""Tests for add_cert_to_config function."""
def test_add_new_cert(self, patch_config_paths):
"""Add new certificate domain."""
add_cert_to_config("example.com")
domains = load_certs_config()
assert "example.com" in domains
def test_add_duplicate_cert(self, patch_config_paths):
"""Adding duplicate cert is a no-op."""
add_cert_to_config("example.com")
add_cert_to_config("example.com")
domains = load_certs_config()
assert domains.count("example.com") == 1
class TestRemoveCertFromConfig:
"""Tests for remove_cert_from_config function."""
def test_remove_existing_cert(self, patch_config_paths):
"""Remove existing certificate domain."""
add_cert_to_config("example.com")
add_cert_to_config("other.com")
remove_cert_from_config("example.com")
domains = load_certs_config()
assert "example.com" not in domains
assert "other.com" in domains
def test_remove_nonexistent_cert(self, patch_config_paths):
"""Removing non-existent cert is a no-op."""
add_cert_to_config("example.com")
remove_cert_from_config("other.com") # Non-existent
domains = load_certs_config()
assert "example.com" in domains
class TestAddDomainToMap:
"""Tests for add_domain_to_map function."""
def test_add_domain(self, patch_config_paths):
"""Add a domain and verify map files are synced."""
add_domain_to_map("example.com", "pool_1")
assert get_domain_backend("example.com") == "pool_1"
with open(patch_config_paths["map_file"]) as f:
assert "example.com pool_1" in f.read()
def test_add_wildcard(self, patch_config_paths):
"""Add a wildcard domain."""
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
entries = get_map_contents()
assert (".example.com", "pool_1") in entries
class TestRemoveDomainFromMap:
"""Tests for remove_domain_from_map function."""
def test_remove_domain(self, patch_config_paths):
"""Remove a domain and its wildcard."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
remove_domain_from_map("example.com")
assert get_domain_backend("example.com") is None
entries = get_map_contents()
assert (".example.com", "pool_1") not in entries
class TestFindAvailablePool:
"""Tests for find_available_pool function."""
def test_first_pool_available(self, patch_config_paths):
"""When no domains exist, pool_1 is returned."""
pool = find_available_pool()
assert pool == "pool_1"
def test_skip_used_pools(self, patch_config_paths):
"""Used pools are skipped."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("other.com", "pool_2")
pool = find_available_pool()
assert pool == "pool_3"
class TestSharedDomains:
"""Tests for shared domain functions."""
def test_get_shared_domain(self, patch_config_paths):
"""Get parent domain for shared domain."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
assert get_shared_domain("www.example.com") == "example.com"
def test_is_shared_domain(self, patch_config_paths):
"""Check if domain is shared."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_shared_domain_to_config("www.example.com", "example.com")
assert is_shared_domain("www.example.com") is True
assert is_shared_domain("example.com") is False
def test_get_domains_sharing_pool(self, patch_config_paths):
"""Get all domains using a pool."""
add_domain_to_map("example.com", "pool_1")
add_domain_to_map("www.example.com", "pool_1")
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
domains = get_domains_sharing_pool("pool_1")
assert "example.com" in domains
assert "www.example.com" in domains
assert ".example.com" not in domains # Wildcards excluded