Files
xdp-defense/lib/xdp_common.py
kaffa 667c6eac81 Fix 12 code review issues (4 MEDIUM + 8 LOW)
MEDIUM:
- M1: Whitelist direct IP/CIDR additions now persist to direct.txt
- M2: get_map_id() uses 5s TTL cache (single bpftool call for all maps)
- M3: IPv6 extension header parsing in xdp_ddos.c (hop-by-hop/routing/frag/dst)
- M4: Shell injection prevention - sanitize_input() + sys.argv[] for all Python calls

LOW:
- L1: Remove redundant self.running (uses _stop_event only)
- L2: Remove unused config values (rate_limit_after, cooldown_multiplier, retrain_interval)
- L3: Thread poll intervals reloaded on SIGHUP
- L4: batch_map_operation counts only successfully written entries
- L5: Clarify unique_ips_approx comment (per-packet counter)
- L6: Document LRU_HASH multi-CPU race condition as acceptable
- L7: Download Cloudflare IPv6 ranges in whitelist preset
- L8: Fix file handle leak in xdp_country.py list_countries()

Also: SIGHUP now preserves EWMA/violation state, daemon skips whitelisted
IPs in EWMA/AI escalation, deep copy for default config, IHL validation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 09:23:41 +09:00

591 lines
18 KiB
Python

#!/usr/bin/env python3
"""
XDP Defense - Common Utilities
Merged from xdp-blocker/xdp_common.py and xdp-ddos/xdp_ddos_common.py
Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats
"""
import subprocess
import json
import os
import tempfile
import struct
import socket
import ipaddress
import logging
import logging.handlers
# ==================== Logging ====================
_syslog_handler = None
try:
_syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
_syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s'))
except Exception:
pass
audit_log = logging.getLogger('xdp-defense')
audit_log.setLevel(logging.INFO)
if _syslog_handler:
audit_log.addHandler(_syslog_handler)
# ==================== BPF Map Helpers ====================
_map_cache = {} # {map_name: (map_id, timestamp)}
_MAP_CACHE_TTL = 5.0 # seconds
def get_map_id(map_name):
"""Get BPF map ID by name (cached with 5s TTL)."""
import time as _time
now = _time.monotonic()
cached = _map_cache.get(map_name)
if cached and (now - cached[1]) < _MAP_CACHE_TTL:
return cached[0]
result = subprocess.run(
["bpftool", "map", "show", "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return None
try:
maps = json.loads(result.stdout)
# Update cache for all maps found
for m in maps:
name = m.get("name")
if name:
_map_cache[name] = (m.get("id"), now)
except Exception:
pass
cached = _map_cache.get(map_name)
return cached[0] if cached else None
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
def validate_cidr(cidr):
"""Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError."""
if '/' in cidr:
ip, prefix_str = cidr.split('/', 1)
prefix = int(prefix_str)
else:
ip = cidr
prefix = 32
if prefix < 0 or prefix > 32:
raise ValueError(f"Invalid prefix: /{prefix}")
parts = [int(x) for x in ip.split('.')]
if len(parts) != 4 or any(p < 0 or p > 255 for p in parts):
raise ValueError(f"Invalid IP: {ip}")
return ip, prefix, parts
def cidr_to_key(cidr):
"""Convert CIDR to LPM trie key hex string."""
_, prefix, parts = validate_cidr(cidr)
key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}"
return key_hex
def is_ipv6(cidr):
"""Check if a CIDR string is IPv6."""
return ':' in cidr
def validate_cidr_v6(cidr):
"""Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError."""
try:
net = ipaddress.IPv6Network(cidr, strict=False)
return net, net.prefixlen
except (ipaddress.AddressValueError, ValueError) as e:
raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e
def cidr_to_key_v6(cidr):
"""Convert IPv6 CIDR to LPM trie key hex string.
Key format: prefixlen (4 bytes LE) + addr (16 bytes)
"""
net, prefix = validate_cidr_v6(cidr)
addr_bytes = net.network_address.packed
prefix_hex = f"{prefix:02x} 00 00 00"
addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes)
return f"{prefix_hex} {addr_hex}"
def classify_cidrs(cidrs):
"""Split a list of CIDRs into v4 and v6 lists."""
v4 = []
v6 = []
for c in cidrs:
if is_ipv6(c):
v6.append(c)
else:
v4.append(c)
return v4, v6
def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False):
"""Execute batch map operations with progress reporting.
Args:
map_id: BPF map ID
cidrs: list of CIDR strings
operation: "update" or "delete"
value_hex: hex value for update operations
batch_size: number of operations per batch
Returns:
Number of processed entries
"""
total = len(cidrs)
processed = 0
written = 0
for i in range(0, total, batch_size):
batch = cidrs[i:i + batch_size]
batch_file = None
batch_written = 0
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
batch_file = f.name
for cidr in batch:
try:
key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr)
if operation == "update":
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
else:
f.write(f"map delete id {map_id} key hex {key_hex}\n")
batch_written += 1
except (ValueError, Exception):
continue
subprocess.run(
["bpftool", "batch", "file", batch_file],
capture_output=True, text=True
)
finally:
if batch_file and os.path.exists(batch_file):
os.unlink(batch_file)
processed += len(batch)
written += batch_written
pct = processed * 100 // total if total > 0 else 100
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
print()
proto = "v6" if ipv6 else "v4"
audit_log.info(f"batch {operation} {proto}: {written} entries on map {map_id}")
return written
# ==================== Whitelist Check (for daemon) ====================
def is_whitelisted(ip_str):
"""Check if an IP is in the BPF whitelist maps.
Returns True if whitelisted, False otherwise.
"""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return False
if isinstance(addr, ipaddress.IPv6Address):
map_name = "whitelist_v6"
key_hex = cidr_to_key_v6(f"{ip_str}/128")
else:
map_name = "whitelist_v4"
key_hex = cidr_to_key(f"{ip_str}/32")
map_id = get_map_id(map_name)
if map_id is None:
return False
result = subprocess.run(
["bpftool", "map", "lookup", "id", str(map_id),
"key", "hex"] + key_hex.split(),
capture_output=True, text=True
)
return result.returncode == 0
# ==================== IP Encoding Helpers (from ddos) ====================
def ip_to_hex_key(ip_str):
"""Convert IP address to hex key for LRU_HASH maps.
IPv4: 4 bytes (network byte order)
IPv6: 16 bytes (network byte order)
Returns space-separated hex string.
"""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
raise ValueError(f"Invalid IP address: {ip_str}")
packed = addr.packed
return ' '.join(f"{b:02x}" for b in packed)
def hex_key_to_ip(hex_bytes, version=4):
"""Convert hex byte list back to IP string.
hex_bytes: list of hex strings like ['0a', '00', '00', '01']
"""
raw = bytes(int(h, 16) for h in hex_bytes)
if version == 6:
return str(ipaddress.IPv6Address(raw))
return str(ipaddress.IPv4Address(raw))
# ==================== DDoS Stats / Features (from ddos) ====================
def read_percpu_stats(map_name="global_stats", num_entries=5):
"""Read PERCPU_ARRAY map and return summed values per key.
Returns dict: {key_index: summed_value}
"""
map_id = get_map_id(map_name)
if map_id is None:
return {}
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return {}
try:
data = json.loads(result.stdout)
except Exception:
return {}
stats = {}
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key")
if isinstance(key, list):
key = int(key[0], 16) if isinstance(key[0], str) else key[0]
elif isinstance(key, str):
key = int(key, 16)
values = fmt.get("values", [])
total = 0
for v in values:
val = v.get("value", v.get("val", 0))
if isinstance(val, str):
val = int(val, 0)
elif isinstance(val, list):
val = int.from_bytes(
bytes(int(x, 16) for x in val), byteorder='little'
)
total += val
stats[key] = total
return stats
def read_percpu_features():
"""Read traffic_features PERCPU_ARRAY and return aggregated struct.
Returns dict with field names and summed values.
"""
map_id = get_map_id("traffic_feature")
if map_id is None:
return {}
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return {}
try:
data = json.loads(result.stdout)
except Exception:
return {}
field_names = [
"total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count",
"udp_count", "icmp_count", "other_proto_count", "unique_ips_approx",
"small_pkt_count", "large_pkt_count"
]
num_fields = len(field_names)
aggregated = {f: 0 for f in field_names}
for entry in data:
fmt = entry.get("formatted", entry)
values = fmt.get("values", [])
for cpu_val in values:
val = cpu_val.get("value", [])
if isinstance(val, list) and len(val) >= num_fields * 8:
raw = bytes(int(x, 16) for x in val[:num_fields * 8])
for i, name in enumerate(field_names):
v = struct.unpack_from('<Q', raw, i * 8)[0]
aggregated[name] += v
elif isinstance(val, dict):
for name in field_names:
aggregated[name] += val.get(name, 0)
return aggregated
# ==================== Rate Config (from ddos) ====================
def read_rate_config():
"""Read current rate_config from BPF map.
Returns dict: {pps_threshold, bps_threshold, window_ns}
"""
map_id = get_map_id("rate_config")
if map_id is None:
return None
result = subprocess.run(
["bpftool", "map", "lookup", "id", str(map_id), "key", "0", "0", "0", "0", "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return None
try:
data = json.loads(result.stdout)
fmt = data.get("formatted", data)
val = fmt.get("value", {})
if isinstance(val, dict):
return {
"pps_threshold": val.get("pps_threshold", 0),
"bps_threshold": val.get("bps_threshold", 0),
"window_ns": val.get("window_ns", 0),
}
elif isinstance(val, list):
raw = bytes(int(x, 16) for x in val[:24])
pps = struct.unpack_from('<Q', raw, 0)[0]
bps = struct.unpack_from('<Q', raw, 8)[0]
win = struct.unpack_from('<Q', raw, 16)[0]
return {"pps_threshold": pps, "bps_threshold": bps, "window_ns": win}
except Exception:
pass
return None
def write_rate_config(pps_threshold, bps_threshold=0, window_ns=1000000000):
"""Write rate_config to BPF map."""
map_id = get_map_id("rate_config")
if map_id is None:
raise RuntimeError("rate_config map not found")
raw = struct.pack('<QQQ', pps_threshold, bps_threshold, window_ns)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(
["bpftool", "map", "update", "id", str(map_id),
"key", "0", "0", "0", "0", "value", "hex"] + val_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to write rate_config: {result.stderr}")
audit_log.info(f"rate_config updated: pps={pps_threshold} bps={bps_threshold} window={window_ns}")
# ==================== Block / Unblock (from ddos) ====================
def block_ip(ip_str, duration_sec=0):
"""Add IP to blocked_ips map.
duration_sec: 0 = permanent, >0 = temporary block
"""
addr = ipaddress.ip_address(ip_str)
is_v6 = isinstance(addr, ipaddress.IPv6Address)
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
map_id = get_map_id(map_name)
if map_id is None:
raise RuntimeError(f"{map_name} map not found")
key_hex = ip_to_hex_key(ip_str)
# Use CLOCK_BOOTTIME (matches BPF ktime_get_ns)
with open('/proc/uptime', 'r') as f:
uptime_sec = float(f.read().split()[0])
now_ns = int(uptime_sec * 1_000_000_000)
if duration_sec > 0:
expire_ns = now_ns + (duration_sec * 1_000_000_000)
else:
expire_ns = 0
raw = struct.pack('<QQQ', expire_ns, now_ns, 0)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(
["bpftool", "map", "update", "id", str(map_id),
"key", "hex"] + key_hex.split() + ["value", "hex"] + val_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to block {ip_str}: {result.stderr}")
duration_str = f"{duration_sec}s" if duration_sec > 0 else "permanent"
audit_log.info(f"blocked {ip_str} ({duration_str})")
def unblock_ip(ip_str):
"""Remove IP from blocked_ips map."""
addr = ipaddress.ip_address(ip_str)
is_v6 = isinstance(addr, ipaddress.IPv6Address)
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
map_id = get_map_id(map_name)
if map_id is None:
raise RuntimeError(f"{map_name} map not found")
key_hex = ip_to_hex_key(ip_str)
result = subprocess.run(
["bpftool", "map", "delete", "id", str(map_id),
"key", "hex"] + key_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}")
audit_log.info(f"unblocked {ip_str}")
# ==================== Dump Helpers (from ddos) ====================
def dump_rate_counters(map_name="rate_counter_v4", top_n=10):
"""Dump top-N IPs by packet count from rate counter map.
Returns list of (ip_str, packets, bytes, last_seen_ns).
"""
map_id = get_map_id(map_name)
if map_id is None:
return []
is_v6 = "v6" in map_name
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return []
try:
data = json.loads(result.stdout)
except Exception:
return []
entries = []
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key", [])
val = fmt.get("value", {})
try:
if isinstance(key, list):
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
elif isinstance(key, dict):
if is_v6:
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
if addr8:
raw = bytes(addr8)
ip_str = str(ipaddress.IPv6Address(raw))
else:
continue
else:
ip_str = socket.inet_ntoa(struct.pack('<I', key)) if isinstance(key, int) else str(key)
elif isinstance(key, int):
ip_str = socket.inet_ntoa(struct.pack('<I', key))
else:
continue
except Exception:
continue
try:
if isinstance(val, dict):
pkts = val.get("packets", 0)
bts = val.get("bytes", 0)
last = val.get("last_seen", 0)
elif isinstance(val, list) and len(val) >= 24:
raw = bytes(int(x, 16) for x in val[:24])
pkts = struct.unpack_from('<Q', raw, 0)[0]
bts = struct.unpack_from('<Q', raw, 8)[0]
last = struct.unpack_from('<Q', raw, 16)[0]
else:
continue
except Exception:
continue
entries.append((ip_str, pkts, bts, last))
entries.sort(key=lambda x: x[1], reverse=True)
return entries[:top_n]
def dump_blocked_ips(map_name="blocked_ips_v4"):
"""Dump all blocked IPs with their block info.
Returns list of (ip_str, expire_ns, blocked_at, drop_count).
"""
map_id = get_map_id(map_name)
if map_id is None:
return []
is_v6 = "v6" in map_name
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return []
try:
data = json.loads(result.stdout)
except Exception:
return []
entries = []
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key", [])
val = fmt.get("value", {})
try:
if isinstance(key, list):
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
elif isinstance(key, dict) and is_v6:
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
if addr8:
ip_str = str(ipaddress.IPv6Address(bytes(addr8)))
else:
continue
elif isinstance(key, int):
ip_str = socket.inet_ntoa(struct.pack('<I', key))
else:
continue
except Exception:
continue
try:
if isinstance(val, dict):
expire = val.get("expire_ns", 0)
blocked = val.get("blocked_at", 0)
drops = val.get("drop_count", 0)
elif isinstance(val, list) and len(val) >= 24:
raw = bytes(int(x, 16) for x in val[:24])
expire = struct.unpack_from('<Q', raw, 0)[0]
blocked = struct.unpack_from('<Q', raw, 8)[0]
drops = struct.unpack_from('<Q', raw, 16)[0]
else:
continue
except Exception:
continue
entries.append((ip_str, expire, blocked, drops))
return entries