refactor: Improve code quality, error handling, and test coverage
- Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
345
tests/conftest.py
Normal file
345
tests/conftest.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Shared pytest fixtures for HAProxy MCP Server tests."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the parent directory to sys.path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
class MockSocket:
|
||||
"""Mock socket for testing HAProxy client communication."""
|
||||
|
||||
def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""):
|
||||
"""Initialize mock socket.
|
||||
|
||||
Args:
|
||||
responses: Dict mapping command prefixes to responses
|
||||
default_response: Response for commands not in responses dict
|
||||
"""
|
||||
self.responses = responses or {}
|
||||
self.default_response = default_response
|
||||
self.sent_commands: list[str] = []
|
||||
self._closed = False
|
||||
self._response_buffer = b""
|
||||
|
||||
def connect(self, address: tuple[str, int]) -> None:
|
||||
"""Mock connect - does nothing."""
|
||||
pass
|
||||
|
||||
def settimeout(self, timeout: float) -> None:
|
||||
"""Mock settimeout - does nothing."""
|
||||
pass
|
||||
|
||||
def setblocking(self, blocking: bool) -> None:
|
||||
"""Mock setblocking - does nothing."""
|
||||
pass
|
||||
|
||||
def sendall(self, data: bytes) -> None:
|
||||
"""Mock sendall - stores sent command."""
|
||||
command = data.decode().strip()
|
||||
self.sent_commands.append(command)
|
||||
# Prepare response for this command
|
||||
response = self.default_response
|
||||
for prefix, resp in self.responses.items():
|
||||
if command.startswith(prefix):
|
||||
response = resp
|
||||
break
|
||||
self._response_buffer = response.encode()
|
||||
|
||||
def shutdown(self, how: int) -> None:
|
||||
"""Mock shutdown - does nothing."""
|
||||
pass
|
||||
|
||||
def recv(self, bufsize: int) -> bytes:
|
||||
"""Mock recv - returns prepared response."""
|
||||
if self._response_buffer:
|
||||
data = self._response_buffer[:bufsize]
|
||||
self._response_buffer = self._response_buffer[bufsize:]
|
||||
return data
|
||||
return b""
|
||||
|
||||
def close(self) -> None:
|
||||
"""Mock close."""
|
||||
self._closed = True
|
||||
|
||||
def fileno(self) -> int:
|
||||
"""Mock fileno for select."""
|
||||
return 999
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
|
||||
class HAProxyResponseBuilder:
|
||||
"""Helper class to build HAProxy-style responses for tests."""
|
||||
|
||||
@staticmethod
|
||||
def servers_state(servers: list[dict[str, Any]]) -> str:
|
||||
"""Build a 'show servers state' response.
|
||||
|
||||
Args:
|
||||
servers: List of server dicts with keys:
|
||||
- be_id: Backend ID (int)
|
||||
- be_name: Backend name
|
||||
- srv_id: Server ID (int)
|
||||
- srv_name: Server name
|
||||
- srv_addr: Server address (IP)
|
||||
- srv_op_state: Operational state (int)
|
||||
- srv_admin_state: Admin state (int)
|
||||
- srv_port: Server port
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted server state string
|
||||
"""
|
||||
lines = ["1"] # Version line
|
||||
lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord")
|
||||
for srv in servers:
|
||||
# Fill in defaults for minimal state line (need 19+ columns)
|
||||
line_parts = [
|
||||
str(srv.get("be_id", 1)),
|
||||
srv.get("be_name", "pool_1"),
|
||||
str(srv.get("srv_id", 1)),
|
||||
srv.get("srv_name", "pool_1_1"),
|
||||
srv.get("srv_addr", "0.0.0.0"),
|
||||
str(srv.get("srv_op_state", 2)),
|
||||
str(srv.get("srv_admin_state", 0)),
|
||||
"1", # srv_uweight
|
||||
"1", # srv_iweight
|
||||
"100", # srv_time_since_last_change
|
||||
"6", # srv_check_status
|
||||
"3", # srv_check_result
|
||||
"4", # srv_check_health
|
||||
"6", # srv_check_state
|
||||
"0", # srv_agent_state
|
||||
"0", # bk_f_forced_id
|
||||
"0", # srv_f_forced_id
|
||||
"-", # srv_fqdn
|
||||
str(srv.get("srv_port", 0)), # srv_port
|
||||
]
|
||||
lines.append(" ".join(line_parts))
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def stat_csv(entries: list[dict[str, Any]]) -> str:
|
||||
"""Build a 'show stat' CSV response.
|
||||
|
||||
Args:
|
||||
entries: List of stat dicts with keys:
|
||||
- pxname: Proxy name
|
||||
- svname: Server name (or FRONTEND/BACKEND)
|
||||
- scur: Current sessions (optional, default 0)
|
||||
- status: Status (UP/DOWN/MAINT)
|
||||
- weight: Weight (optional, default 1)
|
||||
- check_status: Check status (optional)
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted CSV stat string
|
||||
"""
|
||||
lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"]
|
||||
for entry in entries:
|
||||
# Build CSV row with proper field positions
|
||||
# Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ...
|
||||
# status(17), weight(18), ..., check_status(36)
|
||||
row = [""] * 50
|
||||
row[0] = entry.get("pxname", "pool_1")
|
||||
row[1] = entry.get("svname", "pool_1_1")
|
||||
row[4] = str(entry.get("scur", 0)) # SCUR
|
||||
row[5] = str(entry.get("smax", 0)) # SMAX
|
||||
row[17] = entry.get("status", "UP") # STATUS
|
||||
row[18] = str(entry.get("weight", 1)) # WEIGHT
|
||||
row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS
|
||||
lines.append(",".join(row))
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def info(version: str = "3.3.2", uptime: int = 3600) -> str:
|
||||
"""Build a 'show info' response.
|
||||
|
||||
Args:
|
||||
version: HAProxy version string
|
||||
uptime: Uptime in seconds
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted info string
|
||||
"""
|
||||
return f"""Name: HAProxy
|
||||
Version: {version}
|
||||
Release_date: 2024/01/01
|
||||
Nbthread: 4
|
||||
Nbproc: 1
|
||||
Process_num: 1
|
||||
Pid: 1
|
||||
Uptime: 1h0m0s
|
||||
Uptime_sec: {uptime}
|
||||
Memmax_MB: 0
|
||||
PoolAlloc_MB: 0
|
||||
PoolUsed_MB: 0
|
||||
PoolFailed: 0
|
||||
Ulimit-n: 200015
|
||||
Maxsock: 200015
|
||||
Maxconn: 100000
|
||||
Hard_maxconn: 100000
|
||||
CurrConns: 5
|
||||
CumConns: 1000
|
||||
CumReq: 5000"""
|
||||
|
||||
@staticmethod
|
||||
def map_show(entries: list[tuple[str, str]]) -> str:
|
||||
"""Build a 'show map' response.
|
||||
|
||||
Args:
|
||||
entries: List of (key, value) tuples
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted map show string
|
||||
"""
|
||||
lines = []
|
||||
for i, (key, value) in enumerate(entries):
|
||||
lines.append(f"0x{i:08x} {key} {value}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_socket_class():
|
||||
"""Fixture that returns MockSocket class for custom configuration."""
|
||||
return MockSocket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def response_builder():
|
||||
"""Fixture that returns HAProxyResponseBuilder class."""
|
||||
return HAProxyResponseBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_haproxy_socket(mock_socket_class, response_builder):
|
||||
"""Fixture providing a pre-configured mock socket with common responses."""
|
||||
responses = {
|
||||
"show info": response_builder.info(),
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
"show stat": response_builder.stat_csv([]),
|
||||
"show map": response_builder.map_show([]),
|
||||
"show backend": "pool_1\npool_2\npool_3",
|
||||
"add map": "",
|
||||
"del map": "",
|
||||
"set server": "",
|
||||
}
|
||||
return mock_socket_class(responses=responses)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_config_dir(tmp_path):
|
||||
"""Fixture providing a temporary directory with config files."""
|
||||
# Create config files
|
||||
map_file = tmp_path / "domains.map"
|
||||
map_file.write_text("# Domain to Backend mapping\n")
|
||||
|
||||
wildcards_file = tmp_path / "wildcards.map"
|
||||
wildcards_file.write_text("# Wildcard Domain mapping\n")
|
||||
|
||||
servers_file = tmp_path / "servers.json"
|
||||
servers_file.write_text("{}")
|
||||
|
||||
certs_file = tmp_path / "certificates.json"
|
||||
certs_file.write_text('{"domains": []}')
|
||||
|
||||
state_file = tmp_path / "servers.state"
|
||||
state_file.write_text("")
|
||||
|
||||
return {
|
||||
"dir": tmp_path,
|
||||
"map_file": str(map_file),
|
||||
"wildcards_file": str(wildcards_file),
|
||||
"servers_file": str(servers_file),
|
||||
"certs_file": str(certs_file),
|
||||
"state_file": str(state_file),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_config_paths(temp_config_dir):
|
||||
"""Fixture that patches config module paths to use temporary directory."""
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.config",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
STATE_FILE=temp_config_dir["state_file"],
|
||||
):
|
||||
# Also patch file_ops module which imports these
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.file_ops",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
):
|
||||
yield temp_config_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subprocess():
|
||||
"""Fixture that mocks subprocess.run for external command testing."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
# Default to successful command
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="",
|
||||
stderr="",
|
||||
)
|
||||
yield mock_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_socket_module(mock_haproxy_socket):
|
||||
"""Fixture that patches socket module to use mock socket."""
|
||||
|
||||
def create_socket(*args, **kwargs):
|
||||
return mock_haproxy_socket
|
||||
|
||||
with patch("socket.socket", side_effect=create_socket):
|
||||
yield mock_haproxy_socket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_select():
|
||||
"""Fixture that patches select.select for socket recv loops."""
|
||||
with patch("select.select") as mock_sel:
|
||||
# Default: socket is ready immediately
|
||||
mock_sel.return_value = ([True], [], [])
|
||||
yield mock_sel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_servers_config():
|
||||
"""Sample servers.json content for testing."""
|
||||
return {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||
"2": {"ip": "10.0.0.2", "http_port": 80},
|
||||
},
|
||||
"api.example.com": {
|
||||
"1": {"ip": "10.0.0.10", "http_port": 8080},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_map_entries():
|
||||
"""Sample domains.map entries for testing."""
|
||||
return [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
("api.example.com", "pool_2"),
|
||||
(".api.example.com", "pool_2"),
|
||||
]
|
||||
Reference in New Issue
Block a user