refactor: Improve code quality, error handling, and test coverage
- Add file_lock context manager to eliminate duplicate locking patterns - Add ValidationError, ConfigurationError, CertificateError exceptions - Improve rollback logic in haproxy_add_servers (track successful ops only) - Decompose haproxy_add_domain into smaller helper functions - Consolidate certificate constants (CERTS_DIR, ACME_HOME) to config.py - Enhance docstrings for internal functions and magic numbers - Add pytest framework with 48 new tests (269 -> 317 total) - Increase test coverage from 76% to 86% - servers.py: 58% -> 82% - certificates.py: 67% -> 86% - configuration.py: 69% -> 94% Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""HAProxy MCP Server test suite."""
|
||||
345
tests/conftest.py
Normal file
345
tests/conftest.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Shared pytest fixtures for HAProxy MCP Server tests."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the parent directory to sys.path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
class MockSocket:
|
||||
"""Mock socket for testing HAProxy client communication."""
|
||||
|
||||
def __init__(self, responses: dict[str, str] | None = None, default_response: str = ""):
|
||||
"""Initialize mock socket.
|
||||
|
||||
Args:
|
||||
responses: Dict mapping command prefixes to responses
|
||||
default_response: Response for commands not in responses dict
|
||||
"""
|
||||
self.responses = responses or {}
|
||||
self.default_response = default_response
|
||||
self.sent_commands: list[str] = []
|
||||
self._closed = False
|
||||
self._response_buffer = b""
|
||||
|
||||
def connect(self, address: tuple[str, int]) -> None:
|
||||
"""Mock connect - does nothing."""
|
||||
pass
|
||||
|
||||
def settimeout(self, timeout: float) -> None:
|
||||
"""Mock settimeout - does nothing."""
|
||||
pass
|
||||
|
||||
def setblocking(self, blocking: bool) -> None:
|
||||
"""Mock setblocking - does nothing."""
|
||||
pass
|
||||
|
||||
def sendall(self, data: bytes) -> None:
|
||||
"""Mock sendall - stores sent command."""
|
||||
command = data.decode().strip()
|
||||
self.sent_commands.append(command)
|
||||
# Prepare response for this command
|
||||
response = self.default_response
|
||||
for prefix, resp in self.responses.items():
|
||||
if command.startswith(prefix):
|
||||
response = resp
|
||||
break
|
||||
self._response_buffer = response.encode()
|
||||
|
||||
def shutdown(self, how: int) -> None:
|
||||
"""Mock shutdown - does nothing."""
|
||||
pass
|
||||
|
||||
def recv(self, bufsize: int) -> bytes:
|
||||
"""Mock recv - returns prepared response."""
|
||||
if self._response_buffer:
|
||||
data = self._response_buffer[:bufsize]
|
||||
self._response_buffer = self._response_buffer[bufsize:]
|
||||
return data
|
||||
return b""
|
||||
|
||||
def close(self) -> None:
|
||||
"""Mock close."""
|
||||
self._closed = True
|
||||
|
||||
def fileno(self) -> int:
|
||||
"""Mock fileno for select."""
|
||||
return 999
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
|
||||
class HAProxyResponseBuilder:
|
||||
"""Helper class to build HAProxy-style responses for tests."""
|
||||
|
||||
@staticmethod
|
||||
def servers_state(servers: list[dict[str, Any]]) -> str:
|
||||
"""Build a 'show servers state' response.
|
||||
|
||||
Args:
|
||||
servers: List of server dicts with keys:
|
||||
- be_id: Backend ID (int)
|
||||
- be_name: Backend name
|
||||
- srv_id: Server ID (int)
|
||||
- srv_name: Server name
|
||||
- srv_addr: Server address (IP)
|
||||
- srv_op_state: Operational state (int)
|
||||
- srv_admin_state: Admin state (int)
|
||||
- srv_port: Server port
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted server state string
|
||||
"""
|
||||
lines = ["1"] # Version line
|
||||
lines.append("# be_id be_name srv_id srv_name srv_addr srv_op_state srv_admin_state srv_uweight srv_iweight srv_time_since_last_change srv_check_status srv_check_result srv_check_health srv_check_state srv_agent_state bk_f_forced_id srv_f_forced_id srv_fqdn srv_port srvrecord")
|
||||
for srv in servers:
|
||||
# Fill in defaults for minimal state line (need 19+ columns)
|
||||
line_parts = [
|
||||
str(srv.get("be_id", 1)),
|
||||
srv.get("be_name", "pool_1"),
|
||||
str(srv.get("srv_id", 1)),
|
||||
srv.get("srv_name", "pool_1_1"),
|
||||
srv.get("srv_addr", "0.0.0.0"),
|
||||
str(srv.get("srv_op_state", 2)),
|
||||
str(srv.get("srv_admin_state", 0)),
|
||||
"1", # srv_uweight
|
||||
"1", # srv_iweight
|
||||
"100", # srv_time_since_last_change
|
||||
"6", # srv_check_status
|
||||
"3", # srv_check_result
|
||||
"4", # srv_check_health
|
||||
"6", # srv_check_state
|
||||
"0", # srv_agent_state
|
||||
"0", # bk_f_forced_id
|
||||
"0", # srv_f_forced_id
|
||||
"-", # srv_fqdn
|
||||
str(srv.get("srv_port", 0)), # srv_port
|
||||
]
|
||||
lines.append(" ".join(line_parts))
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def stat_csv(entries: list[dict[str, Any]]) -> str:
|
||||
"""Build a 'show stat' CSV response.
|
||||
|
||||
Args:
|
||||
entries: List of stat dicts with keys:
|
||||
- pxname: Proxy name
|
||||
- svname: Server name (or FRONTEND/BACKEND)
|
||||
- scur: Current sessions (optional, default 0)
|
||||
- status: Status (UP/DOWN/MAINT)
|
||||
- weight: Weight (optional, default 1)
|
||||
- check_status: Check status (optional)
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted CSV stat string
|
||||
"""
|
||||
lines = ["# pxname,svname,qcur,qmax,scur,smax,slim,stot,bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status,weight,act,bck,chkfail,chkdown,lastchg,downtime,qlimit,pid,iid,sid,throttle,lbtot,tracked,type,rate,rate_lim,rate_max,check_status,check_code,check_duration,hrsp_1xx,hrsp_2xx,hrsp_3xx,hrsp_4xx,hrsp_5xx,hrsp_other,hanafail,req_rate,req_rate_max,req_tot,cli_abrt,srv_abrt,"]
|
||||
for entry in entries:
|
||||
# Build CSV row with proper field positions
|
||||
# Fields: pxname(0), svname(1), qcur(2), qmax(3), scur(4), smax(5), slim(6), ...
|
||||
# status(17), weight(18), ..., check_status(36)
|
||||
row = [""] * 50
|
||||
row[0] = entry.get("pxname", "pool_1")
|
||||
row[1] = entry.get("svname", "pool_1_1")
|
||||
row[4] = str(entry.get("scur", 0)) # SCUR
|
||||
row[5] = str(entry.get("smax", 0)) # SMAX
|
||||
row[17] = entry.get("status", "UP") # STATUS
|
||||
row[18] = str(entry.get("weight", 1)) # WEIGHT
|
||||
row[36] = entry.get("check_status", "L4OK") # CHECK_STATUS
|
||||
lines.append(",".join(row))
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def info(version: str = "3.3.2", uptime: int = 3600) -> str:
|
||||
"""Build a 'show info' response.
|
||||
|
||||
Args:
|
||||
version: HAProxy version string
|
||||
uptime: Uptime in seconds
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted info string
|
||||
"""
|
||||
return f"""Name: HAProxy
|
||||
Version: {version}
|
||||
Release_date: 2024/01/01
|
||||
Nbthread: 4
|
||||
Nbproc: 1
|
||||
Process_num: 1
|
||||
Pid: 1
|
||||
Uptime: 1h0m0s
|
||||
Uptime_sec: {uptime}
|
||||
Memmax_MB: 0
|
||||
PoolAlloc_MB: 0
|
||||
PoolUsed_MB: 0
|
||||
PoolFailed: 0
|
||||
Ulimit-n: 200015
|
||||
Maxsock: 200015
|
||||
Maxconn: 100000
|
||||
Hard_maxconn: 100000
|
||||
CurrConns: 5
|
||||
CumConns: 1000
|
||||
CumReq: 5000"""
|
||||
|
||||
@staticmethod
|
||||
def map_show(entries: list[tuple[str, str]]) -> str:
|
||||
"""Build a 'show map' response.
|
||||
|
||||
Args:
|
||||
entries: List of (key, value) tuples
|
||||
|
||||
Returns:
|
||||
HAProxy-formatted map show string
|
||||
"""
|
||||
lines = []
|
||||
for i, (key, value) in enumerate(entries):
|
||||
lines.append(f"0x{i:08x} {key} {value}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_socket_class():
|
||||
"""Fixture that returns MockSocket class for custom configuration."""
|
||||
return MockSocket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def response_builder():
|
||||
"""Fixture that returns HAProxyResponseBuilder class."""
|
||||
return HAProxyResponseBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_haproxy_socket(mock_socket_class, response_builder):
|
||||
"""Fixture providing a pre-configured mock socket with common responses."""
|
||||
responses = {
|
||||
"show info": response_builder.info(),
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
"show stat": response_builder.stat_csv([]),
|
||||
"show map": response_builder.map_show([]),
|
||||
"show backend": "pool_1\npool_2\npool_3",
|
||||
"add map": "",
|
||||
"del map": "",
|
||||
"set server": "",
|
||||
}
|
||||
return mock_socket_class(responses=responses)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_config_dir(tmp_path):
|
||||
"""Fixture providing a temporary directory with config files."""
|
||||
# Create config files
|
||||
map_file = tmp_path / "domains.map"
|
||||
map_file.write_text("# Domain to Backend mapping\n")
|
||||
|
||||
wildcards_file = tmp_path / "wildcards.map"
|
||||
wildcards_file.write_text("# Wildcard Domain mapping\n")
|
||||
|
||||
servers_file = tmp_path / "servers.json"
|
||||
servers_file.write_text("{}")
|
||||
|
||||
certs_file = tmp_path / "certificates.json"
|
||||
certs_file.write_text('{"domains": []}')
|
||||
|
||||
state_file = tmp_path / "servers.state"
|
||||
state_file.write_text("")
|
||||
|
||||
return {
|
||||
"dir": tmp_path,
|
||||
"map_file": str(map_file),
|
||||
"wildcards_file": str(wildcards_file),
|
||||
"servers_file": str(servers_file),
|
||||
"certs_file": str(certs_file),
|
||||
"state_file": str(state_file),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_config_paths(temp_config_dir):
|
||||
"""Fixture that patches config module paths to use temporary directory."""
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.config",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
STATE_FILE=temp_config_dir["state_file"],
|
||||
):
|
||||
# Also patch file_ops module which imports these
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.file_ops",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
):
|
||||
yield temp_config_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subprocess():
|
||||
"""Fixture that mocks subprocess.run for external command testing."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
# Default to successful command
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="",
|
||||
stderr="",
|
||||
)
|
||||
yield mock_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_socket_module(mock_haproxy_socket):
|
||||
"""Fixture that patches socket module to use mock socket."""
|
||||
|
||||
def create_socket(*args, **kwargs):
|
||||
return mock_haproxy_socket
|
||||
|
||||
with patch("socket.socket", side_effect=create_socket):
|
||||
yield mock_haproxy_socket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_select():
|
||||
"""Fixture that patches select.select for socket recv loops."""
|
||||
with patch("select.select") as mock_sel:
|
||||
# Default: socket is ready immediately
|
||||
mock_sel.return_value = ([True], [], [])
|
||||
yield mock_sel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_servers_config():
|
||||
"""Sample servers.json content for testing."""
|
||||
return {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||
"2": {"ip": "10.0.0.2", "http_port": 80},
|
||||
},
|
||||
"api.example.com": {
|
||||
"1": {"ip": "10.0.0.10", "http_port": 8080},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_map_entries():
|
||||
"""Sample domains.map entries for testing."""
|
||||
return [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
("api.example.com", "pool_2"),
|
||||
(".api.example.com", "pool_2"),
|
||||
]
|
||||
1
tests/integration/__init__.py
Normal file
1
tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration tests for HAProxy MCP Server."""
|
||||
1
tests/unit/__init__.py
Normal file
1
tests/unit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for HAProxy MCP Server."""
|
||||
198
tests/unit/test_config.py
Normal file
198
tests/unit/test_config.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Unit tests for config module."""
|
||||
|
||||
import os
|
||||
import re
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.config import (
|
||||
DOMAIN_PATTERN,
|
||||
BACKEND_NAME_PATTERN,
|
||||
NON_ALNUM_PATTERN,
|
||||
StatField,
|
||||
StateField,
|
||||
POOL_COUNT,
|
||||
MAX_SLOTS,
|
||||
MAX_RESPONSE_SIZE,
|
||||
SOCKET_TIMEOUT,
|
||||
MAX_BULK_SERVERS,
|
||||
)
|
||||
|
||||
|
||||
class TestDomainPattern:
|
||||
"""Tests for DOMAIN_PATTERN regex."""
|
||||
|
||||
def test_simple_domain(self):
|
||||
"""Match simple domain."""
|
||||
assert DOMAIN_PATTERN.match("example.com") is not None
|
||||
|
||||
def test_subdomain(self):
|
||||
"""Match subdomain."""
|
||||
assert DOMAIN_PATTERN.match("api.example.com") is not None
|
||||
|
||||
def test_deep_subdomain(self):
|
||||
"""Match deep subdomain."""
|
||||
assert DOMAIN_PATTERN.match("a.b.c.d.example.com") is not None
|
||||
|
||||
def test_hyphenated_domain(self):
|
||||
"""Match domain with hyphens."""
|
||||
assert DOMAIN_PATTERN.match("my-api.example-site.com") is not None
|
||||
|
||||
def test_numeric_labels(self):
|
||||
"""Match domain with numeric labels."""
|
||||
assert DOMAIN_PATTERN.match("api123.example.com") is not None
|
||||
|
||||
def test_invalid_starts_with_hyphen(self):
|
||||
"""Reject domain starting with hyphen."""
|
||||
assert DOMAIN_PATTERN.match("-example.com") is None
|
||||
|
||||
def test_invalid_ends_with_hyphen(self):
|
||||
"""Reject label ending with hyphen."""
|
||||
assert DOMAIN_PATTERN.match("example-.com") is None
|
||||
|
||||
def test_invalid_underscore(self):
|
||||
"""Reject domain with underscore."""
|
||||
assert DOMAIN_PATTERN.match("my_api.example.com") is None
|
||||
|
||||
|
||||
class TestBackendNamePattern:
|
||||
"""Tests for BACKEND_NAME_PATTERN regex."""
|
||||
|
||||
def test_pool_name(self):
|
||||
"""Match pool backend names."""
|
||||
assert BACKEND_NAME_PATTERN.match("pool_1") is not None
|
||||
assert BACKEND_NAME_PATTERN.match("pool_100") is not None
|
||||
|
||||
def test_alphanumeric(self):
|
||||
"""Match alphanumeric names."""
|
||||
assert BACKEND_NAME_PATTERN.match("backend123") is not None
|
||||
|
||||
def test_underscore(self):
|
||||
"""Match names with underscores."""
|
||||
assert BACKEND_NAME_PATTERN.match("my_backend") is not None
|
||||
|
||||
def test_hyphen(self):
|
||||
"""Match names with hyphens."""
|
||||
assert BACKEND_NAME_PATTERN.match("my-backend") is not None
|
||||
|
||||
def test_mixed(self):
|
||||
"""Match mixed character names."""
|
||||
assert BACKEND_NAME_PATTERN.match("api_example-com_backend") is not None
|
||||
|
||||
def test_invalid_dot(self):
|
||||
"""Reject names with dots."""
|
||||
assert BACKEND_NAME_PATTERN.match("my.backend") is None
|
||||
|
||||
def test_invalid_special_chars(self):
|
||||
"""Reject names with special characters."""
|
||||
assert BACKEND_NAME_PATTERN.match("my@backend") is None
|
||||
assert BACKEND_NAME_PATTERN.match("my/backend") is None
|
||||
|
||||
|
||||
class TestNonAlnumPattern:
|
||||
"""Tests for NON_ALNUM_PATTERN regex."""
|
||||
|
||||
def test_replace_dots(self):
|
||||
"""Replace dots."""
|
||||
result = NON_ALNUM_PATTERN.sub("_", "example.com")
|
||||
assert result == "example_com"
|
||||
|
||||
def test_replace_hyphens(self):
|
||||
"""Replace hyphens."""
|
||||
result = NON_ALNUM_PATTERN.sub("_", "my-api")
|
||||
assert result == "my_api"
|
||||
|
||||
def test_preserve_alphanumeric(self):
|
||||
"""Preserve alphanumeric characters."""
|
||||
result = NON_ALNUM_PATTERN.sub("_", "abc123")
|
||||
assert result == "abc123"
|
||||
|
||||
def test_complex_replacement(self):
|
||||
"""Complex domain replacement."""
|
||||
result = NON_ALNUM_PATTERN.sub("_", "api.my-site.example.com")
|
||||
assert result == "api_my_site_example_com"
|
||||
|
||||
|
||||
class TestStatField:
|
||||
"""Tests for StatField constants."""
|
||||
|
||||
def test_field_indices(self):
|
||||
"""Verify stat field indices."""
|
||||
assert StatField.PXNAME == 0
|
||||
assert StatField.SVNAME == 1
|
||||
assert StatField.SCUR == 4
|
||||
assert StatField.SMAX == 6
|
||||
assert StatField.STATUS == 17
|
||||
assert StatField.WEIGHT == 18
|
||||
assert StatField.CHECK_STATUS == 36
|
||||
|
||||
|
||||
class TestStateField:
|
||||
"""Tests for StateField constants."""
|
||||
|
||||
def test_field_indices(self):
|
||||
"""Verify state field indices."""
|
||||
assert StateField.BE_ID == 0
|
||||
assert StateField.BE_NAME == 1
|
||||
assert StateField.SRV_ID == 2
|
||||
assert StateField.SRV_NAME == 3
|
||||
assert StateField.SRV_ADDR == 4
|
||||
assert StateField.SRV_OP_STATE == 5
|
||||
assert StateField.SRV_ADMIN_STATE == 6
|
||||
assert StateField.SRV_PORT == 18
|
||||
|
||||
|
||||
class TestConfigConstants:
|
||||
"""Tests for configuration constants."""
|
||||
|
||||
def test_pool_count(self):
|
||||
"""Pool count has expected value."""
|
||||
assert POOL_COUNT == 100
|
||||
|
||||
def test_max_slots(self):
|
||||
"""Max slots has expected value."""
|
||||
assert MAX_SLOTS == 10
|
||||
|
||||
def test_max_response_size(self):
|
||||
"""Max response size is reasonable."""
|
||||
assert MAX_RESPONSE_SIZE == 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
def test_socket_timeout(self):
|
||||
"""Socket timeout is reasonable."""
|
||||
assert SOCKET_TIMEOUT == 5
|
||||
|
||||
def test_max_bulk_servers(self):
|
||||
"""Max bulk servers is reasonable."""
|
||||
assert MAX_BULK_SERVERS == 10
|
||||
|
||||
|
||||
class TestEnvironmentVariables:
|
||||
"""Tests for environment variable configuration."""
|
||||
|
||||
def test_default_mcp_host(self):
|
||||
"""Default MCP host is 0.0.0.0."""
|
||||
# Import fresh to get defaults
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Re-import to test defaults
|
||||
from importlib import reload
|
||||
import haproxy_mcp.config as config
|
||||
reload(config)
|
||||
# Note: Due to Python's module caching, this test verifies the
|
||||
# default values are what we expect from the source code
|
||||
assert config.MCP_HOST == "0.0.0.0"
|
||||
|
||||
def test_default_mcp_port(self):
|
||||
"""Default MCP port is 8000."""
|
||||
from haproxy_mcp.config import MCP_PORT
|
||||
assert MCP_PORT == 8000
|
||||
|
||||
def test_default_haproxy_host(self):
|
||||
"""Default HAProxy host is localhost."""
|
||||
from haproxy_mcp.config import HAPROXY_HOST
|
||||
assert HAPROXY_HOST == "localhost"
|
||||
|
||||
def test_default_haproxy_port(self):
|
||||
"""Default HAProxy port is 9999."""
|
||||
from haproxy_mcp.config import HAPROXY_PORT
|
||||
assert HAPROXY_PORT == 9999
|
||||
498
tests/unit/test_file_ops.py
Normal file
498
tests/unit/test_file_ops.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""Unit tests for file_ops module."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.file_ops import (
|
||||
atomic_write_file,
|
||||
get_map_contents,
|
||||
save_map_file,
|
||||
get_domain_backend,
|
||||
split_domain_entries,
|
||||
is_legacy_backend,
|
||||
get_legacy_backend_name,
|
||||
get_backend_and_prefix,
|
||||
load_servers_config,
|
||||
save_servers_config,
|
||||
add_server_to_config,
|
||||
remove_server_from_config,
|
||||
remove_domain_from_config,
|
||||
load_certs_config,
|
||||
save_certs_config,
|
||||
add_cert_to_config,
|
||||
remove_cert_from_config,
|
||||
)
|
||||
|
||||
|
||||
class TestAtomicWriteFile:
|
||||
"""Tests for atomic_write_file function."""
|
||||
|
||||
def test_write_new_file(self, tmp_path):
|
||||
"""Write to a new file."""
|
||||
file_path = str(tmp_path / "test.txt")
|
||||
content = "Hello, World!"
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
assert os.path.exists(file_path)
|
||||
with open(file_path) as f:
|
||||
assert f.read() == content
|
||||
|
||||
def test_overwrite_existing_file(self, tmp_path):
|
||||
"""Overwrite an existing file."""
|
||||
file_path = str(tmp_path / "test.txt")
|
||||
with open(file_path, "w") as f:
|
||||
f.write("Old content")
|
||||
|
||||
atomic_write_file(file_path, "New content")
|
||||
|
||||
with open(file_path) as f:
|
||||
assert f.read() == "New content"
|
||||
|
||||
def test_preserves_directory(self, tmp_path):
|
||||
"""Writing does not create intermediate directories."""
|
||||
file_path = str(tmp_path / "subdir" / "test.txt")
|
||||
|
||||
with pytest.raises(IOError):
|
||||
atomic_write_file(file_path, "content")
|
||||
|
||||
def test_unicode_content(self, tmp_path):
|
||||
"""Unicode content is properly written."""
|
||||
file_path = str(tmp_path / "unicode.txt")
|
||||
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
assert f.read() == content
|
||||
|
||||
def test_multiline_content(self, tmp_path):
|
||||
"""Multi-line content is properly written."""
|
||||
file_path = str(tmp_path / "multiline.txt")
|
||||
content = "line1\nline2\nline3"
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
with open(file_path) as f:
|
||||
assert f.read() == content
|
||||
|
||||
|
||||
class TestGetMapContents:
|
||||
"""Tests for get_map_contents function."""
|
||||
|
||||
def test_read_map_file(self, patch_config_paths):
|
||||
"""Read entries from map file."""
|
||||
# Write test content to map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert ("api.example.com", "pool_2") in entries
|
||||
|
||||
def test_read_both_map_files(self, patch_config_paths):
|
||||
"""Read entries from both domains.map and wildcards.map."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert (".example.com", "pool_1") in entries
|
||||
|
||||
def test_skip_comments(self, patch_config_paths):
|
||||
"""Comments are skipped."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("# This is a comment\n")
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("# Another comment\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0] == ("example.com", "pool_1")
|
||||
|
||||
def test_skip_empty_lines(self, patch_config_paths):
|
||||
"""Empty lines are skipped."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("\n")
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert len(entries) == 2
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty list."""
|
||||
os.unlink(patch_config_paths["map_file"])
|
||||
os.unlink(patch_config_paths["wildcards_file"])
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert entries == []
|
||||
|
||||
|
||||
class TestSplitDomainEntries:
|
||||
"""Tests for split_domain_entries function."""
|
||||
|
||||
def test_split_entries(self):
|
||||
"""Split entries into exact and wildcard."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
("api.example.com", "pool_2"),
|
||||
(".api.example.com", "pool_2"),
|
||||
]
|
||||
|
||||
exact, wildcards = split_domain_entries(entries)
|
||||
|
||||
assert len(exact) == 2
|
||||
assert len(wildcards) == 2
|
||||
assert ("example.com", "pool_1") in exact
|
||||
assert (".example.com", "pool_1") in wildcards
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""Empty entries returns empty lists."""
|
||||
exact, wildcards = split_domain_entries([])
|
||||
|
||||
assert exact == []
|
||||
assert wildcards == []
|
||||
|
||||
def test_all_exact(self):
|
||||
"""All exact entries."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
("api.example.com", "pool_2"),
|
||||
]
|
||||
|
||||
exact, wildcards = split_domain_entries(entries)
|
||||
|
||||
assert len(exact) == 2
|
||||
assert len(wildcards) == 0
|
||||
|
||||
|
||||
class TestSaveMapFile:
|
||||
"""Tests for save_map_file function."""
|
||||
|
||||
def test_save_entries(self, patch_config_paths):
|
||||
"""Save entries to separate map files."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
]
|
||||
|
||||
save_map_file(entries)
|
||||
|
||||
# Check exact domains file
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "example.com pool_1" in content
|
||||
|
||||
# Check wildcards file
|
||||
with open(patch_config_paths["wildcards_file"]) as f:
|
||||
content = f.read()
|
||||
assert ".example.com pool_1" in content
|
||||
|
||||
def test_sorted_output(self, patch_config_paths):
|
||||
"""Entries are sorted in output."""
|
||||
entries = [
|
||||
("z.example.com", "pool_3"),
|
||||
("a.example.com", "pool_1"),
|
||||
("m.example.com", "pool_2"),
|
||||
]
|
||||
|
||||
save_map_file(entries)
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||
|
||||
assert lines[0] == "a.example.com pool_1"
|
||||
assert lines[1] == "m.example.com pool_2"
|
||||
assert lines[2] == "z.example.com pool_3"
|
||||
|
||||
|
||||
class TestGetDomainBackend:
|
||||
"""Tests for get_domain_backend function."""
|
||||
|
||||
def test_find_existing_domain(self, patch_config_paths):
|
||||
"""Find backend for existing domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
backend = get_domain_backend("example.com")
|
||||
|
||||
assert backend == "pool_1"
|
||||
|
||||
def test_domain_not_found(self, patch_config_paths):
|
||||
"""Non-existent domain returns None."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
backend = get_domain_backend("other.com")
|
||||
|
||||
assert backend is None
|
||||
|
||||
|
||||
class TestIsLegacyBackend:
|
||||
"""Tests for is_legacy_backend function."""
|
||||
|
||||
def test_pool_backend(self):
|
||||
"""Pool backend is not legacy."""
|
||||
assert is_legacy_backend("pool_1") is False
|
||||
assert is_legacy_backend("pool_100") is False
|
||||
|
||||
def test_legacy_backend(self):
|
||||
"""Non-pool backend is legacy."""
|
||||
assert is_legacy_backend("api_example_com_backend") is True
|
||||
assert is_legacy_backend("static_backend") is True
|
||||
|
||||
|
||||
class TestGetLegacyBackendName:
|
||||
"""Tests for get_legacy_backend_name function."""
|
||||
|
||||
def test_convert_domain(self):
|
||||
"""Convert domain to legacy backend name."""
|
||||
result = get_legacy_backend_name("api.example.com")
|
||||
assert result == "api_example_com_backend"
|
||||
|
||||
|
||||
class TestGetBackendAndPrefix:
|
||||
"""Tests for get_backend_and_prefix function."""
|
||||
|
||||
def test_pool_backend(self, patch_config_paths):
|
||||
"""Pool backend returns pool-based prefix."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_5\n")
|
||||
|
||||
backend, prefix = get_backend_and_prefix("example.com")
|
||||
|
||||
assert backend == "pool_5"
|
||||
assert prefix == "pool_5"
|
||||
|
||||
def test_unknown_domain_uses_legacy(self, patch_config_paths):
|
||||
"""Unknown domain uses legacy backend naming."""
|
||||
backend, prefix = get_backend_and_prefix("unknown.com")
|
||||
|
||||
assert backend == "unknown_com_backend"
|
||||
assert prefix == "unknown_com"
|
||||
|
||||
|
||||
class TestLoadServersConfig:
|
||||
"""Tests for load_servers_config function."""
|
||||
|
||||
def test_load_existing_config(self, patch_config_paths, sample_servers_config):
|
||||
"""Load existing config file."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert "example.com" in config
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty dict."""
|
||||
os.unlink(patch_config_paths["servers_file"])
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
def test_invalid_json(self, patch_config_paths):
|
||||
"""Invalid JSON returns empty dict."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestSaveServersConfig:
|
||||
"""Tests for save_servers_config function."""
|
||||
|
||||
def test_save_config(self, patch_config_paths):
|
||||
"""Save config to file."""
|
||||
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
||||
|
||||
save_servers_config(config)
|
||||
|
||||
with open(patch_config_paths["servers_file"]) as f:
|
||||
loaded = json.load(f)
|
||||
assert loaded == config
|
||||
|
||||
|
||||
class TestAddServerToConfig:
|
||||
"""Tests for add_server_to_config function."""
|
||||
|
||||
def test_add_to_empty_config(self, patch_config_paths):
|
||||
"""Add server to empty config."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
config = load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert config["example.com"]["1"]["http_port"] == 80
|
||||
|
||||
def test_add_to_existing_domain(self, patch_config_paths):
|
||||
"""Add server to domain with existing servers."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" in config["example.com"]
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_overwrite_existing_slot(self, patch_config_paths):
|
||||
"""Overwrite existing slot."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 1, "10.0.0.99", 8080)
|
||||
|
||||
config = load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
|
||||
assert config["example.com"]["1"]["http_port"] == 8080
|
||||
|
||||
|
||||
class TestRemoveServerFromConfig:
|
||||
"""Tests for remove_server_from_config function."""
|
||||
|
||||
def test_remove_existing_server(self, patch_config_paths):
|
||||
"""Remove existing server."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
remove_server_from_config("example.com", 1)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" not in config["example.com"]
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_remove_last_server_removes_domain(self, patch_config_paths):
|
||||
"""Removing last server removes domain entry."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_server_from_config("example.com", 1)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" not in config
|
||||
|
||||
def test_remove_nonexistent_server(self, patch_config_paths):
|
||||
"""Removing non-existent server is a no-op."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_server_from_config("example.com", 99) # Non-existent slot
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" in config["example.com"]
|
||||
|
||||
|
||||
class TestRemoveDomainFromConfig:
|
||||
"""Tests for remove_domain_from_config function."""
|
||||
|
||||
def test_remove_existing_domain(self, patch_config_paths):
|
||||
"""Remove existing domain."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("other.com", 1, "10.0.0.2", 80)
|
||||
|
||||
remove_domain_from_config("example.com")
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" not in config
|
||||
assert "other.com" in config
|
||||
|
||||
def test_remove_nonexistent_domain(self, patch_config_paths):
|
||||
"""Removing non-existent domain is a no-op."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_domain_from_config("other.com") # Non-existent
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" in config
|
||||
|
||||
|
||||
class TestLoadCertsConfig:
|
||||
"""Tests for load_certs_config function."""
|
||||
|
||||
def test_load_existing_config(self, patch_config_paths):
|
||||
"""Load existing certs config."""
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "other.com"]}, f)
|
||||
|
||||
domains = load_certs_config()
|
||||
|
||||
assert "example.com" in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty list."""
|
||||
os.unlink(patch_config_paths["certs_file"])
|
||||
|
||||
domains = load_certs_config()
|
||||
|
||||
assert domains == []
|
||||
|
||||
|
||||
class TestSaveCertsConfig:
|
||||
"""Tests for save_certs_config function."""
|
||||
|
||||
def test_save_domains(self, patch_config_paths):
|
||||
"""Save domains to certs config."""
|
||||
save_certs_config(["z.com", "a.com"])
|
||||
|
||||
with open(patch_config_paths["certs_file"]) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Should be sorted
|
||||
assert data["domains"] == ["a.com", "z.com"]
|
||||
|
||||
|
||||
class TestAddCertToConfig:
|
||||
"""Tests for add_cert_to_config function."""
|
||||
|
||||
def test_add_new_cert(self, patch_config_paths):
|
||||
"""Add new certificate domain."""
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" in domains
|
||||
|
||||
def test_add_duplicate_cert(self, patch_config_paths):
|
||||
"""Adding duplicate cert is a no-op."""
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
assert domains.count("example.com") == 1
|
||||
|
||||
|
||||
class TestRemoveCertFromConfig:
|
||||
"""Tests for remove_cert_from_config function."""
|
||||
|
||||
def test_remove_existing_cert(self, patch_config_paths):
|
||||
"""Remove existing certificate domain."""
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("other.com")
|
||||
|
||||
remove_cert_from_config("example.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" not in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
def test_remove_nonexistent_cert(self, patch_config_paths):
|
||||
"""Removing non-existent cert is a no-op."""
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
remove_cert_from_config("other.com") # Non-existent
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" in domains
|
||||
279
tests/unit/test_haproxy_client.py
Normal file
279
tests/unit/test_haproxy_client.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Unit tests for haproxy_client module."""
|
||||
|
||||
import socket
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.haproxy_client import (
|
||||
haproxy_cmd,
|
||||
haproxy_cmd_checked,
|
||||
haproxy_cmd_batch,
|
||||
reload_haproxy,
|
||||
)
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
|
||||
class TestHaproxyCmd:
|
||||
"""Tests for haproxy_cmd function."""
|
||||
|
||||
def test_successful_command(self, mock_socket_class, mock_select):
|
||||
"""Successful command execution returns response."""
|
||||
mock_sock = mock_socket_class(
|
||||
responses={"show info": "Version: 3.3.2\nUptime_sec: 3600"}
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
result = haproxy_cmd("show info")
|
||||
|
||||
assert "Version: 3.3.2" in result
|
||||
assert "show info" in mock_sock.sent_commands
|
||||
|
||||
def test_empty_response(self, mock_socket_class, mock_select):
|
||||
"""Command with empty response returns empty string."""
|
||||
mock_sock = mock_socket_class(default_response="")
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
result = haproxy_cmd("set server pool_1/pool_1_1 state ready")
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_connection_refused_error(self, mock_select):
|
||||
"""Connection refused raises HaproxyError."""
|
||||
with patch("socket.socket") as mock_socket:
|
||||
mock_socket.return_value.__enter__ = MagicMock(side_effect=ConnectionRefusedError())
|
||||
mock_socket.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd("show info")
|
||||
|
||||
assert "Connection refused" in str(exc_info.value)
|
||||
|
||||
def test_socket_timeout_error(self, mock_select):
|
||||
"""Socket timeout raises HaproxyError."""
|
||||
with patch("socket.socket") as mock_socket:
|
||||
mock_socket.return_value.__enter__ = MagicMock(side_effect=socket.timeout())
|
||||
mock_socket.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd("show info")
|
||||
|
||||
assert "timeout" in str(exc_info.value).lower()
|
||||
|
||||
def test_unicode_decode_error(self, mock_socket_class, mock_select):
|
||||
"""Invalid UTF-8 response raises HaproxyError."""
|
||||
# Create a mock that returns invalid UTF-8 bytes
|
||||
class BadUtf8Socket(mock_socket_class):
|
||||
def sendall(self, data):
|
||||
self.sent_commands.append(data.decode().strip())
|
||||
self._response_buffer = b"\xff\xfe" # Invalid UTF-8
|
||||
|
||||
mock_sock = BadUtf8Socket()
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd("show info")
|
||||
|
||||
assert "UTF-8" in str(exc_info.value)
|
||||
|
||||
def test_multiline_response(self, mock_socket_class, mock_select):
|
||||
"""Multi-line response is properly returned."""
|
||||
multi_line = "pool_1\npool_2\npool_3"
|
||||
mock_sock = mock_socket_class(responses={"show backend": multi_line})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
result = haproxy_cmd("show backend")
|
||||
|
||||
assert "pool_1" in result
|
||||
assert "pool_2" in result
|
||||
assert "pool_3" in result
|
||||
|
||||
|
||||
class TestHaproxyCmdChecked:
|
||||
"""Tests for haproxy_cmd_checked function."""
|
||||
|
||||
def test_successful_command(self, mock_socket_class, mock_select):
|
||||
"""Successful command returns response."""
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
result = haproxy_cmd_checked("set server pool_1/pool_1_1 state ready")
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_error_response_no_such(self, mock_socket_class, mock_select):
|
||||
"""Response containing 'No such' raises HaproxyError."""
|
||||
mock_sock = mock_socket_class(
|
||||
responses={"set server": "No such server."}
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd_checked("set server pool_99/pool_99_1 state ready")
|
||||
|
||||
assert "No such" in str(exc_info.value)
|
||||
|
||||
def test_error_response_not_found(self, mock_socket_class, mock_select):
|
||||
"""Response containing 'not found' raises HaproxyError."""
|
||||
mock_sock = mock_socket_class(
|
||||
responses={"del map": "Backend not found."}
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd_checked("del map /path example.com")
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
def test_error_response_error(self, mock_socket_class, mock_select):
|
||||
"""Response containing 'error' raises HaproxyError."""
|
||||
mock_sock = mock_socket_class(
|
||||
responses={"set server": "error: invalid state"}
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd_checked("set server pool_1/pool_1_1 state invalid")
|
||||
|
||||
assert "error" in str(exc_info.value).lower()
|
||||
|
||||
def test_error_response_failed(self, mock_socket_class, mock_select):
|
||||
"""Response containing 'failed' raises HaproxyError."""
|
||||
mock_sock = mock_socket_class(
|
||||
responses={"set server": "Command failed"}
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with pytest.raises(HaproxyError) as exc_info:
|
||||
haproxy_cmd_checked("set server pool_1/pool_1_1 addr bad")
|
||||
|
||||
assert "failed" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestHaproxyCmdBatch:
|
||||
"""Tests for haproxy_cmd_batch function."""
|
||||
|
||||
def test_empty_commands(self):
|
||||
"""Empty command list returns empty list."""
|
||||
result = haproxy_cmd_batch([])
|
||||
assert result == []
|
||||
|
||||
def test_single_command(self, mock_socket_class, mock_select):
|
||||
"""Single command uses haproxy_cmd_checked."""
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
result = haproxy_cmd_batch(["set server pool_1/pool_1_1 state ready"])
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_multiple_commands(self, mock_socket_class, mock_select):
|
||||
"""Multiple commands are executed separately."""
|
||||
# Each command gets its own socket connection
|
||||
call_count = 0
|
||||
def create_mock_socket(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", side_effect=create_mock_socket):
|
||||
result = haproxy_cmd_batch([
|
||||
"set server pool_1/pool_1_1 addr 10.0.0.1 port 80",
|
||||
"set server pool_1/pool_1_1 state ready",
|
||||
])
|
||||
|
||||
assert len(result) == 2
|
||||
assert call_count == 2 # One connection per command
|
||||
|
||||
def test_error_in_batch_raises(self, mock_socket_class, mock_select):
|
||||
"""Error in batch command raises immediately."""
|
||||
mock_sock = mock_socket_class(
|
||||
responses={
|
||||
"set server pool_1/pool_1_1 addr": "",
|
||||
"set server pool_1/pool_1_1 state": "No such server",
|
||||
}
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
def create_socket(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_socket_class(responses={"set server": ""})
|
||||
else:
|
||||
return mock_socket_class(responses={"set server": "No such server"})
|
||||
|
||||
with patch("socket.socket", side_effect=create_socket):
|
||||
with pytest.raises(HaproxyError):
|
||||
haproxy_cmd_batch([
|
||||
"set server pool_1/pool_1_1 addr 10.0.0.1 port 80",
|
||||
"set server pool_1/pool_1_1 state ready",
|
||||
])
|
||||
|
||||
|
||||
class TestReloadHaproxy:
|
||||
"""Tests for reload_haproxy function."""
|
||||
|
||||
def test_successful_reload(self, mock_subprocess):
|
||||
"""Successful reload returns (True, 'OK')."""
|
||||
mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
|
||||
success, message = reload_haproxy()
|
||||
|
||||
assert success is True
|
||||
assert message == "OK"
|
||||
|
||||
def test_validation_failure(self, mock_subprocess):
|
||||
"""Config validation failure returns (False, error)."""
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="[ALERT] Invalid configuration"
|
||||
)
|
||||
|
||||
success, message = reload_haproxy()
|
||||
|
||||
assert success is False
|
||||
assert "validation failed" in message.lower()
|
||||
assert "Invalid configuration" in message
|
||||
|
||||
def test_reload_failure(self, mock_subprocess):
|
||||
"""Reload command failure returns (False, error)."""
|
||||
# First call (validation) succeeds, second call (reload) fails
|
||||
mock_subprocess.side_effect = [
|
||||
MagicMock(returncode=0, stdout="", stderr=""),
|
||||
MagicMock(returncode=1, stdout="", stderr="Container not found"),
|
||||
]
|
||||
|
||||
success, message = reload_haproxy()
|
||||
|
||||
assert success is False
|
||||
assert "Reload failed" in message
|
||||
|
||||
def test_podman_not_found(self, mock_subprocess):
|
||||
"""Podman not found returns (False, error)."""
|
||||
mock_subprocess.side_effect = FileNotFoundError()
|
||||
|
||||
success, message = reload_haproxy()
|
||||
|
||||
assert success is False
|
||||
assert "podman" in message.lower()
|
||||
|
||||
def test_subprocess_timeout(self, mock_subprocess):
|
||||
"""Subprocess timeout returns (False, error)."""
|
||||
import subprocess
|
||||
mock_subprocess.side_effect = subprocess.TimeoutExpired("podman", 30)
|
||||
|
||||
success, message = reload_haproxy()
|
||||
|
||||
assert success is False
|
||||
assert "timed out" in message.lower()
|
||||
|
||||
def test_os_error(self, mock_subprocess):
|
||||
"""OS error returns (False, error)."""
|
||||
mock_subprocess.side_effect = OSError("Permission denied")
|
||||
|
||||
success, message = reload_haproxy()
|
||||
|
||||
assert success is False
|
||||
assert "OS error" in message
|
||||
130
tests/unit/test_utils.py
Normal file
130
tests/unit/test_utils.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Unit tests for utils module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.utils import parse_stat_csv
|
||||
|
||||
|
||||
class TestParseStatCsv:
|
||||
"""Tests for parse_stat_csv function."""
|
||||
|
||||
def test_parse_valid_csv(self, response_builder):
|
||||
"""Parse valid HAProxy stat CSV output."""
|
||||
csv = response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "scur": 5, "status": "UP", "weight": 1, "check_status": "L4OK"},
|
||||
{"pxname": "pool_1", "svname": "pool_1_2", "scur": 3, "status": "UP", "weight": 1, "check_status": "L4OK"},
|
||||
])
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["pxname"] == "pool_1"
|
||||
assert results[0]["svname"] == "pool_1_1"
|
||||
assert results[0]["scur"] == "5"
|
||||
assert results[0]["status"] == "UP"
|
||||
assert results[0]["weight"] == "1"
|
||||
assert results[0]["check_status"] == "L4OK"
|
||||
|
||||
def test_parse_empty_output(self):
|
||||
"""Parse empty output returns no results."""
|
||||
results = list(parse_stat_csv(""))
|
||||
assert results == []
|
||||
|
||||
def test_parse_header_only(self):
|
||||
"""Parse output with only header returns no results."""
|
||||
csv = "# pxname,svname,qcur,qmax,scur,smax,..."
|
||||
results = list(parse_stat_csv(csv))
|
||||
assert results == []
|
||||
|
||||
def test_skip_comment_lines(self):
|
||||
"""Comment lines are skipped."""
|
||||
csv = """# This is a comment
|
||||
# Another comment
|
||||
pool_1,pool_1_1,0,0,5,10,0,0,0,0,0,0,0,0,0,0,0,UP,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK,"""
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
assert len(results) == 1
|
||||
assert results[0]["pxname"] == "pool_1"
|
||||
|
||||
def test_skip_empty_lines(self):
|
||||
"""Empty lines are skipped."""
|
||||
csv = """
|
||||
pool_1,pool_1_1,0,0,5,10,0,0,0,0,0,0,0,0,0,0,0,UP,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK,
|
||||
|
||||
pool_1,pool_1_2,0,0,3,10,0,0,0,0,0,0,0,0,0,0,0,UP,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK,
|
||||
"""
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
assert len(results) == 2
|
||||
|
||||
def test_parse_down_status(self, response_builder):
|
||||
"""Parse server with DOWN status."""
|
||||
csv = response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "DOWN", "check_status": "L4TOUT"},
|
||||
])
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["status"] == "DOWN"
|
||||
assert results[0]["check_status"] == "L4TOUT"
|
||||
|
||||
def test_parse_maint_status(self, response_builder):
|
||||
"""Parse server with MAINT status."""
|
||||
csv = response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "MAINT"},
|
||||
])
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["status"] == "MAINT"
|
||||
|
||||
def test_parse_multiple_backends(self, response_builder):
|
||||
"""Parse output with multiple backends."""
|
||||
csv = response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"},
|
||||
{"pxname": "pool_2", "svname": "pool_2_1", "status": "UP"},
|
||||
{"pxname": "pool_3", "svname": "pool_3_1", "status": "DOWN"},
|
||||
])
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0]["pxname"] == "pool_1"
|
||||
assert results[1]["pxname"] == "pool_2"
|
||||
assert results[2]["pxname"] == "pool_3"
|
||||
|
||||
def test_parse_frontend_backend_rows(self):
|
||||
"""Frontend and BACKEND rows are included."""
|
||||
csv = """pool_1,FRONTEND,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,UP,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,,
|
||||
pool_1,pool_1_1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,UP,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,L4OK,
|
||||
pool_1,BACKEND,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,UP,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,,"""
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
|
||||
# All rows with enough columns are returned
|
||||
assert len(results) == 3
|
||||
svnames = [r["svname"] for r in results]
|
||||
assert "FRONTEND" in svnames
|
||||
assert "pool_1_1" in svnames
|
||||
assert "BACKEND" in svnames
|
||||
|
||||
def test_parse_insufficient_columns(self):
|
||||
"""Rows with insufficient columns are skipped."""
|
||||
csv = "pool_1,pool_1_1,0,0,5" # Only 5 columns, need more than 17
|
||||
|
||||
results = list(parse_stat_csv(csv))
|
||||
assert results == []
|
||||
|
||||
def test_generator_is_lazy(self, response_builder):
|
||||
"""Verify parse_stat_csv returns a generator (lazy evaluation)."""
|
||||
csv = response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"},
|
||||
])
|
||||
|
||||
result = parse_stat_csv(csv)
|
||||
|
||||
# Should return a generator, not a list
|
||||
import types
|
||||
assert isinstance(result, types.GeneratorType)
|
||||
275
tests/unit/test_validation.py
Normal file
275
tests/unit/test_validation.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Unit tests for validation module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.validation import (
|
||||
validate_domain,
|
||||
validate_ip,
|
||||
validate_port,
|
||||
validate_backend_name,
|
||||
domain_to_backend,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateDomain:
|
||||
"""Tests for validate_domain function."""
|
||||
|
||||
def test_valid_simple_domain(self):
|
||||
"""Valid simple domain."""
|
||||
assert validate_domain("example.com") is True
|
||||
|
||||
def test_valid_subdomain(self):
|
||||
"""Valid subdomain."""
|
||||
assert validate_domain("api.example.com") is True
|
||||
|
||||
def test_valid_deep_subdomain(self):
|
||||
"""Valid deep subdomain."""
|
||||
assert validate_domain("a.b.c.example.com") is True
|
||||
|
||||
def test_valid_domain_with_numbers(self):
|
||||
"""Valid domain with numbers."""
|
||||
assert validate_domain("api123.example.com") is True
|
||||
|
||||
def test_valid_domain_with_hyphen(self):
|
||||
"""Valid domain with hyphens."""
|
||||
assert validate_domain("my-api.example-site.com") is True
|
||||
|
||||
def test_valid_single_char_labels(self):
|
||||
"""Valid domain with single character labels."""
|
||||
assert validate_domain("a.b.c") is True
|
||||
|
||||
def test_valid_max_label_length(self):
|
||||
"""Valid domain with max label length (63 chars)."""
|
||||
label = "a" * 63
|
||||
assert validate_domain(f"{label}.com") is True
|
||||
|
||||
def test_invalid_empty_domain(self):
|
||||
"""Empty domain is invalid."""
|
||||
assert validate_domain("") is False
|
||||
|
||||
def test_invalid_none_domain(self):
|
||||
"""None domain is invalid."""
|
||||
assert validate_domain(None) is False
|
||||
|
||||
def test_invalid_starts_with_hyphen(self):
|
||||
"""Domain starting with hyphen is invalid."""
|
||||
assert validate_domain("-example.com") is False
|
||||
|
||||
def test_invalid_ends_with_hyphen(self):
|
||||
"""Domain label ending with hyphen is invalid."""
|
||||
assert validate_domain("example-.com") is False
|
||||
|
||||
def test_invalid_double_dot(self):
|
||||
"""Domain with double dot is invalid."""
|
||||
assert validate_domain("example..com") is False
|
||||
|
||||
def test_invalid_starts_with_dot(self):
|
||||
"""Domain starting with dot is invalid."""
|
||||
assert validate_domain(".example.com") is False
|
||||
|
||||
def test_invalid_special_characters(self):
|
||||
"""Domain with special characters is invalid."""
|
||||
assert validate_domain("example@.com") is False
|
||||
assert validate_domain("example!.com") is False
|
||||
assert validate_domain("example$.com") is False
|
||||
|
||||
def test_invalid_underscore(self):
|
||||
"""Domain with underscore is invalid."""
|
||||
assert validate_domain("my_api.example.com") is False
|
||||
|
||||
def test_invalid_too_long(self):
|
||||
"""Domain exceeding 253 chars is invalid."""
|
||||
long_domain = "a" * 254
|
||||
assert validate_domain(long_domain) is False
|
||||
|
||||
def test_invalid_label_too_long(self):
|
||||
"""Domain label exceeding 63 chars is invalid."""
|
||||
label = "a" * 64
|
||||
assert validate_domain(f"{label}.com") is False
|
||||
|
||||
def test_valid_numeric_domain(self):
|
||||
"""Domain with all numeric label is valid."""
|
||||
assert validate_domain("123.example.com") is True
|
||||
|
||||
def test_invalid_only_dots(self):
|
||||
"""Domain with only dots is invalid."""
|
||||
assert validate_domain("...") is False
|
||||
|
||||
|
||||
class TestValidateIP:
|
||||
"""Tests for validate_ip function."""
|
||||
|
||||
def test_valid_ipv4(self):
|
||||
"""Valid IPv4 address."""
|
||||
assert validate_ip("192.168.1.1") is True
|
||||
assert validate_ip("10.0.0.1") is True
|
||||
assert validate_ip("255.255.255.255") is True
|
||||
assert validate_ip("0.0.0.0") is True
|
||||
|
||||
def test_valid_ipv6(self):
|
||||
"""Valid IPv6 address."""
|
||||
assert validate_ip("::1") is True
|
||||
assert validate_ip("2001:db8::1") is True
|
||||
assert validate_ip("fe80::1") is True
|
||||
assert validate_ip("2001:0db8:0000:0000:0000:0000:0000:0001") is True
|
||||
|
||||
def test_invalid_empty_string(self):
|
||||
"""Empty string is invalid by default."""
|
||||
assert validate_ip("") is False
|
||||
|
||||
def test_valid_empty_string_when_allowed(self):
|
||||
"""Empty string is valid when allow_empty=True."""
|
||||
assert validate_ip("", allow_empty=True) is True
|
||||
|
||||
def test_invalid_none(self):
|
||||
"""None is invalid."""
|
||||
assert validate_ip(None) is False
|
||||
|
||||
def test_invalid_hostname(self):
|
||||
"""Hostname is not a valid IP."""
|
||||
assert validate_ip("example.com") is False
|
||||
|
||||
def test_invalid_ipv4_out_of_range(self):
|
||||
"""IPv4 with octets out of range is invalid."""
|
||||
assert validate_ip("256.1.1.1") is False
|
||||
assert validate_ip("1.1.1.300") is False
|
||||
|
||||
def test_invalid_ipv4_format(self):
|
||||
"""Invalid IPv4 format."""
|
||||
assert validate_ip("192.168.1") is False
|
||||
assert validate_ip("192.168.1.1.1") is False
|
||||
|
||||
def test_invalid_ipv6_format(self):
|
||||
"""Invalid IPv6 format."""
|
||||
assert validate_ip("2001:db8:::1") is False
|
||||
assert validate_ip("gggg::1") is False
|
||||
|
||||
def test_invalid_mixed_format(self):
|
||||
"""Mixed invalid format."""
|
||||
assert validate_ip("192.168.1.1:8080") is False
|
||||
|
||||
|
||||
class TestValidatePort:
|
||||
"""Tests for validate_port function."""
|
||||
|
||||
def test_valid_port_min(self):
|
||||
"""Valid minimum port."""
|
||||
assert validate_port("1") is True
|
||||
|
||||
def test_valid_port_max(self):
|
||||
"""Valid maximum port."""
|
||||
assert validate_port("65535") is True
|
||||
|
||||
def test_valid_port_common(self):
|
||||
"""Valid common ports."""
|
||||
assert validate_port("80") is True
|
||||
assert validate_port("443") is True
|
||||
assert validate_port("8080") is True
|
||||
|
||||
def test_invalid_port_zero(self):
|
||||
"""Port 0 is invalid."""
|
||||
assert validate_port("0") is False
|
||||
|
||||
def test_invalid_port_negative(self):
|
||||
"""Negative port is invalid."""
|
||||
assert validate_port("-1") is False
|
||||
|
||||
def test_invalid_port_too_high(self):
|
||||
"""Port above 65535 is invalid."""
|
||||
assert validate_port("65536") is False
|
||||
|
||||
def test_invalid_port_empty(self):
|
||||
"""Empty port is invalid."""
|
||||
assert validate_port("") is False
|
||||
|
||||
def test_invalid_port_none(self):
|
||||
"""None port is invalid."""
|
||||
assert validate_port(None) is False
|
||||
|
||||
def test_invalid_port_not_numeric(self):
|
||||
"""Non-numeric port is invalid."""
|
||||
assert validate_port("abc") is False
|
||||
assert validate_port("80a") is False
|
||||
|
||||
def test_invalid_port_float(self):
|
||||
"""Float port is invalid."""
|
||||
assert validate_port("80.5") is False
|
||||
|
||||
|
||||
class TestValidateBackendName:
|
||||
"""Tests for validate_backend_name function."""
|
||||
|
||||
def test_valid_pool_name(self):
|
||||
"""Valid pool backend names."""
|
||||
assert validate_backend_name("pool_1") is True
|
||||
assert validate_backend_name("pool_100") is True
|
||||
|
||||
def test_valid_alphanumeric(self):
|
||||
"""Valid alphanumeric names."""
|
||||
assert validate_backend_name("backend1") is True
|
||||
assert validate_backend_name("my_backend") is True
|
||||
assert validate_backend_name("my-backend") is True
|
||||
|
||||
def test_valid_mixed(self):
|
||||
"""Valid mixed character names."""
|
||||
assert validate_backend_name("api_example_com_backend") is True
|
||||
assert validate_backend_name("my-api-backend-1") is True
|
||||
|
||||
def test_invalid_empty(self):
|
||||
"""Empty name is invalid."""
|
||||
assert validate_backend_name("") is False
|
||||
|
||||
def test_invalid_none(self):
|
||||
"""None name is invalid."""
|
||||
assert validate_backend_name(None) is False
|
||||
|
||||
def test_invalid_special_chars(self):
|
||||
"""Names with special characters are invalid."""
|
||||
assert validate_backend_name("backend@1") is False
|
||||
assert validate_backend_name("my.backend") is False
|
||||
assert validate_backend_name("my/backend") is False
|
||||
assert validate_backend_name("my backend") is False
|
||||
|
||||
def test_invalid_too_long(self):
|
||||
"""Name exceeding 255 chars is invalid."""
|
||||
long_name = "a" * 256
|
||||
assert validate_backend_name(long_name) is False
|
||||
|
||||
def test_valid_max_length(self):
|
||||
"""Name at exactly 255 chars is valid."""
|
||||
max_name = "a" * 255
|
||||
assert validate_backend_name(max_name) is True
|
||||
|
||||
|
||||
class TestDomainToBackend:
|
||||
"""Tests for domain_to_backend function."""
|
||||
|
||||
def test_simple_domain(self):
|
||||
"""Simple domain conversion."""
|
||||
assert domain_to_backend("example.com") == "example_com"
|
||||
|
||||
def test_subdomain(self):
|
||||
"""Subdomain conversion."""
|
||||
assert domain_to_backend("api.example.com") == "api_example_com"
|
||||
|
||||
def test_domain_with_hyphens(self):
|
||||
"""Domain with hyphens."""
|
||||
result = domain_to_backend("my-api.example.com")
|
||||
assert result == "my_api_example_com"
|
||||
|
||||
def test_complex_domain(self):
|
||||
"""Complex domain conversion."""
|
||||
result = domain_to_backend("a.b.c.example-site.com")
|
||||
assert result == "a_b_c_example_site_com"
|
||||
|
||||
def test_already_simple(self):
|
||||
"""Domain that's already mostly valid."""
|
||||
result = domain_to_backend("example123")
|
||||
assert result == "example123"
|
||||
|
||||
def test_invalid_result_raises(self):
|
||||
"""Invalid conversion result raises ValueError."""
|
||||
# This should never happen with real domains, but test the safeguard
|
||||
with pytest.raises(ValueError):
|
||||
# Mock a case where conversion would fail
|
||||
domain_to_backend("")
|
||||
1
tests/unit/tools/__init__.py
Normal file
1
tests/unit/tools/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for HAProxy MCP tools."""
|
||||
1198
tests/unit/tools/test_certificates.py
Normal file
1198
tests/unit/tools/test_certificates.py
Normal file
File diff suppressed because it is too large
Load Diff
749
tests/unit/tools/test_configuration.py
Normal file
749
tests/unit/tools/test_configuration.py
Normal file
@@ -0,0 +1,749 @@
|
||||
"""Unit tests for configuration management tools."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRestoreServersFromConfig:
|
||||
"""Tests for restore_servers_from_config function."""
|
||||
|
||||
def test_restore_empty_config(self, patch_config_paths):
|
||||
"""No servers to restore when config is empty."""
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||
"""Restore servers successfully."""
|
||||
# Write config and map
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"set server": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
# example.com has 2 servers, api.example.com has 1
|
||||
assert result == 3
|
||||
|
||||
def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Skip domains not in map file."""
|
||||
config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Skip servers with empty IP."""
|
||||
config = {"example.com": {"1": {"ip": "", "http_port": 80}}}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestStartupRestore:
|
||||
"""Tests for startup_restore function."""
|
||||
|
||||
def test_startup_restore_haproxy_not_ready(self, mock_select):
|
||||
"""Skip restore if HAProxy is not ready."""
|
||||
call_count = 0
|
||||
|
||||
def raise_error(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
with patch("haproxy_mcp.tools.configuration.STARTUP_RETRY_COUNT", 2):
|
||||
from haproxy_mcp.tools.configuration import startup_restore
|
||||
|
||||
startup_restore()
|
||||
|
||||
# Should have tried multiple times
|
||||
assert call_count >= 2
|
||||
|
||||
def test_startup_restore_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully restore servers and certificates on startup."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
"set server": "",
|
||||
"show ssl cert": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
|
||||
with patch("haproxy_mcp.tools.certificates.restore_certificates", return_value=0):
|
||||
from haproxy_mcp.tools.configuration import startup_restore
|
||||
|
||||
startup_restore()
|
||||
|
||||
# No assertions needed - just verify no exceptions
|
||||
|
||||
|
||||
class TestHaproxyReload:
|
||||
"""Tests for haproxy_reload tool function."""
|
||||
|
||||
def test_reload_success(self, mock_socket_class, mock_select, mock_subprocess, response_builder):
|
||||
"""Reload HAProxy successfully."""
|
||||
mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
"set server": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=5):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_reload"]()
|
||||
|
||||
assert "reloaded successfully" in result
|
||||
assert "5 servers restored" in result
|
||||
|
||||
def test_reload_validation_failure(self, mock_subprocess):
|
||||
"""Reload fails on config validation error."""
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="Configuration error"
|
||||
)
|
||||
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_reload"]()
|
||||
|
||||
assert "validation failed" in result.lower() or "Configuration error" in result
|
||||
|
||||
|
||||
class TestHaproxyCheckConfig:
|
||||
"""Tests for haproxy_check_config tool function."""
|
||||
|
||||
def test_check_config_valid(self, mock_subprocess):
|
||||
"""Configuration is valid."""
|
||||
mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_check_config"]()
|
||||
|
||||
assert "valid" in result.lower()
|
||||
|
||||
def test_check_config_invalid(self, mock_subprocess):
|
||||
"""Configuration has errors."""
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="[ALERT] Syntax error"
|
||||
)
|
||||
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_check_config"]()
|
||||
|
||||
assert "error" in result.lower()
|
||||
assert "Syntax error" in result
|
||||
|
||||
def test_check_config_timeout(self, mock_subprocess):
|
||||
"""Configuration check times out."""
|
||||
import subprocess
|
||||
mock_subprocess.side_effect = subprocess.TimeoutExpired("podman", 30)
|
||||
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_check_config"]()
|
||||
|
||||
assert "timed out" in result.lower()
|
||||
|
||||
def test_check_config_podman_not_found(self, mock_subprocess):
|
||||
"""Podman not found."""
|
||||
mock_subprocess.side_effect = FileNotFoundError()
|
||||
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_check_config"]()
|
||||
|
||||
assert "podman" in result.lower()
|
||||
assert "not found" in result.lower()
|
||||
|
||||
|
||||
class TestHaproxySaveState:
|
||||
"""Tests for haproxy_save_state tool function."""
|
||||
|
||||
def test_save_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Save state successfully."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("haproxy_mcp.tools.configuration.STATE_FILE", patch_config_paths["state_file"]):
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_save_state"]()
|
||||
|
||||
assert "saved" in result.lower()
|
||||
|
||||
def test_save_state_haproxy_error(self, mock_select):
|
||||
"""Handle HAProxy connection error."""
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_save_state"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestHaproxyRestoreState:
|
||||
"""Tests for haproxy_restore_state tool function."""
|
||||
|
||||
def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||
"""Restore state successfully."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_restore_state"]()
|
||||
|
||||
assert "restored" in result.lower()
|
||||
assert "3 servers" in result
|
||||
|
||||
def test_restore_state_no_servers(self, patch_config_paths):
|
||||
"""No servers to restore."""
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_restore_state"]()
|
||||
|
||||
assert "No servers to restore" in result
|
||||
|
||||
|
||||
class TestRestoreServersFromConfigBatchFailure:
|
||||
"""Tests for restore_servers_from_config batch failure and fallback."""
|
||||
|
||||
def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Fall back to individual commands when batch fails."""
|
||||
# Create config with servers
|
||||
config = {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||
"2": {"ip": "10.0.0.2", "http_port": 80},
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
# Track call count to simulate batch failure then individual success
|
||||
call_count = [0]
|
||||
|
||||
class BatchFailMockSocket:
|
||||
def __init__(self):
|
||||
self.sent_commands = []
|
||||
self._response_buffer = b""
|
||||
self._closed = False
|
||||
|
||||
def connect(self, address):
|
||||
pass
|
||||
|
||||
def settimeout(self, timeout):
|
||||
pass
|
||||
|
||||
def setblocking(self, blocking):
|
||||
pass
|
||||
|
||||
def sendall(self, data):
|
||||
command = data.decode().strip()
|
||||
self.sent_commands.append(command)
|
||||
call_count[0] += 1
|
||||
# First batch call fails (contains multiple commands)
|
||||
if call_count[0] == 1 and "\n" in data.decode():
|
||||
self._response_buffer = b"error: batch command failed"
|
||||
else:
|
||||
self._response_buffer = b""
|
||||
|
||||
def shutdown(self, how):
|
||||
pass
|
||||
|
||||
def recv(self, bufsize):
|
||||
if self._response_buffer:
|
||||
data = self._response_buffer[:bufsize]
|
||||
self._response_buffer = self._response_buffer[bufsize:]
|
||||
return data
|
||||
return b""
|
||||
|
||||
def close(self):
|
||||
self._closed = True
|
||||
|
||||
def fileno(self):
|
||||
return 999
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
mock_sock = BatchFailMockSocket()
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
# Mock batch to raise error
|
||||
with patch("haproxy_mcp.tools.configuration.haproxy_cmd_batch") as mock_batch:
|
||||
# First call (batch) fails, subsequent calls succeed
|
||||
mock_batch.side_effect = [
|
||||
HaproxyError("Batch failed"), # Initial batch fails
|
||||
None, # Individual server 1 succeeds
|
||||
None, # Individual server 2 succeeds
|
||||
]
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
# Should have restored servers via individual commands
|
||||
assert result == 2
|
||||
|
||||
def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Skip servers with invalid slot number."""
|
||||
config = {
|
||||
"example.com": {
|
||||
"invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot
|
||||
"1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
# Should only restore the valid server
|
||||
assert result == 1
|
||||
|
||||
def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog):
|
||||
"""Skip servers with invalid port."""
|
||||
import logging
|
||||
config = {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port
|
||||
"2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
# Should only restore the valid server
|
||||
assert result == 1
|
||||
|
||||
|
||||
class TestStartupRestoreFailures:
|
||||
"""Tests for startup_restore failure scenarios."""
|
||||
|
||||
def test_startup_restore_haproxy_timeout(self, mock_select):
|
||||
"""Skip restore if HAProxy doesn't become ready in time."""
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
# Mock haproxy_cmd to always fail
|
||||
with patch("haproxy_mcp.tools.configuration.haproxy_cmd", side_effect=HaproxyError("Connection refused")):
|
||||
with patch("haproxy_mcp.tools.configuration.STARTUP_RETRY_COUNT", 2):
|
||||
with patch("time.sleep", return_value=None):
|
||||
from haproxy_mcp.tools.configuration import startup_restore
|
||||
|
||||
# Should not raise, just log warning
|
||||
startup_restore()
|
||||
|
||||
def test_startup_restore_server_restore_failure(self, mock_socket_class, mock_select, patch_config_paths, response_builder, caplog):
|
||||
"""Handle server restore failure during startup."""
|
||||
import logging
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("Disk error")):
|
||||
with patch("haproxy_mcp.tools.certificates.restore_certificates", return_value=0):
|
||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||
from haproxy_mcp.tools.configuration import startup_restore
|
||||
|
||||
startup_restore()
|
||||
|
||||
# Should have logged the failure
|
||||
assert any("Failed to restore servers" in record.message for record in caplog.records)
|
||||
|
||||
def test_startup_restore_certificate_failure(self, mock_socket_class, mock_select, patch_config_paths, response_builder, caplog):
|
||||
"""Handle certificate restore failure during startup."""
|
||||
import logging
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", return_value=0):
|
||||
with patch("haproxy_mcp.tools.certificates.restore_certificates", side_effect=Exception("Certificate error")):
|
||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||
from haproxy_mcp.tools.configuration import startup_restore
|
||||
|
||||
startup_restore()
|
||||
|
||||
# Should have logged the failure
|
||||
assert any("Failed to restore certificates" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
class TestHaproxyReloadFailures:
|
||||
"""Tests for haproxy_reload failure scenarios."""
|
||||
|
||||
def test_reload_haproxy_not_responding_after_reload(self, mock_subprocess, response_builder):
|
||||
"""Handle HAProxy not responding after reload."""
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
|
||||
# Mock haproxy_cmd to fail after reload
|
||||
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
|
||||
with patch("haproxy_mcp.tools.configuration.haproxy_cmd", side_effect=HaproxyError("Not responding")):
|
||||
with patch("haproxy_mcp.tools.configuration.STARTUP_RETRY_COUNT", 2):
|
||||
with patch("time.sleep", return_value=None):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_reload"]()
|
||||
|
||||
assert "not responding" in result.lower()
|
||||
|
||||
def test_reload_server_restore_failure(self, mock_subprocess, mock_socket_class, mock_select, response_builder):
|
||||
"""Handle server restore failure after reload."""
|
||||
mock_subprocess.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("haproxy_mcp.haproxy_client.reload_haproxy", return_value=(True, "Reloaded")):
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=Exception("Restore failed")):
|
||||
with patch("time.sleep", return_value=None):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_reload"]()
|
||||
|
||||
assert "reloaded" in result.lower()
|
||||
assert "failed" in result.lower()
|
||||
|
||||
|
||||
class TestHaproxySaveStateFailures:
|
||||
"""Tests for haproxy_save_state failure scenarios."""
|
||||
|
||||
def test_save_state_io_error(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Handle IO error when saving state."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("haproxy_mcp.tools.configuration.STATE_FILE", patch_config_paths["state_file"]):
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("haproxy_mcp.tools.configuration.atomic_write_file", side_effect=IOError("Disk full")):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_save_state"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestHaproxyRestoreStateFailures:
|
||||
"""Tests for haproxy_restore_state failure scenarios."""
|
||||
|
||||
def test_restore_state_haproxy_error(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||
"""Handle HAProxy error when restoring state."""
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_restore_state"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
def test_restore_state_os_error(self, patch_config_paths):
|
||||
"""Handle OS error when restoring state."""
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=OSError("File not found")):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_restore_state"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
def test_restore_state_value_error(self, patch_config_paths):
|
||||
"""Handle ValueError when restoring state."""
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=ValueError("Invalid config")):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_restore_state"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestHaproxyCheckConfigOSError:
|
||||
"""Tests for haproxy_check_config OS error handling."""
|
||||
|
||||
def test_check_config_os_error(self, mock_subprocess):
|
||||
"""Handle OS error during config check."""
|
||||
mock_subprocess.side_effect = OSError("Permission denied")
|
||||
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_config_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_check_config"]()
|
||||
|
||||
assert "Error" in result
|
||||
assert "OS error" in result
|
||||
476
tests/unit/tools/test_domains.py
Normal file
476
tests/unit/tools/test_domains.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""Unit tests for domain management tools."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
|
||||
class TestHaproxyListDomains:
|
||||
"""Tests for haproxy_list_domains tool function."""
|
||||
|
||||
def test_list_empty_domains(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains when none configured."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
# Import here to get patched config
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=False)
|
||||
|
||||
assert result == "No domains configured"
|
||||
|
||||
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains with configured servers."""
|
||||
# Write map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=False)
|
||||
|
||||
assert "example.com" in result
|
||||
assert "pool_1" in result
|
||||
assert "10.0.0.1" in result
|
||||
|
||||
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains excluding wildcards by default."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=False)
|
||||
|
||||
assert "example.com" in result
|
||||
assert ".example.com" not in result
|
||||
|
||||
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains including wildcards when requested."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_domains"](include_wildcards=True)
|
||||
|
||||
assert "example.com" in result
|
||||
assert ".example.com" in result
|
||||
|
||||
|
||||
class TestHaproxyAddDomain:
|
||||
"""Tests for haproxy_add_domain tool function."""
|
||||
|
||||
def test_add_domain_invalid_format(self, patch_config_paths):
|
||||
"""Reject invalid domain format."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="-invalid.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid domain" in result
|
||||
|
||||
def test_add_domain_invalid_ip(self, patch_config_paths):
|
||||
"""Reject invalid IP address."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="example.com",
|
||||
ip="not-an-ip",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid IP" in result
|
||||
|
||||
def test_add_domain_invalid_port(self, patch_config_paths):
|
||||
"""Reject invalid port."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="example.com",
|
||||
ip="10.0.0.1",
|
||||
http_port=70000
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "Port" in result
|
||||
|
||||
def test_add_domain_starts_with_dot(self, patch_config_paths):
|
||||
"""Reject domain starting with dot (wildcard)."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain=".example.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "cannot start with '.'" in result
|
||||
|
||||
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Reject adding domain that already exists."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="example.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "Error" in result
|
||||
assert "already exists" in result
|
||||
|
||||
def test_add_domain_success_without_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess):
|
||||
"""Successfully add domain without IP."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"add map": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="newdomain.com",
|
||||
ip="",
|
||||
http_port=80
|
||||
)
|
||||
|
||||
assert "newdomain.com" in result
|
||||
assert "pool_1" in result
|
||||
assert "no servers configured" in result
|
||||
|
||||
def test_add_domain_success_with_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess):
|
||||
"""Successfully add domain with IP."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"add map": "",
|
||||
"set server": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_add_domain"](
|
||||
domain="newdomain.com",
|
||||
ip="10.0.0.1",
|
||||
http_port=8080
|
||||
)
|
||||
|
||||
assert "newdomain.com" in result
|
||||
assert "pool_1" in result
|
||||
assert "10.0.0.1:8080" in result
|
||||
|
||||
|
||||
class TestHaproxyRemoveDomain:
|
||||
"""Tests for haproxy_remove_domain tool function."""
|
||||
|
||||
def test_remove_domain_invalid_format(self, patch_config_paths):
|
||||
"""Reject invalid domain format."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="-invalid.com")
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid domain" in result
|
||||
|
||||
def test_remove_domain_not_found(self, patch_config_paths):
|
||||
"""Reject removing domain that doesn't exist."""
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="nonexistent.com")
|
||||
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
def test_remove_legacy_domain_rejected(self, patch_config_paths):
|
||||
"""Reject removing legacy (non-pool) domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com legacy_backend\n")
|
||||
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="example.com")
|
||||
|
||||
assert "Error" in result
|
||||
assert "legacy" in result.lower()
|
||||
|
||||
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully remove domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"del map": "",
|
||||
"set server": "",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_domain_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_remove_domain"](domain="example.com")
|
||||
|
||||
assert "example.com" in result
|
||||
assert "removed" in result.lower()
|
||||
|
||||
|
||||
class TestCheckCertificateCoverage:
|
||||
"""Tests for check_certificate_coverage function."""
|
||||
|
||||
def test_no_cert_directory(self, tmp_path):
|
||||
"""No certificate coverage when directory doesn't exist."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(tmp_path / "nonexistent")):
|
||||
covered, info = check_certificate_coverage("example.com")
|
||||
|
||||
assert covered is False
|
||||
assert "not found" in info.lower()
|
||||
|
||||
def test_exact_cert_match(self, tmp_path):
|
||||
"""Exact certificate match."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
certs_dir = tmp_path / "certs"
|
||||
certs_dir.mkdir()
|
||||
(certs_dir / "example.com.pem").write_text("cert content")
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)):
|
||||
covered, info = check_certificate_coverage("example.com")
|
||||
|
||||
assert covered is True
|
||||
assert info == "example.com"
|
||||
|
||||
def test_wildcard_cert_coverage(self, tmp_path, mock_subprocess):
|
||||
"""Wildcard certificate covers subdomain."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
certs_dir = tmp_path / "certs"
|
||||
certs_dir.mkdir()
|
||||
(certs_dir / "example.com.pem").write_text("cert content")
|
||||
|
||||
# Mock openssl output showing wildcard SAN
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="X509v3 Subject Alternative Name:\n DNS:example.com, DNS:*.example.com"
|
||||
)
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)):
|
||||
covered, info = check_certificate_coverage("api.example.com")
|
||||
|
||||
assert covered is True
|
||||
assert "wildcard" in info
|
||||
|
||||
def test_no_matching_cert(self, tmp_path):
|
||||
"""No matching certificate."""
|
||||
from haproxy_mcp.tools.domains import check_certificate_coverage
|
||||
|
||||
certs_dir = tmp_path / "certs"
|
||||
certs_dir.mkdir()
|
||||
|
||||
with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)):
|
||||
covered, info = check_certificate_coverage("example.com")
|
||||
|
||||
assert covered is False
|
||||
assert "No matching" in info
|
||||
433
tests/unit/tools/test_health.py
Normal file
433
tests/unit/tools/test_health.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Unit tests for health check tools."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
|
||||
class TestHaproxyHealth:
|
||||
"""Tests for haproxy_health tool function."""
|
||||
|
||||
def test_health_all_ok(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess):
|
||||
"""Health check returns healthy when all components are OK."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(version="3.3.2", uptime=3600),
|
||||
})
|
||||
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="running"
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_health"]()
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["components"]["mcp"]["status"] == "ok"
|
||||
assert result["components"]["haproxy"]["status"] == "ok"
|
||||
assert result["components"]["haproxy"]["version"] == "3.3.2"
|
||||
|
||||
def test_health_haproxy_error(self, mock_socket_class, mock_select, patch_config_paths, mock_subprocess):
|
||||
"""Health check returns degraded when HAProxy is unreachable."""
|
||||
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_health"]()
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "degraded"
|
||||
assert result["components"]["haproxy"]["status"] == "error"
|
||||
|
||||
def test_health_missing_config_files(self, mock_socket_class, mock_select, tmp_path, response_builder, mock_subprocess):
|
||||
"""Health check returns degraded when config files are missing."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
})
|
||||
|
||||
mock_subprocess.return_value = MagicMock(returncode=0, stdout="running")
|
||||
|
||||
# Use paths that don't exist
|
||||
with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")):
|
||||
with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")):
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_health"]()
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "degraded"
|
||||
assert result["components"]["config_files"]["status"] == "warning"
|
||||
|
||||
def test_health_container_not_running(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess):
|
||||
"""Health check returns unhealthy when container is not running."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(),
|
||||
})
|
||||
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="No such container"
|
||||
)
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_health"]()
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "unhealthy"
|
||||
assert result["components"]["container"]["status"] == "error"
|
||||
|
||||
|
||||
class TestHaproxyDomainHealth:
|
||||
"""Tests for haproxy_domain_health tool function."""
|
||||
|
||||
def test_domain_health_invalid_domain(self, patch_config_paths):
|
||||
"""Reject invalid domain format."""
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_domain_health"](domain="-invalid")
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert "error" in result
|
||||
assert "Invalid domain" in result["error"]
|
||||
|
||||
def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns healthy when all servers are UP."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "10.0.0.2", "srv_port": 80},
|
||||
]),
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "check_status": "L4OK"},
|
||||
{"pxname": "pool_1", "svname": "pool_1_2", "status": "UP", "check_status": "L4OK"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_domain_health"](domain="example.com")
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["healthy_count"] == 2
|
||||
assert result["total_count"] == 2
|
||||
|
||||
def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns degraded when some servers are DOWN."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_2", "srv_addr": "10.0.0.2", "srv_port": 80},
|
||||
]),
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"},
|
||||
{"pxname": "pool_1", "svname": "pool_1_2", "status": "DOWN"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_domain_health"](domain="example.com")
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "degraded"
|
||||
assert result["healthy_count"] == 1
|
||||
assert result["total_count"] == 2
|
||||
|
||||
def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns down when all servers are DOWN."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80},
|
||||
]),
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "DOWN"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_domain_health"](domain="example.com")
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "down"
|
||||
assert result["healthy_count"] == 0
|
||||
assert result["total_count"] == 1
|
||||
|
||||
def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns no_servers when no servers configured."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
{"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "0.0.0.0", "srv_port": 0},
|
||||
]),
|
||||
"show stat": response_builder.stat_csv([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result_str = registered_tools["haproxy_domain_health"](domain="example.com")
|
||||
result = json.loads(result_str)
|
||||
|
||||
assert result["status"] == "no_servers"
|
||||
assert result["total_count"] == 0
|
||||
|
||||
|
||||
class TestHaproxyGetServerHealth:
|
||||
"""Tests for haproxy_get_server_health tool function."""
|
||||
|
||||
def test_get_server_health_invalid_backend(self, patch_config_paths):
|
||||
"""Reject invalid backend name."""
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_server_health"](backend="invalid@name")
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid backend" in result
|
||||
|
||||
def test_get_server_health_all_backends(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Get health for all backends."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "weight": 1, "check_status": "L4OK"},
|
||||
{"pxname": "pool_2", "svname": "pool_2_1", "status": "DOWN", "weight": 1, "check_status": "L4TOUT"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_server_health"](backend="")
|
||||
|
||||
assert "pool_1" in result
|
||||
assert "pool_2" in result
|
||||
assert "UP" in result
|
||||
assert "DOWN" in result
|
||||
|
||||
def test_get_server_health_filter_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Get health for specific backend."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"},
|
||||
{"pxname": "pool_1", "svname": "pool_1_2", "status": "UP"},
|
||||
{"pxname": "pool_2", "svname": "pool_2_1", "status": "DOWN"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_server_health"](backend="pool_1")
|
||||
|
||||
assert "pool_1" in result
|
||||
assert "pool_2" not in result
|
||||
|
||||
def test_get_server_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""No servers returns appropriate message."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "FRONTEND", "status": "OPEN"},
|
||||
{"pxname": "pool_1", "svname": "BACKEND", "status": "UP"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_server_health"](backend="")
|
||||
|
||||
assert "No servers found" in result
|
||||
|
||||
def test_get_server_health_haproxy_error(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""HAProxy error returns error message."""
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_health_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_server_health"](backend="")
|
||||
|
||||
assert "Error" in result
|
||||
325
tests/unit/tools/test_monitoring.py
Normal file
325
tests/unit/tools/test_monitoring.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Unit tests for monitoring tools."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestHaproxyStats:
|
||||
"""Tests for haproxy_stats tool function."""
|
||||
|
||||
def test_stats_success(self, mock_socket_class, mock_select, response_builder):
|
||||
"""Get HAProxy stats successfully."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show info": response_builder.info(version="3.3.2", uptime=3600),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_stats"]()
|
||||
|
||||
assert "Version" in result
|
||||
assert "3.3.2" in result
|
||||
|
||||
def test_stats_haproxy_error(self, mock_select):
|
||||
"""Handle HAProxy connection error."""
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_stats"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestHaproxyBackends:
|
||||
"""Tests for haproxy_backends tool function."""
|
||||
|
||||
def test_backends_success(self, mock_socket_class, mock_select):
|
||||
"""List backends successfully."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show backend": "pool_1\npool_2\npool_3",
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_backends"]()
|
||||
|
||||
assert "Backends" in result
|
||||
assert "pool_1" in result
|
||||
assert "pool_2" in result
|
||||
assert "pool_3" in result
|
||||
|
||||
def test_backends_haproxy_error(self, mock_select):
|
||||
"""Handle HAProxy connection error."""
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_backends"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestHaproxyListFrontends:
|
||||
"""Tests for haproxy_list_frontends tool function."""
|
||||
|
||||
def test_list_frontends_success(self, mock_socket_class, mock_select, response_builder):
|
||||
"""List frontends successfully."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "http_front", "svname": "FRONTEND", "status": "OPEN", "scur": 10},
|
||||
{"pxname": "https_front", "svname": "FRONTEND", "status": "OPEN", "scur": 50},
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "scur": 5},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_frontends"]()
|
||||
|
||||
assert "Frontends" in result
|
||||
assert "http_front" in result
|
||||
assert "https_front" in result
|
||||
# pool_1 is not a FRONTEND
|
||||
assert "pool_1_1" not in result
|
||||
|
||||
def test_list_frontends_no_frontends(self, mock_socket_class, mock_select, response_builder):
|
||||
"""No frontends found."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP"},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_frontends"]()
|
||||
|
||||
assert "No frontends found" in result
|
||||
|
||||
def test_list_frontends_haproxy_error(self, mock_select):
|
||||
"""Handle HAProxy connection error."""
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_list_frontends"]()
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestHaproxyGetConnections:
|
||||
"""Tests for haproxy_get_connections tool function."""
|
||||
|
||||
def test_get_connections_all_backends(self, mock_socket_class, mock_select, response_builder):
|
||||
"""Get connections for all backends."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "FRONTEND", "status": "OPEN", "scur": 10, "smax": 100},
|
||||
{"pxname": "pool_1", "svname": "pool_1_1", "status": "UP", "scur": 5},
|
||||
{"pxname": "pool_1", "svname": "BACKEND", "status": "UP", "scur": 10, "smax": 100},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_connections"](backend="")
|
||||
|
||||
assert "pool_1" in result
|
||||
assert "FRONTEND" in result or "BACKEND" in result
|
||||
assert "connections" in result
|
||||
|
||||
def test_get_connections_filter_backend(self, mock_socket_class, mock_select, response_builder):
|
||||
"""Filter connections by backend."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([
|
||||
{"pxname": "pool_1", "svname": "FRONTEND", "status": "OPEN", "scur": 10, "smax": 100},
|
||||
{"pxname": "pool_2", "svname": "FRONTEND", "status": "OPEN", "scur": 20, "smax": 200},
|
||||
]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_connections"](backend="pool_1")
|
||||
|
||||
assert "pool_1" in result
|
||||
assert "pool_2" not in result
|
||||
|
||||
def test_get_connections_invalid_backend(self):
|
||||
"""Reject invalid backend name."""
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_connections"](backend="invalid@name")
|
||||
|
||||
assert "Error" in result
|
||||
assert "Invalid backend" in result
|
||||
|
||||
def test_get_connections_no_data(self, mock_socket_class, mock_select, response_builder):
|
||||
"""No connection data found."""
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show stat": response_builder.stat_csv([]),
|
||||
})
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_connections"](backend="")
|
||||
|
||||
assert "No connection data" in result
|
||||
|
||||
def test_get_connections_haproxy_error(self, mock_select):
|
||||
"""Handle HAProxy connection error."""
|
||||
def raise_error(*args, **kwargs):
|
||||
raise ConnectionRefusedError()
|
||||
|
||||
with patch("socket.socket", side_effect=raise_error):
|
||||
from haproxy_mcp.tools.monitoring import register_monitoring_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_monitoring_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_get_connections"](backend="")
|
||||
|
||||
assert "Error" in result
|
||||
1350
tests/unit/tools/test_servers.py
Normal file
1350
tests/unit/tools/test_servers.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user