"""Unit tests for file_ops module (SQLite-backed).""" import os from unittest.mock import patch import pytest from haproxy_mcp.file_ops import ( atomic_write_file, get_map_contents, save_map_file, get_domain_backend, split_domain_entries, is_legacy_backend, get_legacy_backend_name, get_backend_and_prefix, load_servers_config, add_server_to_config, remove_server_from_config, remove_domain_from_config, load_certs_config, add_cert_to_config, remove_cert_from_config, add_domain_to_map, remove_domain_from_map, find_available_pool, add_shared_domain_to_config, get_shared_domain, is_shared_domain, get_domains_sharing_pool, ) class TestAtomicWriteFile: """Tests for atomic_write_file function.""" def test_write_new_file(self, tmp_path): """Write to a new file.""" file_path = str(tmp_path / "test.txt") content = "Hello, World!" atomic_write_file(file_path, content) assert os.path.exists(file_path) with open(file_path) as f: assert f.read() == content def test_overwrite_existing_file(self, tmp_path): """Overwrite an existing file.""" file_path = str(tmp_path / "test.txt") with open(file_path, "w") as f: f.write("Old content") atomic_write_file(file_path, "New content") with open(file_path) as f: assert f.read() == "New content" def test_preserves_directory(self, tmp_path): """Writing does not create intermediate directories.""" file_path = str(tmp_path / "subdir" / "test.txt") with pytest.raises(IOError): atomic_write_file(file_path, "content") def test_unicode_content(self, tmp_path): """Unicode content is properly written.""" file_path = str(tmp_path / "unicode.txt") content = "Hello, \u4e16\u754c!" atomic_write_file(file_path, content) with open(file_path, encoding="utf-8") as f: assert f.read() == content def test_multiline_content(self, tmp_path): """Multi-line content is properly written.""" file_path = str(tmp_path / "multiline.txt") content = "line1\nline2\nline3" atomic_write_file(file_path, content) with open(file_path) as f: assert f.read() == content class TestGetMapContents: """Tests for get_map_contents function (SQLite-backed).""" def test_empty_db(self, patch_config_paths): """Empty database returns empty list.""" entries = get_map_contents() assert entries == [] def test_read_domains(self, patch_config_paths): """Read entries from database.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map("api.example.com", "pool_2") entries = get_map_contents() assert ("example.com", "pool_1") in entries assert ("api.example.com", "pool_2") in entries def test_read_with_wildcards(self, patch_config_paths): """Read entries including wildcards.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) entries = get_map_contents() assert ("example.com", "pool_1") in entries assert (".example.com", "pool_1") in entries class TestSplitDomainEntries: """Tests for split_domain_entries function.""" def test_split_entries(self): """Split entries into exact and wildcard.""" entries = [ ("example.com", "pool_1"), (".example.com", "pool_1"), ("api.example.com", "pool_2"), (".api.example.com", "pool_2"), ] exact, wildcards = split_domain_entries(entries) assert len(exact) == 2 assert len(wildcards) == 2 assert ("example.com", "pool_1") in exact assert (".example.com", "pool_1") in wildcards def test_empty_entries(self): """Empty entries returns empty lists.""" exact, wildcards = split_domain_entries([]) assert exact == [] assert wildcards == [] def test_all_exact(self): """All exact entries.""" entries = [ ("example.com", "pool_1"), ("api.example.com", "pool_2"), ] exact, wildcards = split_domain_entries(entries) assert len(exact) == 2 assert len(wildcards) == 0 class TestSaveMapFile: """Tests for save_map_file function (syncs from DB to map files).""" def test_save_entries(self, patch_config_paths): """Save entries to separate map files.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) save_map_file([]) # Entries param ignored, syncs from DB with open(patch_config_paths["map_file"]) as f: content = f.read() assert "example.com pool_1" in content with open(patch_config_paths["wildcards_file"]) as f: content = f.read() assert ".example.com pool_1" in content def test_sorted_output(self, patch_config_paths): """Entries are sorted in output.""" add_domain_to_map("z.example.com", "pool_3") add_domain_to_map("a.example.com", "pool_1") add_domain_to_map("m.example.com", "pool_2") save_map_file([]) with open(patch_config_paths["map_file"]) as f: lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] assert lines[0] == "a.example.com pool_1" assert lines[1] == "m.example.com pool_2" assert lines[2] == "z.example.com pool_3" class TestGetDomainBackend: """Tests for get_domain_backend function (SQLite-backed).""" def test_find_existing_domain(self, patch_config_paths): """Find backend for existing domain.""" add_domain_to_map("example.com", "pool_1") backend = get_domain_backend("example.com") assert backend == "pool_1" def test_domain_not_found(self, patch_config_paths): """Non-existent domain returns None.""" add_domain_to_map("example.com", "pool_1") backend = get_domain_backend("other.com") assert backend is None class TestIsLegacyBackend: """Tests for is_legacy_backend function.""" def test_pool_backend(self): """Pool backend is not legacy.""" assert is_legacy_backend("pool_1") is False assert is_legacy_backend("pool_100") is False def test_legacy_backend(self): """Non-pool backend is legacy.""" assert is_legacy_backend("api_example_com_backend") is True assert is_legacy_backend("static_backend") is True class TestGetLegacyBackendName: """Tests for get_legacy_backend_name function.""" def test_convert_domain(self): """Convert domain to legacy backend name.""" result = get_legacy_backend_name("api.example.com") assert result == "api_example_com_backend" class TestGetBackendAndPrefix: """Tests for get_backend_and_prefix function.""" def test_pool_backend(self, patch_config_paths): """Pool backend returns pool-based prefix.""" add_domain_to_map("example.com", "pool_5") backend, prefix = get_backend_and_prefix("example.com") assert backend == "pool_5" assert prefix == "pool_5" def test_unknown_domain_uses_legacy(self, patch_config_paths): """Unknown domain uses legacy backend naming.""" backend, prefix = get_backend_and_prefix("unknown.com") assert backend == "unknown_com_backend" assert prefix == "unknown_com" class TestLoadServersConfig: """Tests for load_servers_config function (SQLite-backed).""" def test_load_empty_config(self, patch_config_paths): """Empty database returns empty dict.""" config = load_servers_config() assert config == {} def test_load_with_servers(self, patch_config_paths): """Load config with server entries.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 2, "10.0.0.2", 80) config = load_servers_config() assert "example.com" in config assert config["example.com"]["1"]["ip"] == "10.0.0.1" assert config["example.com"]["2"]["ip"] == "10.0.0.2" def test_load_with_shared_domain(self, patch_config_paths): """Load config with shared domain reference.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map("www.example.com", "pool_1") add_shared_domain_to_config("www.example.com", "example.com") config = load_servers_config() assert config["www.example.com"]["_shares"] == "example.com" class TestAddServerToConfig: """Tests for add_server_to_config function.""" def test_add_to_empty_config(self, patch_config_paths): """Add server to empty config.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) config = load_servers_config() assert config["example.com"]["1"]["ip"] == "10.0.0.1" assert config["example.com"]["1"]["http_port"] == 80 def test_add_to_existing_domain(self, patch_config_paths): """Add server to domain with existing servers.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 2, "10.0.0.2", 80) config = load_servers_config() assert "1" in config["example.com"] assert "2" in config["example.com"] def test_overwrite_existing_slot(self, patch_config_paths): """Overwrite existing slot.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 1, "10.0.0.99", 8080) config = load_servers_config() assert config["example.com"]["1"]["ip"] == "10.0.0.99" assert config["example.com"]["1"]["http_port"] == 8080 class TestRemoveServerFromConfig: """Tests for remove_server_from_config function.""" def test_remove_existing_server(self, patch_config_paths): """Remove existing server.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 2, "10.0.0.2", 80) remove_server_from_config("example.com", 1) config = load_servers_config() assert "1" not in config.get("example.com", {}) assert "2" in config["example.com"] def test_remove_last_server_removes_domain(self, patch_config_paths): """Removing last server removes domain entry from servers.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_server_from_config("example.com", 1) config = load_servers_config() # Domain may or may not exist (no servers = no entry) assert config.get("example.com", {}).get("1") is None def test_remove_nonexistent_server(self, patch_config_paths): """Removing non-existent server is a no-op.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_server_from_config("example.com", 99) # Non-existent slot config = load_servers_config() assert "1" in config["example.com"] class TestRemoveDomainFromConfig: """Tests for remove_domain_from_config function.""" def test_remove_existing_domain(self, patch_config_paths): """Remove existing domain's servers.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("other.com", 1, "10.0.0.2", 80) remove_domain_from_config("example.com") config = load_servers_config() assert config.get("example.com", {}).get("1") is None assert "other.com" in config def test_remove_nonexistent_domain(self, patch_config_paths): """Removing non-existent domain is a no-op.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_domain_from_config("other.com") # Non-existent config = load_servers_config() assert "example.com" in config class TestLoadCertsConfig: """Tests for load_certs_config function (SQLite-backed).""" def test_load_empty(self, patch_config_paths): """Empty database returns empty list.""" domains = load_certs_config() assert domains == [] def test_load_with_certs(self, patch_config_paths): """Load certs from database.""" add_cert_to_config("example.com") add_cert_to_config("other.com") domains = load_certs_config() assert "example.com" in domains assert "other.com" in domains class TestAddCertToConfig: """Tests for add_cert_to_config function.""" def test_add_new_cert(self, patch_config_paths): """Add new certificate domain.""" add_cert_to_config("example.com") domains = load_certs_config() assert "example.com" in domains def test_add_duplicate_cert(self, patch_config_paths): """Adding duplicate cert is a no-op.""" add_cert_to_config("example.com") add_cert_to_config("example.com") domains = load_certs_config() assert domains.count("example.com") == 1 class TestRemoveCertFromConfig: """Tests for remove_cert_from_config function.""" def test_remove_existing_cert(self, patch_config_paths): """Remove existing certificate domain.""" add_cert_to_config("example.com") add_cert_to_config("other.com") remove_cert_from_config("example.com") domains = load_certs_config() assert "example.com" not in domains assert "other.com" in domains def test_remove_nonexistent_cert(self, patch_config_paths): """Removing non-existent cert is a no-op.""" add_cert_to_config("example.com") remove_cert_from_config("other.com") # Non-existent domains = load_certs_config() assert "example.com" in domains class TestAddDomainToMap: """Tests for add_domain_to_map function.""" def test_add_domain(self, patch_config_paths): """Add a domain and verify map files are synced.""" add_domain_to_map("example.com", "pool_1") assert get_domain_backend("example.com") == "pool_1" with open(patch_config_paths["map_file"]) as f: assert "example.com pool_1" in f.read() def test_add_wildcard(self, patch_config_paths): """Add a wildcard domain.""" add_domain_to_map(".example.com", "pool_1", is_wildcard=True) entries = get_map_contents() assert (".example.com", "pool_1") in entries class TestRemoveDomainFromMap: """Tests for remove_domain_from_map function.""" def test_remove_domain(self, patch_config_paths): """Remove a domain and its wildcard.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) remove_domain_from_map("example.com") assert get_domain_backend("example.com") is None entries = get_map_contents() assert (".example.com", "pool_1") not in entries class TestFindAvailablePool: """Tests for find_available_pool function.""" def test_first_pool_available(self, patch_config_paths): """When no domains exist, pool_1 is returned.""" pool = find_available_pool() assert pool == "pool_1" def test_skip_used_pools(self, patch_config_paths): """Used pools are skipped.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map("other.com", "pool_2") pool = find_available_pool() assert pool == "pool_3" class TestSharedDomains: """Tests for shared domain functions.""" def test_get_shared_domain(self, patch_config_paths): """Get parent domain for shared domain.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map("www.example.com", "pool_1") add_shared_domain_to_config("www.example.com", "example.com") assert get_shared_domain("www.example.com") == "example.com" def test_is_shared_domain(self, patch_config_paths): """Check if domain is shared.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map("www.example.com", "pool_1") add_shared_domain_to_config("www.example.com", "example.com") assert is_shared_domain("www.example.com") is True assert is_shared_domain("example.com") is False def test_get_domains_sharing_pool(self, patch_config_paths): """Get all domains using a pool.""" add_domain_to_map("example.com", "pool_1") add_domain_to_map("www.example.com", "pool_1") add_domain_to_map(".example.com", "pool_1", is_wildcard=True) domains = get_domains_sharing_pool("pool_1") assert "example.com" in domains assert "www.example.com" in domains assert ".example.com" not in domains # Wildcards excluded