Chain two XDP programs via libxdp dispatcher on the same interface: xdp_blocker (priority 10) handles CIDR/country/whitelist blocking, xdp_ddos (priority 20) handles rate limiting, EWMA analysis, and AI anomaly detection. Whitelist maps are shared via BPF map pinning so whitelisted IPs bypass both blocklist checks and DDoS rate limiting. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
715 lines
25 KiB
Python
Executable File
715 lines
25 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
XDP Defense Daemon
|
|
Userspace component: EWMA-based rate analysis, AI anomaly detection,
|
|
time-profile switching, and automatic escalation.
|
|
|
|
4 worker threads + main thread (signal handling):
|
|
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
|
|
- AI Thread: reads traffic features, runs Isolation Forest inference
|
|
- Profile Thread: checks time-of-day, switches rate_config profiles
|
|
- Cleanup Thread: removes expired entries from blocked_ips maps
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import signal
|
|
import threading
|
|
import logging
|
|
import logging.handlers
|
|
import csv
|
|
import pickle
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
|
|
import yaml
|
|
|
|
# ==================== Logging ====================
|
|
|
|
log = logging.getLogger('xdp-defense-daemon')
|
|
log.setLevel(logging.INFO)
|
|
|
|
_console = logging.StreamHandler()
|
|
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
|
|
log.addHandler(_console)
|
|
|
|
try:
|
|
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
|
|
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
|
|
log.addHandler(_syslog)
|
|
except Exception:
|
|
pass
|
|
|
|
# ==================== Configuration ====================
|
|
|
|
DEFAULT_CONFIG = {
|
|
'general': {
|
|
'interface': 'eth0',
|
|
'log_level': 'info',
|
|
'pid_file': '/var/lib/xdp-defense/daemon.pid',
|
|
'data_dir': '/var/lib/xdp-defense',
|
|
},
|
|
'rate_limits': {
|
|
'default_pps': 1000,
|
|
'default_bps': 0,
|
|
'window_sec': 1,
|
|
'profiles': {},
|
|
},
|
|
'escalation': {
|
|
'rate_limit_after': 1,
|
|
'temp_block_after': 5,
|
|
'perm_block_after': 20,
|
|
'temp_block_duration': 300,
|
|
'violation_window': 600,
|
|
'cooldown_multiplier': 0.5,
|
|
},
|
|
'ewma': {
|
|
'alpha': 0.3,
|
|
'poll_interval': 1,
|
|
'threshold_multiplier': 3.0,
|
|
},
|
|
'ai': {
|
|
'enabled': True,
|
|
'model_type': 'IsolationForest',
|
|
'contamination': 0.05,
|
|
'n_estimators': 100,
|
|
'learning_duration': 259200,
|
|
'min_samples': 1000,
|
|
'poll_interval': 5,
|
|
'anomaly_threshold': -0.3,
|
|
'retrain_interval': 604800,
|
|
'min_packets_for_sample': 20,
|
|
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
|
|
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
|
|
},
|
|
}
|
|
|
|
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
|
|
|
|
|
|
def load_config(path=CONFIG_PATH):
|
|
"""Load config with defaults."""
|
|
cfg = DEFAULT_CONFIG.copy()
|
|
try:
|
|
with open(path) as f:
|
|
user = yaml.safe_load(f) or {}
|
|
for section in cfg:
|
|
if section in user and isinstance(user[section], dict):
|
|
cfg[section].update(user[section])
|
|
except FileNotFoundError:
|
|
log.warning("Config not found at %s, using defaults", path)
|
|
except Exception as e:
|
|
log.error("Failed to load config: %s", e)
|
|
return cfg
|
|
|
|
|
|
# ==================== ViolationTracker ====================
|
|
|
|
class ViolationTracker:
|
|
"""Track per-IP violation counts and manage escalation."""
|
|
|
|
def __init__(self, escalation_cfg):
|
|
self.cfg = escalation_cfg
|
|
self.violations = defaultdict(list)
|
|
self.lock = threading.Lock()
|
|
|
|
def record_violation(self, ip):
|
|
"""Record a violation and return escalation level.
|
|
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
|
|
"""
|
|
now = time.time()
|
|
window = self.cfg.get('violation_window', 600)
|
|
|
|
with self.lock:
|
|
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
|
|
self.violations[ip].append(now)
|
|
count = len(self.violations[ip])
|
|
|
|
perm_after = self.cfg.get('perm_block_after', 20)
|
|
temp_after = self.cfg.get('temp_block_after', 5)
|
|
|
|
if count >= perm_after:
|
|
return 'perm_block'
|
|
elif count >= temp_after:
|
|
return 'temp_block'
|
|
return 'rate_limit'
|
|
|
|
def clear(self, ip):
|
|
with self.lock:
|
|
self.violations.pop(ip, None)
|
|
|
|
def cleanup_expired(self):
|
|
"""Remove entries with no recent violations."""
|
|
now = time.time()
|
|
window = self.cfg.get('violation_window', 600)
|
|
with self.lock:
|
|
expired = [ip for ip, times in self.violations.items()
|
|
if all(now - t >= window for t in times)]
|
|
for ip in expired:
|
|
del self.violations[ip]
|
|
|
|
|
|
# ==================== EWMAAnalyzer ====================
|
|
|
|
class EWMAAnalyzer:
|
|
"""Per-IP EWMA calculation for rate anomaly detection."""
|
|
|
|
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
|
|
self.alpha = alpha
|
|
self.threshold_multiplier = threshold_multiplier
|
|
self.ewma = {}
|
|
self.baseline = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def update(self, ip, current_pps):
|
|
"""Update EWMA for an IP. Returns True if anomalous."""
|
|
with self.lock:
|
|
if ip not in self.ewma:
|
|
self.ewma[ip] = current_pps
|
|
self.baseline[ip] = current_pps
|
|
return False
|
|
|
|
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
|
|
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
|
|
|
|
base = max(self.baseline[ip], 1)
|
|
if self.ewma[ip] > base * self.threshold_multiplier:
|
|
return True
|
|
return False
|
|
|
|
def get_stats(self, ip):
|
|
with self.lock:
|
|
return {
|
|
'ewma': self.ewma.get(ip, 0),
|
|
'baseline': self.baseline.get(ip, 0),
|
|
}
|
|
|
|
def cleanup_stale(self, active_ips):
|
|
"""Remove tracking for IPs no longer in rate counters."""
|
|
with self.lock:
|
|
stale = set(self.ewma.keys()) - set(active_ips)
|
|
for ip in stale:
|
|
self.ewma.pop(ip, None)
|
|
self.baseline.pop(ip, None)
|
|
|
|
|
|
# ==================== AIDetector ====================
|
|
|
|
class AIDetector:
|
|
"""Isolation Forest based anomaly detection on traffic features."""
|
|
|
|
def __init__(self, ai_cfg):
|
|
self.cfg = ai_cfg
|
|
self.model = None
|
|
self.scaler = None
|
|
self.started_at = time.time()
|
|
self.training_data = []
|
|
self.is_learning = True
|
|
self._retrain_requested = False
|
|
|
|
@property
|
|
def enabled(self):
|
|
return self.cfg.get('enabled', False)
|
|
|
|
def request_retrain(self):
|
|
self._retrain_requested = True
|
|
|
|
def collect_sample(self, features):
|
|
"""Collect a feature sample during learning phase."""
|
|
if not self.enabled:
|
|
return
|
|
|
|
self.training_data.append(features)
|
|
|
|
learning_dur = self.cfg.get('learning_duration', 259200)
|
|
min_samples = self.cfg.get('min_samples', 1000)
|
|
elapsed = time.time() - self.started_at
|
|
|
|
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
|
|
self._train()
|
|
self._retrain_requested = False
|
|
|
|
def _train(self):
|
|
"""Train the Isolation Forest model."""
|
|
try:
|
|
from sklearn.ensemble import IsolationForest
|
|
from sklearn.preprocessing import StandardScaler
|
|
import numpy as np
|
|
except ImportError:
|
|
log.error("scikit-learn not installed. AI detection disabled.")
|
|
self.cfg['enabled'] = False
|
|
return
|
|
|
|
if len(self.training_data) < 10:
|
|
log.warning("Not enough training data (%d samples)", len(self.training_data))
|
|
return
|
|
|
|
log.info("Training AI model with %d samples...", len(self.training_data))
|
|
|
|
try:
|
|
X = np.array(self.training_data)
|
|
|
|
self.scaler = StandardScaler()
|
|
X_scaled = self.scaler.fit_transform(X)
|
|
|
|
self.model = IsolationForest(
|
|
n_estimators=self.cfg.get('n_estimators', 100),
|
|
contamination=self.cfg.get('contamination', 'auto'),
|
|
random_state=42,
|
|
)
|
|
self.model.fit(X_scaled)
|
|
self.is_learning = False
|
|
|
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
with open(model_file, 'wb') as f:
|
|
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
|
|
|
|
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
|
with open(data_file, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count',
|
|
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
|
])
|
|
writer.writerows(self.training_data)
|
|
|
|
log.info("AI model trained and saved to %s", model_file)
|
|
|
|
except Exception as e:
|
|
log.error("AI training failed: %s", e)
|
|
|
|
def load_model(self):
|
|
"""Load a previously trained model."""
|
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
if not os.path.exists(model_file):
|
|
return False
|
|
try:
|
|
with open(model_file, 'rb') as f:
|
|
data = pickle.load(f)
|
|
self.model = data['model']
|
|
self.scaler = data['scaler']
|
|
self.is_learning = False
|
|
log.info("AI model loaded from %s", model_file)
|
|
return True
|
|
except Exception as e:
|
|
log.error("Failed to load AI model: %s", e)
|
|
return False
|
|
|
|
def predict(self, features):
|
|
"""Run anomaly detection. Returns (is_anomaly, score)."""
|
|
if not self.enabled or self.model is None:
|
|
return False, 0.0
|
|
|
|
try:
|
|
import numpy as np
|
|
X = np.array([features])
|
|
X_scaled = self.scaler.transform(X)
|
|
score = self.model.decision_function(X_scaled)[0]
|
|
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
|
return score < threshold, float(score)
|
|
except Exception as e:
|
|
log.error("AI prediction error: %s", e)
|
|
return False, 0.0
|
|
|
|
|
|
# ==================== ProfileManager ====================
|
|
|
|
class ProfileManager:
|
|
"""Manage time-based rate limit profiles."""
|
|
|
|
def __init__(self, rate_cfg):
|
|
self.cfg = rate_cfg
|
|
self.current_profile = 'default'
|
|
|
|
def check_and_apply(self):
|
|
"""Check current time and apply matching profile."""
|
|
from xdp_common import write_rate_config
|
|
|
|
profiles = self.cfg.get('profiles', {})
|
|
now = datetime.now()
|
|
current_hour = now.hour
|
|
current_min = now.minute
|
|
current_time = current_hour * 60 + current_min
|
|
weekday = now.strftime('%a').lower()
|
|
|
|
matched_profile = None
|
|
matched_name = 'default'
|
|
|
|
for name, profile in profiles.items():
|
|
hours = profile.get('hours', '')
|
|
weekdays = profile.get('weekdays', '')
|
|
|
|
if weekdays:
|
|
day_range = weekdays.lower().split('-')
|
|
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
|
|
if len(day_range) == 2:
|
|
try:
|
|
start_idx = day_names.index(day_range[0])
|
|
end_idx = day_names.index(day_range[1])
|
|
current_idx = day_names.index(weekday)
|
|
if start_idx <= end_idx:
|
|
if not (start_idx <= current_idx <= end_idx):
|
|
continue
|
|
else:
|
|
if not (current_idx >= start_idx or current_idx <= end_idx):
|
|
continue
|
|
except ValueError:
|
|
continue
|
|
|
|
if hours:
|
|
try:
|
|
start_str, end_str = hours.split('-')
|
|
sh, sm = map(int, start_str.split(':'))
|
|
eh, em = map(int, end_str.split(':'))
|
|
start_min = sh * 60 + sm
|
|
end_min = eh * 60 + em
|
|
|
|
if start_min <= end_min:
|
|
if not (start_min <= current_time < end_min):
|
|
continue
|
|
else:
|
|
if not (current_time >= start_min or current_time < end_min):
|
|
continue
|
|
except (ValueError, AttributeError):
|
|
continue
|
|
|
|
matched_profile = profile
|
|
matched_name = name
|
|
break
|
|
|
|
if matched_name != self.current_profile:
|
|
if matched_profile:
|
|
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
|
|
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
|
|
else:
|
|
pps = self.cfg.get('default_pps', 1000)
|
|
bps = self.cfg.get('default_bps', 0)
|
|
|
|
window = self.cfg.get('window_sec', 1)
|
|
|
|
try:
|
|
write_rate_config(pps, bps, window * 1_000_000_000)
|
|
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
|
|
self.current_profile = matched_name
|
|
except Exception as e:
|
|
log.error("Failed to apply profile %s: %s", matched_name, e)
|
|
|
|
|
|
# ==================== DDoSDaemon ====================
|
|
|
|
class DDoSDaemon:
|
|
"""Main daemon orchestrator."""
|
|
|
|
def __init__(self, config_path=CONFIG_PATH):
|
|
self.config_path = config_path
|
|
self.cfg = load_config(config_path)
|
|
self.running = False
|
|
self._stop_event = threading.Event()
|
|
self._setup_components()
|
|
|
|
def _setup_components(self):
|
|
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
|
|
self.ewma_analyzer = EWMAAnalyzer(
|
|
alpha=self.cfg['ewma'].get('alpha', 0.3),
|
|
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
|
|
)
|
|
self.ai_detector = AIDetector(self.cfg['ai'])
|
|
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
|
|
|
|
if self.ai_detector.enabled:
|
|
self.ai_detector.load_model()
|
|
|
|
level = self.cfg['general'].get('log_level', 'info').upper()
|
|
log.setLevel(getattr(logging, level, logging.INFO))
|
|
|
|
def _write_pid(self):
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
|
|
with open(pid_file, 'w') as f:
|
|
f.write(str(os.getpid()))
|
|
|
|
def _remove_pid(self):
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
try:
|
|
os.unlink(pid_file)
|
|
except OSError:
|
|
pass
|
|
|
|
def _ensure_single_instance(self):
|
|
"""Stop any existing daemon before starting."""
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
if not os.path.exists(pid_file):
|
|
return
|
|
try:
|
|
with open(pid_file) as f:
|
|
old_pid = int(f.read().strip())
|
|
os.kill(old_pid, 0)
|
|
log.info("Stopping existing daemon (PID %d)...", old_pid)
|
|
os.kill(old_pid, signal.SIGTERM)
|
|
for _ in range(30):
|
|
time.sleep(1)
|
|
try:
|
|
os.kill(old_pid, 0)
|
|
except OSError:
|
|
log.info("Old daemon stopped")
|
|
return
|
|
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
|
|
os.kill(old_pid, signal.SIGKILL)
|
|
time.sleep(1)
|
|
except (ValueError, OSError):
|
|
pass
|
|
|
|
def _handle_sighup(self, signum, frame):
|
|
log.info("SIGHUP received, reloading config...")
|
|
self.cfg = load_config(self.config_path)
|
|
self._setup_components()
|
|
log.info("Config reloaded")
|
|
|
|
def _handle_sigterm(self, signum, frame):
|
|
log.info("SIGTERM received, shutting down...")
|
|
self.running = False
|
|
self._stop_event.set()
|
|
|
|
def _handle_sigusr1(self, signum, frame):
|
|
log.info("SIGUSR1 received, requesting AI retrain...")
|
|
self.ai_detector.request_retrain()
|
|
|
|
# ---- Worker Threads ----
|
|
|
|
def _ewma_thread(self):
|
|
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
|
from xdp_common import dump_rate_counters, block_ip
|
|
|
|
interval = self.cfg['ewma'].get('poll_interval', 1)
|
|
prev_counters = {}
|
|
|
|
while self.running:
|
|
try:
|
|
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
|
|
active_ips = []
|
|
|
|
for ip_str, pkts, bts, last_seen in entries:
|
|
active_ips.append(ip_str)
|
|
|
|
prev = prev_counters.get(ip_str, 0)
|
|
delta = pkts - prev if pkts >= prev else pkts
|
|
prev_counters[ip_str] = pkts
|
|
|
|
if delta <= 0:
|
|
continue
|
|
|
|
pps = delta / max(interval, 0.1)
|
|
|
|
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
|
|
if is_anomalous:
|
|
level = self.violation_tracker.record_violation(ip_str)
|
|
ew = self.ewma_analyzer.get_stats(ip_str)
|
|
log.warning(
|
|
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
|
|
ip_str, pps, ew['ewma'], ew['baseline'], level
|
|
)
|
|
|
|
if level == 'temp_block':
|
|
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
|
try:
|
|
block_ip(ip_str, dur)
|
|
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
|
|
except Exception as e:
|
|
log.error("Failed to temp-block %s: %s", ip_str, e)
|
|
|
|
elif level == 'perm_block':
|
|
try:
|
|
block_ip(ip_str, 0)
|
|
log.warning("PERM BLOCK: %s", ip_str)
|
|
except Exception as e:
|
|
log.error("Failed to perm-block %s: %s", ip_str, e)
|
|
|
|
self.ewma_analyzer.cleanup_stale(active_ips)
|
|
|
|
except Exception as e:
|
|
log.error("EWMA thread error: %s", e)
|
|
|
|
self._stop_event.wait(interval)
|
|
|
|
def _ai_thread(self):
|
|
"""Read traffic features, run AI inference or collect training data."""
|
|
from xdp_common import read_percpu_features, dump_rate_counters, block_ip
|
|
|
|
interval = self.cfg['ai'].get('poll_interval', 5)
|
|
prev_features = None
|
|
|
|
while self.running:
|
|
try:
|
|
if not self.ai_detector.enabled:
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
features = read_percpu_features()
|
|
if not features:
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
feature_names = [
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count'
|
|
]
|
|
|
|
if prev_features is not None:
|
|
deltas = []
|
|
for name in feature_names:
|
|
cur = features.get(name, 0)
|
|
prev = prev_features.get(name, 0)
|
|
deltas.append(max(0, cur - prev))
|
|
|
|
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
|
|
if deltas[0] < min_pkts:
|
|
prev_features = features
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
total = deltas[0] + 1e-6
|
|
syn_ratio = deltas[2] / total
|
|
udp_ratio = deltas[4] / total
|
|
icmp_ratio = deltas[5] / total
|
|
small_pkt_ratio = deltas[8] / total
|
|
avg_pkt_size = deltas[1] / total
|
|
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
|
|
|
|
if self.ai_detector.is_learning:
|
|
self.ai_detector.collect_sample(deltas)
|
|
if len(self.ai_detector.training_data) % 100 == 0:
|
|
log.debug("AI learning: %d samples collected",
|
|
len(self.ai_detector.training_data))
|
|
else:
|
|
is_anomaly, score = self.ai_detector.predict(deltas)
|
|
if is_anomaly:
|
|
log.warning(
|
|
"AI ANOMALY detected: score=%.4f deltas=%s",
|
|
score, dict(zip(feature_names, deltas[:len(feature_names)]))
|
|
)
|
|
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
|
|
for ip_str, pkts, bts, _ in top_ips:
|
|
level = self.violation_tracker.record_violation(ip_str)
|
|
log.warning("AI escalation: %s -> %s", ip_str, level)
|
|
|
|
if level == 'temp_block':
|
|
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
|
try:
|
|
block_ip(ip_str, dur)
|
|
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
|
|
except Exception as e:
|
|
log.error("Failed to AI temp-block %s: %s", ip_str, e)
|
|
|
|
elif level == 'perm_block':
|
|
try:
|
|
block_ip(ip_str, 0)
|
|
log.warning("AI PERM BLOCK: %s", ip_str)
|
|
except Exception as e:
|
|
log.error("Failed to AI perm-block %s: %s", ip_str, e)
|
|
|
|
prev_features = features
|
|
|
|
except Exception as e:
|
|
log.error("AI thread error: %s", e)
|
|
|
|
self._stop_event.wait(interval)
|
|
|
|
def _profile_thread(self):
|
|
"""Check time-of-day and switch rate profiles."""
|
|
while self.running:
|
|
try:
|
|
self.profile_manager.check_and_apply()
|
|
except Exception as e:
|
|
log.error("Profile thread error: %s", e)
|
|
self._stop_event.wait(60)
|
|
|
|
def _cleanup_thread(self):
|
|
"""Periodically clean up expired blocked IPs and stale violations."""
|
|
from xdp_common import dump_blocked_ips, unblock_ip
|
|
|
|
while self.running:
|
|
try:
|
|
with open('/proc/uptime') as f:
|
|
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
|
|
|
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
|
|
entries = dump_blocked_ips(map_name)
|
|
for ip_str, expire_ns, blocked_at, drop_count in entries:
|
|
if expire_ns != 0 and now_ns > expire_ns:
|
|
try:
|
|
unblock_ip(ip_str)
|
|
self.violation_tracker.clear(ip_str)
|
|
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
|
|
except Exception as e:
|
|
log.error("Failed to remove expired block %s: %s", ip_str, e)
|
|
|
|
self.violation_tracker.cleanup_expired()
|
|
|
|
except Exception as e:
|
|
log.error("Cleanup thread error: %s", e)
|
|
|
|
self._stop_event.wait(60)
|
|
|
|
# ---- Main Loop ----
|
|
|
|
def run(self):
|
|
"""Start the daemon."""
|
|
log.info("XDP Defense Daemon starting...")
|
|
|
|
signal.signal(signal.SIGHUP, self._handle_sighup)
|
|
signal.signal(signal.SIGTERM, self._handle_sigterm)
|
|
signal.signal(signal.SIGINT, self._handle_sigterm)
|
|
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
|
|
|
|
self._ensure_single_instance()
|
|
self._write_pid()
|
|
self.running = True
|
|
|
|
threads = [
|
|
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
|
|
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
|
|
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
|
|
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
|
|
]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
log.info("Started %s thread", t.name)
|
|
|
|
log.info("Daemon running (PID %d)", os.getpid())
|
|
|
|
try:
|
|
while self.running:
|
|
self._stop_event.wait(1)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
log.info("Shutting down...")
|
|
self.running = False
|
|
self._stop_event.set()
|
|
|
|
for t in threads:
|
|
t.join(timeout=5)
|
|
|
|
self._remove_pid()
|
|
log.info("Daemon stopped")
|
|
|
|
|
|
# ==================== Entry Point ====================
|
|
|
|
def main():
|
|
config_path = CONFIG_PATH
|
|
if len(sys.argv) > 1:
|
|
config_path = sys.argv[1]
|
|
|
|
daemon = DDoSDaemon(config_path)
|
|
daemon.run()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|