Files
xdp-defense/lib/xdp_defense_daemon.py
kaffa 667c6eac81 Fix 12 code review issues (4 MEDIUM + 8 LOW)
MEDIUM:
- M1: Whitelist direct IP/CIDR additions now persist to direct.txt
- M2: get_map_id() uses 5s TTL cache (single bpftool call for all maps)
- M3: IPv6 extension header parsing in xdp_ddos.c (hop-by-hop/routing/frag/dst)
- M4: Shell injection prevention - sanitize_input() + sys.argv[] for all Python calls

LOW:
- L1: Remove redundant self.running (uses _stop_event only)
- L2: Remove unused config values (rate_limit_after, cooldown_multiplier, retrain_interval)
- L3: Thread poll intervals reloaded on SIGHUP
- L4: batch_map_operation counts only successfully written entries
- L5: Clarify unique_ips_approx comment (per-packet counter)
- L6: Document LRU_HASH multi-CPU race condition as acceptable
- L7: Download Cloudflare IPv6 ranges in whitelist preset
- L8: Fix file handle leak in xdp_country.py list_countries()

Also: SIGHUP now preserves EWMA/violation state, daemon skips whitelisted
IPs in EWMA/AI escalation, deep copy for default config, IHL validation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 09:23:41 +09:00

731 lines
27 KiB
Python
Executable File

#!/usr/bin/env python3
"""
XDP Defense Daemon
Userspace component: EWMA-based rate analysis, AI anomaly detection,
time-profile switching, and automatic escalation.
4 worker threads + main thread (signal handling):
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
- AI Thread: reads traffic features, runs Isolation Forest inference
- Profile Thread: checks time-of-day, switches rate_config profiles
- Cleanup Thread: removes expired entries from blocked_ips maps
"""
import copy
import os
import sys
import time
import signal
import threading
import logging
import logging.handlers
import csv
import pickle
from collections import defaultdict
from datetime import datetime
import yaml
# ==================== Logging ====================
log = logging.getLogger('xdp-defense-daemon')
log.setLevel(logging.INFO)
_console = logging.StreamHandler()
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
log.addHandler(_console)
try:
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
log.addHandler(_syslog)
except Exception:
pass
# ==================== Configuration ====================
DEFAULT_CONFIG = {
'general': {
'interface': 'eth0',
'log_level': 'info',
'pid_file': '/var/lib/xdp-defense/daemon.pid',
'data_dir': '/var/lib/xdp-defense',
},
'rate_limits': {
'default_pps': 1000,
'default_bps': 0,
'window_sec': 1,
'profiles': {},
},
'escalation': {
'temp_block_after': 5,
'perm_block_after': 20,
'temp_block_duration': 300,
'violation_window': 600,
},
'ewma': {
'alpha': 0.3,
'poll_interval': 1,
'threshold_multiplier': 3.0,
},
'ai': {
'enabled': True,
'model_type': 'IsolationForest',
'contamination': 0.05,
'n_estimators': 100,
'learning_duration': 259200,
'min_samples': 1000,
'poll_interval': 5,
'anomaly_threshold': -0.3,
'min_packets_for_sample': 20,
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
},
}
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
def load_config(path=CONFIG_PATH):
"""Load config with defaults."""
cfg = copy.deepcopy(DEFAULT_CONFIG)
try:
with open(path) as f:
user = yaml.safe_load(f) or {}
for section in cfg:
if section in user and isinstance(user[section], dict):
cfg[section].update(user[section])
except FileNotFoundError:
log.warning("Config not found at %s, using defaults", path)
except Exception as e:
log.error("Failed to load config: %s", e)
return cfg
# ==================== ViolationTracker ====================
class ViolationTracker:
"""Track per-IP violation counts and manage escalation."""
def __init__(self, escalation_cfg):
self.cfg = escalation_cfg
self.violations = defaultdict(list)
self.lock = threading.Lock()
def record_violation(self, ip):
"""Record a violation and return escalation level.
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
"""
now = time.time()
window = self.cfg.get('violation_window', 600)
with self.lock:
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
self.violations[ip].append(now)
count = len(self.violations[ip])
perm_after = self.cfg.get('perm_block_after', 20)
temp_after = self.cfg.get('temp_block_after', 5)
if count >= perm_after:
return 'perm_block'
elif count >= temp_after:
return 'temp_block'
return 'rate_limit'
def clear(self, ip):
with self.lock:
self.violations.pop(ip, None)
def cleanup_expired(self):
"""Remove entries with no recent violations."""
now = time.time()
window = self.cfg.get('violation_window', 600)
with self.lock:
expired = [ip for ip, times in self.violations.items()
if all(now - t >= window for t in times)]
for ip in expired:
del self.violations[ip]
# ==================== EWMAAnalyzer ====================
class EWMAAnalyzer:
"""Per-IP EWMA calculation for rate anomaly detection."""
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
self.alpha = alpha
self.threshold_multiplier = threshold_multiplier
self.ewma = {}
self.baseline = {}
self.lock = threading.Lock()
def update(self, ip, current_pps):
"""Update EWMA for an IP. Returns True if anomalous."""
with self.lock:
if ip not in self.ewma:
self.ewma[ip] = current_pps
self.baseline[ip] = current_pps
return False
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
base = max(self.baseline[ip], 1)
if self.ewma[ip] > base * self.threshold_multiplier:
return True
return False
def get_stats(self, ip):
with self.lock:
return {
'ewma': self.ewma.get(ip, 0),
'baseline': self.baseline.get(ip, 0),
}
def cleanup_stale(self, active_ips):
"""Remove tracking for IPs no longer in rate counters."""
with self.lock:
stale = set(self.ewma.keys()) - set(active_ips)
for ip in stale:
self.ewma.pop(ip, None)
self.baseline.pop(ip, None)
# ==================== AIDetector ====================
class AIDetector:
"""Isolation Forest based anomaly detection on traffic features."""
def __init__(self, ai_cfg):
self.cfg = ai_cfg
self.model = None
self.scaler = None
self.started_at = time.time()
self.training_data = []
self.is_learning = True
self._retrain_requested = False
@property
def enabled(self):
return self.cfg.get('enabled', False)
def request_retrain(self):
self._retrain_requested = True
def collect_sample(self, features):
"""Collect a feature sample during learning phase."""
if not self.enabled:
return
self.training_data.append(features)
learning_dur = self.cfg.get('learning_duration', 259200)
min_samples = self.cfg.get('min_samples', 1000)
elapsed = time.time() - self.started_at
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
self._train()
self._retrain_requested = False
def _train(self):
"""Train the Isolation Forest model."""
try:
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
import numpy as np
except ImportError:
log.error("scikit-learn not installed. AI detection disabled.")
self.cfg['enabled'] = False
return
if len(self.training_data) < 10:
log.warning("Not enough training data (%d samples)", len(self.training_data))
return
log.info("Training AI model with %d samples...", len(self.training_data))
try:
X = np.array(self.training_data)
self.scaler = StandardScaler()
X_scaled = self.scaler.fit_transform(X)
self.model = IsolationForest(
n_estimators=self.cfg.get('n_estimators', 100),
contamination=self.cfg.get('contamination', 'auto'),
random_state=42,
)
self.model.fit(X_scaled)
self.is_learning = False
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
with open(model_file, 'wb') as f:
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
with open(data_file, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
'small_pkt_count', 'large_pkt_count',
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
])
writer.writerows(self.training_data)
log.info("AI model trained and saved to %s", model_file)
except Exception as e:
log.error("AI training failed: %s", e)
def load_model(self):
"""Load a previously trained model."""
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
if not os.path.exists(model_file):
return False
try:
with open(model_file, 'rb') as f:
data = pickle.load(f)
self.model = data['model']
self.scaler = data['scaler']
self.is_learning = False
log.info("AI model loaded from %s", model_file)
return True
except Exception as e:
log.error("Failed to load AI model: %s", e)
return False
def predict(self, features):
"""Run anomaly detection. Returns (is_anomaly, score)."""
if not self.enabled or self.model is None:
return False, 0.0
try:
import numpy as np
X = np.array([features])
X_scaled = self.scaler.transform(X)
score = self.model.decision_function(X_scaled)[0]
threshold = self.cfg.get('anomaly_threshold', -0.3)
return score < threshold, float(score)
except Exception as e:
log.error("AI prediction error: %s", e)
return False, 0.0
# ==================== ProfileManager ====================
class ProfileManager:
"""Manage time-based rate limit profiles."""
def __init__(self, rate_cfg):
self.cfg = rate_cfg
self.current_profile = 'default'
def check_and_apply(self):
"""Check current time and apply matching profile."""
from xdp_common import write_rate_config
profiles = self.cfg.get('profiles', {})
now = datetime.now()
current_hour = now.hour
current_min = now.minute
current_time = current_hour * 60 + current_min
weekday = now.strftime('%a').lower()
matched_profile = None
matched_name = 'default'
for name, profile in profiles.items():
hours = profile.get('hours', '')
weekdays = profile.get('weekdays', '')
if weekdays:
day_range = weekdays.lower().split('-')
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
if len(day_range) == 2:
try:
start_idx = day_names.index(day_range[0])
end_idx = day_names.index(day_range[1])
current_idx = day_names.index(weekday)
if start_idx <= end_idx:
if not (start_idx <= current_idx <= end_idx):
continue
else:
if not (current_idx >= start_idx or current_idx <= end_idx):
continue
except ValueError:
continue
if hours:
try:
start_str, end_str = hours.split('-')
sh, sm = map(int, start_str.split(':'))
eh, em = map(int, end_str.split(':'))
start_min = sh * 60 + sm
end_min = eh * 60 + em
if start_min <= end_min:
if not (start_min <= current_time < end_min):
continue
else:
if not (current_time >= start_min or current_time < end_min):
continue
except (ValueError, AttributeError):
continue
matched_profile = profile
matched_name = name
break
if matched_name != self.current_profile:
if matched_profile:
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
else:
pps = self.cfg.get('default_pps', 1000)
bps = self.cfg.get('default_bps', 0)
window = self.cfg.get('window_sec', 1)
try:
write_rate_config(pps, bps, window * 1_000_000_000)
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
self.current_profile = matched_name
except Exception as e:
log.error("Failed to apply profile %s: %s", matched_name, e)
# ==================== DDoSDaemon ====================
class DDoSDaemon:
"""Main daemon orchestrator."""
def __init__(self, config_path=CONFIG_PATH):
self.config_path = config_path
self.cfg = load_config(config_path)
self._stop_event = threading.Event()
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
self._setup_components()
def _setup_components(self):
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
self.ewma_analyzer = EWMAAnalyzer(
alpha=self.cfg['ewma'].get('alpha', 0.3),
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
)
self.ai_detector = AIDetector(self.cfg['ai'])
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
if self.ai_detector.enabled:
self.ai_detector.load_model()
level = self.cfg['general'].get('log_level', 'info').upper()
log.setLevel(getattr(logging, level, logging.INFO))
def _write_pid(self):
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
with open(pid_file, 'w') as f:
f.write(str(os.getpid()))
def _remove_pid(self):
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
try:
os.unlink(pid_file)
except OSError:
pass
def _ensure_single_instance(self):
"""Stop any existing daemon before starting."""
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
if not os.path.exists(pid_file):
return
try:
with open(pid_file) as f:
old_pid = int(f.read().strip())
os.kill(old_pid, 0)
log.info("Stopping existing daemon (PID %d)...", old_pid)
os.kill(old_pid, signal.SIGTERM)
for _ in range(30):
time.sleep(1)
try:
os.kill(old_pid, 0)
except OSError:
log.info("Old daemon stopped")
return
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
os.kill(old_pid, signal.SIGKILL)
time.sleep(1)
except (ValueError, OSError):
pass
def _handle_sighup(self, signum, frame):
log.info("SIGHUP received, reloading config...")
self.cfg = load_config(self.config_path)
# Update existing components without rebuilding (preserves EWMA/violation state)
self.violation_tracker.cfg = self.cfg['escalation']
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3)
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0)
self.ai_detector.cfg = self.cfg['ai']
self.profile_manager.cfg = self.cfg['rate_limits']
# Update poll intervals (used by threads on next iteration)
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
level = self.cfg['general'].get('log_level', 'info').upper()
log.setLevel(getattr(logging, level, logging.INFO))
log.info("Config reloaded (state preserved)")
def _handle_sigterm(self, signum, frame):
log.info("SIGTERM received, shutting down...")
self._stop_event.set()
def _handle_sigusr1(self, signum, frame):
log.info("SIGUSR1 received, requesting AI retrain...")
self.ai_detector.request_retrain()
# ---- Worker Threads ----
def _ewma_thread(self):
"""Poll rate counters, compute EWMA, detect violations, escalate."""
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
prev_counters = {}
while not self._stop_event.is_set():
interval = self._ewma_interval
try:
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
active_ips = []
for ip_str, pkts, bts, last_seen in entries:
active_ips.append(ip_str)
prev = prev_counters.get(ip_str, 0)
delta = pkts - prev if pkts >= prev else pkts
prev_counters[ip_str] = pkts
if delta <= 0:
continue
pps = delta / max(interval, 0.1)
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
if is_anomalous:
# Skip whitelisted IPs
if is_whitelisted(ip_str):
log.debug("EWMA anomaly skipped (whitelisted): %s", ip_str)
continue
level = self.violation_tracker.record_violation(ip_str)
ew = self.ewma_analyzer.get_stats(ip_str)
log.warning(
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
ip_str, pps, ew['ewma'], ew['baseline'], level
)
if level == 'temp_block':
dur = self.cfg['escalation'].get('temp_block_duration', 300)
try:
block_ip(ip_str, dur)
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
except Exception as e:
log.error("Failed to temp-block %s: %s", ip_str, e)
elif level == 'perm_block':
try:
block_ip(ip_str, 0)
log.warning("PERM BLOCK: %s", ip_str)
except Exception as e:
log.error("Failed to perm-block %s: %s", ip_str, e)
self.ewma_analyzer.cleanup_stale(active_ips)
except Exception as e:
log.error("EWMA thread error: %s", e)
self._stop_event.wait(interval)
def _ai_thread(self):
"""Read traffic features, run AI inference or collect training data."""
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
prev_features = None
while not self._stop_event.is_set():
interval = self._ai_interval
try:
if not self.ai_detector.enabled:
self._stop_event.wait(interval)
continue
features = read_percpu_features()
if not features:
self._stop_event.wait(interval)
continue
feature_names = [
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
'small_pkt_count', 'large_pkt_count'
]
if prev_features is not None:
deltas = []
for name in feature_names:
cur = features.get(name, 0)
prev = prev_features.get(name, 0)
deltas.append(max(0, cur - prev))
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
if deltas[0] < min_pkts:
prev_features = features
self._stop_event.wait(interval)
continue
total = deltas[0] + 1e-6
syn_ratio = deltas[2] / total
udp_ratio = deltas[4] / total
icmp_ratio = deltas[5] / total
small_pkt_ratio = deltas[8] / total
avg_pkt_size = deltas[1] / total
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
if self.ai_detector.is_learning:
self.ai_detector.collect_sample(deltas)
if len(self.ai_detector.training_data) % 100 == 0:
log.debug("AI learning: %d samples collected",
len(self.ai_detector.training_data))
else:
is_anomaly, score = self.ai_detector.predict(deltas)
if is_anomaly:
log.warning(
"AI ANOMALY detected: score=%.4f deltas=%s",
score, dict(zip(feature_names, deltas[:len(feature_names)]))
)
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
for ip_str, pkts, bts, _ in top_ips:
# Skip whitelisted IPs
if is_whitelisted(ip_str):
log.debug("AI escalation skipped (whitelisted): %s", ip_str)
continue
level = self.violation_tracker.record_violation(ip_str)
log.warning("AI escalation: %s -> %s", ip_str, level)
if level == 'temp_block':
dur = self.cfg['escalation'].get('temp_block_duration', 300)
try:
block_ip(ip_str, dur)
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
except Exception as e:
log.error("Failed to AI temp-block %s: %s", ip_str, e)
elif level == 'perm_block':
try:
block_ip(ip_str, 0)
log.warning("AI PERM BLOCK: %s", ip_str)
except Exception as e:
log.error("Failed to AI perm-block %s: %s", ip_str, e)
prev_features = features
except Exception as e:
log.error("AI thread error: %s", e)
self._stop_event.wait(interval)
def _profile_thread(self):
"""Check time-of-day and switch rate profiles."""
while not self._stop_event.is_set():
try:
self.profile_manager.check_and_apply()
except Exception as e:
log.error("Profile thread error: %s", e)
self._stop_event.wait(60)
def _cleanup_thread(self):
"""Periodically clean up expired blocked IPs and stale violations."""
from xdp_common import dump_blocked_ips, unblock_ip
while not self._stop_event.is_set():
try:
with open('/proc/uptime') as f:
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
entries = dump_blocked_ips(map_name)
for ip_str, expire_ns, blocked_at, drop_count in entries:
if expire_ns != 0 and now_ns > expire_ns:
try:
unblock_ip(ip_str)
self.violation_tracker.clear(ip_str)
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
except Exception as e:
log.error("Failed to remove expired block %s: %s", ip_str, e)
self.violation_tracker.cleanup_expired()
except Exception as e:
log.error("Cleanup thread error: %s", e)
self._stop_event.wait(60)
# ---- Main Loop ----
def run(self):
"""Start the daemon."""
log.info("XDP Defense Daemon starting...")
signal.signal(signal.SIGHUP, self._handle_sighup)
signal.signal(signal.SIGTERM, self._handle_sigterm)
signal.signal(signal.SIGINT, self._handle_sigterm)
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
self._ensure_single_instance()
self._write_pid()
threads = [
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
]
for t in threads:
t.start()
log.info("Started %s thread", t.name)
log.info("Daemon running (PID %d)", os.getpid())
try:
while not self._stop_event.is_set():
self._stop_event.wait(1)
except KeyboardInterrupt:
pass
log.info("Shutting down...")
self._stop_event.set()
for t in threads:
t.join(timeout=5)
self._remove_pid()
log.info("Daemon stopped")
# ==================== Entry Point ====================
def main():
config_path = CONFIG_PATH
if len(sys.argv) > 1:
config_path = sys.argv[1]
daemon = DDoSDaemon(config_path)
daemon.run()
if __name__ == '__main__':
main()