Fix 12 code review issues (4 MEDIUM + 8 LOW)
MEDIUM: - M1: Whitelist direct IP/CIDR additions now persist to direct.txt - M2: get_map_id() uses 5s TTL cache (single bpftool call for all maps) - M3: IPv6 extension header parsing in xdp_ddos.c (hop-by-hop/routing/frag/dst) - M4: Shell injection prevention - sanitize_input() + sys.argv[] for all Python calls LOW: - L1: Remove redundant self.running (uses _stop_event only) - L2: Remove unused config values (rate_limit_after, cooldown_multiplier, retrain_interval) - L3: Thread poll intervals reloaded on SIGHUP - L4: batch_map_operation counts only successfully written entries - L5: Clarify unique_ips_approx comment (per-packet counter) - L6: Document LRU_HASH multi-CPU race condition as acceptable - L7: Download Cloudflare IPv6 ranges in whitelist preset - L8: Fix file handle leak in xdp_country.py list_countries() Also: SIGHUP now preserves EWMA/violation state, daemon skips whitelisted IPs in EWMA/AI escalation, deep copy for default config, IHL validation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -32,8 +32,18 @@ if _syslog_handler:
|
||||
|
||||
# ==================== BPF Map Helpers ====================
|
||||
|
||||
_map_cache = {} # {map_name: (map_id, timestamp)}
|
||||
_MAP_CACHE_TTL = 5.0 # seconds
|
||||
|
||||
def get_map_id(map_name):
|
||||
"""Get BPF map ID by name."""
|
||||
"""Get BPF map ID by name (cached with 5s TTL)."""
|
||||
import time as _time
|
||||
now = _time.monotonic()
|
||||
|
||||
cached = _map_cache.get(map_name)
|
||||
if cached and (now - cached[1]) < _MAP_CACHE_TTL:
|
||||
return cached[0]
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "show", "-j"],
|
||||
capture_output=True, text=True
|
||||
@@ -42,12 +52,15 @@ def get_map_id(map_name):
|
||||
return None
|
||||
try:
|
||||
maps = json.loads(result.stdout)
|
||||
# Update cache for all maps found
|
||||
for m in maps:
|
||||
if m.get("name") == map_name:
|
||||
return m.get("id")
|
||||
name = m.get("name")
|
||||
if name:
|
||||
_map_cache[name] = (m.get("id"), now)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
cached = _map_cache.get(map_name)
|
||||
return cached[0] if cached else None
|
||||
|
||||
|
||||
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
|
||||
@@ -130,10 +143,12 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
|
||||
"""
|
||||
total = len(cidrs)
|
||||
processed = 0
|
||||
written = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = cidrs[i:i + batch_size]
|
||||
batch_file = None
|
||||
batch_written = 0
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
|
||||
@@ -145,6 +160,7 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
|
||||
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
|
||||
else:
|
||||
f.write(f"map delete id {map_id} key hex {key_hex}\n")
|
||||
batch_written += 1
|
||||
except (ValueError, Exception):
|
||||
continue
|
||||
|
||||
@@ -157,13 +173,44 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
|
||||
os.unlink(batch_file)
|
||||
|
||||
processed += len(batch)
|
||||
written += batch_written
|
||||
pct = processed * 100 // total if total > 0 else 100
|
||||
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
|
||||
|
||||
print()
|
||||
proto = "v6" if ipv6 else "v4"
|
||||
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
|
||||
return processed
|
||||
audit_log.info(f"batch {operation} {proto}: {written} entries on map {map_id}")
|
||||
return written
|
||||
|
||||
|
||||
# ==================== Whitelist Check (for daemon) ====================
|
||||
|
||||
def is_whitelisted(ip_str):
|
||||
"""Check if an IP is in the BPF whitelist maps.
|
||||
Returns True if whitelisted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if isinstance(addr, ipaddress.IPv6Address):
|
||||
map_name = "whitelist_v6"
|
||||
key_hex = cidr_to_key_v6(f"{ip_str}/128")
|
||||
else:
|
||||
map_name = "whitelist_v4"
|
||||
key_hex = cidr_to_key(f"{ip_str}/32")
|
||||
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
return False
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "lookup", "id", str(map_id),
|
||||
"key", "hex"] + key_hex.split(),
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
# ==================== IP Encoding Helpers (from ddos) ====================
|
||||
@@ -360,16 +407,17 @@ def block_ip(ip_str, duration_sec=0):
|
||||
|
||||
key_hex = ip_to_hex_key(ip_str)
|
||||
|
||||
# Use CLOCK_BOOTTIME (matches BPF ktime_get_ns)
|
||||
with open('/proc/uptime', 'r') as f:
|
||||
uptime_sec = float(f.read().split()[0])
|
||||
now_ns = int(uptime_sec * 1_000_000_000)
|
||||
|
||||
if duration_sec > 0:
|
||||
with open('/proc/uptime', 'r') as f:
|
||||
uptime_sec = float(f.read().split()[0])
|
||||
now_ns = int(uptime_sec * 1_000_000_000)
|
||||
expire_ns = now_ns + (duration_sec * 1_000_000_000)
|
||||
else:
|
||||
expire_ns = 0
|
||||
|
||||
now_ns_val = 0
|
||||
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
|
||||
raw = struct.pack('<QQQ', expire_ns, now_ns, 0)
|
||||
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
||||
|
||||
result = subprocess.run(
|
||||
|
||||
@@ -133,7 +133,8 @@ def list_countries():
|
||||
|
||||
for cc_file in sorted(files):
|
||||
cc = cc_file.stem.upper()
|
||||
count = sum(1 for _ in open(cc_file))
|
||||
with open(cc_file) as f:
|
||||
count = sum(1 for _ in f)
|
||||
mtime = cc_file.stat().st_mtime
|
||||
age = int((time.time() - mtime) / 86400)
|
||||
print(f" {cc}: {count} CIDRs (updated {age}d ago)")
|
||||
|
||||
@@ -11,6 +11,7 @@ time-profile switching, and automatic escalation.
|
||||
- Cleanup Thread: removes expired entries from blocked_ips maps
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -57,12 +58,10 @@ DEFAULT_CONFIG = {
|
||||
'profiles': {},
|
||||
},
|
||||
'escalation': {
|
||||
'rate_limit_after': 1,
|
||||
'temp_block_after': 5,
|
||||
'perm_block_after': 20,
|
||||
'temp_block_duration': 300,
|
||||
'violation_window': 600,
|
||||
'cooldown_multiplier': 0.5,
|
||||
},
|
||||
'ewma': {
|
||||
'alpha': 0.3,
|
||||
@@ -78,7 +77,6 @@ DEFAULT_CONFIG = {
|
||||
'min_samples': 1000,
|
||||
'poll_interval': 5,
|
||||
'anomaly_threshold': -0.3,
|
||||
'retrain_interval': 604800,
|
||||
'min_packets_for_sample': 20,
|
||||
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
|
||||
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
|
||||
@@ -90,7 +88,7 @@ CONFIG_PATH = '/etc/xdp-defense/config.yaml'
|
||||
|
||||
def load_config(path=CONFIG_PATH):
|
||||
"""Load config with defaults."""
|
||||
cfg = DEFAULT_CONFIG.copy()
|
||||
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
||||
try:
|
||||
with open(path) as f:
|
||||
user = yaml.safe_load(f) or {}
|
||||
@@ -406,8 +404,9 @@ class DDoSDaemon:
|
||||
def __init__(self, config_path=CONFIG_PATH):
|
||||
self.config_path = config_path
|
||||
self.cfg = load_config(config_path)
|
||||
self.running = False
|
||||
self._stop_event = threading.Event()
|
||||
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
||||
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
||||
self._setup_components()
|
||||
|
||||
def _setup_components(self):
|
||||
@@ -465,12 +464,21 @@ class DDoSDaemon:
|
||||
def _handle_sighup(self, signum, frame):
|
||||
log.info("SIGHUP received, reloading config...")
|
||||
self.cfg = load_config(self.config_path)
|
||||
self._setup_components()
|
||||
log.info("Config reloaded")
|
||||
# Update existing components without rebuilding (preserves EWMA/violation state)
|
||||
self.violation_tracker.cfg = self.cfg['escalation']
|
||||
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3)
|
||||
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0)
|
||||
self.ai_detector.cfg = self.cfg['ai']
|
||||
self.profile_manager.cfg = self.cfg['rate_limits']
|
||||
# Update poll intervals (used by threads on next iteration)
|
||||
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
||||
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
||||
level = self.cfg['general'].get('log_level', 'info').upper()
|
||||
log.setLevel(getattr(logging, level, logging.INFO))
|
||||
log.info("Config reloaded (state preserved)")
|
||||
|
||||
def _handle_sigterm(self, signum, frame):
|
||||
log.info("SIGTERM received, shutting down...")
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
|
||||
def _handle_sigusr1(self, signum, frame):
|
||||
@@ -481,12 +489,12 @@ class DDoSDaemon:
|
||||
|
||||
def _ewma_thread(self):
|
||||
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
||||
from xdp_common import dump_rate_counters, block_ip
|
||||
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
|
||||
|
||||
interval = self.cfg['ewma'].get('poll_interval', 1)
|
||||
prev_counters = {}
|
||||
|
||||
while self.running:
|
||||
while not self._stop_event.is_set():
|
||||
interval = self._ewma_interval
|
||||
try:
|
||||
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
|
||||
active_ips = []
|
||||
@@ -505,6 +513,11 @@ class DDoSDaemon:
|
||||
|
||||
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
|
||||
if is_anomalous:
|
||||
# Skip whitelisted IPs
|
||||
if is_whitelisted(ip_str):
|
||||
log.debug("EWMA anomaly skipped (whitelisted): %s", ip_str)
|
||||
continue
|
||||
|
||||
level = self.violation_tracker.record_violation(ip_str)
|
||||
ew = self.ewma_analyzer.get_stats(ip_str)
|
||||
log.warning(
|
||||
@@ -536,12 +549,12 @@ class DDoSDaemon:
|
||||
|
||||
def _ai_thread(self):
|
||||
"""Read traffic features, run AI inference or collect training data."""
|
||||
from xdp_common import read_percpu_features, dump_rate_counters, block_ip
|
||||
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
|
||||
|
||||
interval = self.cfg['ai'].get('poll_interval', 5)
|
||||
prev_features = None
|
||||
|
||||
while self.running:
|
||||
while not self._stop_event.is_set():
|
||||
interval = self._ai_interval
|
||||
try:
|
||||
if not self.ai_detector.enabled:
|
||||
self._stop_event.wait(interval)
|
||||
@@ -593,6 +606,11 @@ class DDoSDaemon:
|
||||
)
|
||||
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
|
||||
for ip_str, pkts, bts, _ in top_ips:
|
||||
# Skip whitelisted IPs
|
||||
if is_whitelisted(ip_str):
|
||||
log.debug("AI escalation skipped (whitelisted): %s", ip_str)
|
||||
continue
|
||||
|
||||
level = self.violation_tracker.record_violation(ip_str)
|
||||
log.warning("AI escalation: %s -> %s", ip_str, level)
|
||||
|
||||
@@ -620,7 +638,7 @@ class DDoSDaemon:
|
||||
|
||||
def _profile_thread(self):
|
||||
"""Check time-of-day and switch rate profiles."""
|
||||
while self.running:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
self.profile_manager.check_and_apply()
|
||||
except Exception as e:
|
||||
@@ -631,7 +649,7 @@ class DDoSDaemon:
|
||||
"""Periodically clean up expired blocked IPs and stale violations."""
|
||||
from xdp_common import dump_blocked_ips, unblock_ip
|
||||
|
||||
while self.running:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
with open('/proc/uptime') as f:
|
||||
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
||||
@@ -667,7 +685,6 @@ class DDoSDaemon:
|
||||
|
||||
self._ensure_single_instance()
|
||||
self._write_pid()
|
||||
self.running = True
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
|
||||
@@ -683,13 +700,12 @@ class DDoSDaemon:
|
||||
log.info("Daemon running (PID %d)", os.getpid())
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
while not self._stop_event.is_set():
|
||||
self._stop_event.wait(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
log.info("Shutting down...")
|
||||
self.running = False
|
||||
self._stop_event.set()
|
||||
|
||||
for t in threads:
|
||||
|
||||
@@ -35,7 +35,7 @@ PRESETS = {
|
||||
}
|
||||
|
||||
def download_cloudflare():
|
||||
"""Download Cloudflare IP ranges"""
|
||||
"""Download Cloudflare IP ranges (IPv4 + IPv6)"""
|
||||
cidrs = []
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
@@ -43,10 +43,23 @@ def download_cloudflare():
|
||||
headers={"User-Agent": "xdp-whitelist/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req) as r:
|
||||
cidrs.extend(r.read().decode().strip().split('\n'))
|
||||
print(f" Downloaded {len(cidrs)} IPv4 ranges")
|
||||
v4 = r.read().decode().strip().split('\n')
|
||||
cidrs.extend(v4)
|
||||
print(f" Downloaded {len(v4)} IPv4 ranges")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to download IPv4: {e}")
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
PRESETS["cloudflare"]["v6"],
|
||||
headers={"User-Agent": "xdp-whitelist/1.0"}
|
||||
)
|
||||
with urllib.request.urlopen(req) as r:
|
||||
v6 = r.read().decode().strip().split('\n')
|
||||
cidrs.extend(v6)
|
||||
print(f" Downloaded {len(v6)} IPv6 ranges")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to download IPv6: {e}")
|
||||
return cidrs
|
||||
|
||||
def download_aws():
|
||||
@@ -103,7 +116,7 @@ def add_whitelist(name, cidrs=None):
|
||||
|
||||
if cidrs is None and wl_file.exists():
|
||||
with open(wl_file) as f:
|
||||
cidrs = [line.strip() for line in f if line.strip() and ':' not in line]
|
||||
cidrs = [line.strip() for line in f if line.strip()]
|
||||
if cidrs:
|
||||
print(f"[INFO] Using cached {name} ({len(cidrs)} CIDRs)")
|
||||
|
||||
@@ -200,9 +213,15 @@ def list_whitelist():
|
||||
|
||||
for wl_file in sorted(files):
|
||||
name = wl_file.stem
|
||||
count = sum(1 for line in open(wl_file) if line.strip() and ':' not in line)
|
||||
with open(wl_file) as f:
|
||||
cidrs = [line.strip() for line in f if line.strip()]
|
||||
v4_count = sum(1 for c in cidrs if ':' not in c)
|
||||
v6_count = len(cidrs) - v4_count
|
||||
desc = PRESETS.get(name, {}).get("desc", "Custom")
|
||||
print(f" {name}: {count} CIDRs ({desc})")
|
||||
if v6_count > 0:
|
||||
print(f" {name}: {v4_count} v4 + {v6_count} v6 CIDRs ({desc})")
|
||||
else:
|
||||
print(f" {name}: {v4_count} CIDRs ({desc})")
|
||||
|
||||
def show_presets():
|
||||
"""Show available presets"""
|
||||
|
||||
Reference in New Issue
Block a user