refactor: migrate data storage from JSON/map files to SQLite
Replace servers.json, certificates.json, and map file parsing with SQLite (WAL mode) as single source of truth. HAProxy map files are now generated from SQLite via sync_map_files(). Key changes: - Add db.py with schema, connection management, and JSON migration - Add DB_FILE config constant - Delegate file_ops.py functions to db.py - Refactor domains.py to use file_ops instead of direct list manipulation - Fix subprocess.TimeoutExpired not caught (doesn't inherit TimeoutError) - Add DB health check in health.py - Init DB on startup in server.py and __main__.py - Update all 359 tests to use SQLite-backed functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -255,6 +255,8 @@ def temp_config_dir(tmp_path):
|
||||
state_file = tmp_path / "servers.state"
|
||||
state_file.write_text("")
|
||||
|
||||
db_file = tmp_path / "haproxy_mcp.db"
|
||||
|
||||
return {
|
||||
"dir": tmp_path,
|
||||
"map_file": str(map_file),
|
||||
@@ -262,12 +264,15 @@ def temp_config_dir(tmp_path):
|
||||
"servers_file": str(servers_file),
|
||||
"certs_file": str(certs_file),
|
||||
"state_file": str(state_file),
|
||||
"db_file": str(db_file),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_config_paths(temp_config_dir):
|
||||
"""Fixture that patches config module paths to use temporary directory."""
|
||||
from haproxy_mcp.db import close_connection, init_db
|
||||
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.config",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
@@ -275,16 +280,34 @@ def patch_config_paths(temp_config_dir):
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
STATE_FILE=temp_config_dir["state_file"],
|
||||
DB_FILE=temp_config_dir["db_file"],
|
||||
):
|
||||
# Also patch file_ops module which imports these
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.file_ops",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
):
|
||||
yield temp_config_dir
|
||||
# Patch db module which imports these
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.db",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
WILDCARDS_MAP_FILE=temp_config_dir["wildcards_file"],
|
||||
SERVERS_FILE=temp_config_dir["servers_file"],
|
||||
CERTS_FILE=temp_config_dir["certs_file"],
|
||||
DB_FILE=temp_config_dir["db_file"],
|
||||
):
|
||||
# Patch health module which imports MAP_FILE and DB_FILE
|
||||
with patch.multiple(
|
||||
"haproxy_mcp.tools.health",
|
||||
MAP_FILE=temp_config_dir["map_file"],
|
||||
DB_FILE=temp_config_dir["db_file"],
|
||||
):
|
||||
# Close any existing connection and initialize fresh DB
|
||||
close_connection()
|
||||
init_db()
|
||||
yield temp_config_dir
|
||||
close_connection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
433
tests/unit/test_db.py
Normal file
433
tests/unit/test_db.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Unit tests for db module (SQLite database operations)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.db import (
|
||||
get_connection,
|
||||
close_connection,
|
||||
init_db,
|
||||
migrate_from_json,
|
||||
db_get_map_contents,
|
||||
db_get_domain_backend,
|
||||
db_add_domain,
|
||||
db_remove_domain,
|
||||
db_find_available_pool,
|
||||
db_get_domains_sharing_pool,
|
||||
db_load_servers_config,
|
||||
db_add_server,
|
||||
db_remove_server,
|
||||
db_remove_domain_servers,
|
||||
db_add_shared_domain,
|
||||
db_get_shared_domain,
|
||||
db_is_shared_domain,
|
||||
db_load_certs,
|
||||
db_add_cert,
|
||||
db_remove_cert,
|
||||
sync_map_files,
|
||||
SCHEMA_VERSION,
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionManagement:
|
||||
"""Tests for database connection management."""
|
||||
|
||||
def test_get_connection(self, patch_config_paths):
|
||||
"""Get a connection returns a valid SQLite connection."""
|
||||
conn = get_connection()
|
||||
assert conn is not None
|
||||
# Verify WAL mode
|
||||
result = conn.execute("PRAGMA journal_mode").fetchone()
|
||||
assert result[0] == "wal"
|
||||
|
||||
def test_connection_is_thread_local(self, patch_config_paths):
|
||||
"""Same thread gets same connection."""
|
||||
conn1 = get_connection()
|
||||
conn2 = get_connection()
|
||||
assert conn1 is conn2
|
||||
|
||||
def test_close_connection(self, patch_config_paths):
|
||||
"""Close connection clears thread-local."""
|
||||
conn1 = get_connection()
|
||||
close_connection()
|
||||
conn2 = get_connection()
|
||||
assert conn1 is not conn2
|
||||
|
||||
|
||||
class TestInitDb:
|
||||
"""Tests for database initialization."""
|
||||
|
||||
def test_creates_tables(self, patch_config_paths):
|
||||
"""init_db creates all required tables."""
|
||||
conn = get_connection()
|
||||
|
||||
# Check tables exist
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||
).fetchall()
|
||||
table_names = [t["name"] for t in tables]
|
||||
|
||||
assert "domains" in table_names
|
||||
assert "servers" in table_names
|
||||
assert "certificates" in table_names
|
||||
assert "schema_version" in table_names
|
||||
|
||||
def test_schema_version_recorded(self, patch_config_paths):
|
||||
"""Schema version is recorded."""
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT MAX(version) FROM schema_version")
|
||||
version = cur.fetchone()[0]
|
||||
assert version == SCHEMA_VERSION
|
||||
|
||||
def test_idempotent(self, patch_config_paths):
|
||||
"""Calling init_db twice is safe."""
|
||||
# init_db is already called by patch_config_paths
|
||||
# Calling it again should not raise
|
||||
init_db()
|
||||
conn = get_connection()
|
||||
cur = conn.execute("SELECT COUNT(*) FROM schema_version")
|
||||
# May have 1 or 2 entries but should not fail
|
||||
assert cur.fetchone()[0] >= 1
|
||||
|
||||
|
||||
class TestMigrateFromJson:
|
||||
"""Tests for JSON to SQLite migration."""
|
||||
|
||||
def test_migrate_map_files(self, patch_config_paths):
|
||||
"""Migrate domain entries from map files."""
|
||||
# Write test map files
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
entries = db_get_map_contents()
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert ("api.example.com", "pool_2") in entries
|
||||
assert (".example.com", "pool_1") in entries
|
||||
|
||||
def test_migrate_servers_json(self, patch_config_paths):
|
||||
"""Migrate server entries from servers.json."""
|
||||
config = {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||
"2": {"ip": "10.0.0.2", "http_port": 8080},
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
result = db_load_servers_config()
|
||||
assert "example.com" in result
|
||||
assert result["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert result["example.com"]["2"]["http_port"] == 8080
|
||||
|
||||
def test_migrate_shared_domains(self, patch_config_paths):
|
||||
"""Migrate shared domain references."""
|
||||
# First add the domain to map so it exists in DB
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("www.example.com pool_1\n")
|
||||
|
||||
config = {
|
||||
"www.example.com": {"_shares": "example.com"},
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
assert db_get_shared_domain("www.example.com") == "example.com"
|
||||
|
||||
def test_migrate_certificates(self, patch_config_paths):
|
||||
"""Migrate certificate entries."""
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "api.example.com"]}, f)
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
certs = db_load_certs()
|
||||
assert "example.com" in certs
|
||||
assert "api.example.com" in certs
|
||||
|
||||
def test_migrate_idempotent(self, patch_config_paths):
|
||||
"""Migration is idempotent (INSERT OR IGNORE)."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
|
||||
|
||||
migrate_from_json()
|
||||
migrate_from_json() # Should not fail
|
||||
|
||||
entries = db_get_map_contents()
|
||||
assert len([d for d, _ in entries if d == "example.com"]) == 1
|
||||
|
||||
def test_migrate_empty_files(self, patch_config_paths):
|
||||
"""Migration with no existing data does nothing."""
|
||||
os.unlink(patch_config_paths["map_file"])
|
||||
os.unlink(patch_config_paths["wildcards_file"])
|
||||
os.unlink(patch_config_paths["servers_file"])
|
||||
os.unlink(patch_config_paths["certs_file"])
|
||||
|
||||
migrate_from_json() # Should not fail
|
||||
|
||||
def test_backup_files_after_migration(self, patch_config_paths):
|
||||
"""Original JSON files are backed up after migration."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump({"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}, f)
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com"]}, f)
|
||||
|
||||
# Write map file so migration has data
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
migrate_from_json()
|
||||
|
||||
assert os.path.exists(f"{patch_config_paths['servers_file']}.bak")
|
||||
assert os.path.exists(f"{patch_config_paths['certs_file']}.bak")
|
||||
|
||||
|
||||
class TestDomainOperations:
|
||||
"""Tests for domain data access functions."""
|
||||
|
||||
def test_add_and_get_domain(self, patch_config_paths):
|
||||
"""Add domain and retrieve its backend."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
|
||||
assert db_get_domain_backend("example.com") == "pool_1"
|
||||
|
||||
def test_get_nonexistent_domain(self, patch_config_paths):
|
||||
"""Non-existent domain returns None."""
|
||||
assert db_get_domain_backend("nonexistent.com") is None
|
||||
|
||||
def test_remove_domain(self, patch_config_paths):
|
||||
"""Remove domain removes both exact and wildcard."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
db_remove_domain("example.com")
|
||||
|
||||
assert db_get_domain_backend("example.com") is None
|
||||
entries = db_get_map_contents()
|
||||
assert (".example.com", "pool_1") not in entries
|
||||
|
||||
def test_find_available_pool(self, patch_config_paths):
|
||||
"""Find first available pool."""
|
||||
db_add_domain("a.com", "pool_1")
|
||||
db_add_domain("b.com", "pool_2")
|
||||
|
||||
pool = db_find_available_pool()
|
||||
assert pool == "pool_3"
|
||||
|
||||
def test_find_available_pool_empty(self, patch_config_paths):
|
||||
"""First pool available when none used."""
|
||||
pool = db_find_available_pool()
|
||||
assert pool == "pool_1"
|
||||
|
||||
def test_get_domains_sharing_pool(self, patch_config_paths):
|
||||
"""Get non-wildcard domains using a pool."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
domains = db_get_domains_sharing_pool("pool_1")
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
assert ".example.com" not in domains
|
||||
|
||||
def test_update_domain_backend(self, patch_config_paths):
|
||||
"""Updating a domain changes its backend."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("example.com", "pool_5") # Update
|
||||
|
||||
assert db_get_domain_backend("example.com") == "pool_5"
|
||||
|
||||
|
||||
class TestServerOperations:
|
||||
"""Tests for server data access functions."""
|
||||
|
||||
def test_add_and_load_server(self, patch_config_paths):
|
||||
"""Add server and load config."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert config["example.com"]["1"]["http_port"] == 80
|
||||
|
||||
def test_update_server(self, patch_config_paths):
|
||||
"""Update existing server slot."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
db_add_server("example.com", 1, "10.0.0.99", 8080)
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.99"
|
||||
assert config["example.com"]["1"]["http_port"] == 8080
|
||||
|
||||
def test_remove_server(self, patch_config_paths):
|
||||
"""Remove a server slot."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
db_add_server("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
db_remove_server("example.com", 1)
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert "1" not in config.get("example.com", {})
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_remove_domain_servers(self, patch_config_paths):
|
||||
"""Remove all servers for a domain."""
|
||||
db_add_server("example.com", 1, "10.0.0.1", 80)
|
||||
db_add_server("example.com", 2, "10.0.0.2", 80)
|
||||
db_add_server("other.com", 1, "10.0.0.3", 80)
|
||||
|
||||
db_remove_domain_servers("example.com")
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config.get("example.com", {}).get("1") is None
|
||||
assert "other.com" in config
|
||||
|
||||
def test_load_empty(self, patch_config_paths):
|
||||
"""Empty database returns empty dict."""
|
||||
config = db_load_servers_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestSharedDomainOperations:
|
||||
"""Tests for shared domain functions."""
|
||||
|
||||
def test_add_and_get_shared(self, patch_config_paths):
|
||||
"""Add shared domain reference."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_shared_domain("www.example.com", "example.com")
|
||||
|
||||
assert db_get_shared_domain("www.example.com") == "example.com"
|
||||
|
||||
def test_is_shared(self, patch_config_paths):
|
||||
"""Check if domain is shared."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_shared_domain("www.example.com", "example.com")
|
||||
|
||||
assert db_is_shared_domain("www.example.com") is True
|
||||
assert db_is_shared_domain("example.com") is False
|
||||
|
||||
def test_not_shared(self, patch_config_paths):
|
||||
"""Non-shared domain returns None."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
|
||||
assert db_get_shared_domain("example.com") is None
|
||||
assert db_is_shared_domain("example.com") is False
|
||||
|
||||
def test_shared_in_load_config(self, patch_config_paths):
|
||||
"""Shared domain appears in load_servers_config."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("www.example.com", "pool_1")
|
||||
db_add_shared_domain("www.example.com", "example.com")
|
||||
|
||||
config = db_load_servers_config()
|
||||
assert config["www.example.com"]["_shares"] == "example.com"
|
||||
|
||||
|
||||
class TestCertificateOperations:
|
||||
"""Tests for certificate data access functions."""
|
||||
|
||||
def test_add_and_load_cert(self, patch_config_paths):
|
||||
"""Add and load certificate."""
|
||||
db_add_cert("example.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert "example.com" in certs
|
||||
|
||||
def test_add_duplicate(self, patch_config_paths):
|
||||
"""Adding duplicate cert is a no-op."""
|
||||
db_add_cert("example.com")
|
||||
db_add_cert("example.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert certs.count("example.com") == 1
|
||||
|
||||
def test_remove_cert(self, patch_config_paths):
|
||||
"""Remove a certificate."""
|
||||
db_add_cert("example.com")
|
||||
db_add_cert("other.com")
|
||||
|
||||
db_remove_cert("example.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert "example.com" not in certs
|
||||
assert "other.com" in certs
|
||||
|
||||
def test_load_empty(self, patch_config_paths):
|
||||
"""Empty database returns empty list."""
|
||||
certs = db_load_certs()
|
||||
assert certs == []
|
||||
|
||||
def test_sorted_output(self, patch_config_paths):
|
||||
"""Certificates are returned sorted."""
|
||||
db_add_cert("z.com")
|
||||
db_add_cert("a.com")
|
||||
db_add_cert("m.com")
|
||||
|
||||
certs = db_load_certs()
|
||||
assert certs == ["a.com", "m.com", "z.com"]
|
||||
|
||||
|
||||
class TestSyncMapFiles:
|
||||
"""Tests for sync_map_files function."""
|
||||
|
||||
def test_sync_exact_domains(self, patch_config_paths):
|
||||
"""Sync writes exact domains to domains.map."""
|
||||
db_add_domain("example.com", "pool_1")
|
||||
db_add_domain("api.example.com", "pool_2")
|
||||
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "example.com pool_1" in content
|
||||
assert "api.example.com pool_2" in content
|
||||
|
||||
def test_sync_wildcards(self, patch_config_paths):
|
||||
"""Sync writes wildcards to wildcards.map."""
|
||||
db_add_domain(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["wildcards_file"]) as f:
|
||||
content = f.read()
|
||||
assert ".example.com pool_1" in content
|
||||
|
||||
def test_sync_empty(self, patch_config_paths):
|
||||
"""Sync with no domains writes headers only."""
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "Exact Domain" in content
|
||||
# No domain entries
|
||||
lines = [l.strip() for l in content.splitlines() if l.strip() and not l.startswith("#")]
|
||||
assert len(lines) == 0
|
||||
|
||||
def test_sync_sorted(self, patch_config_paths):
|
||||
"""Sync output is sorted."""
|
||||
db_add_domain("z.com", "pool_3")
|
||||
db_add_domain("a.com", "pool_1")
|
||||
|
||||
sync_map_files()
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||
assert lines[0] == "a.com pool_1"
|
||||
assert lines[1] == "z.com pool_3"
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Unit tests for file_ops module."""
|
||||
"""Unit tests for file_ops module (SQLite-backed)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -16,14 +15,19 @@ from haproxy_mcp.file_ops import (
|
||||
get_legacy_backend_name,
|
||||
get_backend_and_prefix,
|
||||
load_servers_config,
|
||||
save_servers_config,
|
||||
add_server_to_config,
|
||||
remove_server_from_config,
|
||||
remove_domain_from_config,
|
||||
load_certs_config,
|
||||
save_certs_config,
|
||||
add_cert_to_config,
|
||||
remove_cert_from_config,
|
||||
add_domain_to_map,
|
||||
remove_domain_from_map,
|
||||
find_available_pool,
|
||||
add_shared_domain_to_config,
|
||||
get_shared_domain,
|
||||
is_shared_domain,
|
||||
get_domains_sharing_pool,
|
||||
)
|
||||
|
||||
|
||||
@@ -62,7 +66,7 @@ class TestAtomicWriteFile:
|
||||
def test_unicode_content(self, tmp_path):
|
||||
"""Unicode content is properly written."""
|
||||
file_path = str(tmp_path / "unicode.txt")
|
||||
content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese
|
||||
content = "Hello, \u4e16\u754c!"
|
||||
|
||||
atomic_write_file(file_path, content)
|
||||
|
||||
@@ -81,66 +85,33 @@ class TestAtomicWriteFile:
|
||||
|
||||
|
||||
class TestGetMapContents:
|
||||
"""Tests for get_map_contents function."""
|
||||
"""Tests for get_map_contents function (SQLite-backed)."""
|
||||
|
||||
def test_read_map_file(self, patch_config_paths):
|
||||
"""Read entries from map file."""
|
||||
# Write test content to map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
def test_empty_db(self, patch_config_paths):
|
||||
"""Empty database returns empty list."""
|
||||
entries = get_map_contents()
|
||||
assert entries == []
|
||||
|
||||
def test_read_domains(self, patch_config_paths):
|
||||
"""Read entries from database."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map("api.example.com", "pool_2")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert ("api.example.com", "pool_2") in entries
|
||||
|
||||
def test_read_both_map_files(self, patch_config_paths):
|
||||
"""Read entries from both domains.map and wildcards.map."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
def test_read_with_wildcards(self, patch_config_paths):
|
||||
"""Read entries including wildcards."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert ("example.com", "pool_1") in entries
|
||||
assert (".example.com", "pool_1") in entries
|
||||
|
||||
def test_skip_comments(self, patch_config_paths):
|
||||
"""Comments are skipped."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("# This is a comment\n")
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("# Another comment\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0] == ("example.com", "pool_1")
|
||||
|
||||
def test_skip_empty_lines(self, patch_config_paths):
|
||||
"""Empty lines are skipped."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("\n")
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert len(entries) == 2
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty list."""
|
||||
os.unlink(patch_config_paths["map_file"])
|
||||
os.unlink(patch_config_paths["wildcards_file"])
|
||||
|
||||
entries = get_map_contents()
|
||||
|
||||
assert entries == []
|
||||
|
||||
|
||||
class TestSplitDomainEntries:
|
||||
"""Tests for split_domain_entries function."""
|
||||
@@ -182,36 +153,30 @@ class TestSplitDomainEntries:
|
||||
|
||||
|
||||
class TestSaveMapFile:
|
||||
"""Tests for save_map_file function."""
|
||||
"""Tests for save_map_file function (syncs from DB to map files)."""
|
||||
|
||||
def test_save_entries(self, patch_config_paths):
|
||||
"""Save entries to separate map files."""
|
||||
entries = [
|
||||
("example.com", "pool_1"),
|
||||
(".example.com", "pool_1"),
|
||||
]
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
save_map_file(entries)
|
||||
save_map_file([]) # Entries param ignored, syncs from DB
|
||||
|
||||
# Check exact domains file
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
content = f.read()
|
||||
assert "example.com pool_1" in content
|
||||
|
||||
# Check wildcards file
|
||||
with open(patch_config_paths["wildcards_file"]) as f:
|
||||
content = f.read()
|
||||
assert ".example.com pool_1" in content
|
||||
|
||||
def test_sorted_output(self, patch_config_paths):
|
||||
"""Entries are sorted in output."""
|
||||
entries = [
|
||||
("z.example.com", "pool_3"),
|
||||
("a.example.com", "pool_1"),
|
||||
("m.example.com", "pool_2"),
|
||||
]
|
||||
add_domain_to_map("z.example.com", "pool_3")
|
||||
add_domain_to_map("a.example.com", "pool_1")
|
||||
add_domain_to_map("m.example.com", "pool_2")
|
||||
|
||||
save_map_file(entries)
|
||||
save_map_file([])
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
lines = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||
@@ -222,12 +187,11 @@ class TestSaveMapFile:
|
||||
|
||||
|
||||
class TestGetDomainBackend:
|
||||
"""Tests for get_domain_backend function."""
|
||||
"""Tests for get_domain_backend function (SQLite-backed)."""
|
||||
|
||||
def test_find_existing_domain(self, patch_config_paths):
|
||||
"""Find backend for existing domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
backend = get_domain_backend("example.com")
|
||||
|
||||
@@ -235,8 +199,7 @@ class TestGetDomainBackend:
|
||||
|
||||
def test_domain_not_found(self, patch_config_paths):
|
||||
"""Non-existent domain returns None."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
backend = get_domain_backend("other.com")
|
||||
|
||||
@@ -271,8 +234,7 @@ class TestGetBackendAndPrefix:
|
||||
|
||||
def test_pool_backend(self, patch_config_paths):
|
||||
"""Pool backend returns pool-based prefix."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_5\n")
|
||||
add_domain_to_map("example.com", "pool_5")
|
||||
|
||||
backend, prefix = get_backend_and_prefix("example.com")
|
||||
|
||||
@@ -288,48 +250,33 @@ class TestGetBackendAndPrefix:
|
||||
|
||||
|
||||
class TestLoadServersConfig:
|
||||
"""Tests for load_servers_config function."""
|
||||
"""Tests for load_servers_config function (SQLite-backed)."""
|
||||
|
||||
def test_load_existing_config(self, patch_config_paths, sample_servers_config):
|
||||
"""Load existing config file."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
def test_load_empty_config(self, patch_config_paths):
|
||||
"""Empty database returns empty dict."""
|
||||
config = load_servers_config()
|
||||
assert config == {}
|
||||
|
||||
def test_load_with_servers(self, patch_config_paths):
|
||||
"""Load config with server entries."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert "example.com" in config
|
||||
assert config["example.com"]["1"]["ip"] == "10.0.0.1"
|
||||
assert config["example.com"]["2"]["ip"] == "10.0.0.2"
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty dict."""
|
||||
os.unlink(patch_config_paths["servers_file"])
|
||||
def test_load_with_shared_domain(self, patch_config_paths):
|
||||
"""Load config with shared domain reference."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map("www.example.com", "pool_1")
|
||||
add_shared_domain_to_config("www.example.com", "example.com")
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
def test_invalid_json(self, patch_config_paths):
|
||||
"""Invalid JSON returns empty dict."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
config = load_servers_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
|
||||
class TestSaveServersConfig:
|
||||
"""Tests for save_servers_config function."""
|
||||
|
||||
def test_save_config(self, patch_config_paths):
|
||||
"""Save config to file."""
|
||||
config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
||||
|
||||
save_servers_config(config)
|
||||
|
||||
with open(patch_config_paths["servers_file"]) as f:
|
||||
loaded = json.load(f)
|
||||
assert loaded == config
|
||||
assert config["www.example.com"]["_shares"] == "example.com"
|
||||
|
||||
|
||||
class TestAddServerToConfig:
|
||||
@@ -373,17 +320,18 @@ class TestRemoveServerFromConfig:
|
||||
remove_server_from_config("example.com", 1)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "1" not in config["example.com"]
|
||||
assert "1" not in config.get("example.com", {})
|
||||
assert "2" in config["example.com"]
|
||||
|
||||
def test_remove_last_server_removes_domain(self, patch_config_paths):
|
||||
"""Removing last server removes domain entry."""
|
||||
"""Removing last server removes domain entry from servers."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
|
||||
remove_server_from_config("example.com", 1)
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" not in config
|
||||
# Domain may or may not exist (no servers = no entry)
|
||||
assert config.get("example.com", {}).get("1") is None
|
||||
|
||||
def test_remove_nonexistent_server(self, patch_config_paths):
|
||||
"""Removing non-existent server is a no-op."""
|
||||
@@ -399,14 +347,14 @@ class TestRemoveDomainFromConfig:
|
||||
"""Tests for remove_domain_from_config function."""
|
||||
|
||||
def test_remove_existing_domain(self, patch_config_paths):
|
||||
"""Remove existing domain."""
|
||||
"""Remove existing domain's servers."""
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("other.com", 1, "10.0.0.2", 80)
|
||||
|
||||
remove_domain_from_config("example.com")
|
||||
|
||||
config = load_servers_config()
|
||||
assert "example.com" not in config
|
||||
assert config.get("example.com", {}).get("1") is None
|
||||
assert "other.com" in config
|
||||
|
||||
def test_remove_nonexistent_domain(self, patch_config_paths):
|
||||
@@ -420,40 +368,23 @@ class TestRemoveDomainFromConfig:
|
||||
|
||||
|
||||
class TestLoadCertsConfig:
|
||||
"""Tests for load_certs_config function."""
|
||||
"""Tests for load_certs_config function (SQLite-backed)."""
|
||||
|
||||
def test_load_existing_config(self, patch_config_paths):
|
||||
"""Load existing certs config."""
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "other.com"]}, f)
|
||||
def test_load_empty(self, patch_config_paths):
|
||||
"""Empty database returns empty list."""
|
||||
domains = load_certs_config()
|
||||
assert domains == []
|
||||
|
||||
def test_load_with_certs(self, patch_config_paths):
|
||||
"""Load certs from database."""
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("other.com")
|
||||
|
||||
domains = load_certs_config()
|
||||
|
||||
assert "example.com" in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
def test_file_not_found(self, patch_config_paths):
|
||||
"""Missing file returns empty list."""
|
||||
os.unlink(patch_config_paths["certs_file"])
|
||||
|
||||
domains = load_certs_config()
|
||||
|
||||
assert domains == []
|
||||
|
||||
|
||||
class TestSaveCertsConfig:
|
||||
"""Tests for save_certs_config function."""
|
||||
|
||||
def test_save_domains(self, patch_config_paths):
|
||||
"""Save domains to certs config."""
|
||||
save_certs_config(["z.com", "a.com"])
|
||||
|
||||
with open(patch_config_paths["certs_file"]) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Should be sorted
|
||||
assert data["domains"] == ["a.com", "z.com"]
|
||||
|
||||
|
||||
class TestAddCertToConfig:
|
||||
"""Tests for add_cert_to_config function."""
|
||||
@@ -496,3 +427,87 @@ class TestRemoveCertFromConfig:
|
||||
|
||||
domains = load_certs_config()
|
||||
assert "example.com" in domains
|
||||
|
||||
|
||||
class TestAddDomainToMap:
|
||||
"""Tests for add_domain_to_map function."""
|
||||
|
||||
def test_add_domain(self, patch_config_paths):
|
||||
"""Add a domain and verify map files are synced."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
assert get_domain_backend("example.com") == "pool_1"
|
||||
|
||||
with open(patch_config_paths["map_file"]) as f:
|
||||
assert "example.com pool_1" in f.read()
|
||||
|
||||
def test_add_wildcard(self, patch_config_paths):
|
||||
"""Add a wildcard domain."""
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
entries = get_map_contents()
|
||||
assert (".example.com", "pool_1") in entries
|
||||
|
||||
|
||||
class TestRemoveDomainFromMap:
|
||||
"""Tests for remove_domain_from_map function."""
|
||||
|
||||
def test_remove_domain(self, patch_config_paths):
|
||||
"""Remove a domain and its wildcard."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
remove_domain_from_map("example.com")
|
||||
|
||||
assert get_domain_backend("example.com") is None
|
||||
entries = get_map_contents()
|
||||
assert (".example.com", "pool_1") not in entries
|
||||
|
||||
|
||||
class TestFindAvailablePool:
|
||||
"""Tests for find_available_pool function."""
|
||||
|
||||
def test_first_pool_available(self, patch_config_paths):
|
||||
"""When no domains exist, pool_1 is returned."""
|
||||
pool = find_available_pool()
|
||||
assert pool == "pool_1"
|
||||
|
||||
def test_skip_used_pools(self, patch_config_paths):
|
||||
"""Used pools are skipped."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map("other.com", "pool_2")
|
||||
|
||||
pool = find_available_pool()
|
||||
assert pool == "pool_3"
|
||||
|
||||
|
||||
class TestSharedDomains:
|
||||
"""Tests for shared domain functions."""
|
||||
|
||||
def test_get_shared_domain(self, patch_config_paths):
|
||||
"""Get parent domain for shared domain."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map("www.example.com", "pool_1")
|
||||
add_shared_domain_to_config("www.example.com", "example.com")
|
||||
|
||||
assert get_shared_domain("www.example.com") == "example.com"
|
||||
|
||||
def test_is_shared_domain(self, patch_config_paths):
|
||||
"""Check if domain is shared."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map("www.example.com", "pool_1")
|
||||
add_shared_domain_to_config("www.example.com", "example.com")
|
||||
|
||||
assert is_shared_domain("www.example.com") is True
|
||||
assert is_shared_domain("example.com") is False
|
||||
|
||||
def test_get_domains_sharing_pool(self, patch_config_paths):
|
||||
"""Get all domains using a pool."""
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map("www.example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
domains = get_domains_sharing_pool("pool_1")
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
assert ".example.com" not in domains # Wildcards excluded
|
||||
|
||||
@@ -6,6 +6,8 @@ from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.file_ops import add_cert_to_config
|
||||
|
||||
|
||||
class TestGetPemPaths:
|
||||
"""Tests for get_pem_paths function."""
|
||||
@@ -127,8 +129,7 @@ class TestRestoreCertificates:
|
||||
def test_restore_certificates_success(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
|
||||
"""Restore certificates successfully."""
|
||||
# Save config
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com"]}, f)
|
||||
add_cert_to_config("example.com")
|
||||
|
||||
# Create PEM
|
||||
certs_dir = tmp_path / "certs"
|
||||
@@ -283,11 +284,17 @@ class TestHaproxyCertInfo:
|
||||
pem_file = tmp_path / "example.com.pem"
|
||||
pem_file.write_text("cert content")
|
||||
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT",
|
||||
stderr=""
|
||||
)
|
||||
def subprocess_side_effect(*args, **kwargs):
|
||||
cmd = args[0] if args else kwargs.get("args", [])
|
||||
if isinstance(cmd, list) and "stat" in cmd:
|
||||
return MagicMock(returncode=0, stdout="1704067200", stderr="")
|
||||
return MagicMock(
|
||||
returncode=0,
|
||||
stdout="subject=CN = example.com\nissuer=CN = Google Trust Services\nnotBefore=Jan 1 00:00:00 2024 GMT\nnotAfter=Apr 1 00:00:00 2024 GMT",
|
||||
stderr=""
|
||||
)
|
||||
|
||||
mock_subprocess.side_effect = subprocess_side_effect
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show ssl cert": "/etc/haproxy/certs/example.com.pem",
|
||||
@@ -337,25 +344,33 @@ class TestHaproxyIssueCert:
|
||||
assert "Error" in result
|
||||
assert "Invalid domain" in result
|
||||
|
||||
def test_issue_cert_no_cf_token(self, tmp_path):
|
||||
def test_issue_cert_no_cf_token(self, tmp_path, mock_subprocess):
|
||||
"""Fail when CF_Token is not set."""
|
||||
acme_sh = str(tmp_path / "acme.sh")
|
||||
mock_subprocess.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="CF_Token is not set. Please export CF_Token environment variable.",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path)):
|
||||
with patch("os.path.exists", return_value=False):
|
||||
from haproxy_mcp.tools.certificates import register_certificate_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
with patch("haproxy_mcp.tools.certificates.ACME_SH", acme_sh):
|
||||
with patch("os.path.exists", return_value=False):
|
||||
from haproxy_mcp.tools.certificates import register_certificate_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_certificate_tools(mcp)
|
||||
mcp.tool = capture_tool
|
||||
register_certificate_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True)
|
||||
result = registered_tools["haproxy_issue_cert"](domain="example.com", wildcard=True)
|
||||
|
||||
assert "CF_Token" in result
|
||||
|
||||
@@ -845,8 +860,8 @@ class TestHaproxyRenewAllCertsMultiple:
|
||||
def test_renew_all_certs_multiple_renewals(self, mock_subprocess, mock_socket_class, mock_select, patch_config_paths, tmp_path):
|
||||
"""Renew multiple certificates successfully."""
|
||||
# Write config with multiple domains
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "example.org"]}, f)
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("example.org")
|
||||
|
||||
# Create PEM files
|
||||
certs_dir = tmp_path / "certs"
|
||||
@@ -1038,30 +1053,32 @@ class TestHaproxyDeleteCertPartialFailure:
|
||||
"show ssl cert": "", # Not loaded
|
||||
})
|
||||
|
||||
# Mock os.remove to fail
|
||||
def mock_remove(path):
|
||||
if "example.com.pem" in str(path):
|
||||
raise PermissionError("Permission denied")
|
||||
raise FileNotFoundError()
|
||||
# Mock subprocess to succeed for acme.sh remove but fail for rm (PEM removal)
|
||||
def subprocess_side_effect(*args, **kwargs):
|
||||
cmd = args[0] if args else kwargs.get("args", [])
|
||||
if isinstance(cmd, list) and cmd[0] == "rm":
|
||||
return MagicMock(returncode=1, stdout="", stderr="Permission denied")
|
||||
return MagicMock(returncode=0, stdout="", stderr="")
|
||||
|
||||
mock_subprocess.side_effect = subprocess_side_effect
|
||||
|
||||
with patch("haproxy_mcp.tools.certificates.ACME_HOME", str(tmp_path / "acme")):
|
||||
with patch("haproxy_mcp.tools.certificates.CERTS_DIR", str(certs_dir)):
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with patch("os.remove", side_effect=mock_remove):
|
||||
from haproxy_mcp.tools.certificates import register_certificate_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
from haproxy_mcp.tools.certificates import register_certificate_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
def capture_tool():
|
||||
def decorator(func):
|
||||
registered_tools[func.__name__] = func
|
||||
return func
|
||||
return decorator
|
||||
|
||||
mcp.tool = capture_tool
|
||||
register_certificate_tools(mcp)
|
||||
mcp.tool = capture_tool
|
||||
register_certificate_tools(mcp)
|
||||
|
||||
result = registered_tools["haproxy_delete_cert"](domain="example.com")
|
||||
result = registered_tools["haproxy_delete_cert"](domain="example.com")
|
||||
|
||||
# Should report partial success (acme.sh deleted) and error (PEM failed)
|
||||
assert "Deleted" in result
|
||||
@@ -1118,8 +1135,8 @@ class TestRestoreCertificatesFailure:
|
||||
def test_restore_certificates_partial_failure(self, patch_config_paths, tmp_path, mock_socket_class, mock_select):
|
||||
"""Handle partial failure when restoring certificates."""
|
||||
# Save config with multiple domains
|
||||
with open(patch_config_paths["certs_file"], "w") as f:
|
||||
json.dump({"domains": ["example.com", "missing.com"]}, f)
|
||||
add_cert_to_config("example.com")
|
||||
add_cert_to_config("missing.com")
|
||||
|
||||
# Create only one PEM file
|
||||
certs_dir = tmp_path / "certs"
|
||||
|
||||
@@ -5,6 +5,8 @@ from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.file_ops import add_domain_to_map, add_server_to_config
|
||||
|
||||
|
||||
class TestRestoreServersFromConfig:
|
||||
"""Tests for restore_servers_from_config function."""
|
||||
@@ -19,12 +21,12 @@ class TestRestoreServersFromConfig:
|
||||
|
||||
def test_restore_servers_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||
"""Restore servers successfully."""
|
||||
# Write config and map
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
# Add domains and servers to database
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
add_domain_to_map("api.example.com", "pool_2")
|
||||
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"set server": "",
|
||||
@@ -40,9 +42,8 @@ class TestRestoreServersFromConfig:
|
||||
|
||||
def test_restore_servers_skip_missing_domain(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Skip domains not in map file."""
|
||||
config = {"unknown.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
# Add server for unknown.com but no map entry (simulates missing domain)
|
||||
add_server_to_config("unknown.com", 1, "10.0.0.1", 80)
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
@@ -55,11 +56,9 @@ class TestRestoreServersFromConfig:
|
||||
|
||||
def test_restore_servers_skip_empty_ip(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Skip servers with empty IP."""
|
||||
config = {"example.com": {"1": {"ip": "", "http_port": 80}}}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
# Add domain to map and server with empty IP (will be skipped during restore)
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_server_to_config("example.com", 1, "", 80)
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
@@ -321,11 +320,12 @@ class TestHaproxyRestoreState:
|
||||
|
||||
def test_restore_state_success(self, mock_socket_class, mock_select, patch_config_paths, sample_servers_config):
|
||||
"""Restore state successfully."""
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
# Add domains and servers to database
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
add_domain_to_map("api.example.com", "pool_2")
|
||||
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
@@ -373,17 +373,10 @@ class TestRestoreServersFromConfigBatchFailure:
|
||||
|
||||
def test_restore_servers_batch_failure_fallback(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Fall back to individual commands when batch fails."""
|
||||
# Create config with servers
|
||||
config = {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": 80},
|
||||
"2": {"ip": "10.0.0.2", "http_port": 80},
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
# Add domain and servers to database
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
|
||||
# Track call count to simulate batch failure then individual success
|
||||
call_count = [0]
|
||||
@@ -457,51 +450,51 @@ class TestRestoreServersFromConfigBatchFailure:
|
||||
|
||||
def test_restore_servers_invalid_slot(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Skip servers with invalid slot number."""
|
||||
# Add domain to map
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# Mock load_servers_config to return config with invalid slot
|
||||
config = {
|
||||
"example.com": {
|
||||
"invalid": {"ip": "10.0.0.1", "http_port": 80}, # Invalid slot
|
||||
"1": {"ip": "10.0.0.2", "http_port": 80}, # Valid slot
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
result = restore_servers_from_config()
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
# Should only restore the valid server
|
||||
assert result == 1
|
||||
# Should only restore the valid server
|
||||
assert result == 1
|
||||
|
||||
def test_restore_servers_invalid_port(self, mock_socket_class, mock_select, patch_config_paths, caplog):
|
||||
"""Skip servers with invalid port."""
|
||||
import logging
|
||||
# Add domain to map
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# Mock load_servers_config to return config with invalid port
|
||||
config = {
|
||||
"example.com": {
|
||||
"1": {"ip": "10.0.0.1", "http_port": "invalid"}, # Invalid port
|
||||
"2": {"ip": "10.0.0.2", "http_port": 80}, # Valid port
|
||||
}
|
||||
}
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with patch("haproxy_mcp.tools.configuration.load_servers_config", return_value=config):
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
|
||||
mock_sock = mock_socket_class(responses={"set server": ""})
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
with caplog.at_level(logging.WARNING, logger="haproxy_mcp"):
|
||||
from haproxy_mcp.tools.configuration import restore_servers_from_config
|
||||
result = restore_servers_from_config()
|
||||
|
||||
result = restore_servers_from_config()
|
||||
|
||||
# Should only restore the valid server
|
||||
assert result == 1
|
||||
# Should only restore the valid server
|
||||
assert result == 1
|
||||
|
||||
|
||||
class TestStartupRestoreFailures:
|
||||
@@ -658,11 +651,12 @@ class TestHaproxyRestoreStateFailures:
|
||||
"""Handle HAProxy error when restoring state."""
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
|
||||
with open(patch_config_paths["servers_file"], "w") as f:
|
||||
json.dump(sample_servers_config, f)
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
f.write("api.example.com pool_2\n")
|
||||
# Add domains and servers to database
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_server_to_config("example.com", 1, "10.0.0.1", 80)
|
||||
add_server_to_config("example.com", 2, "10.0.0.2", 80)
|
||||
add_domain_to_map("api.example.com", "pool_2")
|
||||
add_server_to_config("api.example.com", 1, "10.0.0.10", 8080)
|
||||
|
||||
with patch("haproxy_mcp.tools.configuration.restore_servers_from_config", side_effect=HaproxyError("Connection refused")):
|
||||
from haproxy_mcp.tools.configuration import register_config_tools
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
from haproxy_mcp.file_ops import add_domain_to_map
|
||||
|
||||
|
||||
class TestHaproxyListDomains:
|
||||
@@ -38,9 +39,8 @@ class TestHaproxyListDomains:
|
||||
|
||||
def test_list_domains_with_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains with configured servers."""
|
||||
# Write map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
# Add domain to DB
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -70,10 +70,8 @@ class TestHaproxyListDomains:
|
||||
|
||||
def test_list_domains_exclude_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains excluding wildcards by default."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
@@ -100,10 +98,8 @@ class TestHaproxyListDomains:
|
||||
|
||||
def test_list_domains_include_wildcards(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List domains including wildcards when requested."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
@@ -230,8 +226,7 @@ class TestHaproxyAddDomain:
|
||||
|
||||
def test_add_domain_already_exists(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Reject adding domain that already exists."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
@@ -362,8 +357,7 @@ class TestHaproxyRemoveDomain:
|
||||
|
||||
def test_remove_legacy_domain_rejected(self, patch_config_paths):
|
||||
"""Reject removing legacy (non-pool) domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com legacy_backend\n")
|
||||
add_domain_to_map("example.com", "legacy_backend")
|
||||
|
||||
from haproxy_mcp.tools.domains import register_domain_tools
|
||||
mcp = MagicMock()
|
||||
@@ -385,10 +379,8 @@ class TestHaproxyRemoveDomain:
|
||||
|
||||
def test_remove_domain_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully remove domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
with open(patch_config_paths["wildcards_file"], "w") as f:
|
||||
f.write(".example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
add_domain_to_map(".example.com", "pool_1", is_wildcard=True)
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"del map": "",
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
from haproxy_mcp.file_ops import add_domain_to_map
|
||||
|
||||
|
||||
class TestHaproxyHealth:
|
||||
@@ -80,7 +81,7 @@ class TestHaproxyHealth:
|
||||
|
||||
# Use paths that don't exist
|
||||
with patch("haproxy_mcp.tools.health.MAP_FILE", str(tmp_path / "nonexistent.map")):
|
||||
with patch("haproxy_mcp.tools.health.SERVERS_FILE", str(tmp_path / "nonexistent.json")):
|
||||
with patch("haproxy_mcp.tools.health.DB_FILE", str(tmp_path / "nonexistent.db")):
|
||||
with patch("socket.socket", return_value=mock_sock):
|
||||
from haproxy_mcp.tools.health import register_health_tools
|
||||
mcp = MagicMock()
|
||||
@@ -160,8 +161,7 @@ class TestHaproxyDomainHealth:
|
||||
|
||||
def test_domain_health_healthy(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns healthy when all servers are UP."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -197,8 +197,7 @@ class TestHaproxyDomainHealth:
|
||||
|
||||
def test_domain_health_degraded(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns degraded when some servers are DOWN."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -234,8 +233,7 @@ class TestHaproxyDomainHealth:
|
||||
|
||||
def test_domain_health_down(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns down when all servers are DOWN."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -269,8 +267,7 @@ class TestHaproxyDomainHealth:
|
||||
|
||||
def test_domain_health_no_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Domain health returns no_servers when no servers configured."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from haproxy_mcp.exceptions import HaproxyError
|
||||
from haproxy_mcp.file_ops import add_domain_to_map, load_servers_config
|
||||
|
||||
|
||||
class TestHaproxyListServers:
|
||||
@@ -33,8 +34,7 @@ class TestHaproxyListServers:
|
||||
|
||||
def test_list_servers_empty_backend(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List servers for domain with no servers."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -63,8 +63,7 @@ class TestHaproxyListServers:
|
||||
|
||||
def test_list_servers_with_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""List servers with active servers."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -224,8 +223,7 @@ class TestHaproxyAddServer:
|
||||
|
||||
def test_add_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully add server."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"set server": "",
|
||||
@@ -258,8 +256,7 @@ class TestHaproxyAddServer:
|
||||
|
||||
def test_add_server_auto_slot(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Auto-select slot when slot=0."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -413,8 +410,7 @@ class TestHaproxyAddServers:
|
||||
|
||||
def test_add_servers_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully add multiple servers."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"set server": "",
|
||||
@@ -495,8 +491,7 @@ class TestHaproxyRemoveServer:
|
||||
|
||||
def test_remove_server_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Successfully remove server."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"set server": "",
|
||||
@@ -689,8 +684,7 @@ class TestHaproxyAddServersRollback:
|
||||
|
||||
def test_add_servers_partial_failure_rollback(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Rollback only failed slots when HAProxy error occurs."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# Mock configure_server_slot to fail on second slot
|
||||
call_count = [0]
|
||||
@@ -735,18 +729,16 @@ class TestHaproxyAddServersRollback:
|
||||
assert "slot 2" in result # Failed
|
||||
|
||||
# Verify servers.json only has successfully added server
|
||||
with open(patch_config_paths["servers_file"], "r") as f:
|
||||
config = json.load(f)
|
||||
config = load_servers_config()
|
||||
assert "example.com" in config
|
||||
assert "1" in config["example.com"] # Successfully added stays
|
||||
assert "2" not in config["example.com"] # Failed one was rolled back
|
||||
assert "2" not in config.get("example.com", {}) # Failed one was rolled back
|
||||
|
||||
def test_add_servers_unexpected_error_rollback_only_successful(
|
||||
self, mock_socket_class, mock_select, patch_config_paths
|
||||
):
|
||||
"""Rollback only successfully added servers on unexpected error."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# Track which servers were configured
|
||||
configured_slots = []
|
||||
@@ -793,8 +785,7 @@ class TestHaproxyAddServersRollback:
|
||||
assert "Unexpected system error" in result
|
||||
|
||||
# Verify servers.json is empty (all rolled back)
|
||||
with open(patch_config_paths["servers_file"], "r") as f:
|
||||
config = json.load(f)
|
||||
config = load_servers_config()
|
||||
assert config == {} or "example.com" not in config or config.get("example.com") == {}
|
||||
|
||||
def test_add_servers_rollback_failure_logged(
|
||||
@@ -802,8 +793,7 @@ class TestHaproxyAddServersRollback:
|
||||
):
|
||||
"""Log rollback failures during error recovery."""
|
||||
import logging
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
def mock_configure_server_slot(backend, server_prefix, slot, ip, http_port):
|
||||
if slot == 2:
|
||||
@@ -858,8 +848,7 @@ class TestHaproxyAddServerAutoSlot:
|
||||
|
||||
def test_add_server_auto_slot_all_used(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Auto-select slot fails when all slots are in use."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# Build response with all 10 slots used
|
||||
servers = []
|
||||
@@ -902,8 +891,7 @@ class TestHaproxyAddServerAutoSlot:
|
||||
|
||||
def test_add_server_negative_slot_auto_select(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Negative slot number triggers auto-selection."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -970,8 +958,7 @@ class TestHaproxyWaitDrain:
|
||||
|
||||
def test_wait_drain_success(self, patch_config_paths):
|
||||
"""Successfully wait for connections to drain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# Mock haproxy_cmd to return 0 connections
|
||||
with patch("haproxy_mcp.tools.servers.haproxy_cmd") as mock_cmd:
|
||||
@@ -1000,8 +987,7 @@ class TestHaproxyWaitDrain:
|
||||
|
||||
def test_wait_drain_timeout(self, patch_config_paths):
|
||||
"""Timeout when connections don't drain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
time_values = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] # Simulate time passing
|
||||
time_iter = iter(time_values)
|
||||
@@ -1087,9 +1073,6 @@ class TestHaproxyWaitDrain:
|
||||
def test_wait_drain_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths):
|
||||
"""Error when domain not found in map."""
|
||||
# Empty map file - domain not configured
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("")
|
||||
|
||||
from haproxy_mcp.tools.servers import register_server_tools
|
||||
mcp = MagicMock()
|
||||
registered_tools = {}
|
||||
@@ -1202,8 +1185,7 @@ class TestHaproxySetDomainState:
|
||||
|
||||
def test_set_domain_state_success(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Set all servers of a domain to a state."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([
|
||||
@@ -1283,8 +1265,7 @@ class TestHaproxySetDomainState:
|
||||
|
||||
def test_set_domain_state_no_active_servers(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""No active servers found for domain."""
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("example.com pool_1\n")
|
||||
add_domain_to_map("example.com", "pool_1")
|
||||
|
||||
# All servers have 0.0.0.0 address (not configured)
|
||||
mock_sock = mock_socket_class(responses={
|
||||
@@ -1318,9 +1299,6 @@ class TestHaproxySetDomainState:
|
||||
def test_set_domain_state_domain_not_found(self, mock_socket_class, mock_select, patch_config_paths, response_builder):
|
||||
"""Handle domain not found in map - shows no active servers."""
|
||||
# Empty map file
|
||||
with open(patch_config_paths["map_file"], "w") as f:
|
||||
f.write("")
|
||||
|
||||
# Mock should show no servers for unknown domain's backend
|
||||
mock_sock = mock_socket_class(responses={
|
||||
"show servers state": response_builder.servers_state([]),
|
||||
|
||||
Reference in New Issue
Block a user