"""Unit tests for domain management tools.""" import json from unittest.mock import patch, MagicMock import pytest from haproxy_mcp.exceptions import HaproxyError from haproxy_mcp.file_ops import add_domain_to_map class TestHaproxyListDomains: """Tests for haproxy_list_domains tool function.""" def test_list_empty_domains(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains when none configured.""" mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([]), }) with patch("socket.socket", return_value=mock_sock): # Import here to get patched config from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_list_domains"](include_wildcards=False) assert result == "No domains configured" def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains with configured servers.""" # Add domain to DB add_domain_to_map("example.com", "pool_1") mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([ {"be_name": "pool_1", "srv_name": "pool_1_1", "srv_addr": "10.0.0.1", "srv_port": 80}, ]), }) with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_list_domains"](include_wildcards=False) assert "example.com" in result assert "pool_1" in result assert "10.0.0.1" in result def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains excluding wildcards by default.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([]), }) with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_list_domains"](include_wildcards=False) assert "example.com" in result assert ".example.com" not in result def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """List domains including wildcards when requested.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) mock_sock = mock_socket_class(responses={ "show servers state": response_builder.servers_state([]), }) with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_list_domains"](include_wildcards=True) assert "example.com" in result assert ".example.com" in result class TestHaproxyAddDomain: """Tests for haproxy_add_domain tool function.""" def test_add_domain_invalid_format(self, patch_config_paths): """Reject invalid domain format.""" from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain="-invalid.com", ip="", http_port=80 ) assert "Error" in result assert "Invalid domain" in result def test_add_domain_invalid_ip(self, patch_config_paths): """Reject invalid IP address.""" from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain="example.com", ip="not-an-ip", http_port=80 ) assert "Error" in result assert "Invalid IP" in result def test_add_domain_invalid_port(self, patch_config_paths): """Reject invalid port.""" from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain="example.com", ip="10.0.0.1", http_port=70000 ) assert "Error" in result assert "Port" in result def test_add_domain_starts_with_dot(self, patch_config_paths): """Reject domain starting with dot (wildcard).""" from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain=".example.com", ip="", http_port=80 ) assert "Error" in result assert "cannot start with '.'" in result def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Reject adding domain that already exists.""" add_domain_to_map("example.com", "pool_1") from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain="example.com", ip="", http_port=80 ) assert "Error" in result assert "already exists" in result def test_add_domain_success_without_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess): """Successfully add domain without IP.""" mock_sock = mock_socket_class(responses={ "add map": "", }) with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain="newdomain.com", ip="", http_port=80 ) assert "newdomain.com" in result assert "pool_1" in result assert "no servers configured" in result def test_add_domain_success_with_ip(self, mock_socket_class, mock_select, patch_config_paths, response_builder, mock_subprocess): """Successfully add domain with IP.""" mock_sock = mock_socket_class(responses={ "add map": "", "set server": "", }) with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_add_domain"]( domain="newdomain.com", ip="10.0.0.1", http_port=8080 ) assert "newdomain.com" in result assert "pool_1" in result assert "10.0.0.1:8080" in result class TestHaproxyRemoveDomain: """Tests for haproxy_remove_domain tool function.""" def test_remove_domain_invalid_format(self, patch_config_paths): """Reject invalid domain format.""" from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_remove_domain"](domain="-invalid.com") assert "Error" in result assert "Invalid domain" in result def test_remove_domain_not_found(self, patch_config_paths): """Reject removing domain that doesn't exist.""" from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_remove_domain"](domain="nonexistent.com") assert "Error" in result assert "not found" in result def test_remove_legacy_domain_rejected(self, patch_config_paths): """Reject removing legacy (non-pool) domain.""" add_domain_to_map("example.com", "legacy_backend") from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_remove_domain"](domain="example.com") assert "Error" in result assert "legacy" in result.lower() def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder): """Successfully remove domain.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) mock_sock = mock_socket_class(responses={ "del map": "", "set server": "", }) with patch("socket.socket", return_value=mock_sock): from haproxy_mcp.tools.domains import register_domain_tools mcp = MagicMock() registered_tools = {} def capture_tool(): def decorator(func): registered_tools[func.__name__] = func return func return decorator mcp.tool = capture_tool register_domain_tools(mcp) result = registered_tools["haproxy_remove_domain"](domain="example.com") assert "example.com" in result assert "removed" in result.lower() class TestCheckCertificateCoverage: """Tests for check_certificate_coverage function.""" def test_no_cert_directory(self, tmp_path): """No certificate coverage when directory doesn't exist.""" from haproxy_mcp.tools.domains import check_certificate_coverage with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(tmp_path / "nonexistent")): covered, info = check_certificate_coverage("example.com") assert covered is False assert "not found" in info.lower() def test_exact_cert_match(self, tmp_path): """Exact certificate match.""" from haproxy_mcp.tools.domains import check_certificate_coverage certs_dir = tmp_path / "certs" certs_dir.mkdir() (certs_dir / "example.com.pem").write_text("cert content") with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)): covered, info = check_certificate_coverage("example.com") assert covered is True assert info == "example.com" def test_wildcard_cert_coverage(self, tmp_path, mock_subprocess): """Wildcard certificate covers subdomain.""" from haproxy_mcp.tools.domains import check_certificate_coverage certs_dir = tmp_path / "certs" certs_dir.mkdir() (certs_dir / "example.com.pem").write_text("cert content") # Mock openssl output showing wildcard SAN mock_subprocess.return_value = MagicMock( returncode=0, stdout="X509v3 Subject Alternative Name:\n DNS:example.com, DNS:*.example.com" ) with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)): covered, info = check_certificate_coverage("api.example.com") assert covered is True assert "wildcard" in info def test_no_matching_cert(self, tmp_path): """No matching certificate.""" from haproxy_mcp.tools.domains import check_certificate_coverage certs_dir = tmp_path / "certs" certs_dir.mkdir() with patch("haproxy_mcp.tools.domains.CERTS_DIR", str(certs_dir)): covered, info = check_certificate_coverage("example.com") assert covered is False assert "No matching" in info