Unify xdp-blocker and xdp-ddos into single xdp-defense project

Chain two XDP programs via libxdp dispatcher on the same interface:
xdp_blocker (priority 10) handles CIDR/country/whitelist blocking,
xdp_ddos (priority 20) handles rate limiting, EWMA analysis, and AI
anomaly detection. Whitelist maps are shared via BPF map pinning so
whitelisted IPs bypass both blocklist checks and DDoS rate limiting.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-07 08:39:21 +09:00
commit 1bcaddce25
12 changed files with 3523 additions and 0 deletions

542
lib/xdp_common.py Normal file
View File

@@ -0,0 +1,542 @@
#!/usr/bin/env python3
"""
XDP Defense - Common Utilities
Merged from xdp-blocker/xdp_common.py and xdp-ddos/xdp_ddos_common.py
Provides: map management, CIDR handling, IP encoding, rate config, block/unblock, stats
"""
import subprocess
import json
import os
import tempfile
import struct
import socket
import ipaddress
import logging
import logging.handlers
# ==================== Logging ====================
_syslog_handler = None
try:
_syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
_syslog_handler.setFormatter(logging.Formatter('xdp-defense: %(message)s'))
except Exception:
pass
audit_log = logging.getLogger('xdp-defense')
audit_log.setLevel(logging.INFO)
if _syslog_handler:
audit_log.addHandler(_syslog_handler)
# ==================== BPF Map Helpers ====================
def get_map_id(map_name):
"""Get BPF map ID by name."""
result = subprocess.run(
["bpftool", "map", "show", "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return None
try:
maps = json.loads(result.stdout)
for m in maps:
if m.get("name") == map_name:
return m.get("id")
except Exception:
pass
return None
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
def validate_cidr(cidr):
"""Validate IPv4 CIDR notation. Returns (ip, prefix, parts) or raises ValueError."""
if '/' in cidr:
ip, prefix_str = cidr.split('/', 1)
prefix = int(prefix_str)
else:
ip = cidr
prefix = 32
if prefix < 0 or prefix > 32:
raise ValueError(f"Invalid prefix: /{prefix}")
parts = [int(x) for x in ip.split('.')]
if len(parts) != 4 or any(p < 0 or p > 255 for p in parts):
raise ValueError(f"Invalid IP: {ip}")
return ip, prefix, parts
def cidr_to_key(cidr):
"""Convert CIDR to LPM trie key hex string."""
_, prefix, parts = validate_cidr(cidr)
key_hex = f"{prefix:02x} 00 00 00 {parts[0]:02x} {parts[1]:02x} {parts[2]:02x} {parts[3]:02x}"
return key_hex
def is_ipv6(cidr):
"""Check if a CIDR string is IPv6."""
return ':' in cidr
def validate_cidr_v6(cidr):
"""Validate IPv6 CIDR notation. Returns (network, prefix) or raises ValueError."""
try:
net = ipaddress.IPv6Network(cidr, strict=False)
return net, net.prefixlen
except (ipaddress.AddressValueError, ValueError) as e:
raise ValueError(f"Invalid IPv6 CIDR: {cidr}") from e
def cidr_to_key_v6(cidr):
"""Convert IPv6 CIDR to LPM trie key hex string.
Key format: prefixlen (4 bytes LE) + addr (16 bytes)
"""
net, prefix = validate_cidr_v6(cidr)
addr_bytes = net.network_address.packed
prefix_hex = f"{prefix:02x} 00 00 00"
addr_hex = ' '.join(f"{b:02x}" for b in addr_bytes)
return f"{prefix_hex} {addr_hex}"
def classify_cidrs(cidrs):
"""Split a list of CIDRs into v4 and v6 lists."""
v4 = []
v6 = []
for c in cidrs:
if is_ipv6(c):
v6.append(c)
else:
v4.append(c)
return v4, v6
def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 00 00 00 00 00", batch_size=1000, ipv6=False):
"""Execute batch map operations with progress reporting.
Args:
map_id: BPF map ID
cidrs: list of CIDR strings
operation: "update" or "delete"
value_hex: hex value for update operations
batch_size: number of operations per batch
Returns:
Number of processed entries
"""
total = len(cidrs)
processed = 0
for i in range(0, total, batch_size):
batch = cidrs[i:i + batch_size]
batch_file = None
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
batch_file = f.name
for cidr in batch:
try:
key_hex = cidr_to_key_v6(cidr) if ipv6 else cidr_to_key(cidr)
if operation == "update":
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
else:
f.write(f"map delete id {map_id} key hex {key_hex}\n")
except (ValueError, Exception):
continue
subprocess.run(
["bpftool", "batch", "file", batch_file],
capture_output=True, text=True
)
finally:
if batch_file and os.path.exists(batch_file):
os.unlink(batch_file)
processed += len(batch)
pct = processed * 100 // total if total > 0 else 100
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
print()
proto = "v6" if ipv6 else "v4"
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
return processed
# ==================== IP Encoding Helpers (from ddos) ====================
def ip_to_hex_key(ip_str):
"""Convert IP address to hex key for LRU_HASH maps.
IPv4: 4 bytes (network byte order)
IPv6: 16 bytes (network byte order)
Returns space-separated hex string.
"""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
raise ValueError(f"Invalid IP address: {ip_str}")
packed = addr.packed
return ' '.join(f"{b:02x}" for b in packed)
def hex_key_to_ip(hex_bytes, version=4):
"""Convert hex byte list back to IP string.
hex_bytes: list of hex strings like ['0a', '00', '00', '01']
"""
raw = bytes(int(h, 16) for h in hex_bytes)
if version == 6:
return str(ipaddress.IPv6Address(raw))
return str(ipaddress.IPv4Address(raw))
# ==================== DDoS Stats / Features (from ddos) ====================
def read_percpu_stats(map_name="global_stats", num_entries=5):
"""Read PERCPU_ARRAY map and return summed values per key.
Returns dict: {key_index: summed_value}
"""
map_id = get_map_id(map_name)
if map_id is None:
return {}
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return {}
try:
data = json.loads(result.stdout)
except Exception:
return {}
stats = {}
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key")
if isinstance(key, list):
key = int(key[0], 16) if isinstance(key[0], str) else key[0]
elif isinstance(key, str):
key = int(key, 16)
values = fmt.get("values", [])
total = 0
for v in values:
val = v.get("value", v.get("val", 0))
if isinstance(val, str):
val = int(val, 0)
elif isinstance(val, list):
val = int.from_bytes(
bytes(int(x, 16) for x in val), byteorder='little'
)
total += val
stats[key] = total
return stats
def read_percpu_features():
"""Read traffic_features PERCPU_ARRAY and return aggregated struct.
Returns dict with field names and summed values.
"""
map_id = get_map_id("traffic_feature")
if map_id is None:
return {}
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return {}
try:
data = json.loads(result.stdout)
except Exception:
return {}
field_names = [
"total_packets", "total_bytes", "tcp_syn_count", "tcp_other_count",
"udp_count", "icmp_count", "other_proto_count", "unique_ips_approx",
"small_pkt_count", "large_pkt_count"
]
num_fields = len(field_names)
aggregated = {f: 0 for f in field_names}
for entry in data:
fmt = entry.get("formatted", entry)
values = fmt.get("values", [])
for cpu_val in values:
val = cpu_val.get("value", [])
if isinstance(val, list) and len(val) >= num_fields * 8:
raw = bytes(int(x, 16) for x in val[:num_fields * 8])
for i, name in enumerate(field_names):
v = struct.unpack_from('<Q', raw, i * 8)[0]
aggregated[name] += v
elif isinstance(val, dict):
for name in field_names:
aggregated[name] += val.get(name, 0)
return aggregated
# ==================== Rate Config (from ddos) ====================
def read_rate_config():
"""Read current rate_config from BPF map.
Returns dict: {pps_threshold, bps_threshold, window_ns}
"""
map_id = get_map_id("rate_config")
if map_id is None:
return None
result = subprocess.run(
["bpftool", "map", "lookup", "id", str(map_id), "key", "0", "0", "0", "0", "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return None
try:
data = json.loads(result.stdout)
fmt = data.get("formatted", data)
val = fmt.get("value", {})
if isinstance(val, dict):
return {
"pps_threshold": val.get("pps_threshold", 0),
"bps_threshold": val.get("bps_threshold", 0),
"window_ns": val.get("window_ns", 0),
}
elif isinstance(val, list):
raw = bytes(int(x, 16) for x in val[:24])
pps = struct.unpack_from('<Q', raw, 0)[0]
bps = struct.unpack_from('<Q', raw, 8)[0]
win = struct.unpack_from('<Q', raw, 16)[0]
return {"pps_threshold": pps, "bps_threshold": bps, "window_ns": win}
except Exception:
pass
return None
def write_rate_config(pps_threshold, bps_threshold=0, window_ns=1000000000):
"""Write rate_config to BPF map."""
map_id = get_map_id("rate_config")
if map_id is None:
raise RuntimeError("rate_config map not found")
raw = struct.pack('<QQQ', pps_threshold, bps_threshold, window_ns)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(
["bpftool", "map", "update", "id", str(map_id),
"key", "0", "0", "0", "0", "value", "hex"] + val_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to write rate_config: {result.stderr}")
audit_log.info(f"rate_config updated: pps={pps_threshold} bps={bps_threshold} window={window_ns}")
# ==================== Block / Unblock (from ddos) ====================
def block_ip(ip_str, duration_sec=0):
"""Add IP to blocked_ips map.
duration_sec: 0 = permanent, >0 = temporary block
"""
addr = ipaddress.ip_address(ip_str)
is_v6 = isinstance(addr, ipaddress.IPv6Address)
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
map_id = get_map_id(map_name)
if map_id is None:
raise RuntimeError(f"{map_name} map not found")
key_hex = ip_to_hex_key(ip_str)
if duration_sec > 0:
with open('/proc/uptime', 'r') as f:
uptime_sec = float(f.read().split()[0])
now_ns = int(uptime_sec * 1_000_000_000)
expire_ns = now_ns + (duration_sec * 1_000_000_000)
else:
expire_ns = 0
now_ns_val = 0
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(
["bpftool", "map", "update", "id", str(map_id),
"key", "hex"] + key_hex.split() + ["value", "hex"] + val_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to block {ip_str}: {result.stderr}")
duration_str = f"{duration_sec}s" if duration_sec > 0 else "permanent"
audit_log.info(f"blocked {ip_str} ({duration_str})")
def unblock_ip(ip_str):
"""Remove IP from blocked_ips map."""
addr = ipaddress.ip_address(ip_str)
is_v6 = isinstance(addr, ipaddress.IPv6Address)
map_name = "blocked_ips_v6" if is_v6 else "blocked_ips_v4"
map_id = get_map_id(map_name)
if map_id is None:
raise RuntimeError(f"{map_name} map not found")
key_hex = ip_to_hex_key(ip_str)
result = subprocess.run(
["bpftool", "map", "delete", "id", str(map_id),
"key", "hex"] + key_hex.split(),
capture_output=True, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Failed to unblock {ip_str}: {result.stderr}")
audit_log.info(f"unblocked {ip_str}")
# ==================== Dump Helpers (from ddos) ====================
def dump_rate_counters(map_name="rate_counter_v4", top_n=10):
"""Dump top-N IPs by packet count from rate counter map.
Returns list of (ip_str, packets, bytes, last_seen_ns).
"""
map_id = get_map_id(map_name)
if map_id is None:
return []
is_v6 = "v6" in map_name
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return []
try:
data = json.loads(result.stdout)
except Exception:
return []
entries = []
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key", [])
val = fmt.get("value", {})
try:
if isinstance(key, list):
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
elif isinstance(key, dict):
if is_v6:
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
if addr8:
raw = bytes(addr8)
ip_str = str(ipaddress.IPv6Address(raw))
else:
continue
else:
ip_str = socket.inet_ntoa(struct.pack('<I', key)) if isinstance(key, int) else str(key)
elif isinstance(key, int):
ip_str = socket.inet_ntoa(struct.pack('<I', key))
else:
continue
except Exception:
continue
try:
if isinstance(val, dict):
pkts = val.get("packets", 0)
bts = val.get("bytes", 0)
last = val.get("last_seen", 0)
elif isinstance(val, list) and len(val) >= 24:
raw = bytes(int(x, 16) for x in val[:24])
pkts = struct.unpack_from('<Q', raw, 0)[0]
bts = struct.unpack_from('<Q', raw, 8)[0]
last = struct.unpack_from('<Q', raw, 16)[0]
else:
continue
except Exception:
continue
entries.append((ip_str, pkts, bts, last))
entries.sort(key=lambda x: x[1], reverse=True)
return entries[:top_n]
def dump_blocked_ips(map_name="blocked_ips_v4"):
"""Dump all blocked IPs with their block info.
Returns list of (ip_str, expire_ns, blocked_at, drop_count).
"""
map_id = get_map_id(map_name)
if map_id is None:
return []
is_v6 = "v6" in map_name
result = subprocess.run(
["bpftool", "map", "dump", "id", str(map_id), "-j"],
capture_output=True, text=True
)
if result.returncode != 0:
return []
try:
data = json.loads(result.stdout)
except Exception:
return []
entries = []
for entry in data:
fmt = entry.get("formatted", entry)
key = fmt.get("key", [])
val = fmt.get("value", {})
try:
if isinstance(key, list):
ip_str = hex_key_to_ip(key, version=6 if is_v6 else 4)
elif isinstance(key, dict) and is_v6:
addr8 = key.get('in6_u', {}).get('u6_addr8', [])
if addr8:
ip_str = str(ipaddress.IPv6Address(bytes(addr8)))
else:
continue
elif isinstance(key, int):
ip_str = socket.inet_ntoa(struct.pack('<I', key))
else:
continue
except Exception:
continue
try:
if isinstance(val, dict):
expire = val.get("expire_ns", 0)
blocked = val.get("blocked_at", 0)
drops = val.get("drop_count", 0)
elif isinstance(val, list) and len(val) >= 24:
raw = bytes(int(x, 16) for x in val[:24])
expire = struct.unpack_from('<Q', raw, 0)[0]
blocked = struct.unpack_from('<Q', raw, 8)[0]
drops = struct.unpack_from('<Q', raw, 16)[0]
else:
continue
except Exception:
continue
entries.append((ip_str, expire, blocked, drops))
return entries

175
lib/xdp_country.py Executable file
View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python3
"""
XDP Country Blocker - Fast batch IP blocking
Uses bpftool batch for high-speed map updates
"""
import sys
import time
import urllib.request
from pathlib import Path
from xdp_common import get_map_id, batch_map_operation, classify_cidrs
COUNTRY_DIR = Path("/etc/xdp-blocker/countries")
IPDENY_V4_URL = "https://www.ipdeny.com/ipblocks/data/countries/{}.zone"
IPDENY_V6_URL = "https://www.ipdeny.com/ipblocks/data/ipv6/ipv6-country-blocks/{}.zone"
def download_country(cc):
"""Download country IPv4 + IPv6 IP lists"""
COUNTRY_DIR.mkdir(parents=True, exist_ok=True)
cc_file = COUNTRY_DIR / f"{cc.lower()}.txt"
cidrs = []
# Download IPv4
print(f"[INFO] Downloading {cc.upper()} IPv4 ranges...")
try:
urllib.request.urlretrieve(IPDENY_V4_URL.format(cc.lower()), cc_file)
with open(cc_file) as f:
cidrs.extend(line.strip() for line in f if line.strip())
print(f" IPv4: {len(cidrs)} CIDRs")
except Exception as e:
print(f" [WARN] IPv4 download failed: {e}")
# Download IPv6
v6_count = 0
try:
v6_tmp = COUNTRY_DIR / f"{cc.lower()}_v6.tmp"
urllib.request.urlretrieve(IPDENY_V6_URL.format(cc.lower()), v6_tmp)
with open(v6_tmp) as f:
v6_cidrs = [line.strip() for line in f if line.strip()]
cidrs.extend(v6_cidrs)
v6_count = len(v6_cidrs)
v6_tmp.unlink(missing_ok=True)
print(f" IPv6: {v6_count} CIDRs")
except Exception:
pass
if not cidrs:
print(f"[ERROR] No IP ranges found for {cc.upper()}")
return None
with open(cc_file, 'w') as f:
f.write('\n'.join(cidrs) + '\n')
return cc_file
def add_country(cc):
"""Add country IPs to XDP blocklist using batch (IPv4 + IPv6)"""
cc = cc.lower()
cc_file = COUNTRY_DIR / f"{cc}.txt"
if not cc_file.exists():
cc_file = download_country(cc)
if not cc_file:
return False
with open(cc_file) as f:
cidrs = [line.strip() for line in f if line.strip()]
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
if v4_cidrs:
map_id = get_map_id("blocklist_v4")
if not map_id:
print("[ERROR] blocklist_v4 map not found. Is XDP loaded?")
return False
print(f"[INFO] Adding {len(v4_cidrs)} IPv4 CIDRs for {cc.upper()}...")
batch_map_operation(map_id, v4_cidrs, operation="update")
if v6_cidrs:
map_id_v6 = get_map_id("blocklist_v6")
if map_id_v6:
print(f"[INFO] Adding {len(v6_cidrs)} IPv6 CIDRs for {cc.upper()}...")
batch_map_operation(map_id_v6, v6_cidrs, operation="update", ipv6=True)
else:
print("[WARN] blocklist_v6 map not found, skipping IPv6")
print(f"[OK] Added {cc.upper()}: {len(v4_cidrs)} v4 + {len(v6_cidrs)} v6 CIDRs")
return True
def del_country(cc):
"""Remove country IPs from XDP blocklist"""
cc = cc.lower()
cc_file = COUNTRY_DIR / f"{cc}.txt"
if not cc_file.exists():
print(f"[ERROR] Country {cc.upper()} is not blocked")
return False
with open(cc_file) as f:
cidrs = [line.strip() for line in f if line.strip()]
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
if v4_cidrs:
map_id = get_map_id("blocklist_v4")
if map_id:
print(f"[INFO] Removing {len(v4_cidrs)} IPv4 CIDRs for {cc.upper()}...")
batch_map_operation(map_id, v4_cidrs, operation="delete")
if v6_cidrs:
map_id_v6 = get_map_id("blocklist_v6")
if map_id_v6:
print(f"[INFO] Removing {len(v6_cidrs)} IPv6 CIDRs for {cc.upper()}...")
batch_map_operation(map_id_v6, v6_cidrs, operation="delete", ipv6=True)
cc_file.unlink()
print(f"[OK] Removed {cc.upper()}: {len(v4_cidrs)} v4 + {len(v6_cidrs)} v6 CIDRs")
return True
def list_countries():
"""List blocked countries"""
print("=== Blocked Countries ===")
if not COUNTRY_DIR.exists():
print(" (none)")
return
files = list(COUNTRY_DIR.glob("*.txt"))
if not files:
print(" (none)")
return
for cc_file in sorted(files):
cc = cc_file.stem.upper()
count = sum(1 for _ in open(cc_file))
mtime = cc_file.stat().st_mtime
age = int((time.time() - mtime) / 86400)
print(f" {cc}: {count} CIDRs (updated {age}d ago)")
def show_help():
print("""XDP Country Blocker - Fast batch IP blocking
Usage: xdp-country <command> [args]
Commands:
add <cc> Block a country (e.g., br, cn, ru, kp)
del <cc> Unblock a country
list List blocked countries
Examples:
xdp-country add br # Block Brazil (~13K CIDRs in seconds)
xdp-country add cn # Block China
xdp-country del br # Unblock Brazil
xdp-country list
""")
def main():
if len(sys.argv) < 2:
show_help()
return
cmd = sys.argv[1]
if cmd == "add" and len(sys.argv) >= 3:
add_country(sys.argv[2])
elif cmd == "del" and len(sys.argv) >= 3:
del_country(sys.argv[2])
elif cmd == "list":
list_countries()
else:
show_help()
if __name__ == "__main__":
main()

714
lib/xdp_defense_daemon.py Executable file
View File

@@ -0,0 +1,714 @@
#!/usr/bin/env python3
"""
XDP Defense Daemon
Userspace component: EWMA-based rate analysis, AI anomaly detection,
time-profile switching, and automatic escalation.
4 worker threads + main thread (signal handling):
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
- AI Thread: reads traffic features, runs Isolation Forest inference
- Profile Thread: checks time-of-day, switches rate_config profiles
- Cleanup Thread: removes expired entries from blocked_ips maps
"""
import os
import sys
import time
import signal
import threading
import logging
import logging.handlers
import csv
import pickle
from collections import defaultdict
from datetime import datetime
import yaml
# ==================== Logging ====================
log = logging.getLogger('xdp-defense-daemon')
log.setLevel(logging.INFO)
_console = logging.StreamHandler()
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
log.addHandler(_console)
try:
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
log.addHandler(_syslog)
except Exception:
pass
# ==================== Configuration ====================
DEFAULT_CONFIG = {
'general': {
'interface': 'eth0',
'log_level': 'info',
'pid_file': '/var/lib/xdp-defense/daemon.pid',
'data_dir': '/var/lib/xdp-defense',
},
'rate_limits': {
'default_pps': 1000,
'default_bps': 0,
'window_sec': 1,
'profiles': {},
},
'escalation': {
'rate_limit_after': 1,
'temp_block_after': 5,
'perm_block_after': 20,
'temp_block_duration': 300,
'violation_window': 600,
'cooldown_multiplier': 0.5,
},
'ewma': {
'alpha': 0.3,
'poll_interval': 1,
'threshold_multiplier': 3.0,
},
'ai': {
'enabled': True,
'model_type': 'IsolationForest',
'contamination': 0.05,
'n_estimators': 100,
'learning_duration': 259200,
'min_samples': 1000,
'poll_interval': 5,
'anomaly_threshold': -0.3,
'retrain_interval': 604800,
'min_packets_for_sample': 20,
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
},
}
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
def load_config(path=CONFIG_PATH):
"""Load config with defaults."""
cfg = DEFAULT_CONFIG.copy()
try:
with open(path) as f:
user = yaml.safe_load(f) or {}
for section in cfg:
if section in user and isinstance(user[section], dict):
cfg[section].update(user[section])
except FileNotFoundError:
log.warning("Config not found at %s, using defaults", path)
except Exception as e:
log.error("Failed to load config: %s", e)
return cfg
# ==================== ViolationTracker ====================
class ViolationTracker:
"""Track per-IP violation counts and manage escalation."""
def __init__(self, escalation_cfg):
self.cfg = escalation_cfg
self.violations = defaultdict(list)
self.lock = threading.Lock()
def record_violation(self, ip):
"""Record a violation and return escalation level.
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
"""
now = time.time()
window = self.cfg.get('violation_window', 600)
with self.lock:
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
self.violations[ip].append(now)
count = len(self.violations[ip])
perm_after = self.cfg.get('perm_block_after', 20)
temp_after = self.cfg.get('temp_block_after', 5)
if count >= perm_after:
return 'perm_block'
elif count >= temp_after:
return 'temp_block'
return 'rate_limit'
def clear(self, ip):
with self.lock:
self.violations.pop(ip, None)
def cleanup_expired(self):
"""Remove entries with no recent violations."""
now = time.time()
window = self.cfg.get('violation_window', 600)
with self.lock:
expired = [ip for ip, times in self.violations.items()
if all(now - t >= window for t in times)]
for ip in expired:
del self.violations[ip]
# ==================== EWMAAnalyzer ====================
class EWMAAnalyzer:
"""Per-IP EWMA calculation for rate anomaly detection."""
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
self.alpha = alpha
self.threshold_multiplier = threshold_multiplier
self.ewma = {}
self.baseline = {}
self.lock = threading.Lock()
def update(self, ip, current_pps):
"""Update EWMA for an IP. Returns True if anomalous."""
with self.lock:
if ip not in self.ewma:
self.ewma[ip] = current_pps
self.baseline[ip] = current_pps
return False
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
base = max(self.baseline[ip], 1)
if self.ewma[ip] > base * self.threshold_multiplier:
return True
return False
def get_stats(self, ip):
with self.lock:
return {
'ewma': self.ewma.get(ip, 0),
'baseline': self.baseline.get(ip, 0),
}
def cleanup_stale(self, active_ips):
"""Remove tracking for IPs no longer in rate counters."""
with self.lock:
stale = set(self.ewma.keys()) - set(active_ips)
for ip in stale:
self.ewma.pop(ip, None)
self.baseline.pop(ip, None)
# ==================== AIDetector ====================
class AIDetector:
"""Isolation Forest based anomaly detection on traffic features."""
def __init__(self, ai_cfg):
self.cfg = ai_cfg
self.model = None
self.scaler = None
self.started_at = time.time()
self.training_data = []
self.is_learning = True
self._retrain_requested = False
@property
def enabled(self):
return self.cfg.get('enabled', False)
def request_retrain(self):
self._retrain_requested = True
def collect_sample(self, features):
"""Collect a feature sample during learning phase."""
if not self.enabled:
return
self.training_data.append(features)
learning_dur = self.cfg.get('learning_duration', 259200)
min_samples = self.cfg.get('min_samples', 1000)
elapsed = time.time() - self.started_at
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
self._train()
self._retrain_requested = False
def _train(self):
"""Train the Isolation Forest model."""
try:
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
import numpy as np
except ImportError:
log.error("scikit-learn not installed. AI detection disabled.")
self.cfg['enabled'] = False
return
if len(self.training_data) < 10:
log.warning("Not enough training data (%d samples)", len(self.training_data))
return
log.info("Training AI model with %d samples...", len(self.training_data))
try:
X = np.array(self.training_data)
self.scaler = StandardScaler()
X_scaled = self.scaler.fit_transform(X)
self.model = IsolationForest(
n_estimators=self.cfg.get('n_estimators', 100),
contamination=self.cfg.get('contamination', 'auto'),
random_state=42,
)
self.model.fit(X_scaled)
self.is_learning = False
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
with open(model_file, 'wb') as f:
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
with open(data_file, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
'small_pkt_count', 'large_pkt_count',
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
])
writer.writerows(self.training_data)
log.info("AI model trained and saved to %s", model_file)
except Exception as e:
log.error("AI training failed: %s", e)
def load_model(self):
"""Load a previously trained model."""
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
if not os.path.exists(model_file):
return False
try:
with open(model_file, 'rb') as f:
data = pickle.load(f)
self.model = data['model']
self.scaler = data['scaler']
self.is_learning = False
log.info("AI model loaded from %s", model_file)
return True
except Exception as e:
log.error("Failed to load AI model: %s", e)
return False
def predict(self, features):
"""Run anomaly detection. Returns (is_anomaly, score)."""
if not self.enabled or self.model is None:
return False, 0.0
try:
import numpy as np
X = np.array([features])
X_scaled = self.scaler.transform(X)
score = self.model.decision_function(X_scaled)[0]
threshold = self.cfg.get('anomaly_threshold', -0.3)
return score < threshold, float(score)
except Exception as e:
log.error("AI prediction error: %s", e)
return False, 0.0
# ==================== ProfileManager ====================
class ProfileManager:
"""Manage time-based rate limit profiles."""
def __init__(self, rate_cfg):
self.cfg = rate_cfg
self.current_profile = 'default'
def check_and_apply(self):
"""Check current time and apply matching profile."""
from xdp_common import write_rate_config
profiles = self.cfg.get('profiles', {})
now = datetime.now()
current_hour = now.hour
current_min = now.minute
current_time = current_hour * 60 + current_min
weekday = now.strftime('%a').lower()
matched_profile = None
matched_name = 'default'
for name, profile in profiles.items():
hours = profile.get('hours', '')
weekdays = profile.get('weekdays', '')
if weekdays:
day_range = weekdays.lower().split('-')
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
if len(day_range) == 2:
try:
start_idx = day_names.index(day_range[0])
end_idx = day_names.index(day_range[1])
current_idx = day_names.index(weekday)
if start_idx <= end_idx:
if not (start_idx <= current_idx <= end_idx):
continue
else:
if not (current_idx >= start_idx or current_idx <= end_idx):
continue
except ValueError:
continue
if hours:
try:
start_str, end_str = hours.split('-')
sh, sm = map(int, start_str.split(':'))
eh, em = map(int, end_str.split(':'))
start_min = sh * 60 + sm
end_min = eh * 60 + em
if start_min <= end_min:
if not (start_min <= current_time < end_min):
continue
else:
if not (current_time >= start_min or current_time < end_min):
continue
except (ValueError, AttributeError):
continue
matched_profile = profile
matched_name = name
break
if matched_name != self.current_profile:
if matched_profile:
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
else:
pps = self.cfg.get('default_pps', 1000)
bps = self.cfg.get('default_bps', 0)
window = self.cfg.get('window_sec', 1)
try:
write_rate_config(pps, bps, window * 1_000_000_000)
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
self.current_profile = matched_name
except Exception as e:
log.error("Failed to apply profile %s: %s", matched_name, e)
# ==================== DDoSDaemon ====================
class DDoSDaemon:
"""Main daemon orchestrator."""
def __init__(self, config_path=CONFIG_PATH):
self.config_path = config_path
self.cfg = load_config(config_path)
self.running = False
self._stop_event = threading.Event()
self._setup_components()
def _setup_components(self):
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
self.ewma_analyzer = EWMAAnalyzer(
alpha=self.cfg['ewma'].get('alpha', 0.3),
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
)
self.ai_detector = AIDetector(self.cfg['ai'])
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
if self.ai_detector.enabled:
self.ai_detector.load_model()
level = self.cfg['general'].get('log_level', 'info').upper()
log.setLevel(getattr(logging, level, logging.INFO))
def _write_pid(self):
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
with open(pid_file, 'w') as f:
f.write(str(os.getpid()))
def _remove_pid(self):
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
try:
os.unlink(pid_file)
except OSError:
pass
def _ensure_single_instance(self):
"""Stop any existing daemon before starting."""
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
if not os.path.exists(pid_file):
return
try:
with open(pid_file) as f:
old_pid = int(f.read().strip())
os.kill(old_pid, 0)
log.info("Stopping existing daemon (PID %d)...", old_pid)
os.kill(old_pid, signal.SIGTERM)
for _ in range(30):
time.sleep(1)
try:
os.kill(old_pid, 0)
except OSError:
log.info("Old daemon stopped")
return
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
os.kill(old_pid, signal.SIGKILL)
time.sleep(1)
except (ValueError, OSError):
pass
def _handle_sighup(self, signum, frame):
log.info("SIGHUP received, reloading config...")
self.cfg = load_config(self.config_path)
self._setup_components()
log.info("Config reloaded")
def _handle_sigterm(self, signum, frame):
log.info("SIGTERM received, shutting down...")
self.running = False
self._stop_event.set()
def _handle_sigusr1(self, signum, frame):
log.info("SIGUSR1 received, requesting AI retrain...")
self.ai_detector.request_retrain()
# ---- Worker Threads ----
def _ewma_thread(self):
"""Poll rate counters, compute EWMA, detect violations, escalate."""
from xdp_common import dump_rate_counters, block_ip
interval = self.cfg['ewma'].get('poll_interval', 1)
prev_counters = {}
while self.running:
try:
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
active_ips = []
for ip_str, pkts, bts, last_seen in entries:
active_ips.append(ip_str)
prev = prev_counters.get(ip_str, 0)
delta = pkts - prev if pkts >= prev else pkts
prev_counters[ip_str] = pkts
if delta <= 0:
continue
pps = delta / max(interval, 0.1)
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
if is_anomalous:
level = self.violation_tracker.record_violation(ip_str)
ew = self.ewma_analyzer.get_stats(ip_str)
log.warning(
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
ip_str, pps, ew['ewma'], ew['baseline'], level
)
if level == 'temp_block':
dur = self.cfg['escalation'].get('temp_block_duration', 300)
try:
block_ip(ip_str, dur)
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
except Exception as e:
log.error("Failed to temp-block %s: %s", ip_str, e)
elif level == 'perm_block':
try:
block_ip(ip_str, 0)
log.warning("PERM BLOCK: %s", ip_str)
except Exception as e:
log.error("Failed to perm-block %s: %s", ip_str, e)
self.ewma_analyzer.cleanup_stale(active_ips)
except Exception as e:
log.error("EWMA thread error: %s", e)
self._stop_event.wait(interval)
def _ai_thread(self):
"""Read traffic features, run AI inference or collect training data."""
from xdp_common import read_percpu_features, dump_rate_counters, block_ip
interval = self.cfg['ai'].get('poll_interval', 5)
prev_features = None
while self.running:
try:
if not self.ai_detector.enabled:
self._stop_event.wait(interval)
continue
features = read_percpu_features()
if not features:
self._stop_event.wait(interval)
continue
feature_names = [
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
'small_pkt_count', 'large_pkt_count'
]
if prev_features is not None:
deltas = []
for name in feature_names:
cur = features.get(name, 0)
prev = prev_features.get(name, 0)
deltas.append(max(0, cur - prev))
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
if deltas[0] < min_pkts:
prev_features = features
self._stop_event.wait(interval)
continue
total = deltas[0] + 1e-6
syn_ratio = deltas[2] / total
udp_ratio = deltas[4] / total
icmp_ratio = deltas[5] / total
small_pkt_ratio = deltas[8] / total
avg_pkt_size = deltas[1] / total
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
if self.ai_detector.is_learning:
self.ai_detector.collect_sample(deltas)
if len(self.ai_detector.training_data) % 100 == 0:
log.debug("AI learning: %d samples collected",
len(self.ai_detector.training_data))
else:
is_anomaly, score = self.ai_detector.predict(deltas)
if is_anomaly:
log.warning(
"AI ANOMALY detected: score=%.4f deltas=%s",
score, dict(zip(feature_names, deltas[:len(feature_names)]))
)
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
for ip_str, pkts, bts, _ in top_ips:
level = self.violation_tracker.record_violation(ip_str)
log.warning("AI escalation: %s -> %s", ip_str, level)
if level == 'temp_block':
dur = self.cfg['escalation'].get('temp_block_duration', 300)
try:
block_ip(ip_str, dur)
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
except Exception as e:
log.error("Failed to AI temp-block %s: %s", ip_str, e)
elif level == 'perm_block':
try:
block_ip(ip_str, 0)
log.warning("AI PERM BLOCK: %s", ip_str)
except Exception as e:
log.error("Failed to AI perm-block %s: %s", ip_str, e)
prev_features = features
except Exception as e:
log.error("AI thread error: %s", e)
self._stop_event.wait(interval)
def _profile_thread(self):
"""Check time-of-day and switch rate profiles."""
while self.running:
try:
self.profile_manager.check_and_apply()
except Exception as e:
log.error("Profile thread error: %s", e)
self._stop_event.wait(60)
def _cleanup_thread(self):
"""Periodically clean up expired blocked IPs and stale violations."""
from xdp_common import dump_blocked_ips, unblock_ip
while self.running:
try:
with open('/proc/uptime') as f:
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
entries = dump_blocked_ips(map_name)
for ip_str, expire_ns, blocked_at, drop_count in entries:
if expire_ns != 0 and now_ns > expire_ns:
try:
unblock_ip(ip_str)
self.violation_tracker.clear(ip_str)
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
except Exception as e:
log.error("Failed to remove expired block %s: %s", ip_str, e)
self.violation_tracker.cleanup_expired()
except Exception as e:
log.error("Cleanup thread error: %s", e)
self._stop_event.wait(60)
# ---- Main Loop ----
def run(self):
"""Start the daemon."""
log.info("XDP Defense Daemon starting...")
signal.signal(signal.SIGHUP, self._handle_sighup)
signal.signal(signal.SIGTERM, self._handle_sigterm)
signal.signal(signal.SIGINT, self._handle_sigterm)
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
self._ensure_single_instance()
self._write_pid()
self.running = True
threads = [
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
]
for t in threads:
t.start()
log.info("Started %s thread", t.name)
log.info("Daemon running (PID %d)", os.getpid())
try:
while self.running:
self._stop_event.wait(1)
except KeyboardInterrupt:
pass
log.info("Shutting down...")
self.running = False
self._stop_event.set()
for t in threads:
t.join(timeout=5)
self._remove_pid()
log.info("Daemon stopped")
# ==================== Entry Point ====================
def main():
config_path = CONFIG_PATH
if len(sys.argv) > 1:
config_path = sys.argv[1]
daemon = DDoSDaemon(config_path)
daemon.run()
if __name__ == '__main__':
main()

256
lib/xdp_whitelist.py Executable file
View File

@@ -0,0 +1,256 @@
#!/usr/bin/env python3
"""
XDP Whitelist Manager - Fast batch IP whitelisting
Supports presets like Cloudflare, AWS, Google, etc.
"""
import sys
import json
import urllib.request
from pathlib import Path
from xdp_common import get_map_id, batch_map_operation, classify_cidrs
WHITELIST_DIR = Path("/etc/xdp-blocker/whitelist")
# Preset URLs for trusted services
PRESETS = {
"cloudflare": {
"v4": "https://www.cloudflare.com/ips-v4",
"v6": "https://www.cloudflare.com/ips-v6",
"desc": "Cloudflare CDN/Proxy"
},
"aws": {
"v4": "https://ip-ranges.amazonaws.com/ip-ranges.json",
"desc": "Amazon Web Services (all regions)"
},
"google": {
"v4": "https://www.gstatic.com/ipranges/cloud.json",
"desc": "Google Cloud Platform"
},
"github": {
"v4": "https://api.github.com/meta",
"desc": "GitHub Services"
}
}
def download_cloudflare():
"""Download Cloudflare IP ranges"""
cidrs = []
try:
req = urllib.request.Request(
PRESETS["cloudflare"]["v4"],
headers={"User-Agent": "xdp-whitelist/1.0"}
)
with urllib.request.urlopen(req) as r:
cidrs.extend(r.read().decode().strip().split('\n'))
print(f" Downloaded {len(cidrs)} IPv4 ranges")
except Exception as e:
print(f" [WARN] Failed to download IPv4: {e}")
return cidrs
def download_aws():
"""Download AWS IP ranges"""
cidrs = []
try:
with urllib.request.urlopen(PRESETS["aws"]["v4"]) as r:
data = json.loads(r.read().decode())
for prefix in data.get("prefixes", []):
cidrs.append(prefix["ip_prefix"])
print(f" Downloaded {len(cidrs)} IPv4 ranges")
except Exception as e:
print(f" [WARN] Failed to download: {e}")
return cidrs
def download_google():
"""Download Google Cloud IP ranges"""
cidrs = []
try:
with urllib.request.urlopen(PRESETS["google"]["v4"]) as r:
data = json.loads(r.read().decode())
for prefix in data.get("prefixes", []):
if "ipv4Prefix" in prefix:
cidrs.append(prefix["ipv4Prefix"])
print(f" Downloaded {len(cidrs)} IPv4 ranges")
except Exception as e:
print(f" [WARN] Failed to download: {e}")
return cidrs
def download_github():
"""Download GitHub IP ranges"""
cidrs = []
try:
req = urllib.request.Request(
PRESETS["github"]["v4"],
headers={"User-Agent": "xdp-whitelist"}
)
with urllib.request.urlopen(req) as r:
data = json.loads(r.read().decode())
for key in ["hooks", "web", "api", "git", "packages", "pages", "importer", "actions", "dependabot"]:
if key in data:
cidrs.extend(data[key])
cidrs = list(set(c for c in cidrs if ':' not in c))
print(f" Downloaded {len(cidrs)} IPv4 ranges")
except Exception as e:
print(f" [WARN] Failed to download: {e}")
return cidrs
def add_whitelist(name, cidrs=None):
"""Add IPs to whitelist"""
name = name.lower()
WHITELIST_DIR.mkdir(parents=True, exist_ok=True)
wl_file = WHITELIST_DIR / f"{name}.txt"
if cidrs is None and wl_file.exists():
with open(wl_file) as f:
cidrs = [line.strip() for line in f if line.strip() and ':' not in line]
if cidrs:
print(f"[INFO] Using cached {name} ({len(cidrs)} CIDRs)")
if cidrs is None:
if name == "cloudflare":
print(f"[INFO] Downloading Cloudflare IP ranges...")
cidrs = download_cloudflare()
elif name == "aws":
print(f"[INFO] Downloading AWS IP ranges...")
cidrs = download_aws()
elif name == "google":
print(f"[INFO] Downloading Google Cloud IP ranges...")
cidrs = download_google()
elif name == "github":
print(f"[INFO] Downloading GitHub IP ranges...")
cidrs = download_github()
else:
print(f"[ERROR] Unknown preset: {name}")
print(f"Available presets: {', '.join(PRESETS.keys())}")
return False
if not cidrs:
print("[ERROR] No CIDRs to add")
return False
with open(wl_file, 'w') as f:
f.write('\n'.join(cidrs))
map_id = get_map_id("whitelist_v4")
if not map_id:
print("[ERROR] whitelist_v4 map not found. Is XDP loaded?")
return False
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
if v4_cidrs:
print(f"[INFO] Adding {len(v4_cidrs)} IPv4 CIDRs to whitelist...")
batch_map_operation(map_id, v4_cidrs, operation="update", batch_size=500)
if v6_cidrs:
map_id_v6 = get_map_id("whitelist_v6")
if map_id_v6:
print(f"[INFO] Adding {len(v6_cidrs)} IPv6 CIDRs to whitelist...")
batch_map_operation(map_id_v6, v6_cidrs, operation="update", batch_size=500, ipv6=True)
print(f"[OK] Whitelisted {name}: {len(v4_cidrs)} v4 + {len(v6_cidrs)} v6 CIDRs")
return True
def del_whitelist(name):
"""Remove IPs from whitelist"""
name = name.lower()
wl_file = WHITELIST_DIR / f"{name}.txt"
if not wl_file.exists():
print(f"[ERROR] {name} is not whitelisted")
return False
map_id = get_map_id("whitelist_v4")
if not map_id:
print("[ERROR] whitelist_v4 map not found")
return False
with open(wl_file) as f:
cidrs = [line.strip() for line in f if line.strip()]
v4_cidrs, v6_cidrs = classify_cidrs(cidrs)
if v4_cidrs:
print(f"[INFO] Removing {len(v4_cidrs)} IPv4 CIDRs from whitelist...")
batch_map_operation(map_id, v4_cidrs, operation="delete", batch_size=500)
if v6_cidrs:
map_id_v6 = get_map_id("whitelist_v6")
if map_id_v6:
print(f"[INFO] Removing {len(v6_cidrs)} IPv6 CIDRs from whitelist...")
batch_map_operation(map_id_v6, v6_cidrs, operation="delete", batch_size=500, ipv6=True)
wl_file.unlink()
print(f"[OK] Removed {name} from whitelist")
return True
def list_whitelist():
"""List whitelisted services"""
print("=== Whitelisted Services ===")
if not WHITELIST_DIR.exists():
print(" (none)")
return
files = list(WHITELIST_DIR.glob("*.txt"))
if not files:
print(" (none)")
return
for wl_file in sorted(files):
name = wl_file.stem
count = sum(1 for line in open(wl_file) if line.strip() and ':' not in line)
desc = PRESETS.get(name, {}).get("desc", "Custom")
print(f" {name}: {count} CIDRs ({desc})")
def show_presets():
"""Show available presets"""
print("=== Available Presets ===")
for name, info in PRESETS.items():
print(f" {name}: {info['desc']}")
def show_help():
print("""XDP Whitelist Manager - Fast batch IP whitelisting
Usage: xdp-whitelist <command> [args]
Commands:
add <preset|file> Whitelist a preset or custom file
del <name> Remove from whitelist
list List whitelisted services
presets Show available presets
Presets:
cloudflare Cloudflare CDN/Proxy IPs
aws Amazon Web Services
google Google Cloud Platform
github GitHub Services
Examples:
xdp-whitelist add cloudflare # Whitelist Cloudflare
xdp-whitelist add aws # Whitelist AWS
xdp-whitelist del cloudflare # Remove Cloudflare
xdp-whitelist list
""")
def main():
if len(sys.argv) < 2:
show_help()
return
cmd = sys.argv[1]
if cmd == "add" and len(sys.argv) >= 3:
add_whitelist(sys.argv[2])
elif cmd == "del" and len(sys.argv) >= 3:
del_whitelist(sys.argv[2])
elif cmd == "list":
list_whitelist()
elif cmd == "presets":
show_presets()
else:
show_help()
if __name__ == "__main__":
main()