Files
xdp-defense/lib/xdp_common.py
kaffa 1bcaddce25 Unify xdp-blocker and xdp-ddos into single xdp-defense project
Chain two XDP programs via libxdp dispatcher on the same interface:
xdp_blocker (priority 10) handles CIDR/country/whitelist blocking,
xdp_ddos (priority 20) handles rate limiting, EWMA analysis, and AI
anomaly detection. Whitelist maps are shared via BPF map pinning so
whitelisted IPs bypass both blocklist checks and DDoS rate limiting.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 08:39:21 +09:00

543 lines
16 KiB
Python

#!/usr/bin/env python3
"""
XDP Defense - Common Utilities
Merged from xdp-blocker/xdp_common.py and xdp-ddos/xdp_ddos_common.py
Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats
"""
import subprocess
import json
import os
import tempfile
import struct
import socket
import ipaddress
import logging
import logging.handlers
# ==================== Logging ====================
_syslog_handler = None
try:
_syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
_syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s'))
except Exception:
pass
audit_log = logging.getLogger('xdp-defense')
audit_log.setLevel(logging.INFO)
if _syslog_handler:
audit_log.addHandler(_syslog_handler)
# ==================== BPF Map Helpers ====================
def get_map_id(map_name):
"""Get BPF map ID by name."""
result = subprocess.run(
["bpftool", "map", "show", "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return None
try:
maps = json.loads(result.stdout)
for m in maps:
if m.get("name") == map_name:
return m.get("id")
except Exception:
pass
return None
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
def validate_cidr(cidr):
"""Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError."""
if '/' in cidr:
ip, prefix_str = cidr.split('/', 1)
prefix = int(prefix_str)
else:
ip = cidr
prefix = 32
if prefix < 0 or prefix > 32:
raise ValueError(f"Invalid prefix: /{prefix}")
parts = [int(x) for x in ip.split('.')]
if len(parts) != 4 or any(p < 0 or p > 255 for p in parts):
raise ValueError(f"Invalid IP: {ip}")
return ip, prefix, parts
def cidr_to_key(cidr):
"""Convert CIDR to LPM trie key hex string."""
_, prefix, parts = validate_cidr(cidr)
key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}"
return key_hex
def is_ipv6(cidr):
"""Check if a CIDR string is IPv6."""
return ':' in cidr
def validate_cidr_v6(cidr):
"""Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError."""
try:
net = ipaddress.IPv6Network(cidr, strict=False)
return net, net.prefixlen
except (ipaddress.AddressValueError, ValueError) as e:
raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e
def cidr_to_key_v6(cidr):
"""Convert IPv6 CIDR to LPM trie key hex string.
Key format: prefixlen (4 bytes LE) + addr (16 bytes)
"""
net, prefix = validate_cidr_v6(cidr)
addr_bytes = net.network_address.packed
prefix_hex = f"{prefix:02x} 00 00 00"
addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes)
return f"{prefix_hex} {addr_hex}"
def classify_cidrs(cidrs):
"""Split a list of CIDRs into v4 and v6 lists."""
v4 = []
v6 = []
for c in cidrs:
if is_ipv6(c):
v6.append(c)
else:
v4.append(c)
return v4, v6
def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False):
"""Execute batch map operations with progress reporting.
Args:
map_id: BPF map ID
cidrs: list of CIDR strings
operation: "update" or "delete"
value_hex: hex value for update operations
batch_size: number of operations per batch
Returns:
Number of processed entries
"""
total = len(cidrs)
processed = 0
for i in range(0, total, batch_size):
batch = cidrs[i:i + batch_size]
batch_file = None
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
batch_file = f.name
for cidr in batch:
try:
key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr)
if operation == "update":
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
else:
f.write(f"map delete id {map_id} key hex {key_hex}\n")
except (ValueError, Exception):
continue
subprocess.run(
["bpftool", "batch", "file", batch_file],
capture_output=True, text=True
)
finally:
if batch_file and os.path.exists(batch_file):
os.unlink(batch_file)
processed += len(batch)
pct = processed * 100 // total if total > 0 else 100
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
print()
proto = "v6" if ipv6 else "v4"
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
return processed
# ==================== IP Encoding Helpers (from ddos) ====================
def ip_to_hex_key(ip_str):
"""Convert IP address to hex key for LRU_HASH maps.
IPv4: 4 bytes (network byte order)
IPv6: 16 bytes (network byte order)
Returns space-separated hex string.
"""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
raise ValueError(f"Invalid IP address: {ip_str}")
packed = addr.packed
return ' '.join(f"{b:02x}" for b in packed)
def hex_key_to_ip(hex_bytes, version=4):
"""Convert hex byte list back to IP string.
hex_bytes: list of hex strings like ['0a', '00', '00', '01']
"""
raw = bytes(int(h, 16) for h in hex_bytes)
if version == 6:
return str(ipaddress.IPv6Address(raw))
return str(ipaddress.IPv4Address(raw))
# ==================== DDoS Stats / Features (from ddos) ====================
def read_percpu_stats(map_name="global_stats", num_entries=5):
"""Read PERCPU_ARRAY map and return summed values per key.
Returns dict: {key_index: summed_value}
"""
map_id = get_map_id(map_name)
if map_id is None:
return {}
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return {}
try:
data = json.loads(result.stdout)
except Exception:
return {}
stats = {}
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key")
if isinstance(key, list):
key = int(key[0], 16) if isinstance(key[0], str) else key[0]
elif isinstance(key, str):
key = int(key, 16)
values = fmt.get("values", [])
total = 0
for v in values:
val = v.get("value", v.get("val", 0))
if isinstance(val, str):
val = int(val, 0)
elif isinstance(val, list):
val = int.from_bytes(
bytes(int(x, 16) for x in val), byteorder='little'
)
total += val
stats[key] = total
return stats
def read_percpu_features():
"""Read traffic_features PERCPU_ARRAY and return aggregated struct.
Returns dict with field names and summed values.
"""
map_id = get_map_id("traffic_feature")
if map_id is None:
return {}
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return {}
try:
data = json.loads(result.stdout)
except Exception:
return {}
field_names = [
"total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count",
"udp_count", "icmp_count", "other_proto_count", "unique_ips_approx",
"small_pkt_count", "large_pkt_count"
]
num_fields = len(field_names)
aggregated = {f: 0 for f in field_names}
for entry in data:
fmt = entry.get("formatted", entry)
values = fmt.get("values", [])
for cpu_val in values:
val = cpu_val.get("value", [])
if isinstance(val, list) and len(val) >= num_fields * 8:
raw = bytes(int(x, 16) for x in val[:num_fields * 8])
for i, name in enumerate(field_names):
v = struct.unpack_from('<Q', raw, i * 8)[0]
aggregated[name] += v
elif isinstance(val, dict):
for name in field_names:
aggregated[name] += val.get(name, 0)
return aggregated
# ==================== Rate Config (from ddos) ====================
def read_rate_config():
"""Read current rate_config from BPF map.
Returns dict: {pps_threshold, bps_threshold, window_ns}
"""
map_id = get_map_id("rate_config")
if map_id is None:
return None
result = subprocess.run(
["bpftool", "map", "lookup", "id", str(map_id), "key", "0", "0", "0", "0", "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return None
try:
data = json.loads(result.stdout)
fmt = data.get("formatted", data)
val = fmt.get("value", {})
if isinstance(val, dict):
return {
"pps_threshold": val.get("pps_threshold", 0),
"bps_threshold": val.get("bps_threshold", 0),
"window_ns": val.get("window_ns", 0),
}
elif isinstance(val, list):
raw = bytes(int(x, 16) for x in val[:24])
pps = struct.unpack_from('<Q', raw, 0)[0]
bps = struct.unpack_from('<Q', raw, 8)[0]
win = struct.unpack_from('<Q', raw, 16)[0]
return {"pps_threshold": pps, "bps_threshold": bps, "window_ns": win}
except Exception:
pass
return None
def write_rate_config(pps_threshold, bps_threshold=0, window_ns=1000000000):
"""Write rate_config to BPF map."""
map_id = get_map_id("rate_config")
if map_id is None:
raise RuntimeError("rate_config map not found")
raw = struct.pack('<QQQ', pps_threshold, bps_threshold, window_ns)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(
["bpftool", "map", "update", "id", str(map_id),
"key", "0", "0", "0", "0", "value", "hex"] + val_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to write rate_config: {result.stderr}")
audit_log.info(f"rate_config updated: pps={pps_threshold} bps={bps_threshold} window={window_ns}")
# ==================== Block / Unblock (from ddos) ====================
def block_ip(ip_str, duration_sec=0):
"""Add IP to blocked_ips map.
duration_sec: 0 = permanent, >0 = temporary block
"""
addr = ipaddress.ip_address(ip_str)
is_v6 = isinstance(addr, ipaddress.IPv6Address)
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
map_id = get_map_id(map_name)
if map_id is None:
raise RuntimeError(f"{map_name} map not found")
key_hex = ip_to_hex_key(ip_str)
if duration_sec > 0:
with open('/proc/uptime', 'r') as f:
uptime_sec = float(f.read().split()[0])
now_ns = int(uptime_sec * 1_000_000_000)
expire_ns = now_ns + (duration_sec * 1_000_000_000)
else:
expire_ns = 0
now_ns_val = 0
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(
["bpftool", "map", "update", "id", str(map_id),
"key", "hex"] + key_hex.split() + ["value", "hex"] + val_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to block {ip_str}: {result.stderr}")
duration_str = f"{duration_sec}s" if duration_sec > 0 else "permanent"
audit_log.info(f"blocked {ip_str} ({duration_str})")
def unblock_ip(ip_str):
"""Remove IP from blocked_ips map."""
addr = ipaddress.ip_address(ip_str)
is_v6 = isinstance(addr, ipaddress.IPv6Address)
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
map_id = get_map_id(map_name)
if map_id is None:
raise RuntimeError(f"{map_name} map not found")
key_hex = ip_to_hex_key(ip_str)
result = subprocess.run(
["bpftool", "map", "delete", "id", str(map_id),
"key", "hex"] + key_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}")
audit_log.info(f"unblocked {ip_str}")
# ==================== Dump Helpers (from ddos) ====================
def dump_rate_counters(map_name="rate_counter_v4", top_n=10):
"""Dump top-N IPs by packet count from rate counter map.
Returns list of (ip_str, packets, bytes, last_seen_ns).
"""
map_id = get_map_id(map_name)
if map_id is None:
return []
is_v6 = "v6" in map_name
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return []
try:
data = json.loads(result.stdout)
except Exception:
return []
entries = []
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key", [])
val = fmt.get("value", {})
try:
if isinstance(key, list):
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
elif isinstance(key, dict):
if is_v6:
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
if addr8:
raw = bytes(addr8)
ip_str = str(ipaddress.IPv6Address(raw))
else:
continue
else:
ip_str = socket.inet_ntoa(struct.pack('<I', key)) if isinstance(key, int) else str(key)
elif isinstance(key, int):
ip_str = socket.inet_ntoa(struct.pack('<I', key))
else:
continue
except Exception:
continue
try:
if isinstance(val, dict):
pkts = val.get("packets", 0)
bts = val.get("bytes", 0)
last = val.get("last_seen", 0)
elif isinstance(val, list) and len(val) >= 24:
raw = bytes(int(x, 16) for x in val[:24])
pkts = struct.unpack_from('<Q', raw, 0)[0]
bts = struct.unpack_from('<Q', raw, 8)[0]
last = struct.unpack_from('<Q', raw, 16)[0]
else:
continue
except Exception:
continue
entries.append((ip_str, pkts, bts, last))
entries.sort(key=lambda x: x[1], reverse=True)
return entries[:top_n]
def dump_blocked_ips(map_name="blocked_ips_v4"):
"""Dump all blocked IPs with their block info.
Returns list of (ip_str, expire_ns, blocked_at, drop_count).
"""
map_id = get_map_id(map_name)
if map_id is None:
return []
is_v6 = "v6" in map_name
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return []
try:
data = json.loads(result.stdout)
except Exception:
return []
entries = []
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key", [])
val = fmt.get("value", {})
try:
if isinstance(key, list):
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
elif isinstance(key, dict) and is_v6:
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
if addr8:
ip_str = str(ipaddress.IPv6Address(bytes(addr8)))
else:
continue
elif isinstance(key, int):
ip_str = socket.inet_ntoa(struct.pack('<I', key))
else:
continue
except Exception:
continue
try:
if isinstance(val, dict):
expire = val.get("expire_ns", 0)
blocked = val.get("blocked_at", 0)
drops = val.get("drop_count", 0)
elif isinstance(val, list) and len(val) >= 24:
raw = bytes(int(x, 16) for x in val[:24])
expire = struct.unpack_from('<Q', raw, 0)[0]
blocked = struct.unpack_from('<Q', raw, 8)[0]
drops = struct.unpack_from('<Q', raw, 16)[0]
else:
continue
except Exception:
continue
entries.append((ip_str, expire, blocked, drops))
return entries