refactor: Improve code quality, error handling, and test coverage

- Add file_lock context manager to eliminate duplicate locking patterns
- Add ValidationError, ConfigurationError, CertificateError exceptions
- Improve rollback logic in haproxy_add_servers (track successful ops only)
- Decompose haproxy_add_domain into smaller helper functions
- Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py
- Enhance docstrings for internal functions and magic numbers
- Add pytest framework with 48 new tests (269 -> 317 total)
- Increase test coverage from 76% to 86%
  - servers.py: 58% -> 82%
  - certificates.py: 67% -> 86%
  - configuration.py: 69% -> 94%

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-03 12:50:00 +09:00
parent 18ce812920
commit 6bcfee519c
25 changed files with 6852 additions and 125 deletions

345
tests/conftest.py Normal file
View File

@@ -0,0 +1,345 @@
"""Shared pytest fixtures for HAProxy MCP Server tests."""
import json
import os
import sys
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
# Add the parent directory to sys.path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class MockSocket:
"""Mock socket for testing HAProxy client communication."""
def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""):
"""Initialize mock socket.
Args:
responses: Dict mapping command prefixes to responses
default_response: Response for commands not in responses dict
"""
self.responses = responses or {}
self.default_response = default_response
self.sent_commands: list[str] = []
self._closed = False
self._response_buffer = b""
def connect(self, address: tuple[str, int]) -> None:
"""Mock connect - does nothing."""
pass
def settimeout(self, timeout: float) -> None:
"""Mock settimeout - does nothing."""
pass
def setblocking(self, blocking: bool) -> None:
"""Mock setblocking - does nothing."""
pass
def sendall(self, data: bytes) -> None:
"""Mock sendall - stores sent command."""
command = data.decode().strip()
self.sent_commands.append(command)
# Prepare response for this command
response = self.default_response
for prefix, resp in self.responses.items():
if command.startswith(prefix):
response = resp
break
self._response_buffer = response.encode()
def shutdown(self, how: int) -> None:
"""Mock shutdown - does nothing."""
pass
def recv(self, bufsize: int) -> bytes:
"""Mock recv - returns prepared response."""
if self._response_buffer:
data = self._response_buffer[:bufsize]
self._response_buffer = self._response_buffer[bufsize:]
return data
return b""
def close(self) -> None:
"""Mock close."""
self._closed = True
def fileno(self) -> int:
"""Mock fileno for select."""
return 999
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class HAProxyResponseBuilder:
"""Helper class to build HAProxy-style responses for tests."""
@staticmethod
def servers_state(servers: list[dict[str, Any]]) -> str:
"""Build a 'show servers state' response.
Args:
servers: List of server dicts with keys:
- be_id: Backend ID (int)
- be_name: Backend name
- srv_id: Server ID (int)
- srv_name: Server name
- srv_addr: Server address (IP)
- srv_op_state: Operational state (int)
- srv_admin_state: Admin state (int)
- srv_port: Server port
Returns:
HAProxy-formatted server state string
"""
lines = ["1"] # Version line
lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord")
for srv in servers:
# Fill in defaults for minimal state line (need 19+ columns)
line_parts = [
str(srv.get("be_id", 1)),
srv.get("be_name", "pool_1"),
str(srv.get("srv_id", 1)),
srv.get("srv_name", "pool_1_1"),
srv.get("srv_addr", "0.0.0.0"),
str(srv.get("srv_op_state", 2)),
str(srv.get("srv_admin_state", 0)),
"1", # srv_uweight
"1", # srv_iweight
"100", # srv_time_since_last_change
"6", # srv_check_status
"3", # srv_check_result
"4", # srv_check_health
"6", # srv_check_state
"0", # srv_agent_state
"0", # bk_f_forced_id
"0", # srv_f_forced_id
"-", # srv_fqdn
str(srv.get("srv_port", 0)), # srv_port
]
lines.append(" ".join(line_parts))
return "\n".join(lines)
@staticmethod
def stat_csv(entries: list[dict[str, Any]]) -> str:
"""Build a 'show stat' CSV response.
Args:
entries: List of stat dicts with keys:
- pxname: Proxy name
- svname: Server name (or FRONTEND/BACKEND)
- scur: Current sessions (optional, default 0)
- status: Status (UP/DOWN/MAINT)
- weight: Weight (optional, default 1)
- check_status: Check status (optional)
Returns:
HAProxy-formatted CSV stat string
"""
lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"]
for entry in entries:
# Build CSV row with proper field positions
# Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ...
# status(17), weight(18), ..., check_status(36)
row = [""] * 50
row[0] = entry.get("pxname", "pool_1")
row[1] = entry.get("svname", "pool_1_1")
row[4] = str(entry.get("scur", 0)) # SCUR
row[5] = str(entry.get("smax", 0)) # SMAX
row[17] = entry.get("status", "UP") # STATUS
row[18] = str(entry.get("weight", 1)) # WEIGHT
row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS
lines.append(",".join(row))
return "\n".join(lines)
@staticmethod
def info(version: str = "3.3.2", uptime: int = 3600) -> str:
"""Build a 'show info' response.
Args:
version: HAProxy version string
uptime: Uptime in seconds
Returns:
HAProxy-formatted info string
"""
return f"""Name: HAProxy
Version: {version}
Release_date: 2024/01/01
Nbthread: 4
Nbproc: 1
Process_num: 1
Pid: 1
Uptime: 1h0m0s
Uptime_sec: {uptime}
Memmax_MB: 0
PoolAlloc_MB: 0
PoolUsed_MB: 0
PoolFailed: 0
Ulimit-n: 200015
Maxsock: 200015
Maxconn: 100000
Hard_maxconn: 100000
CurrConns: 5
CumConns: 1000
CumReq: 5000"""
@staticmethod
def map_show(entries: list[tuple[str, str]]) -> str:
"""Build a 'show map' response.
Args:
entries: List of (key, value) tuples
Returns:
HAProxy-formatted map show string
"""
lines = []
for i, (key, value) in enumerate(entries):
lines.append(f"0x{i:08x} {key} {value}")
return "\n".join(lines)
@pytest.fixture
def mock_socket_class():
"""Fixture that returns MockSocket class for custom configuration."""
return MockSocket
@pytest.fixture
def response_builder():
"""Fixture that returns HAProxyResponseBuilder class."""
return HAProxyResponseBuilder
@pytest.fixture
def mock_haproxy_socket(mock_socket_class, response_builder):
"""Fixture providing a pre-configured mock socket with common responses."""
responses = {
"show info": response_builder.info(),
"show servers state": response_builder.servers_state([]),
"show stat": response_builder.stat_csv([]),
"show map": response_builder.map_show([]),
"show backend": "pool_1\npool_2\npool_3",
"add map": "",
"del map": "",
"set server": "",
}
return mock_socket_class(responses=responses)
@pytest.fixture
def temp_config_dir(tmp_path):
"""Fixture providing a temporary directory with config files."""
# Create config files
map_file = tmp_path / "domains.map"
map_file.write_text("# Domain to Backend mapping\n")
wildcards_file = tmp_path / "wildcards.map"
wildcards_file.write_text("# Wildcard Domain mapping\n")
servers_file = tmp_path / "servers.json"
servers_file.write_text("{}")
certs_file = tmp_path / "certificates.json"
certs_file.write_text('{"domains": []}')
state_file = tmp_path / "servers.state"
state_file.write_text("")
return {
"dir": tmp_path,
"map_file": str(map_file),
"wildcards_file": str(wildcards_file),
"servers_file": str(servers_file),
"certs_file": str(certs_file),
"state_file": str(state_file),
}
@pytest.fixture
def patch_config_paths(temp_config_dir):
"""Fixture that patches config module paths to use temporary directory."""
with patch.multiple(
"haproxy_mcp.config",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"],
STATE_FILE=temp_config_dir["state_file"],
):
# Also patch file_ops module which imports these
with patch.multiple(
"haproxy_mcp.file_ops",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"],
):
yield temp_config_dir
@pytest.fixture
def mock_subprocess():
"""Fixture that mocks subprocess.run for external command testing."""
with patch("subprocess.run") as mock_run:
# Default to successful command
mock_run.return_value = MagicMock(
returncode=0,
stdout="",
stderr="",
)
yield mock_run
@pytest.fixture
def mock_socket_module(mock_haproxy_socket):
"""Fixture that patches socket module to use mock socket."""
def create_socket(*args, **kwargs):
return mock_haproxy_socket
with patch("socket.socket", side_effect=create_socket):
yield mock_haproxy_socket
@pytest.fixture
def mock_select():
"""Fixture that patches select.select for socket recv loops."""
with patch("select.select") as mock_sel:
# Default: socket is ready immediately
mock_sel.return_value = ([True], [], [])
yield mock_sel
@pytest.fixture
def sample_servers_config():
"""Sample servers.json content for testing."""
return {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": 80},
"2": {"ip": "10.0.0.2", "http_port": 80},
},
"api.example.com": {
"1": {"ip": "10.0.0.10", "http_port": 8080},
},
}
@pytest.fixture
def sample_map_entries():
"""Sample domains.map entries for testing."""
return [
("example.com", "pool_1"),
(".example.com", "pool_1"),
("api.example.com", "pool_2"),
(".api.example.com", "pool_2"),
]