Unify xdp-blocker and xdp-ddos into single xdp-defense project
Chain two XDP programs via libxdp dispatcher on the same interface: xdp_blocker (priority 10) handles CIDR/country/whitelist blocking, xdp_ddos (priority 20) handles rate limiting, EWMA analysis, and AI anomaly detection. Whitelist maps are shared via BPF map pinning so whitelisted IPs bypass both blocklist checks and DDoS rate limiting. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
542
lib/xdp_common.py
Normal file
542
lib/xdp_common.py
Normal file
@@ -0,0 +1,542 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XDP Defense - Common Utilities
|
||||
Merged from xdp-blocker/xdp_common.py and xdp-ddos/xdp_ddos_common.py
|
||||
Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import struct
|
||||
import socket
|
||||
import ipaddress
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
||||
# ==================== Logging ====================
|
||||
|
||||
_syslog_handler = None
|
||||
try:
|
||||
_syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
|
||||
_syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s'))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
audit_log = logging.getLogger('xdp-defense')
|
||||
audit_log.setLevel(logging.INFO)
|
||||
if _syslog_handler:
|
||||
audit_log.addHandler(_syslog_handler)
|
||||
|
||||
|
||||
# ==================== BPF Map Helpers ====================
|
||||
|
||||
def get_map_id(map_name):
|
||||
"""Get BPF map ID by name."""
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "show", "-j"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
try:
|
||||
maps = json.loads(result.stdout)
|
||||
for m in maps:
|
||||
if m.get("name") == map_name:
|
||||
return m.get("id")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
|
||||
|
||||
def validate_cidr(cidr):
|
||||
"""Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError."""
|
||||
if '/' in cidr:
|
||||
ip, prefix_str = cidr.split('/', 1)
|
||||
prefix = int(prefix_str)
|
||||
else:
|
||||
ip = cidr
|
||||
prefix = 32
|
||||
|
||||
if prefix < 0 or prefix > 32:
|
||||
raise ValueError(f"Invalid prefix: /{prefix}")
|
||||
|
||||
parts = [int(x) for x in ip.split('.')]
|
||||
if len(parts) != 4 or any(p < 0 or p > 255 for p in parts):
|
||||
raise ValueError(f"Invalid IP: {ip}")
|
||||
|
||||
return ip, prefix, parts
|
||||
|
||||
|
||||
def cidr_to_key(cidr):
|
||||
"""Convert CIDR to LPM trie key hex string."""
|
||||
_, prefix, parts = validate_cidr(cidr)
|
||||
key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}"
|
||||
return key_hex
|
||||
|
||||
|
||||
def is_ipv6(cidr):
|
||||
"""Check if a CIDR string is IPv6."""
|
||||
return ':' in cidr
|
||||
|
||||
|
||||
def validate_cidr_v6(cidr):
|
||||
"""Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError."""
|
||||
try:
|
||||
net = ipaddress.IPv6Network(cidr, strict=False)
|
||||
return net, net.prefixlen
|
||||
except (ipaddress.AddressValueError, ValueError) as e:
|
||||
raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e
|
||||
|
||||
|
||||
def cidr_to_key_v6(cidr):
|
||||
"""Convert IPv6 CIDR to LPM trie key hex string.
|
||||
Key format: prefixlen (4 bytes LE) + addr (16 bytes)
|
||||
"""
|
||||
net, prefix = validate_cidr_v6(cidr)
|
||||
addr_bytes = net.network_address.packed
|
||||
prefix_hex = f"{prefix:02x} 00 00 00"
|
||||
addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes)
|
||||
return f"{prefix_hex} {addr_hex}"
|
||||
|
||||
|
||||
def classify_cidrs(cidrs):
|
||||
"""Split a list of CIDRs into v4 and v6 lists."""
|
||||
v4 = []
|
||||
v6 = []
|
||||
for c in cidrs:
|
||||
if is_ipv6(c):
|
||||
v6.append(c)
|
||||
else:
|
||||
v4.append(c)
|
||||
return v4, v6
|
||||
|
||||
|
||||
def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False):
|
||||
"""Execute batch map operations with progress reporting.
|
||||
|
||||
Args:
|
||||
map_id: BPF map ID
|
||||
cidrs: list of CIDR strings
|
||||
operation: "update" or "delete"
|
||||
value_hex: hex value for update operations
|
||||
batch_size: number of operations per batch
|
||||
|
||||
Returns:
|
||||
Number of processed entries
|
||||
"""
|
||||
total = len(cidrs)
|
||||
processed = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = cidrs[i:i + batch_size]
|
||||
batch_file = None
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
|
||||
batch_file = f.name
|
||||
for cidr in batch:
|
||||
try:
|
||||
key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr)
|
||||
if operation == "update":
|
||||
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
|
||||
else:
|
||||
f.write(f"map delete id {map_id} key hex {key_hex}\n")
|
||||
except (ValueError, Exception):
|
||||
continue
|
||||
|
||||
subprocess.run(
|
||||
["bpftool", "batch", "file", batch_file],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
finally:
|
||||
if batch_file and os.path.exists(batch_file):
|
||||
os.unlink(batch_file)
|
||||
|
||||
processed += len(batch)
|
||||
pct = processed * 100 // total if total > 0 else 100
|
||||
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
|
||||
|
||||
print()
|
||||
proto = "v6" if ipv6 else "v4"
|
||||
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
|
||||
return processed
|
||||
|
||||
|
||||
# ==================== IP Encoding Helpers (from ddos) ====================
|
||||
|
||||
def ip_to_hex_key(ip_str):
|
||||
"""Convert IP address to hex key for LRU_HASH maps.
|
||||
IPv4: 4 bytes (network byte order)
|
||||
IPv6: 16 bytes (network byte order)
|
||||
Returns space-separated hex string.
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {ip_str}")
|
||||
|
||||
packed = addr.packed
|
||||
return ' '.join(f"{b:02x}" for b in packed)
|
||||
|
||||
|
||||
def hex_key_to_ip(hex_bytes, version=4):
|
||||
"""Convert hex byte list back to IP string.
|
||||
hex_bytes: list of hex strings like ['0a', '00', '00', '01']
|
||||
"""
|
||||
raw = bytes(int(h, 16) for h in hex_bytes)
|
||||
if version == 6:
|
||||
return str(ipaddress.IPv6Address(raw))
|
||||
return str(ipaddress.IPv4Address(raw))
|
||||
|
||||
|
||||
# ==================== DDoS Stats / Features (from ddos) ====================
|
||||
|
||||
def read_percpu_stats(map_name="global_stats", num_entries=5):
|
||||
"""Read PERCPU_ARRAY map and return summed values per key.
|
||||
Returns dict: {key_index: summed_value}
|
||||
"""
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
return {}
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return {}
|
||||
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
stats = {}
|
||||
for entry in data:
|
||||
fmt = entry.get("formatted", entry)
|
||||
key = fmt.get("key")
|
||||
if isinstance(key, list):
|
||||
key = int(key[0], 16) if isinstance(key[0], str) else key[0]
|
||||
elif isinstance(key, str):
|
||||
key = int(key, 16)
|
||||
|
||||
values = fmt.get("values", [])
|
||||
total = 0
|
||||
for v in values:
|
||||
val = v.get("value", v.get("val", 0))
|
||||
if isinstance(val, str):
|
||||
val = int(val, 0)
|
||||
elif isinstance(val, list):
|
||||
val = int.from_bytes(
|
||||
bytes(int(x, 16) for x in val), byteorder='little'
|
||||
)
|
||||
total += val
|
||||
stats[key] = total
|
||||
return stats
|
||||
|
||||
|
||||
def read_percpu_features():
|
||||
"""Read traffic_features PERCPU_ARRAY and return aggregated struct.
|
||||
Returns dict with field names and summed values.
|
||||
"""
|
||||
map_id = get_map_id("traffic_feature")
|
||||
if map_id is None:
|
||||
return {}
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return {}
|
||||
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
field_names = [
|
||||
"total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count",
|
||||
"udp_count", "icmp_count", "other_proto_count", "unique_ips_approx",
|
||||
"small_pkt_count", "large_pkt_count"
|
||||
]
|
||||
num_fields = len(field_names)
|
||||
|
||||
aggregated = {f: 0 for f in field_names}
|
||||
|
||||
for entry in data:
|
||||
fmt = entry.get("formatted", entry)
|
||||
values = fmt.get("values", [])
|
||||
for cpu_val in values:
|
||||
val = cpu_val.get("value", [])
|
||||
if isinstance(val, list) and len(val) >= num_fields * 8:
|
||||
raw = bytes(int(x, 16) for x in val[:num_fields * 8])
|
||||
for i, name in enumerate(field_names):
|
||||
v = struct.unpack_from('<Q', raw, i * 8)[0]
|
||||
aggregated[name] += v
|
||||
elif isinstance(val, dict):
|
||||
for name in field_names:
|
||||
aggregated[name] += val.get(name, 0)
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
# ==================== Rate Config (from ddos) ====================
|
||||
|
||||
def read_rate_config():
|
||||
"""Read current rate_config from BPF map.
|
||||
Returns dict: {pps_threshold, bps_threshold, window_ns}
|
||||
"""
|
||||
map_id = get_map_id("rate_config")
|
||||
if map_id is None:
|
||||
return None
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "lookup", "id", str(map_id), "key", "0", "0", "0", "0", "-j"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
fmt = data.get("formatted", data)
|
||||
val = fmt.get("value", {})
|
||||
|
||||
if isinstance(val, dict):
|
||||
return {
|
||||
"pps_threshold": val.get("pps_threshold", 0),
|
||||
"bps_threshold": val.get("bps_threshold", 0),
|
||||
"window_ns": val.get("window_ns", 0),
|
||||
}
|
||||
elif isinstance(val, list):
|
||||
raw = bytes(int(x, 16) for x in val[:24])
|
||||
pps = struct.unpack_from('<Q', raw, 0)[0]
|
||||
bps = struct.unpack_from('<Q', raw, 8)[0]
|
||||
win = struct.unpack_from('<Q', raw, 16)[0]
|
||||
return {"pps_threshold": pps, "bps_threshold": bps, "window_ns": win}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def write_rate_config(pps_threshold, bps_threshold=0, window_ns=1000000000):
|
||||
"""Write rate_config to BPF map."""
|
||||
map_id = get_map_id("rate_config")
|
||||
if map_id is None:
|
||||
raise RuntimeError("rate_config map not found")
|
||||
|
||||
raw = struct.pack('<QQQ', pps_threshold, bps_threshold, window_ns)
|
||||
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "update", "id", str(map_id),
|
||||
"key", "0", "0", "0", "0", "value", "hex"] + val_hex.split(),
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to write rate_config: {result.stderr}")
|
||||
|
||||
audit_log.info(f"rate_config updated: pps={pps_threshold} bps={bps_threshold} window={window_ns}")
|
||||
|
||||
|
||||
# ==================== Block / Unblock (from ddos) ====================
|
||||
|
||||
def block_ip(ip_str, duration_sec=0):
|
||||
"""Add IP to blocked_ips map.
|
||||
duration_sec: 0 = permanent, >0 = temporary block
|
||||
"""
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
is_v6 = isinstance(addr, ipaddress.IPv6Address)
|
||||
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
|
||||
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
raise RuntimeError(f"{map_name} map not found")
|
||||
|
||||
key_hex = ip_to_hex_key(ip_str)
|
||||
|
||||
if duration_sec > 0:
|
||||
with open('/proc/uptime', 'r') as f:
|
||||
uptime_sec = float(f.read().split()[0])
|
||||
now_ns = int(uptime_sec * 1_000_000_000)
|
||||
expire_ns = now_ns + (duration_sec * 1_000_000_000)
|
||||
else:
|
||||
expire_ns = 0
|
||||
|
||||
now_ns_val = 0
|
||||
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
|
||||
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "update", "id", str(map_id),
|
||||
"key", "hex"] + key_hex.split() + ["value", "hex"] + val_hex.split(),
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to block {ip_str}: {result.stderr}")
|
||||
|
||||
duration_str = f"{duration_sec}s" if duration_sec > 0 else "permanent"
|
||||
audit_log.info(f"blocked {ip_str} ({duration_str})")
|
||||
|
||||
|
||||
def unblock_ip(ip_str):
|
||||
"""Remove IP from blocked_ips map."""
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
is_v6 = isinstance(addr, ipaddress.IPv6Address)
|
||||
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
|
||||
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
raise RuntimeError(f"{map_name} map not found")
|
||||
|
||||
key_hex = ip_to_hex_key(ip_str)
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "delete", "id", str(map_id),
|
||||
"key", "hex"] + key_hex.split(),
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}")
|
||||
|
||||
audit_log.info(f"unblocked {ip_str}")
|
||||
|
||||
|
||||
# ==================== Dump Helpers (from ddos) ====================
|
||||
|
||||
def dump_rate_counters(map_name="rate_counter_v4", top_n=10):
|
||||
"""Dump top-N IPs by packet count from rate counter map.
|
||||
Returns list of (ip_str, packets, bytes, last_seen_ns).
|
||||
"""
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
return []
|
||||
|
||||
is_v6 = "v6" in map_name
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for entry in data:
|
||||
fmt = entry.get("formatted", entry)
|
||||
key = fmt.get("key", [])
|
||||
val = fmt.get("value", {})
|
||||
|
||||
try:
|
||||
if isinstance(key, list):
|
||||
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
|
||||
elif isinstance(key, dict):
|
||||
if is_v6:
|
||||
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
|
||||
if addr8:
|
||||
raw = bytes(addr8)
|
||||
ip_str = str(ipaddress.IPv6Address(raw))
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
ip_str = socket.inet_ntoa(struct.pack('<I', key)) if isinstance(key, int) else str(key)
|
||||
elif isinstance(key, int):
|
||||
ip_str = socket.inet_ntoa(struct.pack('<I', key))
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
pkts = val.get("packets", 0)
|
||||
bts = val.get("bytes", 0)
|
||||
last = val.get("last_seen", 0)
|
||||
elif isinstance(val, list) and len(val) >= 24:
|
||||
raw = bytes(int(x, 16) for x in val[:24])
|
||||
pkts = struct.unpack_from('<Q', raw, 0)[0]
|
||||
bts = struct.unpack_from('<Q', raw, 8)[0]
|
||||
last = struct.unpack_from('<Q', raw, 16)[0]
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
entries.append((ip_str, pkts, bts, last))
|
||||
|
||||
entries.sort(key=lambda x: x[1], reverse=True)
|
||||
return entries[:top_n]
|
||||
|
||||
|
||||
def dump_blocked_ips(map_name="blocked_ips_v4"):
|
||||
"""Dump all blocked IPs with their block info.
|
||||
Returns list of (ip_str, expire_ns, blocked_at, drop_count).
|
||||
"""
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
return []
|
||||
|
||||
is_v6 = "v6" in map_name
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "dump", "id", str(map_id), "-j"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for entry in data:
|
||||
fmt = entry.get("formatted", entry)
|
||||
key = fmt.get("key", [])
|
||||
val = fmt.get("value", {})
|
||||
|
||||
try:
|
||||
if isinstance(key, list):
|
||||
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
|
||||
elif isinstance(key, dict) and is_v6:
|
||||
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
|
||||
if addr8:
|
||||
ip_str = str(ipaddress.IPv6Address(bytes(addr8)))
|
||||
else:
|
||||
continue
|
||||
elif isinstance(key, int):
|
||||
ip_str = socket.inet_ntoa(struct.pack('<I', key))
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
expire = val.get("expire_ns", 0)
|
||||
blocked = val.get("blocked_at", 0)
|
||||
drops = val.get("drop_count", 0)
|
||||
elif isinstance(val, list) and len(val) >= 24:
|
||||
raw = bytes(int(x, 16) for x in val[:24])
|
||||
expire = struct.unpack_from('<Q', raw, 0)[0]
|
||||
blocked = struct.unpack_from('<Q', raw, 8)[0]
|
||||
drops = struct.unpack_from('<Q', raw, 16)[0]
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
entries.append((ip_str, expire, blocked, drops))
|
||||
|
||||
return entries
|
||||
175
lib/xdp_country.py
Executable file
175
lib/xdp_country.py
Executable file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XDP Country Blocker - Fast batch IP blocking
|
||||
Uses bpftool batch for high-speed map updates
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
from xdp_common import get_map_id, batch_map_operation, classify_cidrs
|
||||
|
||||
COUNTRY_DIR = Path("/etc/xdp-blocker/countries")
|
||||
IPDENY_V4_URL = "https://www.ipdeny.com/ipblocks/data/countries/{}.zone"
|
||||
IPDENY_V6_URL = "https://www.ipdeny.com/ipblocks/data/ipv6/ipv6-country-blocks/{}.zone"
|
||||
|
||||
def download_country(cc):
|
||||
"""Download country IPv4 + IPv6 IP lists"""
|
||||
COUNTRY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
cc_file = COUNTRY_DIR / f"{cc.lower()}.txt"
|
||||
cidrs = []
|
||||
|
||||
# Download IPv4
|
||||
print(f"[INFO] Downloading {cc.upper()} IPv4 ranges...")
|
||||
try:
|
||||
urllib.request.urlretrieve(IPDENY_V4_URL.format(cc.lower()), cc_file)
|
||||
with open(cc_file) as f:
|
||||
cidrs.extend(line.strip() for line in f if line.strip())
|
||||
print(f" IPv4: {len(cidrs)} CIDRs")
|
||||
except Exception as e:
|
||||
print(f" [WARN] IPv4 download failed: {e}")
|
||||
|
||||
# Download IPv6
|
||||
v6_count = 0
|
||||
try:
|
||||
v6_tmp = COUNTRY_DIR / f"{cc.lower()}_v6.tmp"
|
||||
urllib.request.urlretrieve(IPDENY_V6_URL.format(cc.lower()), v6_tmp)
|
||||
with open(v6_tmp) as f:
|
||||
v6_cidrs = [line.strip() for line in f if line.strip()]
|
||||
cidrs.extend(v6_cidrs)
|
||||
v6_count = len(v6_cidrs)
|
||||
v6_tmp.unlink(missing_ok=True)
|
||||
print(f" IPv6: {v6_count} CIDRs")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not cidrs:
|
||||
print(f"[ERROR] No IP ranges found for {cc.upper()}")
|
||||
return None
|
||||
|
||||
with open(cc_file, 'w') as f:
|
||||
f.write('\n'.join(cidrs) + '\n')
|
||||
|
||||
return cc_file
|
||||
|
||||
def add_country(cc):
|
||||
"""Add country IPs to XDP blocklist using batch (IPv4 + IPv6)"""
|
||||
cc = cc.lower()
|
||||
cc_file = COUNTRY_DIR / f"{cc}.txt"
|
||||
|
||||
if not cc_file.exists():
|
||||
cc_file = download_country(cc)
|
||||
if not cc_file:
|
||||
return False
|
||||
|
||||
with open(cc_file) as f:
|
||||
cidrs = [line.strip() for line in f if line.strip()]
|
||||
|
||||
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
|
||||
|
||||
if v4_cidrs:
|
||||
map_id = get_map_id("blocklist_v4")
|
||||
if not map_id:
|
||||
print("[ERROR] blocklist_v4 map not found. Is XDP loaded?")
|
||||
return False
|
||||
print(f"[INFO] Adding {len(v4_cidrs)} IPv4 CIDRs for {cc.upper()}...")
|
||||
batch_map_operation(map_id, v4_cidrs, operation="update")
|
||||
|
||||
if v6_cidrs:
|
||||
map_id_v6 = get_map_id("blocklist_v6")
|
||||
if map_id_v6:
|
||||
print(f"[INFO] Adding {len(v6_cidrs)} IPv6 CIDRs for {cc.upper()}...")
|
||||
batch_map_operation(map_id_v6, v6_cidrs, operation="update", ipv6=True)
|
||||
else:
|
||||
print("[WARN] blocklist_v6 map not found, skipping IPv6")
|
||||
|
||||
print(f"[OK] Added {cc.upper()}: {len(v4_cidrs)} v4 + {len(v6_cidrs)} v6 CIDRs")
|
||||
return True
|
||||
|
||||
def del_country(cc):
|
||||
"""Remove country IPs from XDP blocklist"""
|
||||
cc = cc.lower()
|
||||
cc_file = COUNTRY_DIR / f"{cc}.txt"
|
||||
|
||||
if not cc_file.exists():
|
||||
print(f"[ERROR] Country {cc.upper()} is not blocked")
|
||||
return False
|
||||
|
||||
with open(cc_file) as f:
|
||||
cidrs = [line.strip() for line in f if line.strip()]
|
||||
|
||||
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
|
||||
|
||||
if v4_cidrs:
|
||||
map_id = get_map_id("blocklist_v4")
|
||||
if map_id:
|
||||
print(f"[INFO] Removing {len(v4_cidrs)} IPv4 CIDRs for {cc.upper()}...")
|
||||
batch_map_operation(map_id, v4_cidrs, operation="delete")
|
||||
|
||||
if v6_cidrs:
|
||||
map_id_v6 = get_map_id("blocklist_v6")
|
||||
if map_id_v6:
|
||||
print(f"[INFO] Removing {len(v6_cidrs)} IPv6 CIDRs for {cc.upper()}...")
|
||||
batch_map_operation(map_id_v6, v6_cidrs, operation="delete", ipv6=True)
|
||||
|
||||
cc_file.unlink()
|
||||
print(f"[OK] Removed {cc.upper()}: {len(v4_cidrs)} v4 + {len(v6_cidrs)} v6 CIDRs")
|
||||
return True
|
||||
|
||||
def list_countries():
|
||||
"""List blocked countries"""
|
||||
print("=== Blocked Countries ===")
|
||||
|
||||
if not COUNTRY_DIR.exists():
|
||||
print(" (none)")
|
||||
return
|
||||
|
||||
files = list(COUNTRY_DIR.glob("*.txt"))
|
||||
if not files:
|
||||
print(" (none)")
|
||||
return
|
||||
|
||||
for cc_file in sorted(files):
|
||||
cc = cc_file.stem.upper()
|
||||
count = sum(1 for _ in open(cc_file))
|
||||
mtime = cc_file.stat().st_mtime
|
||||
age = int((time.time() - mtime) / 86400)
|
||||
print(f" {cc}: {count} CIDRs (updated {age}d ago)")
|
||||
|
||||
def show_help():
|
||||
print("""XDP Country Blocker - Fast batch IP blocking
|
||||
|
||||
Usage: xdp-country <command> [args]
|
||||
|
||||
Commands:
|
||||
add <cc> Block a country (e.g., br, cn, ru, kp)
|
||||
del <cc> Unblock a country
|
||||
list List blocked countries
|
||||
|
||||
Examples:
|
||||
xdp-country add br # Block Brazil (~13K CIDRs in seconds)
|
||||
xdp-country add cn # Block China
|
||||
xdp-country del br # Unblock Brazil
|
||||
xdp-country list
|
||||
""")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
show_help()
|
||||
return
|
||||
|
||||
cmd = sys.argv[1]
|
||||
|
||||
if cmd == "add" and len(sys.argv) >= 3:
|
||||
add_country(sys.argv[2])
|
||||
elif cmd == "del" and len(sys.argv) >= 3:
|
||||
del_country(sys.argv[2])
|
||||
elif cmd == "list":
|
||||
list_countries()
|
||||
else:
|
||||
show_help()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
714
lib/xdp_defense_daemon.py
Executable file
714
lib/xdp_defense_daemon.py
Executable file
@@ -0,0 +1,714 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XDP Defense Daemon
|
||||
Userspace component: EWMA-based rate analysis, AI anomaly detection,
|
||||
time-profile switching, and automatic escalation.
|
||||
|
||||
4 worker threads + main thread (signal handling):
|
||||
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
|
||||
- AI Thread: reads traffic features, runs Isolation Forest inference
|
||||
- Profile Thread: checks time-of-day, switches rate_config profiles
|
||||
- Cleanup Thread: removes expired entries from blocked_ips maps
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import signal
|
||||
import threading
|
||||
import logging
|
||||
import logging.handlers
|
||||
import csv
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
|
||||
# ==================== Logging ====================
|
||||
|
||||
log = logging.getLogger('xdp-defense-daemon')
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
_console = logging.StreamHandler()
|
||||
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
|
||||
log.addHandler(_console)
|
||||
|
||||
try:
|
||||
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
|
||||
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
|
||||
log.addHandler(_syslog)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ==================== Configuration ====================
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
'general': {
|
||||
'interface': 'eth0',
|
||||
'log_level': 'info',
|
||||
'pid_file': '/var/lib/xdp-defense/daemon.pid',
|
||||
'data_dir': '/var/lib/xdp-defense',
|
||||
},
|
||||
'rate_limits': {
|
||||
'default_pps': 1000,
|
||||
'default_bps': 0,
|
||||
'window_sec': 1,
|
||||
'profiles': {},
|
||||
},
|
||||
'escalation': {
|
||||
'rate_limit_after': 1,
|
||||
'temp_block_after': 5,
|
||||
'perm_block_after': 20,
|
||||
'temp_block_duration': 300,
|
||||
'violation_window': 600,
|
||||
'cooldown_multiplier': 0.5,
|
||||
},
|
||||
'ewma': {
|
||||
'alpha': 0.3,
|
||||
'poll_interval': 1,
|
||||
'threshold_multiplier': 3.0,
|
||||
},
|
||||
'ai': {
|
||||
'enabled': True,
|
||||
'model_type': 'IsolationForest',
|
||||
'contamination': 0.05,
|
||||
'n_estimators': 100,
|
||||
'learning_duration': 259200,
|
||||
'min_samples': 1000,
|
||||
'poll_interval': 5,
|
||||
'anomaly_threshold': -0.3,
|
||||
'retrain_interval': 604800,
|
||||
'min_packets_for_sample': 20,
|
||||
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
|
||||
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
|
||||
},
|
||||
}
|
||||
|
||||
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
|
||||
|
||||
|
||||
def load_config(path=CONFIG_PATH):
|
||||
"""Load config with defaults."""
|
||||
cfg = DEFAULT_CONFIG.copy()
|
||||
try:
|
||||
with open(path) as f:
|
||||
user = yaml.safe_load(f) or {}
|
||||
for section in cfg:
|
||||
if section in user and isinstance(user[section], dict):
|
||||
cfg[section].update(user[section])
|
||||
except FileNotFoundError:
|
||||
log.warning("Config not found at %s, using defaults", path)
|
||||
except Exception as e:
|
||||
log.error("Failed to load config: %s", e)
|
||||
return cfg
|
||||
|
||||
|
||||
# ==================== ViolationTracker ====================
|
||||
|
||||
class ViolationTracker:
|
||||
"""Track per-IP violation counts and manage escalation."""
|
||||
|
||||
def __init__(self, escalation_cfg):
|
||||
self.cfg = escalation_cfg
|
||||
self.violations = defaultdict(list)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def record_violation(self, ip):
|
||||
"""Record a violation and return escalation level.
|
||||
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
|
||||
"""
|
||||
now = time.time()
|
||||
window = self.cfg.get('violation_window', 600)
|
||||
|
||||
with self.lock:
|
||||
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
|
||||
self.violations[ip].append(now)
|
||||
count = len(self.violations[ip])
|
||||
|
||||
perm_after = self.cfg.get('perm_block_after', 20)
|
||||
temp_after = self.cfg.get('temp_block_after', 5)
|
||||
|
||||
if count >= perm_after:
|
||||
return 'perm_block'
|
||||
elif count >= temp_after:
|
||||
return 'temp_block'
|
||||
return 'rate_limit'
|
||||
|
||||
def clear(self, ip):
|
||||
with self.lock:
|
||||
self.violations.pop(ip, None)
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Remove entries with no recent violations."""
|
||||
now = time.time()
|
||||
window = self.cfg.get('violation_window', 600)
|
||||
with self.lock:
|
||||
expired = [ip for ip, times in self.violations.items()
|
||||
if all(now - t >= window for t in times)]
|
||||
for ip in expired:
|
||||
del self.violations[ip]
|
||||
|
||||
|
||||
# ==================== EWMAAnalyzer ====================
|
||||
|
||||
class EWMAAnalyzer:
|
||||
"""Per-IP EWMA calculation for rate anomaly detection."""
|
||||
|
||||
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
|
||||
self.alpha = alpha
|
||||
self.threshold_multiplier = threshold_multiplier
|
||||
self.ewma = {}
|
||||
self.baseline = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def update(self, ip, current_pps):
|
||||
"""Update EWMA for an IP. Returns True if anomalous."""
|
||||
with self.lock:
|
||||
if ip not in self.ewma:
|
||||
self.ewma[ip] = current_pps
|
||||
self.baseline[ip] = current_pps
|
||||
return False
|
||||
|
||||
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
|
||||
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
|
||||
|
||||
base = max(self.baseline[ip], 1)
|
||||
if self.ewma[ip] > base * self.threshold_multiplier:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_stats(self, ip):
|
||||
with self.lock:
|
||||
return {
|
||||
'ewma': self.ewma.get(ip, 0),
|
||||
'baseline': self.baseline.get(ip, 0),
|
||||
}
|
||||
|
||||
def cleanup_stale(self, active_ips):
|
||||
"""Remove tracking for IPs no longer in rate counters."""
|
||||
with self.lock:
|
||||
stale = set(self.ewma.keys()) - set(active_ips)
|
||||
for ip in stale:
|
||||
self.ewma.pop(ip, None)
|
||||
self.baseline.pop(ip, None)
|
||||
|
||||
|
||||
# ==================== AIDetector ====================
|
||||
|
||||
class AIDetector:
|
||||
"""Isolation Forest based anomaly detection on traffic features."""
|
||||
|
||||
def __init__(self, ai_cfg):
|
||||
self.cfg = ai_cfg
|
||||
self.model = None
|
||||
self.scaler = None
|
||||
self.started_at = time.time()
|
||||
self.training_data = []
|
||||
self.is_learning = True
|
||||
self._retrain_requested = False
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.cfg.get('enabled', False)
|
||||
|
||||
def request_retrain(self):
|
||||
self._retrain_requested = True
|
||||
|
||||
def collect_sample(self, features):
|
||||
"""Collect a feature sample during learning phase."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self.training_data.append(features)
|
||||
|
||||
learning_dur = self.cfg.get('learning_duration', 259200)
|
||||
min_samples = self.cfg.get('min_samples', 1000)
|
||||
elapsed = time.time() - self.started_at
|
||||
|
||||
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
|
||||
self._train()
|
||||
self._retrain_requested = False
|
||||
|
||||
def _train(self):
|
||||
"""Train the Isolation Forest model."""
|
||||
try:
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
log.error("scikit-learn not installed. AI detection disabled.")
|
||||
self.cfg['enabled'] = False
|
||||
return
|
||||
|
||||
if len(self.training_data) < 10:
|
||||
log.warning("Not enough training data (%d samples)", len(self.training_data))
|
||||
return
|
||||
|
||||
log.info("Training AI model with %d samples...", len(self.training_data))
|
||||
|
||||
try:
|
||||
X = np.array(self.training_data)
|
||||
|
||||
self.scaler = StandardScaler()
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
|
||||
self.model = IsolationForest(
|
||||
n_estimators=self.cfg.get('n_estimators', 100),
|
||||
contamination=self.cfg.get('contamination', 'auto'),
|
||||
random_state=42,
|
||||
)
|
||||
self.model.fit(X_scaled)
|
||||
self.is_learning = False
|
||||
|
||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||
with open(model_file, 'wb') as f:
|
||||
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
|
||||
|
||||
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
||||
with open(data_file, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
||||
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
||||
'small_pkt_count', 'large_pkt_count',
|
||||
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
||||
])
|
||||
writer.writerows(self.training_data)
|
||||
|
||||
log.info("AI model trained and saved to %s", model_file)
|
||||
|
||||
except Exception as e:
|
||||
log.error("AI training failed: %s", e)
|
||||
|
||||
def load_model(self):
|
||||
"""Load a previously trained model."""
|
||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||
if not os.path.exists(model_file):
|
||||
return False
|
||||
try:
|
||||
with open(model_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
self.model = data['model']
|
||||
self.scaler = data['scaler']
|
||||
self.is_learning = False
|
||||
log.info("AI model loaded from %s", model_file)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error("Failed to load AI model: %s", e)
|
||||
return False
|
||||
|
||||
def predict(self, features):
|
||||
"""Run anomaly detection. Returns (is_anomaly, score)."""
|
||||
if not self.enabled or self.model is None:
|
||||
return False, 0.0
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
X = np.array([features])
|
||||
X_scaled = self.scaler.transform(X)
|
||||
score = self.model.decision_function(X_scaled)[0]
|
||||
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
||||
return score < threshold, float(score)
|
||||
except Exception as e:
|
||||
log.error("AI prediction error: %s", e)
|
||||
return False, 0.0
|
||||
|
||||
|
||||
# ==================== ProfileManager ====================
|
||||
|
||||
class ProfileManager:
|
||||
"""Manage time-based rate limit profiles."""
|
||||
|
||||
def __init__(self, rate_cfg):
|
||||
self.cfg = rate_cfg
|
||||
self.current_profile = 'default'
|
||||
|
||||
def check_and_apply(self):
|
||||
"""Check current time and apply matching profile."""
|
||||
from xdp_common import write_rate_config
|
||||
|
||||
profiles = self.cfg.get('profiles', {})
|
||||
now = datetime.now()
|
||||
current_hour = now.hour
|
||||
current_min = now.minute
|
||||
current_time = current_hour * 60 + current_min
|
||||
weekday = now.strftime('%a').lower()
|
||||
|
||||
matched_profile = None
|
||||
matched_name = 'default'
|
||||
|
||||
for name, profile in profiles.items():
|
||||
hours = profile.get('hours', '')
|
||||
weekdays = profile.get('weekdays', '')
|
||||
|
||||
if weekdays:
|
||||
day_range = weekdays.lower().split('-')
|
||||
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
|
||||
if len(day_range) == 2:
|
||||
try:
|
||||
start_idx = day_names.index(day_range[0])
|
||||
end_idx = day_names.index(day_range[1])
|
||||
current_idx = day_names.index(weekday)
|
||||
if start_idx <= end_idx:
|
||||
if not (start_idx <= current_idx <= end_idx):
|
||||
continue
|
||||
else:
|
||||
if not (current_idx >= start_idx or current_idx <= end_idx):
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if hours:
|
||||
try:
|
||||
start_str, end_str = hours.split('-')
|
||||
sh, sm = map(int, start_str.split(':'))
|
||||
eh, em = map(int, end_str.split(':'))
|
||||
start_min = sh * 60 + sm
|
||||
end_min = eh * 60 + em
|
||||
|
||||
if start_min <= end_min:
|
||||
if not (start_min <= current_time < end_min):
|
||||
continue
|
||||
else:
|
||||
if not (current_time >= start_min or current_time < end_min):
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
matched_profile = profile
|
||||
matched_name = name
|
||||
break
|
||||
|
||||
if matched_name != self.current_profile:
|
||||
if matched_profile:
|
||||
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
|
||||
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
|
||||
else:
|
||||
pps = self.cfg.get('default_pps', 1000)
|
||||
bps = self.cfg.get('default_bps', 0)
|
||||
|
||||
window = self.cfg.get('window_sec', 1)
|
||||
|
||||
try:
|
||||
write_rate_config(pps, bps, window * 1_000_000_000)
|
||||
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
|
||||
self.current_profile = matched_name
|
||||
except Exception as e:
|
||||
log.error("Failed to apply profile %s: %s", matched_name, e)
|
||||
|
||||
|
||||
# ==================== DDoSDaemon ====================
|
||||
|
||||
class DDoSDaemon:
|
||||
"""Main daemon orchestrator."""
|
||||
|
||||
def __init__(self, config_path=CONFIG_PATH):
|
||||
self.config_path = config_path
|
||||
self.cfg = load_config(config_path)
|
||||
self.running = False
|
||||
self._stop_event = threading.Event()
|
||||
self._setup_components()
|
||||
|
||||
def _setup_components(self):
|
||||
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
|
||||
self.ewma_analyzer = EWMAAnalyzer(
|
||||
alpha=self.cfg['ewma'].get('alpha', 0.3),
|
||||
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
|
||||
)
|
||||
self.ai_detector = AIDetector(self.cfg['ai'])
|
||||
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
|
||||
|
||||
if self.ai_detector.enabled:
|
||||
self.ai_detector.load_model()
|
||||
|
||||
level = self.cfg['general'].get('log_level', 'info').upper()
|
||||
log.setLevel(getattr(logging, level, logging.INFO))
|
||||
|
||||
def _write_pid(self):
|
||||
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
||||
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
|
||||
with open(pid_file, 'w') as f:
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
def _remove_pid(self):
|
||||
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
||||
try:
|
||||
os.unlink(pid_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _ensure_single_instance(self):
|
||||
"""Stop any existing daemon before starting."""
|
||||
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
||||
if not os.path.exists(pid_file):
|
||||
return
|
||||
try:
|
||||
with open(pid_file) as f:
|
||||
old_pid = int(f.read().strip())
|
||||
os.kill(old_pid, 0)
|
||||
log.info("Stopping existing daemon (PID %d)...", old_pid)
|
||||
os.kill(old_pid, signal.SIGTERM)
|
||||
for _ in range(30):
|
||||
time.sleep(1)
|
||||
try:
|
||||
os.kill(old_pid, 0)
|
||||
except OSError:
|
||||
log.info("Old daemon stopped")
|
||||
return
|
||||
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
|
||||
os.kill(old_pid, signal.SIGKILL)
|
||||
time.sleep(1)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
def _handle_sighup(self, signum, frame):
|
||||
log.info("SIGHUP received, reloading config...")
|
||||
self.cfg = load_config(self.config_path)
|
||||
self._setup_components()
|
||||
log.info("Config reloaded")
|
||||
|
||||
def _handle_sigterm(self, signum, frame):
|
||||
log.info("SIGTERM received, shutting down...")
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
|
||||
def _handle_sigusr1(self, signum, frame):
|
||||
log.info("SIGUSR1 received, requesting AI retrain...")
|
||||
self.ai_detector.request_retrain()
|
||||
|
||||
# ---- Worker Threads ----
|
||||
|
||||
def _ewma_thread(self):
|
||||
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
||||
from xdp_common import dump_rate_counters, block_ip
|
||||
|
||||
interval = self.cfg['ewma'].get('poll_interval', 1)
|
||||
prev_counters = {}
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
|
||||
active_ips = []
|
||||
|
||||
for ip_str, pkts, bts, last_seen in entries:
|
||||
active_ips.append(ip_str)
|
||||
|
||||
prev = prev_counters.get(ip_str, 0)
|
||||
delta = pkts - prev if pkts >= prev else pkts
|
||||
prev_counters[ip_str] = pkts
|
||||
|
||||
if delta <= 0:
|
||||
continue
|
||||
|
||||
pps = delta / max(interval, 0.1)
|
||||
|
||||
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
|
||||
if is_anomalous:
|
||||
level = self.violation_tracker.record_violation(ip_str)
|
||||
ew = self.ewma_analyzer.get_stats(ip_str)
|
||||
log.warning(
|
||||
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
|
||||
ip_str, pps, ew['ewma'], ew['baseline'], level
|
||||
)
|
||||
|
||||
if level == 'temp_block':
|
||||
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
||||
try:
|
||||
block_ip(ip_str, dur)
|
||||
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
|
||||
except Exception as e:
|
||||
log.error("Failed to temp-block %s: %s", ip_str, e)
|
||||
|
||||
elif level == 'perm_block':
|
||||
try:
|
||||
block_ip(ip_str, 0)
|
||||
log.warning("PERM BLOCK: %s", ip_str)
|
||||
except Exception as e:
|
||||
log.error("Failed to perm-block %s: %s", ip_str, e)
|
||||
|
||||
self.ewma_analyzer.cleanup_stale(active_ips)
|
||||
|
||||
except Exception as e:
|
||||
log.error("EWMA thread error: %s", e)
|
||||
|
||||
self._stop_event.wait(interval)
|
||||
|
||||
def _ai_thread(self):
|
||||
"""Read traffic features, run AI inference or collect training data."""
|
||||
from xdp_common import read_percpu_features, dump_rate_counters, block_ip
|
||||
|
||||
interval = self.cfg['ai'].get('poll_interval', 5)
|
||||
prev_features = None
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if not self.ai_detector.enabled:
|
||||
self._stop_event.wait(interval)
|
||||
continue
|
||||
|
||||
features = read_percpu_features()
|
||||
if not features:
|
||||
self._stop_event.wait(interval)
|
||||
continue
|
||||
|
||||
feature_names = [
|
||||
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
||||
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
||||
'small_pkt_count', 'large_pkt_count'
|
||||
]
|
||||
|
||||
if prev_features is not None:
|
||||
deltas = []
|
||||
for name in feature_names:
|
||||
cur = features.get(name, 0)
|
||||
prev = prev_features.get(name, 0)
|
||||
deltas.append(max(0, cur - prev))
|
||||
|
||||
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
|
||||
if deltas[0] < min_pkts:
|
||||
prev_features = features
|
||||
self._stop_event.wait(interval)
|
||||
continue
|
||||
|
||||
total = deltas[0] + 1e-6
|
||||
syn_ratio = deltas[2] / total
|
||||
udp_ratio = deltas[4] / total
|
||||
icmp_ratio = deltas[5] / total
|
||||
small_pkt_ratio = deltas[8] / total
|
||||
avg_pkt_size = deltas[1] / total
|
||||
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
|
||||
|
||||
if self.ai_detector.is_learning:
|
||||
self.ai_detector.collect_sample(deltas)
|
||||
if len(self.ai_detector.training_data) % 100 == 0:
|
||||
log.debug("AI learning: %d samples collected",
|
||||
len(self.ai_detector.training_data))
|
||||
else:
|
||||
is_anomaly, score = self.ai_detector.predict(deltas)
|
||||
if is_anomaly:
|
||||
log.warning(
|
||||
"AI ANOMALY detected: score=%.4f deltas=%s",
|
||||
score, dict(zip(feature_names, deltas[:len(feature_names)]))
|
||||
)
|
||||
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
|
||||
for ip_str, pkts, bts, _ in top_ips:
|
||||
level = self.violation_tracker.record_violation(ip_str)
|
||||
log.warning("AI escalation: %s -> %s", ip_str, level)
|
||||
|
||||
if level == 'temp_block':
|
||||
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
||||
try:
|
||||
block_ip(ip_str, dur)
|
||||
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
|
||||
except Exception as e:
|
||||
log.error("Failed to AI temp-block %s: %s", ip_str, e)
|
||||
|
||||
elif level == 'perm_block':
|
||||
try:
|
||||
block_ip(ip_str, 0)
|
||||
log.warning("AI PERM BLOCK: %s", ip_str)
|
||||
except Exception as e:
|
||||
log.error("Failed to AI perm-block %s: %s", ip_str, e)
|
||||
|
||||
prev_features = features
|
||||
|
||||
except Exception as e:
|
||||
log.error("AI thread error: %s", e)
|
||||
|
||||
self._stop_event.wait(interval)
|
||||
|
||||
def _profile_thread(self):
|
||||
"""Check time-of-day and switch rate profiles."""
|
||||
while self.running:
|
||||
try:
|
||||
self.profile_manager.check_and_apply()
|
||||
except Exception as e:
|
||||
log.error("Profile thread error: %s", e)
|
||||
self._stop_event.wait(60)
|
||||
|
||||
def _cleanup_thread(self):
|
||||
"""Periodically clean up expired blocked IPs and stale violations."""
|
||||
from xdp_common import dump_blocked_ips, unblock_ip
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
with open('/proc/uptime') as f:
|
||||
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
||||
|
||||
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
|
||||
entries = dump_blocked_ips(map_name)
|
||||
for ip_str, expire_ns, blocked_at, drop_count in entries:
|
||||
if expire_ns != 0 and now_ns > expire_ns:
|
||||
try:
|
||||
unblock_ip(ip_str)
|
||||
self.violation_tracker.clear(ip_str)
|
||||
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
|
||||
except Exception as e:
|
||||
log.error("Failed to remove expired block %s: %s", ip_str, e)
|
||||
|
||||
self.violation_tracker.cleanup_expired()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Cleanup thread error: %s", e)
|
||||
|
||||
self._stop_event.wait(60)
|
||||
|
||||
# ---- Main Loop ----
|
||||
|
||||
def run(self):
|
||||
"""Start the daemon."""
|
||||
log.info("XDP Defense Daemon starting...")
|
||||
|
||||
signal.signal(signal.SIGHUP, self._handle_sighup)
|
||||
signal.signal(signal.SIGTERM, self._handle_sigterm)
|
||||
signal.signal(signal.SIGINT, self._handle_sigterm)
|
||||
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
|
||||
|
||||
self._ensure_single_instance()
|
||||
self._write_pid()
|
||||
self.running = True
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
|
||||
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
|
||||
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
|
||||
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
|
||||
]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
log.info("Started %s thread", t.name)
|
||||
|
||||
log.info("Daemon running (PID %d)", os.getpid())
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
self._stop_event.wait(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
log.info("Shutting down...")
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
self._remove_pid()
|
||||
log.info("Daemon stopped")
|
||||
|
||||
|
||||
# ==================== Entry Point ====================
|
||||
|
||||
def main():
|
||||
config_path = CONFIG_PATH
|
||||
if len(sys.argv) > 1:
|
||||
config_path = sys.argv[1]
|
||||
|
||||
daemon = DDoSDaemon(config_path)
|
||||
daemon.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
256
lib/xdp_whitelist.py
Executable file
256
lib/xdp_whitelist.py
Executable file
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XDP Whitelist Manager - Fast batch IP whitelisting
|
||||
Supports presets like Cloudflare, AWS, Google, etc.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
from xdp_common import get_map_id, batch_map_operation, classify_cidrs
|
||||
|
||||
WHITELIST_DIR = Path("/etc/xdp-blocker/whitelist")
|
||||
|
||||
# Preset URLs for trusted services
|
||||
PRESETS = {
|
||||
"cloudflare": {
|
||||
"v4": "https://www.cloudflare.com/ips-v4",
|
||||
"v6": "https://www.cloudflare.com/ips-v6",
|
||||
"desc": "Cloudflare CDN/Proxy"
|
||||
},
|
||||
"aws": {
|
||||
"v4": "https://ip-ranges.amazonaws.com/ip-ranges.json",
|
||||
"desc": "Amazon Web Services (all regions)"
|
||||
},
|
||||
"google": {
|
||||
"v4": "https://www.gstatic.com/ipranges/cloud.json",
|
||||
"desc": "Google Cloud Platform"
|
||||
},
|
||||
"github": {
|
||||
"v4": "https://api.github.com/meta",
|
||||
"desc": "GitHub Services"
|
||||
}
|
||||
}
|
||||
|
||||
def download_cloudflare():
|
||||
"""Download Cloudflare IP ranges"""
|
||||
cidrs = []
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
PRESETS["cloudflare"]["v4"],
|
||||
headers={"User-Agent": "xdp-whitelist/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req) as r:
|
||||
cidrs.extend(r.read().decode().strip().split('\n'))
|
||||
print(f" Downloaded {len(cidrs)} IPv4 ranges")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to download IPv4: {e}")
|
||||
return cidrs
|
||||
|
||||
def download_aws():
|
||||
"""Download AWS IP ranges"""
|
||||
cidrs = []
|
||||
try:
|
||||
with urllib.request.urlopen(PRESETS["aws"]["v4"]) as r:
|
||||
data = json.loads(r.read().decode())
|
||||
for prefix in data.get("prefixes", []):
|
||||
cidrs.append(prefix["ip_prefix"])
|
||||
print(f" Downloaded {len(cidrs)} IPv4 ranges")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to download: {e}")
|
||||
return cidrs
|
||||
|
||||
def download_google():
|
||||
"""Download Google Cloud IP ranges"""
|
||||
cidrs = []
|
||||
try:
|
||||
with urllib.request.urlopen(PRESETS["google"]["v4"]) as r:
|
||||
data = json.loads(r.read().decode())
|
||||
for prefix in data.get("prefixes", []):
|
||||
if "ipv4Prefix" in prefix:
|
||||
cidrs.append(prefix["ipv4Prefix"])
|
||||
print(f" Downloaded {len(cidrs)} IPv4 ranges")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to download: {e}")
|
||||
return cidrs
|
||||
|
||||
def download_github():
|
||||
"""Download GitHub IP ranges"""
|
||||
cidrs = []
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
PRESETS["github"]["v4"],
|
||||
headers={"User-Agent": "xdp-whitelist"}
|
||||
)
|
||||
with urllib.request.urlopen(req) as r:
|
||||
data = json.loads(r.read().decode())
|
||||
for key in ["hooks", "web", "api", "git", "packages", "pages", "importer", "actions", "dependabot"]:
|
||||
if key in data:
|
||||
cidrs.extend(data[key])
|
||||
cidrs = list(set(c for c in cidrs if ':' not in c))
|
||||
print(f" Downloaded {len(cidrs)} IPv4 ranges")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to download: {e}")
|
||||
return cidrs
|
||||
|
||||
def add_whitelist(name, cidrs=None):
|
||||
"""Add IPs to whitelist"""
|
||||
name = name.lower()
|
||||
WHITELIST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
wl_file = WHITELIST_DIR / f"{name}.txt"
|
||||
|
||||
if cidrs is None and wl_file.exists():
|
||||
with open(wl_file) as f:
|
||||
cidrs = [line.strip() for line in f if line.strip() and ':' not in line]
|
||||
if cidrs:
|
||||
print(f"[INFO] Using cached {name} ({len(cidrs)} CIDRs)")
|
||||
|
||||
if cidrs is None:
|
||||
if name == "cloudflare":
|
||||
print(f"[INFO] Downloading Cloudflare IP ranges...")
|
||||
cidrs = download_cloudflare()
|
||||
elif name == "aws":
|
||||
print(f"[INFO] Downloading AWS IP ranges...")
|
||||
cidrs = download_aws()
|
||||
elif name == "google":
|
||||
print(f"[INFO] Downloading Google Cloud IP ranges...")
|
||||
cidrs = download_google()
|
||||
elif name == "github":
|
||||
print(f"[INFO] Downloading GitHub IP ranges...")
|
||||
cidrs = download_github()
|
||||
else:
|
||||
print(f"[ERROR] Unknown preset: {name}")
|
||||
print(f"Available presets: {', '.join(PRESETS.keys())}")
|
||||
return False
|
||||
|
||||
if not cidrs:
|
||||
print("[ERROR] No CIDRs to add")
|
||||
return False
|
||||
|
||||
with open(wl_file, 'w') as f:
|
||||
f.write('\n'.join(cidrs))
|
||||
|
||||
map_id = get_map_id("whitelist_v4")
|
||||
if not map_id:
|
||||
print("[ERROR] whitelist_v4 map not found. Is XDP loaded?")
|
||||
return False
|
||||
|
||||
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
|
||||
|
||||
if v4_cidrs:
|
||||
print(f"[INFO] Adding {len(v4_cidrs)} IPv4 CIDRs to whitelist...")
|
||||
batch_map_operation(map_id, v4_cidrs, operation="update", batch_size=500)
|
||||
|
||||
if v6_cidrs:
|
||||
map_id_v6 = get_map_id("whitelist_v6")
|
||||
if map_id_v6:
|
||||
print(f"[INFO] Adding {len(v6_cidrs)} IPv6 CIDRs to whitelist...")
|
||||
batch_map_operation(map_id_v6, v6_cidrs, operation="update", batch_size=500, ipv6=True)
|
||||
|
||||
print(f"[OK] Whitelisted {name}: {len(v4_cidrs)} v4 + {len(v6_cidrs)} v6 CIDRs")
|
||||
return True
|
||||
|
||||
def del_whitelist(name):
|
||||
"""Remove IPs from whitelist"""
|
||||
name = name.lower()
|
||||
wl_file = WHITELIST_DIR / f"{name}.txt"
|
||||
|
||||
if not wl_file.exists():
|
||||
print(f"[ERROR] {name} is not whitelisted")
|
||||
return False
|
||||
|
||||
map_id = get_map_id("whitelist_v4")
|
||||
if not map_id:
|
||||
print("[ERROR] whitelist_v4 map not found")
|
||||
return False
|
||||
|
||||
with open(wl_file) as f:
|
||||
cidrs = [line.strip() for line in f if line.strip()]
|
||||
|
||||
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
|
||||
|
||||
if v4_cidrs:
|
||||
print(f"[INFO] Removing {len(v4_cidrs)} IPv4 CIDRs from whitelist...")
|
||||
batch_map_operation(map_id, v4_cidrs, operation="delete", batch_size=500)
|
||||
|
||||
if v6_cidrs:
|
||||
map_id_v6 = get_map_id("whitelist_v6")
|
||||
if map_id_v6:
|
||||
print(f"[INFO] Removing {len(v6_cidrs)} IPv6 CIDRs from whitelist...")
|
||||
batch_map_operation(map_id_v6, v6_cidrs, operation="delete", batch_size=500, ipv6=True)
|
||||
|
||||
wl_file.unlink()
|
||||
print(f"[OK] Removed {name} from whitelist")
|
||||
return True
|
||||
|
||||
def list_whitelist():
|
||||
"""List whitelisted services"""
|
||||
print("=== Whitelisted Services ===")
|
||||
|
||||
if not WHITELIST_DIR.exists():
|
||||
print(" (none)")
|
||||
return
|
||||
|
||||
files = list(WHITELIST_DIR.glob("*.txt"))
|
||||
if not files:
|
||||
print(" (none)")
|
||||
return
|
||||
|
||||
for wl_file in sorted(files):
|
||||
name = wl_file.stem
|
||||
count = sum(1 for line in open(wl_file) if line.strip() and ':' not in line)
|
||||
desc = PRESETS.get(name, {}).get("desc", "Custom")
|
||||
print(f" {name}: {count} CIDRs ({desc})")
|
||||
|
||||
def show_presets():
|
||||
"""Show available presets"""
|
||||
print("=== Available Presets ===")
|
||||
for name, info in PRESETS.items():
|
||||
print(f" {name}: {info['desc']}")
|
||||
|
||||
def show_help():
|
||||
print("""XDP Whitelist Manager - Fast batch IP whitelisting
|
||||
|
||||
Usage: xdp-whitelist <command> [args]
|
||||
|
||||
Commands:
|
||||
add <preset|file> Whitelist a preset or custom file
|
||||
del <name> Remove from whitelist
|
||||
list List whitelisted services
|
||||
presets Show available presets
|
||||
|
||||
Presets:
|
||||
cloudflare Cloudflare CDN/Proxy IPs
|
||||
aws Amazon Web Services
|
||||
google Google Cloud Platform
|
||||
github GitHub Services
|
||||
|
||||
Examples:
|
||||
xdp-whitelist add cloudflare # Whitelist Cloudflare
|
||||
xdp-whitelist add aws # Whitelist AWS
|
||||
xdp-whitelist del cloudflare # Remove Cloudflare
|
||||
xdp-whitelist list
|
||||
""")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
show_help()
|
||||
return
|
||||
|
||||
cmd = sys.argv[1]
|
||||
|
||||
if cmd == "add" and len(sys.argv) >= 3:
|
||||
add_whitelist(sys.argv[2])
|
||||
elif cmd == "del" and len(sys.argv) >= 3:
|
||||
del_whitelist(sys.argv[2])
|
||||
elif cmd == "list":
|
||||
list_whitelist()
|
||||
elif cmd == "presets":
|
||||
show_presets()
|
||||
else:
|
||||
show_help()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user