refactor: Improve code quality, error handling, and test coverage
- Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
476
tests/unit/tools/test_domains.py
Normal file
476
tests/unit/tools/test_domains.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""Unit tests for domain management tools."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
|
||||
class TestHaproxyListDomains:
|
||||
"""Tests for haproxy_list_domains tool function."""
|
||||
|
||||
def test_list_empty_domains(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains when none configured."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
# Import here to get patched config
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=False)
|
||||
|
||||
assert result == "No domains configured"
|
||||
|
||||
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains with configured servers."""
|
||||
# Write map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=False)
|
||||
|
||||
assert "example.com" in result
|
||||
assert "pool_1" in result
|
||||
assert "10.0.0.1" in result
|
||||
|
||||
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains excluding wildcards by default."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=False)
|
||||
|
||||
assert "example.com" in result
|
||||
assert ".example.com" not in result
|
||||
|
||||
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains including wildcards when requested."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=True)
|
||||
|
||||
assert "example.com" in result
|
||||
assert ".example.com" in result
|
||||
|
||||
|
||||
class TestHaproxyAddDomain:
|
||||
"""Tests for haproxy_add_domain tool function."""
|
||||
|
||||
def test_add_domain_invalid_format(self, patch_config_paths):
|
||||
"""Reject invalid domain format."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="-invalid.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid domain" in result
|
||||
|
||||
def test_add_domain_invalid_ip(self, patch_config_paths):
|
||||
"""Reject invalid IP address."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="example.com",
|
||||
ip="not-an-ip",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid IP" in result
|
||||
|
||||
def test_add_domain_invalid_port(self, patch_config_paths):
|
||||
"""Reject invalid port."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="example.com",
|
||||
ip="10.0.0.1",
|
||||
http_port=70000
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "Port" in result
|
||||
|
||||
def test_add_domain_starts_with_dot(self, patch_config_paths):
|
||||
"""Reject domain starting with dot (wildcard)."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain=".example.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "cannot start with '.'" in result
|
||||
|
||||
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Reject adding domain that already exists."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="example.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "already exists" in result
|
||||
|
||||
def test_add_domain_success_without_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess):
|
||||
"""Successfully add domain without IP."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"add map": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="newdomain.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "newdomain.com" in result
|
||||
assert "pool_1" in result
|
||||
assert "no servers configured" in result
|
||||
|
||||
def test_add_domain_success_with_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess):
|
||||
"""Successfully add domain with IP."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"add map": "",
|
||||
"set server": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="newdomain.com",
|
||||
ip="10.0.0.1",
|
||||
http_port=8080
|
||||
)
|
||||
|
||||
assert "newdomain.com" in result
|
||||
assert "pool_1" in result
|
||||
assert "10.0.0.1:8080" in result
|
||||
|
||||
|
||||
class TestHaproxyRemoveDomain:
|
||||
"""Tests for haproxy_remove_domain tool function."""
|
||||
|
||||
def test_remove_domain_invalid_format(self, patch_config_paths):
|
||||
"""Reject invalid domain format."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="-invalid.com")
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid domain" in result
|
||||
|
||||
def test_remove_domain_not_found(self, patch_config_paths):
|
||||
"""Reject removing domain that doesn't exist."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="nonexistent.com")
|
||||
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
def test_remove_legacy_domain_rejected(self, patch_config_paths):
|
||||
"""Reject removing legacy (non-pool) domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com legacy_backend\n")
|
||||
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="example.com")
|
||||
|
||||
assert "Error" in result
|
||||
assert "legacy" in result.lower()
|
||||
|
||||
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully remove domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"del map": "",
|
||||
"set server": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="example.com")
|
||||
|
||||
assert "example.com" in result
|
||||
assert "removed" in result.lower()
|
||||
|
||||
|
||||
class TestCheckCertificateCoverage:
|
||||
"""Tests for check_certificate_coverage function."""
|
||||
|
||||
def test_no_cert_directory(self, tmp_path):
|
||||
"""No certificate coverage when directory doesn't exist."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(tmp_path / "nonexistent")):
|
||||
covered, info = check_certificate_coverage("example.com")
|
||||
|
||||
assert covered is False
|
||||
assert "not found" in info.lower()
|
||||
|
||||
def test_exact_cert_match(self, tmp_path):
|
||||
"""Exact certificate match."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
certs_dir = tmp_path / "certs"
|
||||
certs_dir.mkdir()
|
||||
(certs_dir / "example.com.pem").write_text("cert content")
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)):
|
||||
covered, info = check_certificate_coverage("example.com")
|
||||
|
||||
assert covered is True
|
||||
assert info == "example.com"
|
||||
|
||||
def test_wildcard_cert_coverage(self, tmp_path, mock_subprocess):
|
||||
"""Wildcard certificate covers subdomain."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
certs_dir = tmp_path / "certs"
|
||||
certs_dir.mkdir()
|
||||
(certs_dir / "example.com.pem").write_text("cert content")
|
||||
|
||||
# Mock openssl output showing wildcard SAN
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="X509v3 Subject Alternative Name:\n DNS:example.com, DNS:*.example.com"
|
||||
)
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)):
|
||||
covered, info = check_certificate_coverage("api.example.com")
|
||||
|
||||
assert covered is True
|
||||
assert "wildcard" in info
|
||||
|
||||
def test_no_matching_cert(self, tmp_path):
|
||||
"""No matching certificate."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
certs_dir = tmp_path / "certs"
|
||||
certs_dir.mkdir()
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)):
|
||||
covered, info = check_certificate_coverage("example.com")
|
||||
|
||||
assert covered is False
|
||||
assert "No matching" in info
|
||||
Reference in New Issue
Block a user