Files
haproxy-mcp/tests/conftest.py
kappa 12fd3b5e8f Store SQLite DB on remote host via SCP for persistence
Instead of syncing JSON files back, the SQLite DB itself is now
the persistent store on the remote HAProxy host:
- Startup: download remote DB via SCP (skip migration if exists)
- After writes: upload local DB via SCP (WAL checkpoint first)
- JSON sync removed (sync_servers_json, sync_certs_json deleted)

New functions:
- ssh_ops: remote_download_file(), remote_upload_file() via SCP
- db: sync_db_to_remote(), _try_download_remote_db()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 11:46:36 +09:00

370 lines
12 KiB
Python

"""Shared pytest fixtures for HAProxy MCP Server tests."""
import json
import os
import sys
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
# Add the parent directory to sys.path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class MockSocket:
"""Mock socket for testing HAProxy client communication."""
def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""):
"""Initialize mock socket.
Args:
responses: Dict mapping command prefixes to responses
default_response: Response for commands not in responses dict
"""
self.responses = responses or {}
self.default_response = default_response
self.sent_commands: list[str] = []
self._closed = False
self._response_buffer = b""
def connect(self, address: tuple[str, int]) -> None:
"""Mock connect - does nothing."""
pass
def settimeout(self, timeout: float) -> None:
"""Mock settimeout - does nothing."""
pass
def setblocking(self, blocking: bool) -> None:
"""Mock setblocking - does nothing."""
pass
def sendall(self, data: bytes) -> None:
"""Mock sendall - stores sent command."""
command = data.decode().strip()
self.sent_commands.append(command)
# Prepare response for this command
response = self.default_response
for prefix, resp in self.responses.items():
if command.startswith(prefix):
response = resp
break
self._response_buffer = response.encode()
def shutdown(self, how: int) -> None:
"""Mock shutdown - does nothing."""
pass
def recv(self, bufsize: int) -> bytes:
"""Mock recv - returns prepared response."""
if self._response_buffer:
data = self._response_buffer[:bufsize]
self._response_buffer = self._response_buffer[bufsize:]
return data
return b""
def close(self) -> None:
"""Mock close."""
self._closed = True
def fileno(self) -> int:
"""Mock fileno for select."""
return 999
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class HAProxyResponseBuilder:
"""Helper class to build HAProxy-style responses for tests."""
@staticmethod
def servers_state(servers: list[dict[str, Any]]) -> str:
"""Build a 'show servers state' response.
Args:
servers: List of server dicts with keys:
- be_id: Backend ID (int)
- be_name: Backend name
- srv_id: Server ID (int)
- srv_name: Server name
- srv_addr: Server address (IP)
- srv_op_state: Operational state (int)
- srv_admin_state: Admin state (int)
- srv_port: Server port
Returns:
HAProxy-formatted server state string
"""
lines = ["1"] # Version line
lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord")
for srv in servers:
# Fill in defaults for minimal state line (need 19+ columns)
line_parts = [
str(srv.get("be_id", 1)),
srv.get("be_name", "pool_1"),
str(srv.get("srv_id", 1)),
srv.get("srv_name", "pool_1_1"),
srv.get("srv_addr", "0.0.0.0"),
str(srv.get("srv_op_state", 2)),
str(srv.get("srv_admin_state", 0)),
"1", # srv_uweight
"1", # srv_iweight
"100", # srv_time_since_last_change
"6", # srv_check_status
"3", # srv_check_result
"4", # srv_check_health
"6", # srv_check_state
"0", # srv_agent_state
"0", # bk_f_forced_id
"0", # srv_f_forced_id
"-", # srv_fqdn
str(srv.get("srv_port", 0)), # srv_port
]
lines.append(" ".join(line_parts))
return "\n".join(lines)
@staticmethod
def stat_csv(entries: list[dict[str, Any]]) -> str:
"""Build a 'show stat' CSV response.
Args:
entries: List of stat dicts with keys:
- pxname: Proxy name
- svname: Server name (or FRONTEND/BACKEND)
- scur: Current sessions (optional, default 0)
- status: Status (UP/DOWN/MAINT)
- weight: Weight (optional, default 1)
- check_status: Check status (optional)
Returns:
HAProxy-formatted CSV stat string
"""
lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"]
for entry in entries:
# Build CSV row with proper field positions
# Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ...
# status(17), weight(18), ..., check_status(36)
row = [""] * 50
row[0] = entry.get("pxname", "pool_1")
row[1] = entry.get("svname", "pool_1_1")
row[4] = str(entry.get("scur", 0)) # SCUR
row[5] = str(entry.get("smax", 0)) # SMAX
row[17] = entry.get("status", "UP") # STATUS
row[18] = str(entry.get("weight", 1)) # WEIGHT
row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS
lines.append(",".join(row))
return "\n".join(lines)
@staticmethod
def info(version: str = "3.3.2", uptime: int = 3600) -> str:
"""Build a 'show info' response.
Args:
version: HAProxy version string
uptime: Uptime in seconds
Returns:
HAProxy-formatted info string
"""
return f"""Name: HAProxy
Version: {version}
Release_date: 2024/01/01
Nbthread: 4
Nbproc: 1
Process_num: 1
Pid: 1
Uptime: 1h0m0s
Uptime_sec: {uptime}
Memmax_MB: 0
PoolAlloc_MB: 0
PoolUsed_MB: 0
PoolFailed: 0
Ulimit-n: 200015
Maxsock: 200015
Maxconn: 100000
Hard_maxconn: 100000
CurrConns: 5
CumConns: 1000
CumReq: 5000"""
@staticmethod
def map_show(entries: list[tuple[str, str]]) -> str:
"""Build a 'show map' response.
Args:
entries: List of (key, value) tuples
Returns:
HAProxy-formatted map show string
"""
lines = []
for i, (key, value) in enumerate(entries):
lines.append(f"0x{i:08x} {key} {value}")
return "\n".join(lines)
@pytest.fixture
def mock_socket_class():
"""Fixture that returns MockSocket class for custom configuration."""
return MockSocket
@pytest.fixture
def response_builder():
"""Fixture that returns HAProxyResponseBuilder class."""
return HAProxyResponseBuilder
@pytest.fixture
def mock_haproxy_socket(mock_socket_class, response_builder):
"""Fixture providing a pre-configured mock socket with common responses."""
responses = {
"show info": response_builder.info(),
"show servers state": response_builder.servers_state([]),
"show stat": response_builder.stat_csv([]),
"show map": response_builder.map_show([]),
"show backend": "pool_1\npool_2\npool_3",
"add map": "",
"del map": "",
"set server": "",
}
return mock_socket_class(responses=responses)
@pytest.fixture
def temp_config_dir(tmp_path):
"""Fixture providing a temporary directory with config files."""
# Create config files
map_file = tmp_path / "domains.map"
map_file.write_text("# Domain to Backend mapping\n")
wildcards_file = tmp_path / "wildcards.map"
wildcards_file.write_text("# Wildcard Domain mapping\n")
servers_file = tmp_path / "servers.json"
servers_file.write_text("{}")
certs_file = tmp_path / "certificates.json"
certs_file.write_text('{"domains": []}')
state_file = tmp_path / "servers.state"
state_file.write_text("")
db_file = tmp_path / "haproxy_mcp.db"
return {
"dir": tmp_path,
"map_file": str(map_file),
"wildcards_file": str(wildcards_file),
"servers_file": str(servers_file),
"certs_file": str(certs_file),
"state_file": str(state_file),
"db_file": str(db_file),
}
@pytest.fixture
def patch_config_paths(temp_config_dir):
"""Fixture that patches config module paths to use temporary directory."""
from haproxy_mcp.db import close_connection, init_db
with patch.multiple(
"haproxy_mcp.config",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"],
STATE_FILE=temp_config_dir["state_file"],
DB_FILE=temp_config_dir["db_file"],
):
# Also patch file_ops module which imports these
with patch.multiple(
"haproxy_mcp.file_ops",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
):
# Patch db module which imports these
with patch.multiple(
"haproxy_mcp.db",
MAP_FILE=temp_config_dir["map_file"],
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
SERVERS_FILE=temp_config_dir["servers_file"],
CERTS_FILE=temp_config_dir["certs_file"],
DB_FILE=temp_config_dir["db_file"],
REMOTE_DB_FILE=temp_config_dir["db_file"],
):
# Patch health module which imports MAP_FILE and DB_FILE
with patch.multiple(
"haproxy_mcp.tools.health",
MAP_FILE=temp_config_dir["map_file"],
DB_FILE=temp_config_dir["db_file"],
):
# Close any existing connection and initialize fresh DB
close_connection()
init_db()
yield temp_config_dir
close_connection()
@pytest.fixture
def mock_subprocess():
"""Fixture that mocks subprocess.run for external command testing."""
with patch("subprocess.run") as mock_run:
# Default to successful command
mock_run.return_value = MagicMock(
returncode=0,
stdout="",
stderr="",
)
yield mock_run
@pytest.fixture
def mock_socket_module(mock_haproxy_socket):
"""Fixture that patches socket module to use mock socket."""
def create_socket(*args, **kwargs):
return mock_haproxy_socket
with patch("socket.socket", side_effect=create_socket):
yield mock_haproxy_socket
@pytest.fixture
def mock_select():
"""Fixture that patches select.select for socket recv loops."""
with patch("select.select") as mock_sel:
# Default: socket is ready immediately
mock_sel.return_value = ([True], [], [])
yield mock_sel
@pytest.fixture
def sample_servers_config():
"""Sample servers.json content for testing."""
return {
"example.com": {
"1": {"ip": "10.0.0.1", "http_port": 80},
"2": {"ip": "10.0.0.2", "http_port": 80},
},
"api.example.com": {
"1": {"ip": "10.0.0.10", "http_port": 8080},
},
}
@pytest.fixture
def sample_map_entries():
"""Sample domains.map entries for testing."""
return [
("example.com", "pool_1"),
(".example.com", "pool_1"),
("api.example.com", "pool_2"),
(".api.example.com", "pool_2"),
]