"""Shared pytest fixtures for HAProxy MCP Server tests.""" import json import os import sys from typing import Any from unittest.mock import MagicMock, patch import pytest # Add the parent directory to sys.path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class MockSocket: """Mock socket for testing HAProxy client communication.""" def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""): """Initialize mock socket. Args: responses: Dict mapping command prefixes to responses default_response: Response for commands not in responses dict """ self.responses = responses or {} self.default_response = default_response self.sent_commands: list[str] = [] self._closed = False self._response_buffer = b"" def connect(self, address: tuple[str, int]) -> None: """Mock connect - does nothing.""" pass def settimeout(self, timeout: float) -> None: """Mock settimeout - does nothing.""" pass def setblocking(self, blocking: bool) -> None: """Mock setblocking - does nothing.""" pass def sendall(self, data: bytes) -> None: """Mock sendall - stores sent command.""" command = data.decode().strip() self.sent_commands.append(command) # Prepare response for this command response = self.default_response for prefix, resp in self.responses.items(): if command.startswith(prefix): response = resp break self._response_buffer = response.encode() def shutdown(self, how: int) -> None: """Mock shutdown - does nothing.""" pass def recv(self, bufsize: int) -> bytes: """Mock recv - returns prepared response.""" if self._response_buffer: data = self._response_buffer[:bufsize] self._response_buffer = self._response_buffer[bufsize:] return data return b"" def close(self) -> None: """Mock close.""" self._closed = True def fileno(self) -> int: """Mock fileno for select.""" return 999 def __enter__(self): return self def __exit__(self, *args): self.close() class HAProxyResponseBuilder: """Helper class to build HAProxy-style responses for tests.""" @staticmethod def servers_state(servers: list[dict[str, Any]]) -> str: """Build a 'show servers state' response. Args: servers: List of server dicts with keys: - be_id: Backend ID (int) - be_name: Backend name - srv_id: Server ID (int) - srv_name: Server name - srv_addr: Server address (IP) - srv_op_state: Operational state (int) - srv_admin_state: Admin state (int) - srv_port: Server port Returns: HAProxy-formatted server state string """ lines = ["1"] # Version line lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord") for srv in servers: # Fill in defaults for minimal state line (need 19+ columns) line_parts = [ str(srv.get("be_id", 1)), srv.get("be_name", "pool_1"), str(srv.get("srv_id", 1)), srv.get("srv_name", "pool_1_1"), srv.get("srv_addr", "0.0.0.0"), str(srv.get("srv_op_state", 2)), str(srv.get("srv_admin_state", 0)), "1", # srv_uweight "1", # srv_iweight "100", # srv_time_since_last_change "6", # srv_check_status "3", # srv_check_result "4", # srv_check_health "6", # srv_check_state "0", # srv_agent_state "0", # bk_f_forced_id "0", # srv_f_forced_id "-", # srv_fqdn str(srv.get("srv_port", 0)), # srv_port ] lines.append(" ".join(line_parts)) return "\n".join(lines) @staticmethod def stat_csv(entries: list[dict[str, Any]]) -> str: """Build a 'show stat' CSV response. Args: entries: List of stat dicts with keys: - pxname: Proxy name - svname: Server name (or FRONTEND/BACKEND) - scur: Current sessions (optional, default 0) - status: Status (UP/DOWN/MAINT) - weight: Weight (optional, default 1) - check_status: Check status (optional) Returns: HAProxy-formatted CSV stat string """ lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"] for entry in entries: # Build CSV row with proper field positions # Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ... # status(17), weight(18), ..., check_status(36) row = [""] * 50 row[0] = entry.get("pxname", "pool_1") row[1] = entry.get("svname", "pool_1_1") row[4] = str(entry.get("scur", 0)) # SCUR row[5] = str(entry.get("smax", 0)) # SMAX row[17] = entry.get("status", "UP") # STATUS row[18] = str(entry.get("weight", 1)) # WEIGHT row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS lines.append(",".join(row)) return "\n".join(lines) @staticmethod def info(version: str = "3.3.2", uptime: int = 3600) -> str: """Build a 'show info' response. Args: version: HAProxy version string uptime: Uptime in seconds Returns: HAProxy-formatted info string """ return f"""Name: HAProxy Version: {version} Release_date: 2024/01/01 Nbthread: 4 Nbproc: 1 Process_num: 1 Pid: 1 Uptime: 1h0m0s Uptime_sec: {uptime} Memmax_MB: 0 PoolAlloc_MB: 0 PoolUsed_MB: 0 PoolFailed: 0 Ulimit-n: 200015 Maxsock: 200015 Maxconn: 100000 Hard_maxconn: 100000 CurrConns: 5 CumConns: 1000 CumReq: 5000""" @staticmethod def map_show(entries: list[tuple[str, str]]) -> str: """Build a 'show map' response. Args: entries: List of (key, value) tuples Returns: HAProxy-formatted map show string """ lines = [] for i, (key, value) in enumerate(entries): lines.append(f"0x{i:08x} {key} {value}") return "\n".join(lines) @pytest.fixture def mock_socket_class(): """Fixture that returns MockSocket class for custom configuration.""" return MockSocket @pytest.fixture def response_builder(): """Fixture that returns HAProxyResponseBuilder class.""" return HAProxyResponseBuilder @pytest.fixture def mock_haproxy_socket(mock_socket_class, response_builder): """Fixture providing a pre-configured mock socket with common responses.""" responses = { "show info": response_builder.info(), "show servers state": response_builder.servers_state([]), "show stat": response_builder.stat_csv([]), "show map": response_builder.map_show([]), "show backend": "pool_1\npool_2\npool_3", "add map": "", "del map": "", "set server": "", } return mock_socket_class(responses=responses) @pytest.fixture def temp_config_dir(tmp_path): """Fixture providing a temporary directory with config files.""" # Create config files map_file = tmp_path / "domains.map" map_file.write_text("# Domain to Backend mapping\n") wildcards_file = tmp_path / "wildcards.map" wildcards_file.write_text("# Wildcard Domain mapping\n") servers_file = tmp_path / "servers.json" servers_file.write_text("{}") certs_file = tmp_path / "certificates.json" certs_file.write_text('{"domains": []}') state_file = tmp_path / "servers.state" state_file.write_text("") db_file = tmp_path / "haproxy_mcp.db" return { "dir": tmp_path, "map_file": str(map_file), "wildcards_file": str(wildcards_file), "servers_file": str(servers_file), "certs_file": str(certs_file), "state_file": str(state_file), "db_file": str(db_file), } @pytest.fixture def patch_config_paths(temp_config_dir): """Fixture that patches config module paths to use temporary directory.""" from haproxy_mcp.db import close_connection, init_db with patch.multiple( "haproxy_mcp.config", MAP_FILE=temp_config_dir["map_file"], WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], SERVERS_FILE=temp_config_dir["servers_file"], CERTS_FILE=temp_config_dir["certs_file"], STATE_FILE=temp_config_dir["state_file"], DB_FILE=temp_config_dir["db_file"], ): # Also patch file_ops module which imports these with patch.multiple( "haproxy_mcp.file_ops", MAP_FILE=temp_config_dir["map_file"], WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], ): # Patch db module which imports these with patch.multiple( "haproxy_mcp.db", MAP_FILE=temp_config_dir["map_file"], WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"], SERVERS_FILE=temp_config_dir["servers_file"], CERTS_FILE=temp_config_dir["certs_file"], DB_FILE=temp_config_dir["db_file"], ): # Patch health module which imports MAP_FILE and DB_FILE with patch.multiple( "haproxy_mcp.tools.health", MAP_FILE=temp_config_dir["map_file"], DB_FILE=temp_config_dir["db_file"], ): # Close any existing connection and initialize fresh DB close_connection() init_db() yield temp_config_dir close_connection() @pytest.fixture def mock_subprocess(): """Fixture that mocks subprocess.run for external command testing.""" with patch("subprocess.run") as mock_run: # Default to successful command mock_run.return_value = MagicMock( returncode=0, stdout="", stderr="", ) yield mock_run @pytest.fixture def mock_socket_module(mock_haproxy_socket): """Fixture that patches socket module to use mock socket.""" def create_socket(*args, **kwargs): return mock_haproxy_socket with patch("socket.socket", side_effect=create_socket): yield mock_haproxy_socket @pytest.fixture def mock_select(): """Fixture that patches select.select for socket recv loops.""" with patch("select.select") as mock_sel: # Default: socket is ready immediately mock_sel.return_value = ([True], [], []) yield mock_sel @pytest.fixture def sample_servers_config(): """Sample servers.json content for testing.""" return { "example.com": { "1": {"ip": "10.0.0.1", "http_port": 80}, "2": {"ip": "10.0.0.2", "http_port": 80}, }, "api.example.com": { "1": {"ip": "10.0.0.10", "http_port": 8080}, }, } @pytest.fixture def sample_map_entries(): """Sample domains.map entries for testing.""" return [ ("example.com", "pool_1"), (".example.com", "pool_1"), ("api.example.com", "pool_2"), (".api.example.com", "pool_2"), ]