#!/usr/bin/env python3 """ XDP Defense Daemon Userspace component: EWMA-based rate analysis, AI anomaly detection, time-profile switching, and automatic escalation. 4 worker threads + main thread (signal handling): - EWMA Thread: polls rate counters, calculates EWMA, detects violations - AI Thread: reads traffic features, runs Isolation Forest inference - Profile Thread: checks time-of-day, switches rate_config profiles - Cleanup Thread: removes expired entries from blocked_ips maps """ import copy import math import os import stat import sys import time import signal import threading import logging import logging.handlers import csv import pickle import sqlite3 from collections import defaultdict from datetime import datetime, timedelta import yaml try: from xdp_common import ( dump_rate_counters, block_ip, is_whitelisted, read_percpu_features, dump_blocked_ips, unblock_ip, write_rate_config, read_rate_config, ) except ImportError as e: print(f"FATAL: Cannot import xdp_common: {e}", file=sys.stderr) sys.exit(1) # ==================== Logging ==================== log = logging.getLogger('xdp-defense-daemon') log.setLevel(logging.INFO) _console = logging.StreamHandler() _console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')) log.addHandler(_console) try: _syslog = logging.handlers.SysLogHandler(address='/dev/log') _syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s')) log.addHandler(_syslog) except Exception: pass # ==================== Configuration ==================== DEFAULT_CONFIG = { 'general': { 'interface': 'eth0', 'log_level': 'info', 'pid_file': '/var/lib/xdp-defense/daemon.pid', 'data_dir': '/var/lib/xdp-defense', }, 'rate_limits': { 'default_pps': 1000, 'default_bps': 0, 'window_sec': 1, 'profiles': {}, }, 'escalation': { 'temp_block_after': 5, 'perm_block_after': 20, 'temp_block_duration': 300, 'violation_window': 600, }, 'ewma': { 'alpha': 0.3, 'poll_interval': 1, 'threshold_multiplier': 5.0, 'min_pps': 20, }, 'ai': { 'enabled': True, 'model_type': 'IsolationForest', 'contamination': 0.05, 'n_estimators': 100, 'learning_duration': 259200, 'min_samples': 1000, 'poll_interval': 5, 'anomaly_threshold': -0.3, 'min_packets_for_sample': 20, 'model_file': '/var/lib/xdp-defense/ai_model.pkl', 'training_data_file': '/var/lib/xdp-defense/training_data.csv', 'traffic_log_db': '/var/lib/xdp-defense/traffic_log.db', 'traffic_log_retention_days': 7, 'retrain_interval': 86400, 'retrain_window': 604800, }, } CONFIG_PATH = '/etc/xdp-defense/config.yaml' def load_config(path=CONFIG_PATH): """Load config with defaults.""" cfg = copy.deepcopy(DEFAULT_CONFIG) try: with open(path) as f: user = yaml.safe_load(f) or {} for section in cfg: if section in user and isinstance(user[section], dict): cfg[section].update(user[section]) except FileNotFoundError: log.warning("Config not found at %s, using defaults", path) except Exception as e: log.error("Failed to load config: %s", e) return cfg # ==================== ViolationTracker ==================== class ViolationTracker: """Track per-IP violation counts and manage escalation.""" def __init__(self, escalation_cfg): self.cfg = escalation_cfg self.violations = defaultdict(list) self.lock = threading.Lock() def record_violation(self, ip): """Record a violation and return escalation level. Returns: 'rate_limit', 'temp_block', 'perm_block', or None """ now = time.time() window = self.cfg.get('violation_window', 600) with self.lock: self.violations[ip] = [t for t in self.violations[ip] if now - t < window] self.violations[ip].append(now) count = len(self.violations[ip]) perm_after = self.cfg.get('perm_block_after', 20) temp_after = self.cfg.get('temp_block_after', 5) if count >= perm_after: return 'perm_block' elif count >= temp_after: return 'temp_block' return 'rate_limit' def clear(self, ip): with self.lock: self.violations.pop(ip, None) def cleanup_expired(self): """Remove entries with no recent violations.""" now = time.time() window = self.cfg.get('violation_window', 600) with self.lock: expired = [ip for ip, times in self.violations.items() if all(now - t >= window for t in times)] for ip in expired: del self.violations[ip] # ==================== EWMAAnalyzer ==================== class EWMAAnalyzer: """Per-IP EWMA calculation for rate anomaly detection.""" def __init__(self, alpha=0.3, threshold_multiplier=5.0, min_pps=20): self.alpha = alpha self.threshold_multiplier = threshold_multiplier self.min_pps = min_pps self.ewma = {} self.baseline = {} self.lock = threading.Lock() def update(self, ip, current_pps): """Update EWMA for an IP. Returns True if anomalous.""" with self.lock: if ip not in self.ewma: self.ewma[ip] = current_pps self.baseline[ip] = current_pps return False self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip] self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip] if current_pps < self.min_pps: return False base = max(self.baseline[ip], 1) if self.ewma[ip] > base * self.threshold_multiplier: return True return False def get_stats(self, ip): with self.lock: return { 'ewma': self.ewma.get(ip, 0), 'baseline': self.baseline.get(ip, 0), } def cleanup_stale(self, active_ips): """Remove tracking for IPs no longer in rate counters.""" with self.lock: stale = set(self.ewma.keys()) - set(active_ips) for ip in stale: self.ewma.pop(ip, None) self.baseline.pop(ip, None) # ==================== AIDetector ==================== class AIDetector: """Time-period Isolation Forest anomaly detection. Maintains 4 separate models for different time periods: night (00-06), morning (06-12), afternoon (12-18), evening (18-24). Each model learns the normal traffic pattern for its time window. """ TIME_PERIODS = { 'night': (0, 6), 'morning': (6, 12), 'afternoon': (12, 18), 'evening': (18, 24), } def __init__(self, ai_cfg): self.cfg = ai_cfg self.models = {} # {period_name: {'model': ..., 'scaler': ...}} self.started_at = time.time() self.training_data = defaultdict(list) # {period_name: [samples]} self.is_learning = True self._retrain_requested = False self._retrain_lock = threading.Lock() @property def enabled(self): return self.cfg.get('enabled', False) @staticmethod def get_period(hour): """Map hour (0-24) to period name.""" for name, (start, end) in AIDetector.TIME_PERIODS.items(): if start <= hour < end: return name return 'night' # hour == 24 edge case def request_retrain(self): self._retrain_requested = True def collect_sample(self, features, hour): """Collect a feature sample during learning phase, bucketed by period.""" if not self.enabled: return period = self.get_period(hour) self.training_data[period].append(features) learning_dur = self.cfg.get('learning_duration', 259200) min_samples = self.cfg.get('min_samples', 1000) elapsed = time.time() - self.started_at total = sum(len(v) for v in self.training_data.values()) if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested: self._train() self._retrain_requested = False def _train(self, period_data=None): """Train per-period Isolation Forest models. If period_data provided, use it instead of self.training_data. """ if period_data is None: period_data = self.training_data try: from sklearn.ensemble import IsolationForest from sklearn.preprocessing import StandardScaler import numpy as np except ImportError: log.error("scikit-learn not installed. AI detection disabled.") self.cfg['enabled'] = False return total = sum(len(v) for v in period_data.values()) if total < 10: log.warning("Not enough training data (%d samples)", total) return log.info("Training AI models: %s", {p: len(s) for p, s in period_data.items() if s}) try: new_models = {} all_samples = [] for period, samples in period_data.items(): if len(samples) < 10: log.info("Period %s: %d samples (too few, skip)", period, len(samples)) continue X = np.array(samples) scaler = StandardScaler() X_scaled = scaler.fit_transform(X) model = IsolationForest( n_estimators=self.cfg.get('n_estimators', 100), contamination=self.cfg.get('contamination', 'auto'), random_state=42, ) model.fit(X_scaled) new_models[period] = {'model': model, 'scaler': scaler} all_samples.extend(samples) log.info("Period %s: trained with %d samples", period, len(samples)) if not new_models: log.warning("No period had enough data to train") return # Atomic swap self.models = new_models self.is_learning = False # Save to disk (atomic) model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl') tmp_model = model_file + '.tmp' fd = os.open(tmp_model, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) with os.fdopen(fd, 'wb') as f: pickle.dump({ 'format': 'period_models', 'models': {p: {'model': m['model'], 'scaler': m['scaler']} for p, m in new_models.items()}, 'feature_count': 17, }, f) f.flush() os.fsync(f.fileno()) os.rename(tmp_model, model_file) # Save training data CSV data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv') tmp_data = data_file + '.tmp' fd = os.open(tmp_data, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644) with os.fdopen(fd, 'w', newline='') as f: writer = csv.writer(f) writer.writerow([ 'hour_sin', 'hour_cos', 'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count', 'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx', 'small_pkt_count', 'large_pkt_count', 'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size' ]) writer.writerows(all_samples) f.flush() os.fsync(f.fileno()) os.rename(tmp_data, data_file) log.info("AI models trained and saved (%d periods)", len(new_models)) except Exception as e: log.error("AI training failed: %s", e) def load_model(self): """Load previously trained models. Handle old and new formats.""" model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl') if not os.path.exists(model_file): return False try: st = os.stat(model_file) # Warn if file is world-writable if st.st_mode & stat.S_IWOTH: log.warning("Model file %s is world-writable! Refusing to load.", model_file) return False with open(model_file, 'rb') as f: data = pickle.load(f) # New period_models format if data.get('format') == 'period_models': self.models = data['models'] self.is_learning = False periods = list(self.models.keys()) log.info("AI models loaded: %s (%d features)", periods, data.get('feature_count', '?')) return True # Old single-model format → discard log.warning("Old single-model format detected. Switching to learning mode.") self.is_learning = True return False except Exception as e: log.error("Failed to load AI model: %s", e) return False def predict(self, features, hour): """Run anomaly detection using the period-specific model. Returns (is_anomaly, score). """ if not self.enabled or not self.models: return False, 0.0 period = self.get_period(hour) model_info = self.models.get(period) # Fallback: try adjacent periods if model_info is None: for p in self.models: model_info = self.models[p] break if model_info is None: return False, 0.0 try: import numpy as np X = np.array([features]) X_scaled = model_info['scaler'].transform(X) score = model_info['model'].decision_function(X_scaled)[0] threshold = self.cfg.get('anomaly_threshold', -0.3) return score < threshold, float(score) except Exception as e: log.error("AI prediction error: %s", e) return False, 0.0 def retrain_from_log(self, db_path=None, background=False): """Retrain models from traffic_log.db. Optionally in background thread.""" if db_path is None: db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') if background: t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True) t.start() return True # started, not yet finished return self._do_retrain(db_path) def _do_retrain(self, db_path): """Actual retrain logic (can run in background thread).""" if not self._retrain_lock.acquire(blocking=False): log.info("Retrain already in progress, skipping") return False try: return self._retrain_impl(db_path) finally: self._retrain_lock.release() def _retrain_impl(self, db_path): """Load data from DB, filter outliers, train per-period models.""" if not os.path.exists(db_path): log.warning("Traffic log DB not found: %s", db_path) return False retrain_window = self.cfg.get('retrain_window', 86400) cutoff = (datetime.now() - timedelta(seconds=retrain_window)).isoformat() conn = None try: conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True) cur = conn.execute( 'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, ' 'tcp_syn_count, tcp_other_count, udp_count, icmp_count, ' 'other_proto_count, unique_ips_approx, small_pkt_count, ' 'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, ' 'small_pkt_ratio, avg_pkt_size ' 'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp', (cutoff,) ) rows = cur.fetchall() except Exception as e: log.error("retrain_from_log DB read failed: %s", e) return False finally: if conn: conn.close() if len(rows) < 10: log.warning("Not enough recent samples for retrain (%d)", len(rows)) return False # Split by period: row[0]=hour, row[1:]=17 features period_data = defaultdict(list) for row in rows: hour = row[0] features = list(row[1:]) # 17 features period = self.get_period(hour) period_data[period].append(features) # Filter outliers using existing models (item #2) try: import numpy as np except ImportError: log.error("numpy not available for outlier filtering") return False filtered_count = 0 for period, samples in period_data.items(): model_info = self.models.get(period) if model_info is None or len(samples) < 10: continue X = np.array(samples) X_scaled = model_info['scaler'].transform(X) scores = model_info['model'].decision_function(X_scaled) clean = [s for s, sc in zip(samples, scores) if sc >= -0.5] removed = len(samples) - len(clean) if removed > 0: filtered_count += removed period_data[period] = clean total = sum(len(v) for v in period_data.values()) log.info("Retrain: %d samples loaded, %d outliers filtered, %s", len(rows), filtered_count, {p: len(s) for p, s in period_data.items() if s}) self._train(period_data=period_data) return not self.is_learning # ==================== ProfileManager ==================== class ProfileManager: """Manage time-based rate limit profiles.""" def __init__(self, rate_cfg): self.cfg = rate_cfg self.current_profile = 'default' def check_and_apply(self): """Check current time and apply matching profile.""" profiles = self.cfg.get('profiles', {}) now = datetime.now() current_hour = now.hour current_min = now.minute current_time = current_hour * 60 + current_min weekday = now.strftime('%a').lower() matched_profile = None matched_name = 'default' for name, profile in profiles.items(): hours = profile.get('hours', '') weekdays = profile.get('weekdays', '') if weekdays: day_range = weekdays.lower().split('-') day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] if len(day_range) == 2: try: start_idx = day_names.index(day_range[0]) end_idx = day_names.index(day_range[1]) current_idx = day_names.index(weekday) if start_idx <= end_idx: if not (start_idx <= current_idx <= end_idx): continue else: if not (current_idx >= start_idx or current_idx <= end_idx): continue except ValueError: continue if hours: try: start_str, end_str = hours.split('-') sh, sm = map(int, start_str.split(':')) eh, em = map(int, end_str.split(':')) start_min = sh * 60 + sm end_min = eh * 60 + em if start_min <= end_min: if not (start_min <= current_time < end_min): continue else: if not (current_time >= start_min or current_time < end_min): continue except (ValueError, AttributeError): continue matched_profile = profile matched_name = name break if matched_name != self.current_profile: if matched_profile: pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000)) bps = matched_profile.get('bps', self.cfg.get('default_bps', 0)) else: pps = self.cfg.get('default_pps', 1000) bps = self.cfg.get('default_bps', 0) window = self.cfg.get('window_sec', 1) try: write_rate_config(pps, bps, window * 1_000_000_000) log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps) self.current_profile = matched_name except Exception as e: log.error("Failed to apply profile %s: %s", matched_name, e) # ==================== DDoSDaemon ==================== class DDoSDaemon: """Main daemon orchestrator.""" def __init__(self, config_path=CONFIG_PATH): self.config_path = config_path self.cfg = load_config(config_path) self._stop_event = threading.Event() self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1) self._ai_interval = self.cfg['ai'].get('poll_interval', 5) self._setup_components() def _setup_components(self): self.violation_tracker = ViolationTracker(self.cfg['escalation']) self.ewma_analyzer = EWMAAnalyzer( alpha=self.cfg['ewma'].get('alpha', 0.3), threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 5.0), min_pps=self.cfg['ewma'].get('min_pps', 20), ) self.ai_detector = AIDetector(self.cfg['ai']) self.profile_manager = ProfileManager(self.cfg['rate_limits']) if self.ai_detector.enabled: self.ai_detector.load_model() self._last_retrain_time = self._get_model_mtime() self._last_log_cleanup = time.time() self._init_traffic_db() level = self.cfg['general'].get('log_level', 'info').upper() log.setLevel(getattr(logging, level, logging.INFO)) def _write_pid(self): pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid') os.makedirs(os.path.dirname(pid_file), exist_ok=True) with open(pid_file, 'w') as f: f.write(str(os.getpid())) def _remove_pid(self): pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid') try: os.unlink(pid_file) except OSError: pass def _ensure_single_instance(self): """Stop any existing daemon before starting.""" pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid') if not os.path.exists(pid_file): return try: with open(pid_file) as f: old_pid = int(f.read().strip()) os.kill(old_pid, 0) log.info("Stopping existing daemon (PID %d)...", old_pid) os.kill(old_pid, signal.SIGTERM) for _ in range(30): time.sleep(1) try: os.kill(old_pid, 0) except OSError: log.info("Old daemon stopped") return log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid) os.kill(old_pid, signal.SIGKILL) time.sleep(1) except (ValueError, OSError): pass def _handle_sighup(self, signum, frame): log.info("SIGHUP received, reloading config...") new_cfg = load_config(self.config_path) # Build all new values before swapping anything new_escalation = new_cfg['escalation'] new_alpha = new_cfg['ewma'].get('alpha', 0.3) new_threshold = new_cfg['ewma'].get('threshold_multiplier', 5.0) new_min_pps = new_cfg['ewma'].get('min_pps', 20) new_ai_cfg = new_cfg['ai'] new_rate_cfg = new_cfg['rate_limits'] new_ewma_interval = new_cfg['ewma'].get('poll_interval', 1) new_ai_interval = new_cfg['ai'].get('poll_interval', 5) level = new_cfg['general'].get('log_level', 'info').upper() # Now apply all at once self.cfg = new_cfg self.violation_tracker.cfg = new_escalation self.ewma_analyzer.alpha = new_alpha self.ewma_analyzer.threshold_multiplier = new_threshold self.ewma_analyzer.min_pps = new_min_pps self.ai_detector.cfg = new_ai_cfg self.profile_manager.cfg = new_rate_cfg self._ewma_interval = new_ewma_interval self._ai_interval = new_ai_interval log.setLevel(getattr(logging, level, logging.INFO)) log.info("Config reloaded (state preserved)") def _handle_sigterm(self, signum, frame): log.info("SIGTERM received, shutting down...") self._stop_event.set() def _handle_sigusr1(self, signum, frame): log.info("SIGUSR1 received, triggering retrain from traffic log (background)...") db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') self.ai_detector.retrain_from_log(db_path, background=True) self._last_retrain_time = time.time() # ---- Traffic Logging (SQLite) ---- def _init_traffic_db(self): """Initialize SQLite database for traffic logging.""" db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') os.makedirs(os.path.dirname(db_path), exist_ok=True) self._db_lock = threading.Lock() self._traffic_db = sqlite3.connect(db_path, check_same_thread=False) self._traffic_db.execute( 'CREATE TABLE IF NOT EXISTS traffic_samples (' ' id INTEGER PRIMARY KEY AUTOINCREMENT,' ' timestamp TEXT NOT NULL,' ' hour REAL NOT NULL,' ' hour_sin REAL NOT NULL,' ' hour_cos REAL NOT NULL,' ' total_packets REAL NOT NULL,' ' total_bytes REAL NOT NULL,' ' tcp_syn_count REAL NOT NULL,' ' tcp_other_count REAL NOT NULL,' ' udp_count REAL NOT NULL,' ' icmp_count REAL NOT NULL,' ' other_proto_count REAL NOT NULL,' ' unique_ips_approx REAL NOT NULL,' ' small_pkt_count REAL NOT NULL,' ' large_pkt_count REAL NOT NULL,' ' syn_ratio REAL NOT NULL,' ' udp_ratio REAL NOT NULL,' ' icmp_ratio REAL NOT NULL,' ' small_pkt_ratio REAL NOT NULL,' ' avg_pkt_size REAL NOT NULL' ')' ) self._traffic_db.execute( 'CREATE INDEX IF NOT EXISTS idx_timestamp ON traffic_samples(timestamp)' ) self._traffic_db.commit() log.info("Traffic log DB initialized: %s", db_path) def _log_traffic(self, now, hour, features): """Insert one row into traffic_samples table.""" try: with self._db_lock: self._traffic_db.execute( 'INSERT INTO traffic_samples (' ' timestamp, hour, hour_sin, hour_cos,' ' total_packets, total_bytes, tcp_syn_count, tcp_other_count,' ' udp_count, icmp_count, other_proto_count, unique_ips_approx,' ' small_pkt_count, large_pkt_count,' ' syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size' ') VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', (now.isoformat(), hour, *features) ) self._traffic_db.commit() except Exception as e: log.error("Failed to write traffic log: %s", e) def _cleanup_traffic_log(self): """Remove entries older than retention_days from traffic_samples.""" retention_days = self.cfg['ai'].get('traffic_log_retention_days', 7) cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat() try: with self._db_lock: cur = self._traffic_db.execute( 'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,) ) deleted = cur.rowcount self._traffic_db.commit() if deleted > 1000: self._traffic_db.execute('VACUUM') log.info("Traffic log cleanup: deleted %d rows (retention=%dd)", deleted, retention_days) except Exception as e: log.error("Traffic log cleanup failed: %s", e) def _get_model_mtime(self): """Get model file modification time, or current time if not found.""" model_file = self.cfg['ai'].get('model_file', '/var/lib/xdp-defense/ai_model.pkl') try: return os.path.getmtime(model_file) except OSError: return time.time() # ---- Worker Threads ---- def _ewma_thread(self): """Poll rate counters, compute EWMA, detect violations, escalate.""" prev_counters = {} while not self._stop_event.is_set(): interval = self._ewma_interval try: entries = dump_rate_counters('rate_counter_v4', top_n=1000) active_ips = [] for ip_str, pkts, bts, last_seen in entries: active_ips.append(ip_str) prev = prev_counters.get(ip_str, 0) delta = pkts - prev if pkts >= prev else pkts prev_counters[ip_str] = pkts if delta <= 0: continue pps = delta / max(interval, 0.1) is_anomalous = self.ewma_analyzer.update(ip_str, pps) if is_anomalous: # Skip whitelisted IPs if is_whitelisted(ip_str): log.debug("EWMA anomaly skipped (whitelisted): %s", ip_str) continue level = self.violation_tracker.record_violation(ip_str) ew = self.ewma_analyzer.get_stats(ip_str) log.warning( "EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s", ip_str, pps, ew['ewma'], ew['baseline'], level ) if level == 'temp_block': dur = self.cfg['escalation'].get('temp_block_duration', 300) try: block_ip(ip_str, dur) log.warning("TEMP BLOCK: %s for %ds", ip_str, dur) except Exception as e: log.error("Failed to temp-block %s: %s", ip_str, e) elif level == 'perm_block': try: block_ip(ip_str, 0) log.warning("PERM BLOCK: %s", ip_str) except Exception as e: log.error("Failed to perm-block %s: %s", ip_str, e) self.ewma_analyzer.cleanup_stale(active_ips) except Exception as e: log.error("EWMA thread error: %s", e) self._stop_event.wait(interval) def _ai_thread(self): """Read traffic features, run AI inference or collect training data.""" prev_features = None ai_prev_counters = {} ai_prev_counter_time = 0 self._last_retrain_time = self._get_model_mtime() self._last_log_cleanup = time.time() while not self._stop_event.is_set(): interval = self._ai_interval try: if not self.ai_detector.enabled: self._stop_event.wait(interval) continue features = read_percpu_features() if not features: self._stop_event.wait(interval) continue feature_names = [ 'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count', 'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx', 'small_pkt_count', 'large_pkt_count' ] if prev_features is not None: deltas = [] for name in feature_names: cur = features.get(name, 0) prev = prev_features.get(name, 0) deltas.append(max(0, cur - prev)) min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20) if deltas[0] < min_pkts: prev_features = features self._stop_event.wait(interval) continue total = deltas[0] + 1e-6 syn_ratio = deltas[2] / total udp_ratio = deltas[4] / total icmp_ratio = deltas[5] / total small_pkt_ratio = deltas[8] / total avg_pkt_size = deltas[1] / total deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size]) # Add time features (hour_sin, hour_cos) at the front now = datetime.now() hour = now.hour + now.minute / 60.0 hour_sin = math.sin(2 * math.pi * hour / 24) hour_cos = math.cos(2 * math.pi * hour / 24) deltas_with_time = [hour_sin, hour_cos] + deltas # 17 features # Log to traffic DB self._log_traffic(now, hour, deltas_with_time) # Periodic log file cleanup (once per day) if time.time() - self._last_log_cleanup > 86400: self._cleanup_traffic_log() self._last_log_cleanup = time.time() if self.ai_detector.is_learning: self.ai_detector.collect_sample(deltas_with_time, hour) total_samples = sum(len(v) for v in self.ai_detector.training_data.values()) if total_samples % 100 == 0 and total_samples > 0: log.debug("AI learning: %d samples collected", total_samples) else: # Auto-retrain check (background, no inference gap) retrain_interval = self.cfg['ai'].get('retrain_interval', 86400) if time.time() - self._last_retrain_time >= retrain_interval: log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval) db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') self.ai_detector.retrain_from_log(db_path, background=True) self._last_retrain_time = time.time() is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour) if is_anomaly: log.warning( "AI ANOMALY detected: score=%.4f deltas=%s", score, dict(zip(feature_names, deltas[:len(feature_names)])) ) top_ips = dump_rate_counters('rate_counter_v4', top_n=10) now_ts = time.time() ai_elapsed = now_ts - ai_prev_counter_time if ai_prev_counter_time > 0 else interval ai_prev_counter_time = now_ts for ip_str, pkts, bts, _ in top_ips: prev_pkts = ai_prev_counters.get(ip_str) ai_prev_counters[ip_str] = pkts if is_whitelisted(ip_str): log.debug("AI escalation skipped (whitelisted): %s", ip_str) continue stats = self.ewma_analyzer.get_stats(ip_str) baseline = max(stats['baseline'], 1) ewma = stats['ewma'] if stats['baseline'] > 0: if ewma <= baseline * 2.0: log.debug("AI skip (normal EWMA): %s ewma=%.1f baseline=%.1f", ip_str, ewma, baseline) continue else: pps_limit = self.cfg['rate_limit'].get('pps', 2000) if prev_pkts is not None: delta = pkts - prev_pkts if pkts >= prev_pkts else pkts est_pps = delta / max(ai_elapsed, 1) if est_pps <= pps_limit: log.debug("AI skip (new IP, low pps): %s est_pps=%.1f", ip_str, est_pps) continue else: log.debug("AI skip (new IP, first seen): %s", ip_str) continue level = self.violation_tracker.record_violation(ip_str) log.warning("AI escalation: %s ewma=%.1f baseline=%.1f -> %s", ip_str, ewma, baseline, level) if level == 'temp_block': dur = self.cfg['escalation'].get('temp_block_duration', 300) try: block_ip(ip_str, dur) log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur) except Exception as e: log.error("Failed to AI temp-block %s: %s", ip_str, e) elif level == 'perm_block': try: block_ip(ip_str, 0) log.warning("AI PERM BLOCK: %s", ip_str) except Exception as e: log.error("Failed to AI perm-block %s: %s", ip_str, e) prev_features = features except Exception as e: log.error("AI thread error: %s", e) self._stop_event.wait(interval) def _profile_thread(self): """Check time-of-day and switch rate profiles.""" while not self._stop_event.is_set(): try: self.profile_manager.check_and_apply() except Exception as e: log.error("Profile thread error: %s", e) self._stop_event.wait(60) def _cleanup_thread(self): """Periodically clean up expired blocked IPs and stale violations.""" while not self._stop_event.is_set(): try: with open('/proc/uptime') as f: now_ns = int(float(f.read().split()[0]) * 1_000_000_000) for map_name in ['blocked_ips_v4', 'blocked_ips_v6']: entries = dump_blocked_ips(map_name) for ip_str, expire_ns, blocked_at, drop_count in entries: if expire_ns != 0 and now_ns > expire_ns: try: unblock_ip(ip_str) self.violation_tracker.clear(ip_str) log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count) except Exception as e: log.error("Failed to remove expired block %s: %s", ip_str, e) self.violation_tracker.cleanup_expired() except Exception as e: log.error("Cleanup thread error: %s", e) self._stop_event.wait(60) # ---- Main Loop ---- def run(self): """Start the daemon.""" log.info("XDP Defense Daemon starting...") signal.signal(signal.SIGHUP, self._handle_sighup) signal.signal(signal.SIGTERM, self._handle_sigterm) signal.signal(signal.SIGINT, self._handle_sigterm) signal.signal(signal.SIGUSR1, self._handle_sigusr1) self._ensure_single_instance() self._write_pid() threads = [ threading.Thread(target=self._ewma_thread, name='ewma', daemon=True), threading.Thread(target=self._ai_thread, name='ai', daemon=True), threading.Thread(target=self._profile_thread, name='profile', daemon=True), threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True), ] for t in threads: t.start() log.info("Started %s thread", t.name) log.info("Daemon running (PID %d)", os.getpid()) try: while not self._stop_event.is_set(): self._stop_event.wait(1) except KeyboardInterrupt: pass log.info("Shutting down...") self._stop_event.set() for t in threads: t.join(timeout=5) self._remove_pid() if hasattr(self, '_traffic_db') and self._traffic_db: try: self._traffic_db.close() except Exception: pass log.info("Daemon stopped") # ==================== Entry Point ==================== def main(): config_path = CONFIG_PATH if len(sys.argv) > 1: config_path = sys.argv[1] daemon = DDoSDaemon(config_path) daemon.run() if __name__ == '__main__': main()