#!/usr/bin/env python3 """ XDP Defense - Common Utilities Merged from xdp-defense common utilities Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats """ import subprocess import json import os import tempfile import struct import socket import ipaddress import logging import logging.handlers # ==================== Logging ==================== _syslog_handler = None try: _syslog_handler = logging.handlers.SysLogHandler(address='/dev/log') _syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s')) except Exception: pass audit_log = logging.getLogger('xdp-defense') audit_log.setLevel(logging.INFO) if _syslog_handler: audit_log.addHandler(_syslog_handler) # ==================== BPF Map Helpers ==================== _map_cache = {} # {map_name: (map_id, timestamp)} _MAP_CACHE_TTL = 5.0 # seconds def get_map_id(map_name): """Get BPF map ID by name (cached with 5s TTL).""" import time as _time now = _time.monotonic() cached = _map_cache.get(map_name) if cached and (now - cached[1]) < _MAP_CACHE_TTL: return cached[0] result = subprocess.run( ["bpftool", "map", "show", "-j"], capture_output=True, text=True ) if result.returncode != 0: return None try: maps = json.loads(result.stdout) # Update cache for all maps found for m in maps: name = m.get("name") if name: _map_cache[name] = (m.get("id"), now) except Exception: pass cached = _map_cache.get(map_name) return cached[0] if cached else None # ==================== CIDR / IPv4 Helpers (from blocker) ==================== def validate_cidr(cidr): """Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError.""" if '/' in cidr: ip, prefix_str = cidr.split('/', 1) prefix = int(prefix_str) else: ip = cidr prefix = 32 if prefix < 0 or prefix > 32: raise ValueError(f"Invalid prefix: /{prefix}") parts = [int(x) for x in ip.split('.')] if len(parts) != 4 or any(p < 0 or p > 255 for p in parts): raise ValueError(f"Invalid IP: {ip}") return ip, prefix, parts def cidr_to_key(cidr): """Convert CIDR to LPM trie key hex string.""" _, prefix, parts = validate_cidr(cidr) key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}" return key_hex def is_ipv6(cidr): """Check if a CIDR string is IPv6.""" return ':' in cidr def validate_cidr_v6(cidr): """Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError.""" try: net = ipaddress.IPv6Network(cidr, strict=False) return net, net.prefixlen except (ipaddress.AddressValueError, ValueError) as e: raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e def cidr_to_key_v6(cidr): """Convert IPv6 CIDR to LPM trie key hex string. Key format: prefixlen (4 bytes LE) + addr (16 bytes) """ net, prefix = validate_cidr_v6(cidr) addr_bytes = net.network_address.packed prefix_hex = f"{prefix:02x} 00 00 00" addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes) return f"{prefix_hex} {addr_hex}" def classify_cidrs(cidrs): """Split a list of CIDRs into v4 and v6 lists.""" v4 = [] v6 = [] for c in cidrs: if is_ipv6(c): v6.append(c) else: v4.append(c) return v4, v6 def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False): """Execute batch map operations with progress reporting. Args: map_id: BPF map ID cidrs: list of CIDR strings operation: "update" or "delete" value_hex: hex value for update operations batch_size: number of operations per batch Returns: Number of processed entries """ total = len(cidrs) processed = 0 written = 0 for i in range(0, total, batch_size): batch = cidrs[i:i + batch_size] batch_file = None batch_written = 0 try: with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f: batch_file = f.name for cidr in batch: try: key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr) if operation == "update": f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n") else: f.write(f"map delete id {map_id} key hex {key_hex}\n") batch_written += 1 except (ValueError, Exception): continue subprocess.run( ["bpftool", "batch", "file", batch_file], capture_output=True, text=True ) finally: if batch_file and os.path.exists(batch_file): os.unlink(batch_file) processed += len(batch) written += batch_written pct = processed * 100 // total if total > 0 else 100 print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True) print() proto = "v6" if ipv6 else "v4" audit_log.info(f"batch {operation} {proto}: {written} entries on map {map_id}") return written # ==================== Whitelist Check (for daemon) ==================== def is_whitelisted(ip_str): """Check if an IP is in the BPF whitelist maps. Returns True if whitelisted, False otherwise. """ try: addr = ipaddress.ip_address(ip_str) except ValueError: return False if isinstance(addr, ipaddress.IPv6Address): map_name = "whitelist_v6" key_hex = cidr_to_key_v6(f"{ip_str}/128") else: map_name = "whitelist_v4" key_hex = cidr_to_key(f"{ip_str}/32") map_id = get_map_id(map_name) if map_id is None: return False result = subprocess.run( ["bpftool", "map", "lookup", "id", str(map_id), "key", "hex"] + key_hex.split(), capture_output=True, text=True ) return result.returncode == 0 # ==================== IP Encoding Helpers (from ddos) ==================== def ip_to_hex_key(ip_str): """Convert IP address to hex key for LRU_HASH maps. IPv4: 4 bytes (network byte order) IPv6: 16 bytes (network byte order) Returns space-separated hex string. """ try: addr = ipaddress.ip_address(ip_str) except ValueError: raise ValueError(f"Invalid IP address: {ip_str}") packed = addr.packed return ' '.join(f"{b:02x}" for b in packed) def hex_key_to_ip(hex_bytes, version=4): """Convert hex byte list back to IP string. hex_bytes: list of hex strings like ['0a', '00', '00', '01'] """ raw = bytes(int(h, 16) for h in hex_bytes) if version == 6: return str(ipaddress.IPv6Address(raw)) return str(ipaddress.IPv4Address(raw)) # ==================== DDoS Stats / Features (from ddos) ==================== def read_percpu_stats(map_name="global_stats", num_entries=5): """Read PERCPU_ARRAY map and return summed values per key. Returns dict: {key_index: summed_value} """ map_id = get_map_id(map_name) if map_id is None: return {} result = subprocess.run( ["bpftool", "map", "dump", "id", str(map_id), "-j"], capture_output=True, text=True ) if result.returncode != 0: return {} try: data = json.loads(result.stdout) except Exception: return {} stats = {} for entry in data: fmt = entry.get("formatted", entry) key = fmt.get("key") if isinstance(key, list): key = int(key[0], 16) if isinstance(key[0], str) else key[0] elif isinstance(key, str): key = int(key, 16) values = fmt.get("values", []) total = 0 for v in values: val = v.get("value", v.get("val", 0)) if isinstance(val, str): val = int(val, 0) elif isinstance(val, list): val = int.from_bytes( bytes(int(x, 16) for x in val), byteorder='little' ) total += val stats[key] = total return stats def read_percpu_features(): """Read traffic_features PERCPU_ARRAY and return aggregated struct. Returns dict with field names and summed values. """ map_id = get_map_id("traffic_feature") if map_id is None: return {} result = subprocess.run( ["bpftool", "map", "dump", "id", str(map_id), "-j"], capture_output=True, text=True ) if result.returncode != 0: return {} try: data = json.loads(result.stdout) except Exception: return {} field_names = [ "total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count", "udp_count", "icmp_count", "other_proto_count", "unique_ips_approx", "small_pkt_count", "large_pkt_count" ] num_fields = len(field_names) aggregated = {f: 0 for f in field_names} for entry in data: fmt = entry.get("formatted", entry) values = fmt.get("values", []) for cpu_val in values: val = cpu_val.get("value", []) if isinstance(val, list) and len(val) >= num_fields * 8: raw = bytes(int(x, 16) for x in val[:num_fields * 8]) for i, name in enumerate(field_names): v = struct.unpack_from('0 = temporary block """ addr = ipaddress.ip_address(ip_str) is_v6 = isinstance(addr, ipaddress.IPv6Address) map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4" map_id = get_map_id(map_name) if map_id is None: raise RuntimeError(f"{map_name} map not found") key_hex = ip_to_hex_key(ip_str) # Use CLOCK_BOOTTIME (matches BPF ktime_get_ns) with open('/proc/uptime', 'r') as f: uptime_sec = float(f.read().split()[0]) now_ns = int(uptime_sec * 1_000_000_000) if duration_sec > 0: expire_ns = now_ns + (duration_sec * 1_000_000_000) else: expire_ns = 0 raw = struct.pack(' 0 else "permanent" audit_log.info(f"blocked {ip_str} ({duration_str})") def unblock_ip(ip_str): """Remove IP from blocked_ips map.""" addr = ipaddress.ip_address(ip_str) is_v6 = isinstance(addr, ipaddress.IPv6Address) map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4" map_id = get_map_id(map_name) if map_id is None: raise RuntimeError(f"{map_name} map not found") key_hex = ip_to_hex_key(ip_str) result = subprocess.run( ["bpftool", "map", "delete", "id", str(map_id), "key", "hex"] + key_hex.split(), capture_output=True, text=True ) if result.returncode != 0: raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}") audit_log.info(f"unblocked {ip_str}") # ==================== Dump Helpers (from ddos) ==================== def dump_rate_counters(map_name="rate_counter_v4", top_n=10): """Dump top-N IPs by packet count from rate counter map. Returns list of (ip_str, packets, bytes, last_seen_ns). """ map_id = get_map_id(map_name) if map_id is None: return [] is_v6 = "v6" in map_name result = subprocess.run( ["bpftool", "map", "dump", "id", str(map_id), "-j"], capture_output=True, text=True ) if result.returncode != 0: return [] try: data = json.loads(result.stdout) except Exception: return [] entries = [] for entry in data: fmt = entry.get("formatted", entry) key = fmt.get("key", []) val = fmt.get("value", {}) try: if isinstance(key, list): ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4) elif isinstance(key, dict): if is_v6: addr8 = key.get('in6_u', {}).get('u6_addr8', []) if addr8: raw = bytes(addr8) ip_str = str(ipaddress.IPv6Address(raw)) else: continue else: ip_str = socket.inet_ntoa(struct.pack('= 24: raw = bytes(int(x, 16) for x in val[:24]) pkts = struct.unpack_from('= 24: raw = bytes(int(x, 16) for x in val[:24]) expire = struct.unpack_from('