All config/data paths now use /etc/xdp-defense/ consistently, eliminating the legacy xdp-blocker directory reference. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
591 lines
18 KiB
Python
591 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
XDP Defense - Common Utilities
|
|
Merged from xdp-defense common utilities
|
|
Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats
|
|
"""
|
|
|
|
import subprocess
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import struct
|
|
import socket
|
|
import ipaddress
|
|
import logging
|
|
import logging.handlers
|
|
|
|
# ==================== Logging ====================
|
|
|
|
_syslog_handler = None
|
|
try:
|
|
_syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
|
|
_syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s'))
|
|
except Exception:
|
|
pass
|
|
|
|
audit_log = logging.getLogger('xdp-defense')
|
|
audit_log.setLevel(logging.INFO)
|
|
if _syslog_handler:
|
|
audit_log.addHandler(_syslog_handler)
|
|
|
|
|
|
# ==================== BPF Map Helpers ====================
|
|
|
|
_map_cache = {} # {map_name: (map_id, timestamp)}
|
|
_MAP_CACHE_TTL = 5.0 # seconds
|
|
|
|
def get_map_id(map_name):
|
|
"""Get BPF map ID by name (cached with 5s TTL)."""
|
|
import time as _time
|
|
now = _time.monotonic()
|
|
|
|
cached = _map_cache.get(map_name)
|
|
if cached and (now - cached[1]) < _MAP_CACHE_TTL:
|
|
return cached[0]
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "show", "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return None
|
|
try:
|
|
maps = json.loads(result.stdout)
|
|
# Update cache for all maps found
|
|
for m in maps:
|
|
name = m.get("name")
|
|
if name:
|
|
_map_cache[name] = (m.get("id"), now)
|
|
except Exception:
|
|
pass
|
|
cached = _map_cache.get(map_name)
|
|
return cached[0] if cached else None
|
|
|
|
|
|
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
|
|
|
|
def validate_cidr(cidr):
|
|
"""Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError."""
|
|
if '/' in cidr:
|
|
ip, prefix_str = cidr.split('/', 1)
|
|
prefix = int(prefix_str)
|
|
else:
|
|
ip = cidr
|
|
prefix = 32
|
|
|
|
if prefix < 0 or prefix > 32:
|
|
raise ValueError(f"Invalid prefix: /{prefix}")
|
|
|
|
parts = [int(x) for x in ip.split('.')]
|
|
if len(parts) != 4 or any(p < 0 or p > 255 for p in parts):
|
|
raise ValueError(f"Invalid IP: {ip}")
|
|
|
|
return ip, prefix, parts
|
|
|
|
|
|
def cidr_to_key(cidr):
|
|
"""Convert CIDR to LPM trie key hex string."""
|
|
_, prefix, parts = validate_cidr(cidr)
|
|
key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}"
|
|
return key_hex
|
|
|
|
|
|
def is_ipv6(cidr):
|
|
"""Check if a CIDR string is IPv6."""
|
|
return ':' in cidr
|
|
|
|
|
|
def validate_cidr_v6(cidr):
|
|
"""Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError."""
|
|
try:
|
|
net = ipaddress.IPv6Network(cidr, strict=False)
|
|
return net, net.prefixlen
|
|
except (ipaddress.AddressValueError, ValueError) as e:
|
|
raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e
|
|
|
|
|
|
def cidr_to_key_v6(cidr):
|
|
"""Convert IPv6 CIDR to LPM trie key hex string.
|
|
Key format: prefixlen (4 bytes LE) + addr (16 bytes)
|
|
"""
|
|
net, prefix = validate_cidr_v6(cidr)
|
|
addr_bytes = net.network_address.packed
|
|
prefix_hex = f"{prefix:02x} 00 00 00"
|
|
addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes)
|
|
return f"{prefix_hex} {addr_hex}"
|
|
|
|
|
|
def classify_cidrs(cidrs):
|
|
"""Split a list of CIDRs into v4 and v6 lists."""
|
|
v4 = []
|
|
v6 = []
|
|
for c in cidrs:
|
|
if is_ipv6(c):
|
|
v6.append(c)
|
|
else:
|
|
v4.append(c)
|
|
return v4, v6
|
|
|
|
|
|
def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False):
|
|
"""Execute batch map operations with progress reporting.
|
|
|
|
Args:
|
|
map_id: BPF map ID
|
|
cidrs: list of CIDR strings
|
|
operation: "update" or "delete"
|
|
value_hex: hex value for update operations
|
|
batch_size: number of operations per batch
|
|
|
|
Returns:
|
|
Number of processed entries
|
|
"""
|
|
total = len(cidrs)
|
|
processed = 0
|
|
written = 0
|
|
|
|
for i in range(0, total, batch_size):
|
|
batch = cidrs[i:i + batch_size]
|
|
batch_file = None
|
|
batch_written = 0
|
|
|
|
try:
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
|
|
batch_file = f.name
|
|
for cidr in batch:
|
|
try:
|
|
key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr)
|
|
if operation == "update":
|
|
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
|
|
else:
|
|
f.write(f"map delete id {map_id} key hex {key_hex}\n")
|
|
batch_written += 1
|
|
except (ValueError, Exception):
|
|
continue
|
|
|
|
subprocess.run(
|
|
["bpftool", "batch", "file", batch_file],
|
|
capture_output=True, text=True
|
|
)
|
|
finally:
|
|
if batch_file and os.path.exists(batch_file):
|
|
os.unlink(batch_file)
|
|
|
|
processed += len(batch)
|
|
written += batch_written
|
|
pct = processed * 100 // total if total > 0 else 100
|
|
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
|
|
|
|
print()
|
|
proto = "v6" if ipv6 else "v4"
|
|
audit_log.info(f"batch {operation} {proto}: {written} entries on map {map_id}")
|
|
return written
|
|
|
|
|
|
# ==================== Whitelist Check (for daemon) ====================
|
|
|
|
def is_whitelisted(ip_str):
|
|
"""Check if an IP is in the BPF whitelist maps.
|
|
Returns True if whitelisted, False otherwise.
|
|
"""
|
|
try:
|
|
addr = ipaddress.ip_address(ip_str)
|
|
except ValueError:
|
|
return False
|
|
|
|
if isinstance(addr, ipaddress.IPv6Address):
|
|
map_name = "whitelist_v6"
|
|
key_hex = cidr_to_key_v6(f"{ip_str}/128")
|
|
else:
|
|
map_name = "whitelist_v4"
|
|
key_hex = cidr_to_key(f"{ip_str}/32")
|
|
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return False
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "lookup", "id", str(map_id),
|
|
"key", "hex"] + key_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
return result.returncode == 0
|
|
|
|
|
|
# ==================== IP Encoding Helpers (from ddos) ====================
|
|
|
|
def ip_to_hex_key(ip_str):
|
|
"""Convert IP address to hex key for LRU_HASH maps.
|
|
IPv4: 4 bytes (network byte order)
|
|
IPv6: 16 bytes (network byte order)
|
|
Returns space-separated hex string.
|
|
"""
|
|
try:
|
|
addr = ipaddress.ip_address(ip_str)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid IP address: {ip_str}")
|
|
|
|
packed = addr.packed
|
|
return ' '.join(f"{b:02x}" for b in packed)
|
|
|
|
|
|
def hex_key_to_ip(hex_bytes, version=4):
|
|
"""Convert hex byte list back to IP string.
|
|
hex_bytes: list of hex strings like ['0a', '00', '00', '01']
|
|
"""
|
|
raw = bytes(int(h, 16) for h in hex_bytes)
|
|
if version == 6:
|
|
return str(ipaddress.IPv6Address(raw))
|
|
return str(ipaddress.IPv4Address(raw))
|
|
|
|
|
|
# ==================== DDoS Stats / Features (from ddos) ====================
|
|
|
|
def read_percpu_stats(map_name="global_stats", num_entries=5):
|
|
"""Read PERCPU_ARRAY map and return summed values per key.
|
|
Returns dict: {key_index: summed_value}
|
|
"""
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return {}
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return {}
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return {}
|
|
|
|
stats = {}
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
key = fmt.get("key")
|
|
if isinstance(key, list):
|
|
key = int(key[0], 16) if isinstance(key[0], str) else key[0]
|
|
elif isinstance(key, str):
|
|
key = int(key, 16)
|
|
|
|
values = fmt.get("values", [])
|
|
total = 0
|
|
for v in values:
|
|
val = v.get("value", v.get("val", 0))
|
|
if isinstance(val, str):
|
|
val = int(val, 0)
|
|
elif isinstance(val, list):
|
|
val = int.from_bytes(
|
|
bytes(int(x, 16) for x in val), byteorder='little'
|
|
)
|
|
total += val
|
|
stats[key] = total
|
|
return stats
|
|
|
|
|
|
def read_percpu_features():
|
|
"""Read traffic_features PERCPU_ARRAY and return aggregated struct.
|
|
Returns dict with field names and summed values.
|
|
"""
|
|
map_id = get_map_id("traffic_feature")
|
|
if map_id is None:
|
|
return {}
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return {}
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return {}
|
|
|
|
field_names = [
|
|
"total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count",
|
|
"udp_count", "icmp_count", "other_proto_count", "unique_ips_approx",
|
|
"small_pkt_count", "large_pkt_count"
|
|
]
|
|
num_fields = len(field_names)
|
|
|
|
aggregated = {f: 0 for f in field_names}
|
|
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
values = fmt.get("values", [])
|
|
for cpu_val in values:
|
|
val = cpu_val.get("value", [])
|
|
if isinstance(val, list) and len(val) >= num_fields * 8:
|
|
raw = bytes(int(x, 16) for x in val[:num_fields * 8])
|
|
for i, name in enumerate(field_names):
|
|
v = struct.unpack_from('<Q', raw, i * 8)[0]
|
|
aggregated[name] += v
|
|
elif isinstance(val, dict):
|
|
for name in field_names:
|
|
aggregated[name] += val.get(name, 0)
|
|
|
|
return aggregated
|
|
|
|
|
|
# ==================== Rate Config (from ddos) ====================
|
|
|
|
def read_rate_config():
|
|
"""Read current rate_config from BPF map.
|
|
Returns dict: {pps_threshold, bps_threshold, window_ns}
|
|
"""
|
|
map_id = get_map_id("rate_config")
|
|
if map_id is None:
|
|
return None
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "lookup", "id", str(map_id), "key", "0", "0", "0", "0", "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return None
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
fmt = data.get("formatted", data)
|
|
val = fmt.get("value", {})
|
|
|
|
if isinstance(val, dict):
|
|
return {
|
|
"pps_threshold": val.get("pps_threshold", 0),
|
|
"bps_threshold": val.get("bps_threshold", 0),
|
|
"window_ns": val.get("window_ns", 0),
|
|
}
|
|
elif isinstance(val, list):
|
|
raw = bytes(int(x, 16) for x in val[:24])
|
|
pps = struct.unpack_from('<Q', raw, 0)[0]
|
|
bps = struct.unpack_from('<Q', raw, 8)[0]
|
|
win = struct.unpack_from('<Q', raw, 16)[0]
|
|
return {"pps_threshold": pps, "bps_threshold": bps, "window_ns": win}
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def write_rate_config(pps_threshold, bps_threshold=0, window_ns=1000000000):
|
|
"""Write rate_config to BPF map."""
|
|
map_id = get_map_id("rate_config")
|
|
if map_id is None:
|
|
raise RuntimeError("rate_config map not found")
|
|
|
|
raw = struct.pack('<QQQ', pps_threshold, bps_threshold, window_ns)
|
|
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "update", "id", str(map_id),
|
|
"key", "0", "0", "0", "0", "value", "hex"] + val_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to write rate_config: {result.stderr}")
|
|
|
|
audit_log.info(f"rate_config updated: pps={pps_threshold} bps={bps_threshold} window={window_ns}")
|
|
|
|
|
|
# ==================== Block / Unblock (from ddos) ====================
|
|
|
|
def block_ip(ip_str, duration_sec=0):
|
|
"""Add IP to blocked_ips map.
|
|
duration_sec: 0 = permanent, >0 = temporary block
|
|
"""
|
|
addr = ipaddress.ip_address(ip_str)
|
|
is_v6 = isinstance(addr, ipaddress.IPv6Address)
|
|
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
|
|
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
raise RuntimeError(f"{map_name} map not found")
|
|
|
|
key_hex = ip_to_hex_key(ip_str)
|
|
|
|
# Use CLOCK_BOOTTIME (matches BPF ktime_get_ns)
|
|
with open('/proc/uptime', 'r') as f:
|
|
uptime_sec = float(f.read().split()[0])
|
|
now_ns = int(uptime_sec * 1_000_000_000)
|
|
|
|
if duration_sec > 0:
|
|
expire_ns = now_ns + (duration_sec * 1_000_000_000)
|
|
else:
|
|
expire_ns = 0
|
|
|
|
raw = struct.pack('<QQQ', expire_ns, now_ns, 0)
|
|
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "update", "id", str(map_id),
|
|
"key", "hex"] + key_hex.split() + ["value", "hex"] + val_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to block {ip_str}: {result.stderr}")
|
|
|
|
duration_str = f"{duration_sec}s" if duration_sec > 0 else "permanent"
|
|
audit_log.info(f"blocked {ip_str} ({duration_str})")
|
|
|
|
|
|
def unblock_ip(ip_str):
|
|
"""Remove IP from blocked_ips map."""
|
|
addr = ipaddress.ip_address(ip_str)
|
|
is_v6 = isinstance(addr, ipaddress.IPv6Address)
|
|
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
|
|
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
raise RuntimeError(f"{map_name} map not found")
|
|
|
|
key_hex = ip_to_hex_key(ip_str)
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "delete", "id", str(map_id),
|
|
"key", "hex"] + key_hex.split(),
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}")
|
|
|
|
audit_log.info(f"unblocked {ip_str}")
|
|
|
|
|
|
# ==================== Dump Helpers (from ddos) ====================
|
|
|
|
def dump_rate_counters(map_name="rate_counter_v4", top_n=10):
|
|
"""Dump top-N IPs by packet count from rate counter map.
|
|
Returns list of (ip_str, packets, bytes, last_seen_ns).
|
|
"""
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return []
|
|
|
|
is_v6 = "v6" in map_name
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return []
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return []
|
|
|
|
entries = []
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
key = fmt.get("key", [])
|
|
val = fmt.get("value", {})
|
|
|
|
try:
|
|
if isinstance(key, list):
|
|
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
|
|
elif isinstance(key, dict):
|
|
if is_v6:
|
|
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
|
|
if addr8:
|
|
raw = bytes(addr8)
|
|
ip_str = str(ipaddress.IPv6Address(raw))
|
|
else:
|
|
continue
|
|
else:
|
|
ip_str = socket.inet_ntoa(struct.pack('<I', key)) if isinstance(key, int) else str(key)
|
|
elif isinstance(key, int):
|
|
ip_str = socket.inet_ntoa(struct.pack('<I', key))
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
try:
|
|
if isinstance(val, dict):
|
|
pkts = val.get("packets", 0)
|
|
bts = val.get("bytes", 0)
|
|
last = val.get("last_seen", 0)
|
|
elif isinstance(val, list) and len(val) >= 24:
|
|
raw = bytes(int(x, 16) for x in val[:24])
|
|
pkts = struct.unpack_from('<Q', raw, 0)[0]
|
|
bts = struct.unpack_from('<Q', raw, 8)[0]
|
|
last = struct.unpack_from('<Q', raw, 16)[0]
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
entries.append((ip_str, pkts, bts, last))
|
|
|
|
entries.sort(key=lambda x: x[1], reverse=True)
|
|
return entries[:top_n]
|
|
|
|
|
|
def dump_blocked_ips(map_name="blocked_ips_v4"):
|
|
"""Dump all blocked IPs with their block info.
|
|
Returns list of (ip_str, expire_ns, blocked_at, drop_count).
|
|
"""
|
|
map_id = get_map_id(map_name)
|
|
if map_id is None:
|
|
return []
|
|
|
|
is_v6 = "v6" in map_name
|
|
|
|
result = subprocess.run(
|
|
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
|
capture_output=True, text=True
|
|
)
|
|
if result.returncode != 0:
|
|
return []
|
|
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
except Exception:
|
|
return []
|
|
|
|
entries = []
|
|
for entry in data:
|
|
fmt = entry.get("formatted", entry)
|
|
key = fmt.get("key", [])
|
|
val = fmt.get("value", {})
|
|
|
|
try:
|
|
if isinstance(key, list):
|
|
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
|
|
elif isinstance(key, dict) and is_v6:
|
|
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
|
|
if addr8:
|
|
ip_str = str(ipaddress.IPv6Address(bytes(addr8)))
|
|
else:
|
|
continue
|
|
elif isinstance(key, int):
|
|
ip_str = socket.inet_ntoa(struct.pack('<I', key))
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
try:
|
|
if isinstance(val, dict):
|
|
expire = val.get("expire_ns", 0)
|
|
blocked = val.get("blocked_at", 0)
|
|
drops = val.get("drop_count", 0)
|
|
elif isinstance(val, list) and len(val) >= 24:
|
|
raw = bytes(int(x, 16) for x in val[:24])
|
|
expire = struct.unpack_from('<Q', raw, 0)[0]
|
|
blocked = struct.unpack_from('<Q', raw, 8)[0]
|
|
drops = struct.unpack_from('<Q', raw, 16)[0]
|
|
else:
|
|
continue
|
|
except Exception:
|
|
continue
|
|
|
|
entries.append((ip_str, expire, blocked, drops))
|
|
|
|
return entries
|