"""Unit tests for file_ops module.""" import json import os from unittest.mock import patch import pytest from haproxy_mcp.file_ops import ( atomic_write_file, get_map_contents, save_map_file, get_domain_backend, split_domain_entries, is_legacy_backend, get_legacy_backend_name, get_backend_and_prefix, load_servers_config, save_servers_config, add_server_to_config, remove_server_from_config, remove_domain_from_config, load_certs_config, save_certs_config, add_cert_to_config, remove_cert_from_config, ) class TestAtomicWriteFile: """Tests for atomic_write_file function.""" def test_write_new_file(self, tmp_path): """Write to a new file.""" file_path = str(tmp_path / "test.txt") content = "Hello, World!" atomic_write_file(file_path, content) assert os.path.exists(file_path) with open(file_path) as f: assert f.read() == content def test_overwrite_existing_file(self, tmp_path): """Overwrite an existing file.""" file_path = str(tmp_path / "test.txt") with open(file_path, "w") as f: f.write("Old content") atomic_write_file(file_path, "New content") with open(file_path) as f: assert f.read() == "New content" def test_preserves_directory(self, tmp_path): """Writing does not create intermediate directories.""" file_path = str(tmp_path / "subdir" / "test.txt") with pytest.raises(IOError): atomic_write_file(file_path, "content") def test_unicode_content(self, tmp_path): """Unicode content is properly written.""" file_path = str(tmp_path / "unicode.txt") content = "Hello, \u4e16\u754c!" # "Hello, World!" in Chinese atomic_write_file(file_path, content) with open(file_path, encoding="utf-8") as f: assert f.read() == content def test_multiline_content(self, tmp_path): """Multi-line content is properly written.""" file_path = str(tmp_path / "multiline.txt") content = "line1\nline2\nline3" atomic_write_file(file_path, content) with open(file_path) as f: assert f.read() == content class TestGetMapContents: """Tests for get_map_contents function.""" def test_read_map_file(self, patch_config_paths): """Read entries from map file.""" # Write test content to map file with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") f.write("api.example.com pool_2\n") entries = get_map_contents() assert ("example.com", "pool_1") in entries assert ("api.example.com", "pool_2") in entries def test_read_both_map_files(self, patch_config_paths): """Read entries from both domains.map and wildcards.map.""" with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") with open(patch_config_paths["wildcards_file"], "w") as f: f.write(".example.com pool_1\n") entries = get_map_contents() assert ("example.com", "pool_1") in entries assert (".example.com", "pool_1") in entries def test_skip_comments(self, patch_config_paths): """Comments are skipped.""" with open(patch_config_paths["map_file"], "w") as f: f.write("# This is a comment\n") f.write("example.com pool_1\n") f.write("# Another comment\n") entries = get_map_contents() assert len(entries) == 1 assert entries[0] == ("example.com", "pool_1") def test_skip_empty_lines(self, patch_config_paths): """Empty lines are skipped.""" with open(patch_config_paths["map_file"], "w") as f: f.write("\n") f.write("example.com pool_1\n") f.write("\n") f.write("api.example.com pool_2\n") entries = get_map_contents() assert len(entries) == 2 def test_file_not_found(self, patch_config_paths): """Missing file returns empty list.""" os.unlink(patch_config_paths["map_file"]) os.unlink(patch_config_paths["wildcards_file"]) entries = get_map_contents() assert entries == [] class TestSplitDomainEntries: """Tests for split_domain_entries function.""" def test_split_entries(self): """Split entries into exact and wildcard.""" entries = [ ("example.com", "pool_1"), (".example.com", "pool_1"), ("api.example.com", "pool_2"), (".api.example.com", "pool_2"), ] exact, wildcards = split_domain_entries(entries) assert len(exact) == 2 assert len(wildcards) == 2 assert ("example.com", "pool_1") in exact assert (".example.com", "pool_1") in wildcards def test_empty_entries(self): """Empty entries returns empty lists.""" exact, wildcards = split_domain_entries([]) assert exact == [] assert wildcards == [] def test_all_exact(self): """All exact entries.""" entries = [ ("example.com", "pool_1"), ("api.example.com", "pool_2"), ] exact, wildcards = split_domain_entries(entries) assert len(exact) == 2 assert len(wildcards) == 0 class TestSaveMapFile: """Tests for save_map_file function.""" def test_save_entries(self, patch_config_paths): """Save entries to separate map files.""" entries = [ ("example.com", "pool_1"), (".example.com", "pool_1"), ] save_map_file(entries) # Check exact domains file with open(patch_config_paths["map_file"]) as f: content = f.read() assert "example.com pool_1" in content # Check wildcards file with open(patch_config_paths["wildcards_file"]) as f: content = f.read() assert ".example.com pool_1" in content def test_sorted_output(self, patch_config_paths): """Entries are sorted in output.""" entries = [ ("z.example.com", "pool_3"), ("a.example.com", "pool_1"), ("m.example.com", "pool_2"), ] save_map_file(entries) with open(patch_config_paths["map_file"]) as f: lines = [l.strip() for l in f if l.strip() and not l.startswith("#")] assert lines[0] == "a.example.com pool_1" assert lines[1] == "m.example.com pool_2" assert lines[2] == "z.example.com pool_3" class TestGetDomainBackend: """Tests for get_domain_backend function.""" def test_find_existing_domain(self, patch_config_paths): """Find backend for existing domain.""" with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") backend = get_domain_backend("example.com") assert backend == "pool_1" def test_domain_not_found(self, patch_config_paths): """Non-existent domain returns None.""" with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_1\n") backend = get_domain_backend("other.com") assert backend is None class TestIsLegacyBackend: """Tests for is_legacy_backend function.""" def test_pool_backend(self): """Pool backend is not legacy.""" assert is_legacy_backend("pool_1") is False assert is_legacy_backend("pool_100") is False def test_legacy_backend(self): """Non-pool backend is legacy.""" assert is_legacy_backend("api_example_com_backend") is True assert is_legacy_backend("static_backend") is True class TestGetLegacyBackendName: """Tests for get_legacy_backend_name function.""" def test_convert_domain(self): """Convert domain to legacy backend name.""" result = get_legacy_backend_name("api.example.com") assert result == "api_example_com_backend" class TestGetBackendAndPrefix: """Tests for get_backend_and_prefix function.""" def test_pool_backend(self, patch_config_paths): """Pool backend returns pool-based prefix.""" with open(patch_config_paths["map_file"], "w") as f: f.write("example.com pool_5\n") backend, prefix = get_backend_and_prefix("example.com") assert backend == "pool_5" assert prefix == "pool_5" def test_unknown_domain_uses_legacy(self, patch_config_paths): """Unknown domain uses legacy backend naming.""" backend, prefix = get_backend_and_prefix("unknown.com") assert backend == "unknown_com_backend" assert prefix == "unknown_com" class TestLoadServersConfig: """Tests for load_servers_config function.""" def test_load_existing_config(self, patch_config_paths, sample_servers_config): """Load existing config file.""" with open(patch_config_paths["servers_file"], "w") as f: json.dump(sample_servers_config, f) config = load_servers_config() assert "example.com" in config assert config["example.com"]["1"]["ip"] == "10.0.0.1" def test_file_not_found(self, patch_config_paths): """Missing file returns empty dict.""" os.unlink(patch_config_paths["servers_file"]) config = load_servers_config() assert config == {} def test_invalid_json(self, patch_config_paths): """Invalid JSON returns empty dict.""" with open(patch_config_paths["servers_file"], "w") as f: f.write("not valid json {{{") config = load_servers_config() assert config == {} class TestSaveServersConfig: """Tests for save_servers_config function.""" def test_save_config(self, patch_config_paths): """Save config to file.""" config = {"example.com": {"1": {"ip": "10.0.0.1", "http_port": 80}}} save_servers_config(config) with open(patch_config_paths["servers_file"]) as f: loaded = json.load(f) assert loaded == config class TestAddServerToConfig: """Tests for add_server_to_config function.""" def test_add_to_empty_config(self, patch_config_paths): """Add server to empty config.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) config = load_servers_config() assert config["example.com"]["1"]["ip"] == "10.0.0.1" assert config["example.com"]["1"]["http_port"] == 80 def test_add_to_existing_domain(self, patch_config_paths): """Add server to domain with existing servers.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 2, "10.0.0.2", 80) config = load_servers_config() assert "1" in config["example.com"] assert "2" in config["example.com"] def test_overwrite_existing_slot(self, patch_config_paths): """Overwrite existing slot.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 1, "10.0.0.99", 8080) config = load_servers_config() assert config["example.com"]["1"]["ip"] == "10.0.0.99" assert config["example.com"]["1"]["http_port"] == 8080 class TestRemoveServerFromConfig: """Tests for remove_server_from_config function.""" def test_remove_existing_server(self, patch_config_paths): """Remove existing server.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("example.com", 2, "10.0.0.2", 80) remove_server_from_config("example.com", 1) config = load_servers_config() assert "1" not in config["example.com"] assert "2" in config["example.com"] def test_remove_last_server_removes_domain(self, patch_config_paths): """Removing last server removes domain entry.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_server_from_config("example.com", 1) config = load_servers_config() assert "example.com" not in config def test_remove_nonexistent_server(self, patch_config_paths): """Removing non-existent server is a no-op.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_server_from_config("example.com", 99) # Non-existent slot config = load_servers_config() assert "1" in config["example.com"] class TestRemoveDomainFromConfig: """Tests for remove_domain_from_config function.""" def test_remove_existing_domain(self, patch_config_paths): """Remove existing domain.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) add_server_to_config("other.com", 1, "10.0.0.2", 80) remove_domain_from_config("example.com") config = load_servers_config() assert "example.com" not in config assert "other.com" in config def test_remove_nonexistent_domain(self, patch_config_paths): """Removing non-existent domain is a no-op.""" add_server_to_config("example.com", 1, "10.0.0.1", 80) remove_domain_from_config("other.com") # Non-existent config = load_servers_config() assert "example.com" in config class TestLoadCertsConfig: """Tests for load_certs_config function.""" def test_load_existing_config(self, patch_config_paths): """Load existing certs config.""" with open(patch_config_paths["certs_file"], "w") as f: json.dump({"domains": ["example.com", "other.com"]}, f) domains = load_certs_config() assert "example.com" in domains assert "other.com" in domains def test_file_not_found(self, patch_config_paths): """Missing file returns empty list.""" os.unlink(patch_config_paths["certs_file"]) domains = load_certs_config() assert domains == [] class TestSaveCertsConfig: """Tests for save_certs_config function.""" def test_save_domains(self, patch_config_paths): """Save domains to certs config.""" save_certs_config(["z.com", "a.com"]) with open(patch_config_paths["certs_file"]) as f: data = json.load(f) # Should be sorted assert data["domains"] == ["a.com", "z.com"] class TestAddCertToConfig: """Tests for add_cert_to_config function.""" def test_add_new_cert(self, patch_config_paths): """Add new certificate domain.""" add_cert_to_config("example.com") domains = load_certs_config() assert "example.com" in domains def test_add_duplicate_cert(self, patch_config_paths): """Adding duplicate cert is a no-op.""" add_cert_to_config("example.com") add_cert_to_config("example.com") domains = load_certs_config() assert domains.count("example.com") == 1 class TestRemoveCertFromConfig: """Tests for remove_cert_from_config function.""" def test_remove_existing_cert(self, patch_config_paths): """Remove existing certificate domain.""" add_cert_to_config("example.com") add_cert_to_config("other.com") remove_cert_from_config("example.com") domains = load_certs_config() assert "example.com" not in domains assert "other.com" in domains def test_remove_nonexistent_cert(self, patch_config_paths): """Removing non-existent cert is a no-op.""" add_cert_to_config("example.com") remove_cert_from_config("other.com") # Non-existent domains = load_certs_config() assert "example.com" in domains