Instead of syncing JSON files back, the SQLite DB itself is now the persistent store on the remote HAProxy host: - Startup: download remote DB via SCP (skip migration if exists) - After writes: upload local DB via SCP (WAL checkpoint first) - JSON sync removed (sync_servers_json, sync_certs_json deleted) New functions: - ssh_ops: remote_download_file(), remote_upload_file() via SCP - db: sync_db_to_remote(), _try_download_remote_db() Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
370 lines
12 KiB
Python
370 lines
12 KiB
Python
"""Shared pytest fixtures for HAProxy MCP Server tests."""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# Add the parent directory to sys.path for imports
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
class MockSocket:
|
|
"""Mock socket for testing HAProxy client communication."""
|
|
|
|
def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""):
|
|
"""Initialize mock socket.
|
|
|
|
Args:
|
|
responses: Dict mapping command prefixes to responses
|
|
default_response: Response for commands not in responses dict
|
|
"""
|
|
self.responses = responses or {}
|
|
self.default_response = default_response
|
|
self.sent_commands: list[str] = []
|
|
self._closed = False
|
|
self._response_buffer = b""
|
|
|
|
def connect(self, address: tuple[str, int]) -> None:
|
|
"""Mock connect - does nothing."""
|
|
pass
|
|
|
|
def settimeout(self, timeout: float) -> None:
|
|
"""Mock settimeout - does nothing."""
|
|
pass
|
|
|
|
def setblocking(self, blocking: bool) -> None:
|
|
"""Mock setblocking - does nothing."""
|
|
pass
|
|
|
|
def sendall(self, data: bytes) -> None:
|
|
"""Mock sendall - stores sent command."""
|
|
command = data.decode().strip()
|
|
self.sent_commands.append(command)
|
|
# Prepare response for this command
|
|
response = self.default_response
|
|
for prefix, resp in self.responses.items():
|
|
if command.startswith(prefix):
|
|
response = resp
|
|
break
|
|
self._response_buffer = response.encode()
|
|
|
|
def shutdown(self, how: int) -> None:
|
|
"""Mock shutdown - does nothing."""
|
|
pass
|
|
|
|
def recv(self, bufsize: int) -> bytes:
|
|
"""Mock recv - returns prepared response."""
|
|
if self._response_buffer:
|
|
data = self._response_buffer[:bufsize]
|
|
self._response_buffer = self._response_buffer[bufsize:]
|
|
return data
|
|
return b""
|
|
|
|
def close(self) -> None:
|
|
"""Mock close."""
|
|
self._closed = True
|
|
|
|
def fileno(self) -> int:
|
|
"""Mock fileno for select."""
|
|
return 999
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.close()
|
|
|
|
|
|
class HAProxyResponseBuilder:
|
|
"""Helper class to build HAProxy-style responses for tests."""
|
|
|
|
@staticmethod
|
|
def servers_state(servers: list[dict[str, Any]]) -> str:
|
|
"""Build a 'show servers state' response.
|
|
|
|
Args:
|
|
servers: List of server dicts with keys:
|
|
- be_id: Backend ID (int)
|
|
- be_name: Backend name
|
|
- srv_id: Server ID (int)
|
|
- srv_name: Server name
|
|
- srv_addr: Server address (IP)
|
|
- srv_op_state: Operational state (int)
|
|
- srv_admin_state: Admin state (int)
|
|
- srv_port: Server port
|
|
|
|
Returns:
|
|
HAProxy-formatted server state string
|
|
"""
|
|
lines = ["1"] # Version line
|
|
lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord")
|
|
for srv in servers:
|
|
# Fill in defaults for minimal state line (need 19+ columns)
|
|
line_parts = [
|
|
str(srv.get("be_id", 1)),
|
|
srv.get("be_name", "pool_1"),
|
|
str(srv.get("srv_id", 1)),
|
|
srv.get("srv_name", "pool_1_1"),
|
|
srv.get("srv_addr", "0.0.0.0"),
|
|
str(srv.get("srv_op_state", 2)),
|
|
str(srv.get("srv_admin_state", 0)),
|
|
"1", # srv_uweight
|
|
"1", # srv_iweight
|
|
"100", # srv_time_since_last_change
|
|
"6", # srv_check_status
|
|
"3", # srv_check_result
|
|
"4", # srv_check_health
|
|
"6", # srv_check_state
|
|
"0", # srv_agent_state
|
|
"0", # bk_f_forced_id
|
|
"0", # srv_f_forced_id
|
|
"-", # srv_fqdn
|
|
str(srv.get("srv_port", 0)), # srv_port
|
|
]
|
|
lines.append(" ".join(line_parts))
|
|
return "\n".join(lines)
|
|
|
|
@staticmethod
|
|
def stat_csv(entries: list[dict[str, Any]]) -> str:
|
|
"""Build a 'show stat' CSV response.
|
|
|
|
Args:
|
|
entries: List of stat dicts with keys:
|
|
- pxname: Proxy name
|
|
- svname: Server name (or FRONTEND/BACKEND)
|
|
- scur: Current sessions (optional, default 0)
|
|
- status: Status (UP/DOWN/MAINT)
|
|
- weight: Weight (optional, default 1)
|
|
- check_status: Check status (optional)
|
|
|
|
Returns:
|
|
HAProxy-formatted CSV stat string
|
|
"""
|
|
lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"]
|
|
for entry in entries:
|
|
# Build CSV row with proper field positions
|
|
# Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ...
|
|
# status(17), weight(18), ..., check_status(36)
|
|
row = [""] * 50
|
|
row[0] = entry.get("pxname", "pool_1")
|
|
row[1] = entry.get("svname", "pool_1_1")
|
|
row[4] = str(entry.get("scur", 0)) # SCUR
|
|
row[5] = str(entry.get("smax", 0)) # SMAX
|
|
row[17] = entry.get("status", "UP") # STATUS
|
|
row[18] = str(entry.get("weight", 1)) # WEIGHT
|
|
row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS
|
|
lines.append(",".join(row))
|
|
return "\n".join(lines)
|
|
|
|
@staticmethod
|
|
def info(version: str = "3.3.2", uptime: int = 3600) -> str:
|
|
"""Build a 'show info' response.
|
|
|
|
Args:
|
|
version: HAProxy version string
|
|
uptime: Uptime in seconds
|
|
|
|
Returns:
|
|
HAProxy-formatted info string
|
|
"""
|
|
return f"""Name: HAProxy
|
|
Version: {version}
|
|
Release_date: 2024/01/01
|
|
Nbthread: 4
|
|
Nbproc: 1
|
|
Process_num: 1
|
|
Pid: 1
|
|
Uptime: 1h0m0s
|
|
Uptime_sec: {uptime}
|
|
Memmax_MB: 0
|
|
PoolAlloc_MB: 0
|
|
PoolUsed_MB: 0
|
|
PoolFailed: 0
|
|
Ulimit-n: 200015
|
|
Maxsock: 200015
|
|
Maxconn: 100000
|
|
Hard_maxconn: 100000
|
|
CurrConns: 5
|
|
CumConns: 1000
|
|
CumReq: 5000"""
|
|
|
|
@staticmethod
|
|
def map_show(entries: list[tuple[str, str]]) -> str:
|
|
"""Build a 'show map' response.
|
|
|
|
Args:
|
|
entries: List of (key, value) tuples
|
|
|
|
Returns:
|
|
HAProxy-formatted map show string
|
|
"""
|
|
lines = []
|
|
for i, (key, value) in enumerate(entries):
|
|
lines.append(f"0x{i:08x} {key} {value}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_socket_class():
|
|
"""Fixture that returns MockSocket class for custom configuration."""
|
|
return MockSocket
|
|
|
|
|
|
@pytest.fixture
|
|
def response_builder():
|
|
"""Fixture that returns HAProxyResponseBuilder class."""
|
|
return HAProxyResponseBuilder
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_haproxy_socket(mock_socket_class, response_builder):
|
|
"""Fixture providing a pre-configured mock socket with common responses."""
|
|
responses = {
|
|
"show info": response_builder.info(),
|
|
"show servers state": response_builder.servers_state([]),
|
|
"show stat": response_builder.stat_csv([]),
|
|
"show map": response_builder.map_show([]),
|
|
"show backend": "pool_1\npool_2\npool_3",
|
|
"add map": "",
|
|
"del map": "",
|
|
"set server": "",
|
|
}
|
|
return mock_socket_class(responses=responses)
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_config_dir(tmp_path):
|
|
"""Fixture providing a temporary directory with config files."""
|
|
# Create config files
|
|
map_file = tmp_path / "domains.map"
|
|
map_file.write_text("# Domain to Backend mapping\n")
|
|
|
|
wildcards_file = tmp_path / "wildcards.map"
|
|
wildcards_file.write_text("# Wildcard Domain mapping\n")
|
|
|
|
servers_file = tmp_path / "servers.json"
|
|
servers_file.write_text("{}")
|
|
|
|
certs_file = tmp_path / "certificates.json"
|
|
certs_file.write_text('{"domains": []}')
|
|
|
|
state_file = tmp_path / "servers.state"
|
|
state_file.write_text("")
|
|
|
|
db_file = tmp_path / "haproxy_mcp.db"
|
|
|
|
return {
|
|
"dir": tmp_path,
|
|
"map_file": str(map_file),
|
|
"wildcards_file": str(wildcards_file),
|
|
"servers_file": str(servers_file),
|
|
"certs_file": str(certs_file),
|
|
"state_file": str(state_file),
|
|
"db_file": str(db_file),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_config_paths(temp_config_dir):
|
|
"""Fixture that patches config module paths to use temporary directory."""
|
|
from haproxy_mcp.db import close_connection, init_db
|
|
|
|
with patch.multiple(
|
|
"haproxy_mcp.config",
|
|
MAP_FILE=temp_config_dir["map_file"],
|
|
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
|
SERVERS_FILE=temp_config_dir["servers_file"],
|
|
CERTS_FILE=temp_config_dir["certs_file"],
|
|
STATE_FILE=temp_config_dir["state_file"],
|
|
DB_FILE=temp_config_dir["db_file"],
|
|
):
|
|
# Also patch file_ops module which imports these
|
|
with patch.multiple(
|
|
"haproxy_mcp.file_ops",
|
|
MAP_FILE=temp_config_dir["map_file"],
|
|
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
|
):
|
|
# Patch db module which imports these
|
|
with patch.multiple(
|
|
"haproxy_mcp.db",
|
|
MAP_FILE=temp_config_dir["map_file"],
|
|
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
|
SERVERS_FILE=temp_config_dir["servers_file"],
|
|
CERTS_FILE=temp_config_dir["certs_file"],
|
|
DB_FILE=temp_config_dir["db_file"],
|
|
REMOTE_DB_FILE=temp_config_dir["db_file"],
|
|
):
|
|
# Patch health module which imports MAP_FILE and DB_FILE
|
|
with patch.multiple(
|
|
"haproxy_mcp.tools.health",
|
|
MAP_FILE=temp_config_dir["map_file"],
|
|
DB_FILE=temp_config_dir["db_file"],
|
|
):
|
|
# Close any existing connection and initialize fresh DB
|
|
close_connection()
|
|
init_db()
|
|
yield temp_config_dir
|
|
close_connection()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_subprocess():
|
|
"""Fixture that mocks subprocess.run for external command testing."""
|
|
with patch("subprocess.run") as mock_run:
|
|
# Default to successful command
|
|
mock_run.return_value = MagicMock(
|
|
returncode=0,
|
|
stdout="",
|
|
stderr="",
|
|
)
|
|
yield mock_run
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_socket_module(mock_haproxy_socket):
|
|
"""Fixture that patches socket module to use mock socket."""
|
|
|
|
def create_socket(*args, **kwargs):
|
|
return mock_haproxy_socket
|
|
|
|
with patch("socket.socket", side_effect=create_socket):
|
|
yield mock_haproxy_socket
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_select():
|
|
"""Fixture that patches select.select for socket recv loops."""
|
|
with patch("select.select") as mock_sel:
|
|
# Default: socket is ready immediately
|
|
mock_sel.return_value = ([True], [], [])
|
|
yield mock_sel
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_servers_config():
|
|
"""Sample servers.json content for testing."""
|
|
return {
|
|
"example.com": {
|
|
"1": {"ip": "10.0.0.1", "http_port": 80},
|
|
"2": {"ip": "10.0.0.2", "http_port": 80},
|
|
},
|
|
"api.example.com": {
|
|
"1": {"ip": "10.0.0.10", "http_port": 8080},
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_map_entries():
|
|
"""Sample domains.map entries for testing."""
|
|
return [
|
|
("example.com", "pool_1"),
|
|
(".example.com", "pool_1"),
|
|
("api.example.com", "pool_2"),
|
|
(".api.example.com", "pool_2"),
|
|
]
|