Files
haproxy-mcp/haproxy_mcp/tools/domains.py
kappa 170c48e257 Detect subdomains structurally to skip wildcard entries without certs
Add CUSTOM_TLDS config (HAPROXY_CUSTOM_TLDS env, default: "it.com")
and _get_base_domain() for eTLD+1 detection. _check_subdomain now uses
three layers: registered domains, certificate domains, and structural
analysis. This ensures nocodb.inouter.com never gets a *.nocodb wildcard
entry even when inouter.com has no cert or registration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 20:34:57 +09:00

411 lines
16 KiB
Python

"""Domain management tools for HAProxy MCP Server."""
import os
from typing import Annotated, Optional
from pydantic import Field
from ..config import (
MAP_FILE_CONTAINER,
WILDCARDS_MAP_FILE_CONTAINER,
POOL_COUNT,
MAX_SLOTS,
SUBPROCESS_TIMEOUT,
CERTS_DIR,
CUSTOM_TLDS,
REMOTE_MODE,
logger,
)
from ..ssh_ops import run_command, remote_file_exists
from ..exceptions import HaproxyError
from ..validation import validate_domain, validate_ip, validate_port_int
from ..haproxy_client import haproxy_cmd
from ..file_ops import (
get_map_contents,
get_domain_backend,
is_legacy_backend,
add_server_to_config,
remove_server_from_config,
remove_domain_from_config,
add_shared_domain_to_config,
get_domains_sharing_pool,
is_shared_domain,
add_domain_to_map,
remove_domain_from_map,
find_available_pool,
)
from ..db import db_load_certs
from ..utils import parse_servers_state, disable_server_slot
def _get_base_domain(domain: str) -> Optional[str]:
"""Get the base domain (eTLD+1) considering custom multi-part TLDs.
Examples (with CUSTOM_TLDS={"it.com"}):
inouter.com -> inouter.com (base domain itself)
nocodb.inouter.com -> inouter.com
anvil.it.com -> anvil.it.com (base domain, it.com is TLD)
gitea.anvil.it.com -> anvil.it.com
Returns:
The base domain, or None if the domain is a TLD itself.
"""
parts = domain.split(".")
# Check custom multi-part TLDs first (e.g., it.com)
for tld in CUSTOM_TLDS:
tld_parts = tld.split(".")
if len(parts) > len(tld_parts) and domain.endswith("." + tld):
return ".".join(parts[-(len(tld_parts) + 1):])
# Standard single-part TLD (e.g., .com, .net, .org)
if len(parts) >= 2:
return ".".join(parts[-2:])
return None
def _check_subdomain(domain: str, registered_domains: set[str]) -> tuple[bool, Optional[str]]:
"""Check if a domain is a subdomain using structural analysis and known domains.
Uses three layers of detection:
1. Registered domains: vault.anvil.it.com is a subdomain if anvil.it.com is registered.
2. Certificate domains: nocodb.inouter.com is a subdomain if inouter.com has a cert.
3. Structural analysis: nocodb.inouter.com is deeper than eTLD+1 (inouter.com).
Wildcard entries are skipped for subdomains because wildcard certs
(*.example.com) only cover one level deep.
Args:
domain: Domain name to check (e.g., "api.example.com").
registered_domains: Set of already registered domain names.
Returns:
Tuple of (is_subdomain, parent_domain or None).
"""
# Check against registered domains and certificate domains
cert_domains = set(db_load_certs())
known_domains = registered_domains | cert_domains
parts = domain.split(".")
for i in range(1, len(parts)):
candidate = ".".join(parts[i:])
if candidate in known_domains:
return True, candidate
# Structural analysis: if domain is deeper than its base domain, it's a subdomain
base = _get_base_domain(domain)
if base and domain != base:
return True, base
return False, None
def _update_haproxy_maps(domain: str, pool: str, is_subdomain: bool) -> None:
"""Update HAProxy maps via Runtime API.
Uses 2-stage matching: exact domains go to domains.map,
wildcards go to wildcards.map.
Args:
domain: Domain name to add.
pool: Pool backend name (e.g., "pool_5").
is_subdomain: If True, skip adding wildcard entry.
Raises:
HaproxyError: If HAProxy Runtime API command fails.
"""
haproxy_cmd(f"add map {MAP_FILE_CONTAINER} {domain} {pool}")
if not is_subdomain:
haproxy_cmd(f"add map {WILDCARDS_MAP_FILE_CONTAINER} .{domain} {pool}")
def _rollback_domain_addition(domain: str) -> None:
"""Rollback a failed domain addition by removing from SQLite + map files.
Called when HAProxy Runtime API update fails after the domain
has already been saved to the database.
Args:
domain: Domain name that was added.
"""
try:
remove_domain_from_map(domain)
except (IOError, Exception):
logger.error("Failed to rollback domain %s after HAProxy error", domain)
def _file_exists(path: str) -> bool:
"""Check file existence locally or remotely."""
if REMOTE_MODE:
return remote_file_exists(path)
return os.path.exists(path)
def check_certificate_coverage(domain: str) -> tuple[bool, str]:
"""Check if a domain is covered by an existing certificate.
Args:
domain: Domain name to check (e.g., api.example.com)
Returns:
Tuple of (is_covered, certificate_name or message)
"""
if REMOTE_MODE:
dir_check = run_command(["test", "-d", CERTS_DIR])
if dir_check.returncode != 0:
return False, "Certificate directory not found"
elif not os.path.isdir(CERTS_DIR):
return False, "Certificate directory not found"
# Check for exact match first
exact_pem = f"{CERTS_DIR}/{domain}.pem"
if _file_exists(exact_pem):
return True, domain
# Check for wildcard coverage (e.g., api.example.com covered by *.example.com)
parts = domain.split(".")
if len(parts) >= 2:
parent_domain = ".".join(parts[1:])
parent_pem = f"{CERTS_DIR}/{parent_domain}.pem"
if _file_exists(parent_pem):
try:
result = run_command(
["openssl", "x509", "-in", parent_pem, "-noout", "-ext", "subjectAltName"],
timeout=SUBPROCESS_TIMEOUT,
)
if result.returncode == 0:
wildcard = f"*.{parent_domain}"
if wildcard in result.stdout:
return True, f"{parent_domain} (wildcard)"
except (TimeoutError, OSError):
pass
return False, "No matching certificate"
def register_domain_tools(mcp):
"""Register domain management tools with MCP server."""
@mcp.tool()
def haproxy_list_domains(
include_wildcards: Annotated[bool, Field(default=False, description="Include wildcard entries (.example.com). Default: False")] = False
) -> str:
"""List all configured domains with their backend servers."""
try:
domains = []
state = haproxy_cmd("show servers state")
parsed_state = parse_servers_state(state)
# Build server map from HAProxy state
server_map: dict[str, list[str]] = {}
for backend, servers_dict in parsed_state.items():
for server_name, srv_info in servers_dict.items():
if srv_info["addr"] != "0.0.0.0":
if backend not in server_map:
server_map[backend] = []
server_map[backend].append(
f"{server_name}={srv_info['addr']}:{srv_info['port']}"
)
# Read from domains.map
seen_domains: set[str] = set()
for domain, backend in get_map_contents():
# Skip wildcard entries unless explicitly requested
if domain.startswith(".") and not include_wildcards:
continue
if domain in seen_domains:
continue
seen_domains.add(domain)
servers = server_map.get(backend, ["(none)"])
if domain.startswith("."):
backend_type = "wildcard"
elif backend.startswith("pool_"):
backend_type = "pool"
else:
backend_type = "static"
domains.append(f"{domain} -> {backend} ({backend_type}): {', '.join(servers)}")
return "\n".join(domains) if domains else "No domains configured"
except HaproxyError as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_add_domain(
domain: Annotated[str, Field(description="Domain name to add (e.g., api.example.com, example.com)")],
ip: Annotated[str, Field(default="", description="Optional: Initial server IP. If provided, adds server to slot 1")] = "",
http_port: Annotated[int, Field(default=80, description="HTTP port for backend server (default: 80)")] = 80,
share_with: Annotated[str, Field(default="", description="Optional: Existing domain to share pool with. New domain uses same backend servers.")] = ""
) -> str:
"""Add a new domain to HAProxy (no reload required).
Creates domain→pool mapping. Use haproxy_add_server to add more servers later.
Pool sharing: Use share_with to reuse an existing domain's pool. This saves pool
slots when multiple domains point to the same backend servers.
Example: haproxy_add_domain("api.example.com", ip="10.0.0.1", http_port=8080)
Example: haproxy_add_domain("www.example.com", share_with="example.com")
"""
# Validate inputs
if domain.startswith("."):
return "Error: Domain cannot start with '.' (wildcard entries are added automatically)"
if not validate_domain(domain):
return "Error: Invalid domain format"
if not validate_ip(ip, allow_empty=True):
return "Error: Invalid IP address format"
if not validate_port_int(http_port):
return "Error: Port must be between 1 and 65535"
if share_with and not validate_domain(share_with):
return "Error: Invalid share_with domain format"
if share_with and ip:
return "Error: Cannot specify both ip and share_with (shared domains use existing servers)"
# Read current entries for existence check and subdomain detection
entries = get_map_contents()
# Check if domain already exists
for domain_entry, backend in entries:
if domain_entry == domain:
return f"Error: Domain {domain} already exists (mapped to {backend})"
# Build registered domains set for subdomain check
registered_domains: set[str] = set()
for entry_domain, _ in entries:
if not entry_domain.startswith("."):
registered_domains.add(entry_domain)
# Handle share_with: reuse existing domain's pool
if share_with:
share_backend = get_domain_backend(share_with)
if not share_backend:
return f"Error: Domain {share_with} not found"
if not share_backend.startswith("pool_"):
return f"Error: Cannot share with legacy backend {share_backend}"
pool = share_backend
else:
# Find available pool (SQLite query, O(1))
pool = find_available_pool()
if not pool:
return f"Error: All {POOL_COUNT} pool backends are in use"
# Check if this is a subdomain of an existing domain
is_subdomain, parent_domain = _check_subdomain(domain, registered_domains)
try:
# Save to SQLite + sync map files (atomic via SQLite transaction)
try:
add_domain_to_map(domain, pool)
if not is_subdomain:
add_domain_to_map(f".{domain}", pool, is_wildcard=True)
except (IOError, Exception) as e:
return f"Error: Failed to save domain: {e}"
# Update HAProxy maps via Runtime API
try:
_update_haproxy_maps(domain, pool, is_subdomain)
except HaproxyError as e:
_rollback_domain_addition(domain)
return f"Error: Failed to update HAProxy map: {e}"
# Handle server configuration based on mode
if share_with:
# Save shared domain reference
add_shared_domain_to_config(domain, share_with)
result = f"Domain {domain} added, sharing pool {pool} with {share_with}"
elif ip:
# Add server to slot 1
add_server_to_config(domain, 1, ip, http_port)
try:
server = f"{pool}_1"
haproxy_cmd(f"set server {pool}/{server} addr {ip} port {http_port}")
haproxy_cmd(f"set server {pool}/{server} state ready")
except HaproxyError as e:
remove_server_from_config(domain, 1)
return f"Domain {domain} added to {pool} but server config failed: {e}"
result = f"Domain {domain} added to {pool} with server {ip}:{http_port}"
else:
result = f"Domain {domain} added to {pool} (no servers configured)"
if is_subdomain:
result += f" (subdomain of {parent_domain}, no wildcard)"
# Check certificate coverage
cert_covered, cert_info = check_certificate_coverage(domain)
if cert_covered:
result += f"\nSSL: Using certificate {cert_info}"
else:
result += f"\nSSL: No certificate found. Use haproxy_issue_cert(\"{domain}\") to issue one."
return result
except HaproxyError as e:
return f"Error: {e}"
@mcp.tool()
def haproxy_remove_domain(
domain: Annotated[str, Field(description="Domain name to remove (e.g., api.example.com)")]
) -> str:
"""Remove a domain from HAProxy (no reload required)."""
if not validate_domain(domain):
return "Error: Invalid domain format"
# Look up the domain in the map
backend = get_domain_backend(domain)
if not backend:
return f"Error: Domain {domain} not found"
# Check if this is a legacy backend (not a pool)
if is_legacy_backend(backend):
return f"Error: Cannot remove legacy domain {domain} (uses static backend {backend})"
try:
# Check if this domain is sharing another domain's pool
domain_is_shared = is_shared_domain(domain)
# Check if other domains are sharing this pool
domains_using_pool = get_domains_sharing_pool(backend)
other_domains = [d for d in domains_using_pool if d != domain]
# Remove from SQLite + sync map files
remove_domain_from_map(domain)
# Remove from persistent server config
remove_domain_from_config(domain)
# Clear map entries via Runtime API (immediate effect)
# 2-stage matching: exact from domains.map, wildcard from wildcards.map
haproxy_cmd(f"del map {MAP_FILE_CONTAINER} {domain}")
try:
haproxy_cmd(f"del map {WILDCARDS_MAP_FILE_CONTAINER} .{domain}")
except HaproxyError as e:
logger.warning("Failed to remove wildcard entry for %s: %s", domain, e)
# Only clear servers if no other domains are using this pool
if other_domains:
return f"Domain {domain} removed from {backend} (pool still used by: {', '.join(other_domains)})"
# If this domain was sharing another domain's pool, don't clear servers
if domain_is_shared:
return f"Domain {domain} removed from shared pool {backend}"
# Disable all servers in the pool (reset to 0.0.0.0:0)
for slot in range(1, MAX_SLOTS + 1):
server = f"{backend}_{slot}"
try:
disable_server_slot(backend, server)
except HaproxyError as e:
logger.warning(
"Failed to clear server %s/%s for domain %s: %s",
backend, server, domain, e
)
# Continue with remaining cleanup
return f"Domain {domain} removed from {backend}"
except IOError as e:
return f"Error: Failed to update map file: {e}"
except HaproxyError as e:
return f"Error: {e}"