Files
haproxy-mcp/tests/unit/test_file_ops.py
kaffa 6bcfee519c refactor: Improve code quality, error handling, and test coverage
- Add file_lock context manager to eliminate duplicate locking patterns
- Add ValidationError, ConfigurationError, CertificateError exceptions
- Improve rollback logic in haproxy_add_servers (track successful ops only)
- Decompose haproxy_add_domain into smaller helper functions
- Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py
- Enhance docstrings for internal functions and magic numbers
- Add pytest framework with 48 new tests (269 -> 317 total)
- Increase test coverage from 76% to 86%
  - servers.py: 58% -> 82%
  - certificates.py: 67% -> 86%
  - configuration.py: 69% -> 94%

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 12:50:00 +09:00

499 lines
15 KiB
Python

"""Unit tests for file_ops module."""
import json
import os
from unittest.mock import patch
import pytest
from haproxy_mcp.file_ops import (
atomic_write_file,
get_map_contents,
save_map_file,
get_domain_backend,
split_domain_entries,
is_legacy_backend,
get_legacy_backend_name,
get_backend_and_prefix,
load_servers_config,
save_servers_config,
add_server_to_config,
remove_server_from_config,
remove_domain_from_config,
load_certs_config,
save_certs_config,
add_cert_to_config,
remove_cert_from_config,
)
class TestAtomicWriteFile:
"""Tests for atomic_write_file function."""
def test_write_new_file(self, tmp_path):
"""Write to a new file."""
file_path = str(tmp_path / "test.txt")
content = "Hello, World!"
atomic_write_file(file_path, content)
assert os.path.exists(file_path)
with open(file_path) as f:
assert f.read() == content
def test_overwrite_existing_file(self, tmp_path):
"""Overwrite an existing file."""
file_path = str(tmp_path / "test.txt")
with open(file_path, "w") as f:
f.write("Old content")
atomic_write_file(file_path, "New content")
with open(file_path) as f:
assert f.read() == "New content"
def test_preserves_directory(self, tmp_path):
"""Writing does not create intermediate directories."""
file_path = str(tmp_path / "subdir" / "test.txt")
with pytest.raises(IOError):
atomic_write_file(file_path, "content")
def test_unicode_content(self, tmp_path):
"""Unicode content is properly written."""
file_path = str(tmp_path / "unicode.txt")
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese
atomic_write_file(file_path, content)
with open(file_path, encoding="utf-8") as f:
assert f.read() == content
def test_multiline_content(self, tmp_path):
"""Multi-line content is properly written."""
file_path = str(tmp_path / "multiline.txt")
content = "line1\nline2\nline3"
atomic_write_file(file_path, content)
with open(file_path) as f:
assert f.read() == content
class TestGetMapContents:
"""Tests for get_map_contents function."""
def test_read_map_file(self, patch_config_paths):
"""Read entries from map file."""
# Write test content to map file
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
f.write("api.example.com pool_2\n")
entries = get_map_contents()
assert ("example.com", "pool_1") in entries
assert ("api.example.com", "pool_2") in entries
def test_read_both_map_files(self, patch_config_paths):
"""Read entries from both domains.map and wildcards.map."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
with open(patch_config_paths["wildcards_file"], "w") as f:
f.write(".example.com pool_1\n")
entries = get_map_contents()
assert ("example.com", "pool_1") in entries
assert (".example.com", "pool_1") in entries
def test_skip_comments(self, patch_config_paths):
"""Comments are skipped."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("# This is a comment\n")
f.write("example.com pool_1\n")
f.write("# Another comment\n")
entries = get_map_contents()
assert len(entries) == 1
assert entries[0] == ("example.com", "pool_1")
def test_skip_empty_lines(self, patch_config_paths):
"""Empty lines are skipped."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("\n")
f.write("example.com pool_1\n")
f.write("\n")
f.write("api.example.com pool_2\n")
entries = get_map_contents()
assert len(entries) == 2
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty list."""
os.unlink(patch_config_paths["map_file"])
os.unlink(patch_config_paths["wildcards_file"])
entries = get_map_contents()
assert entries == []
class TestSplitDomainEntries:
"""Tests for split_domain_entries function."""
def test_split_entries(self):
"""Split entries into exact and wildcard."""
entries = [
("example.com", "pool_1"),
(".example.com", "pool_1"),
("api.example.com", "pool_2"),
(".api.example.com", "pool_2"),
]
exact, wildcards = split_domain_entries(entries)
assert len(exact) == 2
assert len(wildcards) == 2
assert ("example.com", "pool_1") in exact
assert (".example.com", "pool_1") in wildcards
def test_empty_entries(self):
"""Empty entries returns empty lists."""
exact, wildcards = split_domain_entries([])
assert exact == []
assert wildcards == []
def test_all_exact(self):
"""All exact entries."""
entries = [
("example.com", "pool_1"),
("api.example.com", "pool_2"),
]
exact, wildcards = split_domain_entries(entries)
assert len(exact) == 2
assert len(wildcards) == 0
class TestSaveMapFile:
"""Tests for save_map_file function."""
def test_save_entries(self, patch_config_paths):
"""Save entries to separate map files."""
entries = [
("example.com", "pool_1"),
(".example.com", "pool_1"),
]
save_map_file(entries)
# Check exact domains file
with open(patch_config_paths["map_file"]) as f:
content = f.read()
assert "example.com pool_1" in content
# Check wildcards file
with open(patch_config_paths["wildcards_file"]) as f:
content = f.read()
assert ".example.com pool_1" in content
def test_sorted_output(self, patch_config_paths):
"""Entries are sorted in output."""
entries = [
("z.example.com", "pool_3"),
("a.example.com", "pool_1"),
("m.example.com", "pool_2"),
]
save_map_file(entries)
with open(patch_config_paths["map_file"]) as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
assert lines[0] == "a.example.com pool_1"
assert lines[1] == "m.example.com pool_2"
assert lines[2] == "z.example.com pool_3"
class TestGetDomainBackend:
"""Tests for get_domain_backend function."""
def test_find_existing_domain(self, patch_config_paths):
"""Find backend for existing domain."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
backend = get_domain_backend("example.com")
assert backend == "pool_1"
def test_domain_not_found(self, patch_config_paths):
"""Non-existent domain returns None."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_1\n")
backend = get_domain_backend("other.com")
assert backend is None
class TestIsLegacyBackend:
"""Tests for is_legacy_backend function."""
def test_pool_backend(self):
"""Pool backend is not legacy."""
assert is_legacy_backend("pool_1") is False
assert is_legacy_backend("pool_100") is False
def test_legacy_backend(self):
"""Non-pool backend is legacy."""
assert is_legacy_backend("api_example_com_backend") is True
assert is_legacy_backend("static_backend") is True
class TestGetLegacyBackendName:
"""Tests for get_legacy_backend_name function."""
def test_convert_domain(self):
"""Convert domain to legacy backend name."""
result = get_legacy_backend_name("api.example.com")
assert result == "api_example_com_backend"
class TestGetBackendAndPrefix:
"""Tests for get_backend_and_prefix function."""
def test_pool_backend(self, patch_config_paths):
"""Pool backend returns pool-based prefix."""
with open(patch_config_paths["map_file"], "w") as f:
f.write("example.com pool_5\n")
backend, prefix = get_backend_and_prefix("example.com")
assert backend == "pool_5"
assert prefix == "pool_5"
def test_unknown_domain_uses_legacy(self, patch_config_paths):
"""Unknown domain uses legacy backend naming."""
backend, prefix = get_backend_and_prefix("unknown.com")
assert backend == "unknown_com_backend"
assert prefix == "unknown_com"
class TestLoadServersConfig:
"""Tests for load_servers_config function."""
def test_load_existing_config(self, patch_config_paths, sample_servers_config):
"""Load existing config file."""
with open(patch_config_paths["servers_file"], "w") as f:
json.dump(sample_servers_config, f)
config = load_servers_config()
assert "example.com" in config
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty dict."""
os.unlink(patch_config_paths["servers_file"])
config = load_servers_config()
assert config == {}
def test_invalid_json(self, patch_config_paths):
"""Invalid JSON returns empty dict."""
with open(patch_config_paths["servers_file"], "w") as f:
f.write("not valid json {{{")
config = load_servers_config()
assert config == {}
class TestSaveServersConfig:
"""Tests for save_servers_config function."""
def test_save_config(self, patch_config_paths):
"""Save config to file."""
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
save_servers_config(config)
with open(patch_config_paths["servers_file"]) as f:
loaded = json.load(f)
assert loaded == config
class TestAddServerToConfig:
"""Tests for add_server_to_config function."""
def test_add_to_empty_config(self, patch_config_paths):
"""Add server to empty config."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
config = load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
assert config["example.com"]["1"]["http_port"] == 80
def test_add_to_existing_domain(self, patch_config_paths):
"""Add server to domain with existing servers."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
config = load_servers_config()
assert "1" in config["example.com"]
assert "2" in config["example.com"]
def test_overwrite_existing_slot(self, patch_config_paths):
"""Overwrite existing slot."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 1, "10.0.0.99", 8080)
config = load_servers_config()
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
assert config["example.com"]["1"]["http_port"] == 8080
class TestRemoveServerFromConfig:
"""Tests for remove_server_from_config function."""
def test_remove_existing_server(self, patch_config_paths):
"""Remove existing server."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("example.com", 2, "10.0.0.2", 80)
remove_server_from_config("example.com", 1)
config = load_servers_config()
assert "1" not in config["example.com"]
assert "2" in config["example.com"]
def test_remove_last_server_removes_domain(self, patch_config_paths):
"""Removing last server removes domain entry."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_server_from_config("example.com", 1)
config = load_servers_config()
assert "example.com" not in config
def test_remove_nonexistent_server(self, patch_config_paths):
"""Removing non-existent server is a no-op."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_server_from_config("example.com", 99) # Non-existent slot
config = load_servers_config()
assert "1" in config["example.com"]
class TestRemoveDomainFromConfig:
"""Tests for remove_domain_from_config function."""
def test_remove_existing_domain(self, patch_config_paths):
"""Remove existing domain."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
add_server_to_config("other.com", 1, "10.0.0.2", 80)
remove_domain_from_config("example.com")
config = load_servers_config()
assert "example.com" not in config
assert "other.com" in config
def test_remove_nonexistent_domain(self, patch_config_paths):
"""Removing non-existent domain is a no-op."""
add_server_to_config("example.com", 1, "10.0.0.1", 80)
remove_domain_from_config("other.com") # Non-existent
config = load_servers_config()
assert "example.com" in config
class TestLoadCertsConfig:
"""Tests for load_certs_config function."""
def test_load_existing_config(self, patch_config_paths):
"""Load existing certs config."""
with open(patch_config_paths["certs_file"], "w") as f:
json.dump({"domains": ["example.com", "other.com"]}, f)
domains = load_certs_config()
assert "example.com" in domains
assert "other.com" in domains
def test_file_not_found(self, patch_config_paths):
"""Missing file returns empty list."""
os.unlink(patch_config_paths["certs_file"])
domains = load_certs_config()
assert domains == []
class TestSaveCertsConfig:
"""Tests for save_certs_config function."""
def test_save_domains(self, patch_config_paths):
"""Save domains to certs config."""
save_certs_config(["z.com", "a.com"])
with open(patch_config_paths["certs_file"]) as f:
data = json.load(f)
# Should be sorted
assert data["domains"] == ["a.com", "z.com"]
class TestAddCertToConfig:
"""Tests for add_cert_to_config function."""
def test_add_new_cert(self, patch_config_paths):
"""Add new certificate domain."""
add_cert_to_config("example.com")
domains = load_certs_config()
assert "example.com" in domains
def test_add_duplicate_cert(self, patch_config_paths):
"""Adding duplicate cert is a no-op."""
add_cert_to_config("example.com")
add_cert_to_config("example.com")
domains = load_certs_config()
assert domains.count("example.com") == 1
class TestRemoveCertFromConfig:
"""Tests for remove_cert_from_config function."""
def test_remove_existing_cert(self, patch_config_paths):
"""Remove existing certificate domain."""
add_cert_to_config("example.com")
add_cert_to_config("other.com")
remove_cert_from_config("example.com")
domains = load_certs_config()
assert "example.com" not in domains
assert "other.com" in domains
def test_remove_nonexistent_cert(self, patch_config_paths):
"""Removing non-existent cert is a no-op."""
add_cert_to_config("example.com")
remove_cert_from_config("other.com") # Non-existent
domains = load_certs_config()
assert "example.com" in domains