refactor: Improve code quality, error handling, and test coverage
- Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
498
tests/unit/test_file_ops.py
Normal file
498
tests/unit/test_file_ops.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""Unit tests for file_ops module."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.file_ops import (
|
||||
atomic_write_file,
|
||||
get_map_contents,
|
||||
save_map_file,
|
||||
get_domain_backend,
|
||||
split_domain_entries,
|
||||
is_legacy_backend,
|
||||
get_legacy_backend_name,
|
||||
get_backend_and_prefix,
|
||||
load_servers_config,
|
||||
save_servers_config,
|
||||
add_server_to_config,
|
||||
remove_server_from_config,
|
||||
remove_domain_from_config,
|
||||
load_certs_config,
|
||||
save_certs_config,
|
||||
add_cert_to_config,
|
||||
remove_cert_from_config,
|
||||
)
|
||||
|
||||
|
||||
class TestAtomicWriteFile:
|
||||
"""Tests for atomic_write_file function."""
|
||||
|
||||
def test_write_new_file(self, tmp_path):
|
||||
"""Write to a new file."""
|
||||
file_path = str(tmp_path / "test.txt")
|
||||
content = "Hello, World!"
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
assert os.path.exists(file_path)
|
||||
with open(file_path) as f:
|
||||
assert f.read() == content
|
||||
|
||||
def test_overwrite_existing_file(self, tmp_path):
|
||||
"""Overwrite an existing file."""
|
||||
file_path = str(tmp_path / "test.txt")
|
||||
with open(file_path, "w") as f:
|
||||
f.write("Old content")
|
||||
|
||||
atomic_write_file(file_path, "New content")
|
||||
|
||||
with open(file_path) as f:
|
||||
assert f.read() == "New content"
|
||||
|
||||
def test_preserves_directory(self, tmp_path):
|
||||
"""Writing does not create intermediate directories."""
|
||||
file_path = str(tmp_path / "subdir" / "test.txt")
|
||||
|
||||
with pytest.raises(IOError):
|
||||
atomic_write_file(file_path, "content")
|
||||
|
||||
def test_unicode_content(self, tmp_path):
|
||||
"""Unicode content is properly written."""
|
||||
file_path = str(tmp_path / "unicode.txt")
|
||||
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
assert f.read() == content
|
||||
|
||||
def test_multiline_content(self, tmp_path):
|
||||
"""Multi-line content is properly written."""
|
||||
file_path = str(tmp_path / "multiline.txt")
|
||||
content = "line1\nline2\nline3"
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
with open(file_path) as f:
|
||||
assert f.read() == content
|
||||
|
||||
|
||||
class TestGetMapContents:
|
||||
"""Tests for get_map_contents function."""
|
||||
|
||||
def test_read_map_file(self, patch_config_paths):
|
||||
"""Read entries from map file."""
|
||||
# Write test content to map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert ("api.example.com", "pool_2") in entries
|
||||
|
||||
def test_read_both_map_files(self, patch_config_paths):
|
||||
"""Read entries from both domains.map and wildcards.map."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert (".example.com", "pool_1") in entries
|
||||
|
||||
def test_skip_comments(self, patch_config_paths):
|
||||
"""Comments are skipped."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("# This is a comment\n")
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("# Another comment\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0] == ("example.com", "pool_1")
|
||||
|
||||
def test_skip_empty_lines(self, patch_config_paths):
|
||||
"""Empty lines are skipped."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("\n")
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert len(entries) == 2
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty list."""
|
||||
os.unlink(patch_config_paths["map_file"])
|
||||
os.unlink(patch_config_paths["wildcards_file"])
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert entries == []
|
||||
|
||||
|
||||
class TestSplitDomainEntries:
|
||||
"""Tests for split_domain_entries function."""
|
||||
|
||||
def test_split_entries(self):
|
||||
"""Split entries into exact and wildcard."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
("api.example.com", "pool_2"),
|
||||
(".api.example.com", "pool_2"),
|
||||
]
|
||||
|
||||
exact, wildcards = split_domain_entries(entries)
|
||||
|
||||
assert len(exact) == 2
|
||||
assert len(wildcards) == 2
|
||||
assert ("example.com", "pool_1") in exact
|
||||
assert (".example.com", "pool_1") in wildcards
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""Empty entries returns empty lists."""
|
||||
exact, wildcards = split_domain_entries([])
|
||||
|
||||
assert exact == []
|
||||
assert wildcards == []
|
||||
|
||||
def test_all_exact(self):
|
||||
"""All exact entries."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
("api.example.com", "pool_2"),
|
||||
]
|
||||
|
||||
exact, wildcards = split_domain_entries(entries)
|
||||
|
||||
assert len(exact) == 2
|
||||
assert len(wildcards) == 0
|
||||
|
||||
|
||||
class TestSaveMapFile:
|
||||
"""Tests for save_map_file function."""
|
||||
|
||||
def test_save_entries(self, patch_config_paths):
|
||||
"""Save entries to separate map files."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
]
|
||||
|
||||
save_map_file(entries)
|
||||
|
||||
# Check exact domains file
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "example.com pool_1" in content
|
||||
|
||||
# Check wildcards file
|
||||
with open(patch_config_paths["wildcards_file"]) as f:
|
||||
content = f.read()
|
||||
assert ".example.com pool_1" in content
|
||||
|
||||
def test_sorted_output(self, patch_config_paths):
|
||||
"""Entries are sorted in output."""
|
||||
entries = [
|
||||
("z.example.com", "pool_3"),
|
||||
("a.example.com", "pool_1"),
|
||||
("m.example.com", "pool_2"),
|
||||
]
|
||||
|
||||
save_map_file(entries)
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||
|
||||
assert lines[0] == "a.example.com pool_1"
|
||||
assert lines[1] == "m.example.com pool_2"
|
||||
assert lines[2] == "z.example.com pool_3"
|
||||
|
||||
|
||||
class TestGetDomainBackend:
|
||||
"""Tests for get_domain_backend function."""
|
||||
|
||||
def test_find_existing_domain(self, patch_config_paths):
|
||||
"""Find backend for existing domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
backend = get_domain_backend("example.com")
|
||||
|
||||
assert backend == "pool_1"
|
||||
|
||||
def test_domain_not_found(self, patch_config_paths):
|
||||
"""Non-existent domain returns None."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
backend = get_domain_backend("other.com")
|
||||
|
||||
assert backend is None
|
||||
|
||||
|
||||
class TestIsLegacyBackend:
|
||||
"""Tests for is_legacy_backend function."""
|
||||
|
||||
def test_pool_backend(self):
|
||||
"""Pool backend is not legacy."""
|
||||
assert is_legacy_backend("pool_1") is False
|
||||
assert is_legacy_backend("pool_100") is False
|
||||
|
||||
def test_legacy_backend(self):
|
||||
"""Non-pool backend is legacy."""
|
||||
assert is_legacy_backend("api_example_com_backend") is True
|
||||
assert is_legacy_backend("static_backend") is True
|
||||
|
||||
|
||||
class TestGetLegacyBackendName:
|
||||
"""Tests for get_legacy_backend_name function."""
|
||||
|
||||
def test_convert_domain(self):
|
||||
"""Convert domain to legacy backend name."""
|
||||
result = get_legacy_backend_name("api.example.com")
|
||||
assert result == "api_example_com_backend"
|
||||
|
||||
|
||||
class TestGetBackendAndPrefix:
|
||||
"""Tests for get_backend_and_prefix function."""
|
||||
|
||||
def test_pool_backend(self, patch_config_paths):
|
||||
"""Pool backend returns pool-based prefix."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_5\n")
|
||||
|
||||
backend, prefix = get_backend_and_prefix("example.com")
|
||||
|
||||
assert backend == "pool_5"
|
||||
assert prefix == "pool_5"
|
||||
|
||||
def test_unknown_domain_uses_legacy(self, patch_config_paths):
|
||||
"""Unknown domain uses legacy backend naming."""
|
||||
backend, prefix = get_backend_and_prefix("unknown.com")
|
||||
|
||||
assert backend == "unknown_com_backend"
|
||||
assert prefix == "unknown_com"
|
||||
|
||||
|
||||
class TestLoadServersConfig:
|
||||
"""Tests for load_servers_config function."""
|
||||
|
||||
def test_load_existing_config(self, patch_config_paths, sample_servers_config):
|
||||
"""Load existing config file."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert "example.com" in config
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty dict."""
|
||||
os.unlink(patch_config_paths["servers_file"])
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
def test_invalid_json(self, patch_config_paths):
|
||||
"""Invalid JSON returns empty dict."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestSaveServersConfig:
|
||||
"""Tests for save_servers_config function."""
|
||||
|
||||
def test_save_config(self, patch_config_paths):
|
||||
"""Save config to file."""
|
||||
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
||||
|
||||
save_servers_config(config)
|
||||
|
||||
with open(patch_config_paths["servers_file"]) as f:
|
||||
loaded = json.load(f)
|
||||
assert loaded == config
|
||||
|
||||
|
||||
class TestAddServerToConfig:
|
||||
"""Tests for add_server_to_config function."""
|
||||
|
||||
def test_add_to_empty_config(self, patch_config_paths):
|
||||
"""Add server to empty config."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
config = load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert config["example.com"]["1"]["http_port"] == 80
|
||||
|
||||
def test_add_to_existing_domain(self, patch_config_paths):
|
||||
"""Add server to domain with existing servers."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" in config["example.com"]
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_overwrite_existing_slot(self, patch_config_paths):
|
||||
"""Overwrite existing slot."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 1, "10.0.0.99", 8080)
|
||||
|
||||
config = load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
|
||||
assert config["example.com"]["1"]["http_port"] == 8080
|
||||
|
||||
|
||||
class TestRemoveServerFromConfig:
|
||||
"""Tests for remove_server_from_config function."""
|
||||
|
||||
def test_remove_existing_server(self, patch_config_paths):
|
||||
"""Remove existing server."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
remove_server_from_config("example.com", 1)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" not in config["example.com"]
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_remove_last_server_removes_domain(self, patch_config_paths):
|
||||
"""Removing last server removes domain entry."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_server_from_config("example.com", 1)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" not in config
|
||||
|
||||
def test_remove_nonexistent_server(self, patch_config_paths):
|
||||
"""Removing non-existent server is a no-op."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_server_from_config("example.com", 99) # Non-existent slot
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" in config["example.com"]
|
||||
|
||||
|
||||
class TestRemoveDomainFromConfig:
|
||||
"""Tests for remove_domain_from_config function."""
|
||||
|
||||
def test_remove_existing_domain(self, patch_config_paths):
|
||||
"""Remove existing domain."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("other.com", 1, "10.0.0.2", 80)
|
||||
|
||||
remove_domain_from_config("example.com")
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" not in config
|
||||
assert "other.com" in config
|
||||
|
||||
def test_remove_nonexistent_domain(self, patch_config_paths):
|
||||
"""Removing non-existent domain is a no-op."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_domain_from_config("other.com") # Non-existent
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" in config
|
||||
|
||||
|
||||
class TestLoadCertsConfig:
|
||||
"""Tests for load_certs_config function."""
|
||||
|
||||
def test_load_existing_config(self, patch_config_paths):
|
||||
"""Load existing certs config."""
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "other.com"]}, f)
|
||||
|
||||
domains = load_certs_config()
|
||||
|
||||
assert "example.com" in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty list."""
|
||||
os.unlink(patch_config_paths["certs_file"])
|
||||
|
||||
domains = load_certs_config()
|
||||
|
||||
assert domains == []
|
||||
|
||||
|
||||
class TestSaveCertsConfig:
|
||||
"""Tests for save_certs_config function."""
|
||||
|
||||
def test_save_domains(self, patch_config_paths):
|
||||
"""Save domains to certs config."""
|
||||
save_certs_config(["z.com", "a.com"])
|
||||
|
||||
with open(patch_config_paths["certs_file"]) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Should be sorted
|
||||
assert data["domains"] == ["a.com", "z.com"]
|
||||
|
||||
|
||||
class TestAddCertToConfig:
|
||||
"""Tests for add_cert_to_config function."""
|
||||
|
||||
def test_add_new_cert(self, patch_config_paths):
|
||||
"""Add new certificate domain."""
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" in domains
|
||||
|
||||
def test_add_duplicate_cert(self, patch_config_paths):
|
||||
"""Adding duplicate cert is a no-op."""
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
assert domains.count("example.com") == 1
|
||||
|
||||
|
||||
class TestRemoveCertFromConfig:
|
||||
"""Tests for remove_cert_from_config function."""
|
||||
|
||||
def test_remove_existing_cert(self, patch_config_paths):
|
||||
"""Remove existing certificate domain."""
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("other.com")
|
||||
|
||||
remove_cert_from_config("example.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" not in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
def test_remove_nonexistent_cert(self, patch_config_paths):
|
||||
"""Removing non-existent cert is a no-op."""
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
remove_cert_from_config("other.com") # Non-existent
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" in domains
|
||||
Reference in New Issue
Block a user