Chain two XDP programs via libxdp dispatcher on the same interface: xdp_blocker (priority 10) handles CIDR/country/whitelist blocking, xdp_ddos (priority 20) handles rate limiting, EWMA analysis, and AI anomaly detection. Whitelist maps are shared via BPF map pinning so whitelisted IPs bypass both blocklist checks and DDoS rate limiting. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
543 lines
16 KiB
Python
543 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
XDP Defense - Common Utilities
|
|
Merged from xdp-blocker/xdp_common.py and xdp-ddos/xdp_ddos_common.py
|
|
Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats
|
|
"""
|
|
|
|
import subprocess
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import struct
|
|
import socket
|
|
import ipaddress
|
|
import logging
|
|
import logging.handlers
|
|
|
|
# ==================== Logging ====================
|
|
|
|
_syslog_handler = None
|
|
try:
|
|
_syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
|
|
_syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s'))
|
|
except Exception:
|
|
pass
|
|
|
|
audit_log = logging.getLogger('xdp-defense')
|
|
audit_log.setLevel(logging.INFO)
|
|
if _syslog_handler:
|
|
audit_log.addHandler(_syslog_handler)
|
|
|
|
|
|
# ==================== BPF Map Helpers ====================
|
|
|
|
def get_map_id(map_name):
|
|
"""Get BPF map ID by name."""
|
|
result = subprocess.run(
|
|
["bpftool", "map", "show", "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return None
|
|
try:
|
|
maps = json.loads(result.stdout)
|
|
for m in maps:
|
|
if m.get("name") == map_name:
|
|
return m.get("id")
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
|
|
|
|
def validate_cidr(cidr):
|
|
"""Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError."""
|
|
if '/' in cidr:
|
|
ip, prefix_str = cidr.split('/', 1)
|
|
prefix = int(prefix_str)
|
|
else:
|
|
ip = cidr
|
|
prefix = 32
|
|
|
|
if prefix < 0 or prefix > 32:
|
|
raise ValueError(f"Invalid prefix: /{prefix}")
|
|
|
|
parts = [int(x) for x in ip.split('.')]
|
|
if len(parts) != 4 or any(p < 0 or p > 255 for p in parts):
|
|
raise ValueError(f"Invalid IP: {ip}")
|
|
|
|
return ip, prefix, parts
|
|
|
|
|
|
def cidr_to_key(cidr):
|
|
"""Convert CIDR to LPM trie key hex string."""
|
|
_, prefix, parts = validate_cidr(cidr)
|
|
key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}"
|
|
return key_hex
|
|
|
|
|
|
def is_ipv6(cidr):
|
|
"""Check if a CIDR string is IPv6."""
|
|
return ':' in cidr
|
|
|
|
|
|
def validate_cidr_v6(cidr):
|
|
"""Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError."""
|
|
try:
|
|
net = ipaddress.IPv6Network(cidr, strict=False)
|
|
return net, net.prefixlen
|
|
except (ipaddress.AddressValueError, ValueError) as e:
|
|
raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e
|
|
|
|
|
|
def cidr_to_key_v6(cidr):
|
|
"""Convert IPv6 CIDR to LPM trie key hex string.
|
|
Key format: prefixlen (4 bytes LE) + addr (16 bytes)
|
|
"""
|
|
net, prefix = validate_cidr_v6(cidr)
|
|
addr_bytes = net.network_address.packed
|
|
prefix_hex = f"{prefix:02x} 00 00 00"
|
|
addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes)
|
|
return f"{prefix_hex} {addr_hex}"
|
|
|
|
|
|
def classify_cidrs(cidrs):
|
|
"""Split a list of CIDRs into v4 and v6 lists."""
|
|
v4 = []
|
|
v6 = []
|
|
for c in cidrs:
|
|
if is_ipv6(c):
|
|
v6.append(c)
|
|
else:
|
|
v4.append(c)
|
|
return v4, v6
|
|
|
|
|
|
def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False):
|
|
"""Execute batch map operations with progress reporting.
|
|
|
|
Args:
|
|
map_id: BPF map ID
|
|
cidrs: list of CIDR strings
|
|
operation: "update" or "delete"
|
|
value_hex: hex value for update operations
|
|
batch_size: number of operations per batch
|
|
|
|
Returns:
|
|
Number of processed entries
|
|
"""
|
|
total = len(cidrs)
|
|
processed = 0
|
|
|
|
for i in range(0, total, batch_size):
|
|
batch = cidrs[i:i + batch_size]
|
|
batch_file = None
|
|
|
|
try:
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
|
|
batch_file = f.name
|
|
for cidr in batch:
|
|
try:
|
|
key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr)
|
|
if operation == "update":
|
|
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
|
|
else:
|
|
f.write(f"map delete id {map_id} key hex {key_hex}\n")
|
|
except (ValueError, Exception):
|
|
continue
|
|
|
|
subprocess.run(
|
|
["bpftool", "batch", "file", batch_file],
|
|
capture_output=True, text=True
|
|
)
|
|
finally:
|
|
if batch_file and os.path.exists(batch_file):
|
|
os.unlink(batch_file)
|
|
|
|
processed += len(batch)
|
|
pct = processed * 100 // total if total > 0 else 100
|
|
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
|
|
|
|
print()
|
|
proto = "v6" if ipv6 else "v4"
|
|
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
|
|
return processed
|
|
|
|
|
|
# ==================== IP Encoding Helpers (from ddos) ====================
|
|
|
|
def ip_to_hex_key(ip_str):
|
|
"""Convert IP address to hex key for LRU_HASH maps.
|
|
IPv4: 4 bytes (network byte order)
|
|
IPv6: 16 bytes (network byte order)
|
|
Returns space-separated hex string.
|
|
"""
|
|
try:
|
|
addr = ipaddress.ip_address(ip_str)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid IP address: {ip_str}")
|
|
|
|
packed = addr.packed
|
|
return ' '.join(f"{b:02x}" for b in packed)
|
|
|
|
|
|
def hex_key_to_ip(hex_bytes, version=4):
|
|
"""Convert hex byte list back to IP string.
|
|
hex_bytes: list of hex strings like ['0a', '00', '00', '01']
|
|
"""
|
|
raw = bytes(int(h, 16) for h in hex_bytes)
|
|
if version == 6:
|
|
return str(ipaddress.IPv6Address(raw))
|
|
return str(ipaddress.IPv4Address(raw))
|
|
|
|
|
|
# ==================== DDoS Stats / Features (from ddos) ====================
|
|
|
|
def read_percpu_stats(map_name="global_stats", num_entries=5):
|
|
"""Read PERCPU_ARRAY map and return summed values per key.
|
|
Returns dict: {key_index: summed_value}
|
|
"""
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return {}
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return {}
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return {}
|
|
|
|
stats = {}
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
key = fmt.get("key")
|
|
if isinstance(key, list):
|
|
key = int(key[0], 16) if isinstance(key[0], str) else key[0]
|
|
elif isinstance(key, str):
|
|
key = int(key, 16)
|
|
|
|
values = fmt.get("values", [])
|
|
total = 0
|
|
for v in values:
|
|
val = v.get("value", v.get("val", 0))
|
|
if isinstance(val, str):
|
|
val = int(val, 0)
|
|
elif isinstance(val, list):
|
|
val = int.from_bytes(
|
|
bytes(int(x, 16) for x in val), byteorder='little'
|
|
)
|
|
total += val
|
|
stats[key] = total
|
|
return stats
|
|
|
|
|
|
def read_percpu_features():
|
|
"""Read traffic_features PERCPU_ARRAY and return aggregated struct.
|
|
Returns dict with field names and summed values.
|
|
"""
|
|
map_id = get_map_id("traffic_feature")
|
|
if map_id is None:
|
|
return {}
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return {}
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return {}
|
|
|
|
field_names = [
|
|
"total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count",
|
|
"udp_count", "icmp_count", "other_proto_count", "unique_ips_approx",
|
|
"small_pkt_count", "large_pkt_count"
|
|
]
|
|
num_fields = len(field_names)
|
|
|
|
aggregated = {f: 0 for f in field_names}
|
|
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
values = fmt.get("values", [])
|
|
for cpu_val in values:
|
|
val = cpu_val.get("value", [])
|
|
if isinstance(val, list) and len(val) >= num_fields * 8:
|
|
raw = bytes(int(x, 16) for x in val[:num_fields * 8])
|
|
for i, name in enumerate(field_names):
|
|
v = struct.unpack_from('<Q', raw, i * 8)[0]
|
|
aggregated[name] += v
|
|
elif isinstance(val, dict):
|
|
for name in field_names:
|
|
aggregated[name] += val.get(name, 0)
|
|
|
|
return aggregated
|
|
|
|
|
|
# ==================== Rate Config (from ddos) ====================
|
|
|
|
def read_rate_config():
|
|
"""Read current rate_config from BPF map.
|
|
Returns dict: {pps_threshold, bps_threshold, window_ns}
|
|
"""
|
|
map_id = get_map_id("rate_config")
|
|
if map_id is None:
|
|
return None
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "lookup", "id", str(map_id), "key", "0", "0", "0", "0", "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return None
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
fmt = data.get("formatted", data)
|
|
val = fmt.get("value", {})
|
|
|
|
if isinstance(val, dict):
|
|
return {
|
|
"pps_threshold": val.get("pps_threshold", 0),
|
|
"bps_threshold": val.get("bps_threshold", 0),
|
|
"window_ns": val.get("window_ns", 0),
|
|
}
|
|
elif isinstance(val, list):
|
|
raw = bytes(int(x, 16) for x in val[:24])
|
|
pps = struct.unpack_from('<Q', raw, 0)[0]
|
|
bps = struct.unpack_from('<Q', raw, 8)[0]
|
|
win = struct.unpack_from('<Q', raw, 16)[0]
|
|
return {"pps_threshold": pps, "bps_threshold": bps, "window_ns": win}
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def write_rate_config(pps_threshold, bps_threshold=0, window_ns=1000000000):
|
|
"""Write rate_config to BPF map."""
|
|
map_id = get_map_id("rate_config")
|
|
if map_id is None:
|
|
raise RuntimeError("rate_config map not found")
|
|
|
|
raw = struct.pack('<QQQ', pps_threshold, bps_threshold, window_ns)
|
|
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "update", "id", str(map_id),
|
|
"key", "0", "0", "0", "0", "value", "hex"] + val_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to write rate_config: {result.stderr}")
|
|
|
|
audit_log.info(f"rate_config updated: pps={pps_threshold} bps={bps_threshold} window={window_ns}")
|
|
|
|
|
|
# ==================== Block / Unblock (from ddos) ====================
|
|
|
|
def block_ip(ip_str, duration_sec=0):
|
|
"""Add IP to blocked_ips map.
|
|
duration_sec: 0 = permanent, >0 = temporary block
|
|
"""
|
|
addr = ipaddress.ip_address(ip_str)
|
|
is_v6 = isinstance(addr, ipaddress.IPv6Address)
|
|
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
|
|
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
raise RuntimeError(f"{map_name} map not found")
|
|
|
|
key_hex = ip_to_hex_key(ip_str)
|
|
|
|
if duration_sec > 0:
|
|
with open('/proc/uptime', 'r') as f:
|
|
uptime_sec = float(f.read().split()[0])
|
|
now_ns = int(uptime_sec * 1_000_000_000)
|
|
expire_ns = now_ns + (duration_sec * 1_000_000_000)
|
|
else:
|
|
expire_ns = 0
|
|
|
|
now_ns_val = 0
|
|
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
|
|
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "update", "id", str(map_id),
|
|
"key", "hex"] + key_hex.split() + ["value", "hex"] + val_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to block {ip_str}: {result.stderr}")
|
|
|
|
duration_str = f"{duration_sec}s" if duration_sec > 0 else "permanent"
|
|
audit_log.info(f"blocked {ip_str} ({duration_str})")
|
|
|
|
|
|
def unblock_ip(ip_str):
|
|
"""Remove IP from blocked_ips map."""
|
|
addr = ipaddress.ip_address(ip_str)
|
|
is_v6 = isinstance(addr, ipaddress.IPv6Address)
|
|
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
|
|
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
raise RuntimeError(f"{map_name} map not found")
|
|
|
|
key_hex = ip_to_hex_key(ip_str)
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "delete", "id", str(map_id),
|
|
"key", "hex"] + key_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}")
|
|
|
|
audit_log.info(f"unblocked {ip_str}")
|
|
|
|
|
|
# ==================== Dump Helpers (from ddos) ====================
|
|
|
|
def dump_rate_counters(map_name="rate_counter_v4", top_n=10):
|
|
"""Dump top-N IPs by packet count from rate counter map.
|
|
Returns list of (ip_str, packets, bytes, last_seen_ns).
|
|
"""
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return []
|
|
|
|
is_v6 = "v6" in map_name
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return []
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return []
|
|
|
|
entries = []
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
key = fmt.get("key", [])
|
|
val = fmt.get("value", {})
|
|
|
|
try:
|
|
if isinstance(key, list):
|
|
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
|
|
elif isinstance(key, dict):
|
|
if is_v6:
|
|
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
|
|
if addr8:
|
|
raw = bytes(addr8)
|
|
ip_str = str(ipaddress.IPv6Address(raw))
|
|
else:
|
|
continue
|
|
else:
|
|
ip_str = socket.inet_ntoa(struct.pack('<I', key)) if isinstance(key, int) else str(key)
|
|
elif isinstance(key, int):
|
|
ip_str = socket.inet_ntoa(struct.pack('<I', key))
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
try:
|
|
if isinstance(val, dict):
|
|
pkts = val.get("packets", 0)
|
|
bts = val.get("bytes", 0)
|
|
last = val.get("last_seen", 0)
|
|
elif isinstance(val, list) and len(val) >= 24:
|
|
raw = bytes(int(x, 16) for x in val[:24])
|
|
pkts = struct.unpack_from('<Q', raw, 0)[0]
|
|
bts = struct.unpack_from('<Q', raw, 8)[0]
|
|
last = struct.unpack_from('<Q', raw, 16)[0]
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
entries.append((ip_str, pkts, bts, last))
|
|
|
|
entries.sort(key=lambda x: x[1], reverse=True)
|
|
return entries[:top_n]
|
|
|
|
|
|
def dump_blocked_ips(map_name="blocked_ips_v4"):
|
|
"""Dump all blocked IPs with their block info.
|
|
Returns list of (ip_str, expire_ns, blocked_at, drop_count).
|
|
"""
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return []
|
|
|
|
is_v6 = "v6" in map_name
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return []
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return []
|
|
|
|
entries = []
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
key = fmt.get("key", [])
|
|
val = fmt.get("value", {})
|
|
|
|
try:
|
|
if isinstance(key, list):
|
|
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
|
|
elif isinstance(key, dict) and is_v6:
|
|
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
|
|
if addr8:
|
|
ip_str = str(ipaddress.IPv6Address(bytes(addr8)))
|
|
else:
|
|
continue
|
|
elif isinstance(key, int):
|
|
ip_str = socket.inet_ntoa(struct.pack('<I', key))
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
try:
|
|
if isinstance(val, dict):
|
|
expire = val.get("expire_ns", 0)
|
|
blocked = val.get("blocked_at", 0)
|
|
drops = val.get("drop_count", 0)
|
|
elif isinstance(val, list) and len(val) >= 24:
|
|
raw = bytes(int(x, 16) for x in val[:24])
|
|
expire = struct.unpack_from('<Q', raw, 0)[0]
|
|
blocked = struct.unpack_from('<Q', raw, 8)[0]
|
|
drops = struct.unpack_from('<Q', raw, 16)[0]
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
entries.append((ip_str, expire, blocked, drops))
|
|
|
|
return entries
|