Files
xdp-defense/lib/xdp_defense_daemon.py
2026-02-07 10:49:35 +09:00

1022 lines
39 KiB
Python
Executable File

#!/usr/bin/env python3
"""
XDP Defense Daemon
Userspace component: EWMA-based rate analysis, AI anomaly detection,
time-profile switching, and automatic escalation.
4 worker threads + main thread (signal handling):
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
- AI Thread: reads traffic features, runs Isolation Forest inference
- Profile Thread: checks time-of-day, switches rate_config profiles
- Cleanup Thread: removes expired entries from blocked_ips maps
"""
import copy
import math
import os
import sys
import time
import signal
import threading
import logging
import logging.handlers
import csv
import pickle
import sqlite3
from collections import defaultdict
from datetime import datetime, timedelta
import yaml
# ==================== Logging ====================
log = logging.getLogger('xdp-defense-daemon')
log.setLevel(logging.INFO)
_console = logging.StreamHandler()
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
log.addHandler(_console)
try:
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
log.addHandler(_syslog)
except Exception:
pass
# ==================== Configuration ====================
DEFAULT_CONFIG = {
'general': {
'interface': 'eth0',
'log_level': 'info',
'pid_file': '/var/lib/xdp-defense/daemon.pid',
'data_dir': '/var/lib/xdp-defense',
},
'rate_limits': {
'default_pps': 1000,
'default_bps': 0,
'window_sec': 1,
'profiles': {},
},
'escalation': {
'temp_block_after': 5,
'perm_block_after': 20,
'temp_block_duration': 300,
'violation_window': 600,
},
'ewma': {
'alpha': 0.3,
'poll_interval': 1,
'threshold_multiplier': 3.0,
},
'ai': {
'enabled': True,
'model_type': 'IsolationForest',
'contamination': 0.05,
'n_estimators': 100,
'learning_duration': 259200,
'min_samples': 1000,
'poll_interval': 5,
'anomaly_threshold': -0.3,
'min_packets_for_sample': 20,
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
'traffic_log_db': '/var/lib/xdp-defense/traffic_log.db',
'traffic_log_retention_days': 7,
'retrain_interval': 86400,
'retrain_window': 604800,
},
}
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
def load_config(path=CONFIG_PATH):
"""Load config with defaults."""
cfg = copy.deepcopy(DEFAULT_CONFIG)
try:
with open(path) as f:
user = yaml.safe_load(f) or {}
for section in cfg:
if section in user and isinstance(user[section], dict):
cfg[section].update(user[section])
except FileNotFoundError:
log.warning("Config not found at %s, using defaults", path)
except Exception as e:
log.error("Failed to load config: %s", e)
return cfg
# ==================== ViolationTracker ====================
class ViolationTracker:
"""Track per-IP violation counts and manage escalation."""
def __init__(self, escalation_cfg):
self.cfg = escalation_cfg
self.violations = defaultdict(list)
self.lock = threading.Lock()
def record_violation(self, ip):
"""Record a violation and return escalation level.
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
"""
now = time.time()
window = self.cfg.get('violation_window', 600)
with self.lock:
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
self.violations[ip].append(now)
count = len(self.violations[ip])
perm_after = self.cfg.get('perm_block_after', 20)
temp_after = self.cfg.get('temp_block_after', 5)
if count >= perm_after:
return 'perm_block'
elif count >= temp_after:
return 'temp_block'
return 'rate_limit'
def clear(self, ip):
with self.lock:
self.violations.pop(ip, None)
def cleanup_expired(self):
"""Remove entries with no recent violations."""
now = time.time()
window = self.cfg.get('violation_window', 600)
with self.lock:
expired = [ip for ip, times in self.violations.items()
if all(now - t >= window for t in times)]
for ip in expired:
del self.violations[ip]
# ==================== EWMAAnalyzer ====================
class EWMAAnalyzer:
"""Per-IP EWMA calculation for rate anomaly detection."""
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
self.alpha = alpha
self.threshold_multiplier = threshold_multiplier
self.ewma = {}
self.baseline = {}
self.lock = threading.Lock()
def update(self, ip, current_pps):
"""Update EWMA for an IP. Returns True if anomalous."""
with self.lock:
if ip not in self.ewma:
self.ewma[ip] = current_pps
self.baseline[ip] = current_pps
return False
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
base = max(self.baseline[ip], 1)
if self.ewma[ip] > base * self.threshold_multiplier:
return True
return False
def get_stats(self, ip):
with self.lock:
return {
'ewma': self.ewma.get(ip, 0),
'baseline': self.baseline.get(ip, 0),
}
def cleanup_stale(self, active_ips):
"""Remove tracking for IPs no longer in rate counters."""
with self.lock:
stale = set(self.ewma.keys()) - set(active_ips)
for ip in stale:
self.ewma.pop(ip, None)
self.baseline.pop(ip, None)
# ==================== AIDetector ====================
class AIDetector:
"""Time-period Isolation Forest anomaly detection.
Maintains 4 separate models for different time periods:
night (00-06), morning (06-12), afternoon (12-18), evening (18-24).
Each model learns the normal traffic pattern for its time window.
"""
TIME_PERIODS = {
'night': (0, 6),
'morning': (6, 12),
'afternoon': (12, 18),
'evening': (18, 24),
}
def __init__(self, ai_cfg):
self.cfg = ai_cfg
self.models = {} # {period_name: {'model': ..., 'scaler': ...}}
self.started_at = time.time()
self.training_data = defaultdict(list) # {period_name: [samples]}
self.is_learning = True
self._retrain_requested = False
self._retrain_lock = threading.Lock()
@property
def enabled(self):
return self.cfg.get('enabled', False)
@staticmethod
def get_period(hour):
"""Map hour (0-24) to period name."""
for name, (start, end) in AIDetector.TIME_PERIODS.items():
if start <= hour < end:
return name
return 'night' # hour == 24 edge case
def request_retrain(self):
self._retrain_requested = True
def collect_sample(self, features, hour):
"""Collect a feature sample during learning phase, bucketed by period."""
if not self.enabled:
return
period = self.get_period(hour)
self.training_data[period].append(features)
learning_dur = self.cfg.get('learning_duration', 259200)
min_samples = self.cfg.get('min_samples', 1000)
elapsed = time.time() - self.started_at
total = sum(len(v) for v in self.training_data.values())
if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested:
self._train()
self._retrain_requested = False
def _train(self):
"""Train per-period Isolation Forest models."""
try:
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
import numpy as np
except ImportError:
log.error("scikit-learn not installed. AI detection disabled.")
self.cfg['enabled'] = False
return
total = sum(len(v) for v in self.training_data.values())
if total < 10:
log.warning("Not enough training data (%d samples)", total)
return
log.info("Training AI models: %s",
{p: len(s) for p, s in self.training_data.items() if s})
try:
new_models = {}
all_samples = []
for period, samples in self.training_data.items():
if len(samples) < 10:
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
continue
X = np.array(samples)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
model = IsolationForest(
n_estimators=self.cfg.get('n_estimators', 100),
contamination=self.cfg.get('contamination', 'auto'),
random_state=42,
)
model.fit(X_scaled)
new_models[period] = {'model': model, 'scaler': scaler}
all_samples.extend(samples)
log.info("Period %s: trained with %d samples", period, len(samples))
if not new_models:
log.warning("No period had enough data to train")
return
# Atomic swap
self.models = new_models
self.is_learning = False
# Save to disk (atomic)
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
tmp_model = model_file + '.tmp'
with open(tmp_model, 'wb') as f:
pickle.dump({
'format': 'period_models',
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
for p, m in new_models.items()},
'feature_count': 17,
}, f)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_model, model_file)
# Save training data CSV
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
tmp_data = data_file + '.tmp'
with open(tmp_data, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'hour_sin', 'hour_cos',
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
'small_pkt_count', 'large_pkt_count',
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
])
writer.writerows(all_samples)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_data, data_file)
log.info("AI models trained and saved (%d periods)", len(new_models))
except Exception as e:
log.error("AI training failed: %s", e)
def load_model(self):
"""Load previously trained models. Handle old and new formats."""
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
if not os.path.exists(model_file):
return False
try:
with open(model_file, 'rb') as f:
data = pickle.load(f)
# New period_models format
if data.get('format') == 'period_models':
self.models = data['models']
self.is_learning = False
periods = list(self.models.keys())
log.info("AI models loaded: %s (%d features)",
periods, data.get('feature_count', '?'))
return True
# Old single-model format → discard
log.warning("Old single-model format detected. Switching to learning mode.")
self.is_learning = True
return False
except Exception as e:
log.error("Failed to load AI model: %s", e)
return False
def predict(self, features, hour):
"""Run anomaly detection using the period-specific model.
Returns (is_anomaly, score).
"""
if not self.enabled or not self.models:
return False, 0.0
period = self.get_period(hour)
model_info = self.models.get(period)
# Fallback: try adjacent periods
if model_info is None:
for p in self.models:
model_info = self.models[p]
break
if model_info is None:
return False, 0.0
try:
import numpy as np
X = np.array([features])
X_scaled = model_info['scaler'].transform(X)
score = model_info['model'].decision_function(X_scaled)[0]
threshold = self.cfg.get('anomaly_threshold', -0.3)
return score < threshold, float(score)
except Exception as e:
log.error("AI prediction error: %s", e)
return False, 0.0
def retrain_from_log(self, db_path=None, background=False):
"""Retrain models from traffic_log.db. Optionally in background thread."""
if db_path is None:
db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
if background:
t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True)
t.start()
return True # started, not yet finished
return self._do_retrain(db_path)
def _do_retrain(self, db_path):
"""Actual retrain logic (can run in background thread)."""
if not self._retrain_lock.acquire(blocking=False):
log.info("Retrain already in progress, skipping")
return False
try:
return self._retrain_impl(db_path)
finally:
self._retrain_lock.release()
def _retrain_impl(self, db_path):
"""Load data from DB, filter outliers, train per-period models."""
if not os.path.exists(db_path):
log.warning("Traffic log DB not found: %s", db_path)
return False
retrain_window = self.cfg.get('retrain_window', 86400)
cutoff = (datetime.now() - timedelta(seconds=retrain_window)).isoformat()
conn = None
try:
conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
cur = conn.execute(
'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, '
'tcp_syn_count, tcp_other_count, udp_count, icmp_count, '
'other_proto_count, unique_ips_approx, small_pkt_count, '
'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, '
'small_pkt_ratio, avg_pkt_size '
'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp',
(cutoff,)
)
rows = cur.fetchall()
except Exception as e:
log.error("retrain_from_log DB read failed: %s", e)
return False
finally:
if conn:
conn.close()
if len(rows) < 10:
log.warning("Not enough recent samples for retrain (%d)", len(rows))
return False
# Split by period: row[0]=hour, row[1:]=17 features
period_data = defaultdict(list)
for row in rows:
hour = row[0]
features = list(row[1:]) # 17 features
period = self.get_period(hour)
period_data[period].append(features)
# Filter outliers using existing models (item #2)
try:
import numpy as np
except ImportError:
log.error("numpy not available for outlier filtering")
return False
filtered_count = 0
for period, samples in period_data.items():
model_info = self.models.get(period)
if model_info is None or len(samples) < 10:
continue
X = np.array(samples)
X_scaled = model_info['scaler'].transform(X)
scores = model_info['model'].decision_function(X_scaled)
clean = [s for s, sc in zip(samples, scores) if sc >= -0.5]
removed = len(samples) - len(clean)
if removed > 0:
filtered_count += removed
period_data[period] = clean
total = sum(len(v) for v in period_data.values())
log.info("Retrain: %d samples loaded, %d outliers filtered, %s",
len(rows), filtered_count,
{p: len(s) for p, s in period_data.items() if s})
self.training_data = period_data
self._train()
return not self.is_learning
# ==================== ProfileManager ====================
class ProfileManager:
"""Manage time-based rate limit profiles."""
def __init__(self, rate_cfg):
self.cfg = rate_cfg
self.current_profile = 'default'
def check_and_apply(self):
"""Check current time and apply matching profile."""
from xdp_common import write_rate_config
profiles = self.cfg.get('profiles', {})
now = datetime.now()
current_hour = now.hour
current_min = now.minute
current_time = current_hour * 60 + current_min
weekday = now.strftime('%a').lower()
matched_profile = None
matched_name = 'default'
for name, profile in profiles.items():
hours = profile.get('hours', '')
weekdays = profile.get('weekdays', '')
if weekdays:
day_range = weekdays.lower().split('-')
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
if len(day_range) == 2:
try:
start_idx = day_names.index(day_range[0])
end_idx = day_names.index(day_range[1])
current_idx = day_names.index(weekday)
if start_idx <= end_idx:
if not (start_idx <= current_idx <= end_idx):
continue
else:
if not (current_idx >= start_idx or current_idx <= end_idx):
continue
except ValueError:
continue
if hours:
try:
start_str, end_str = hours.split('-')
sh, sm = map(int, start_str.split(':'))
eh, em = map(int, end_str.split(':'))
start_min = sh * 60 + sm
end_min = eh * 60 + em
if start_min <= end_min:
if not (start_min <= current_time < end_min):
continue
else:
if not (current_time >= start_min or current_time < end_min):
continue
except (ValueError, AttributeError):
continue
matched_profile = profile
matched_name = name
break
if matched_name != self.current_profile:
if matched_profile:
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
else:
pps = self.cfg.get('default_pps', 1000)
bps = self.cfg.get('default_bps', 0)
window = self.cfg.get('window_sec', 1)
try:
write_rate_config(pps, bps, window * 1_000_000_000)
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
self.current_profile = matched_name
except Exception as e:
log.error("Failed to apply profile %s: %s", matched_name, e)
# ==================== DDoSDaemon ====================
class DDoSDaemon:
"""Main daemon orchestrator."""
def __init__(self, config_path=CONFIG_PATH):
self.config_path = config_path
self.cfg = load_config(config_path)
self._stop_event = threading.Event()
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
self._setup_components()
def _setup_components(self):
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
self.ewma_analyzer = EWMAAnalyzer(
alpha=self.cfg['ewma'].get('alpha', 0.3),
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
)
self.ai_detector = AIDetector(self.cfg['ai'])
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
if self.ai_detector.enabled:
self.ai_detector.load_model()
self._last_retrain_time = self._get_model_mtime()
self._last_log_cleanup = time.time()
self._init_traffic_db()
level = self.cfg['general'].get('log_level', 'info').upper()
log.setLevel(getattr(logging, level, logging.INFO))
def _write_pid(self):
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
with open(pid_file, 'w') as f:
f.write(str(os.getpid()))
def _remove_pid(self):
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
try:
os.unlink(pid_file)
except OSError:
pass
def _ensure_single_instance(self):
"""Stop any existing daemon before starting."""
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
if not os.path.exists(pid_file):
return
try:
with open(pid_file) as f:
old_pid = int(f.read().strip())
os.kill(old_pid, 0)
log.info("Stopping existing daemon (PID %d)...", old_pid)
os.kill(old_pid, signal.SIGTERM)
for _ in range(30):
time.sleep(1)
try:
os.kill(old_pid, 0)
except OSError:
log.info("Old daemon stopped")
return
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
os.kill(old_pid, signal.SIGKILL)
time.sleep(1)
except (ValueError, OSError):
pass
def _handle_sighup(self, signum, frame):
log.info("SIGHUP received, reloading config...")
self.cfg = load_config(self.config_path)
# Update existing components without rebuilding (preserves EWMA/violation state)
self.violation_tracker.cfg = self.cfg['escalation']
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3)
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0)
self.ai_detector.cfg = self.cfg['ai']
self.profile_manager.cfg = self.cfg['rate_limits']
# Update poll intervals (used by threads on next iteration)
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
level = self.cfg['general'].get('log_level', 'info').upper()
log.setLevel(getattr(logging, level, logging.INFO))
log.info("Config reloaded (state preserved)")
def _handle_sigterm(self, signum, frame):
log.info("SIGTERM received, shutting down...")
self._stop_event.set()
def _handle_sigusr1(self, signum, frame):
log.info("SIGUSR1 received, triggering retrain from traffic log (background)...")
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
self.ai_detector.retrain_from_log(db_path, background=True)
self._last_retrain_time = time.time()
# ---- Traffic Logging (SQLite) ----
def _init_traffic_db(self):
"""Initialize SQLite database for traffic logging."""
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._traffic_db = sqlite3.connect(db_path, check_same_thread=False)
self._traffic_db.execute(
'CREATE TABLE IF NOT EXISTS traffic_samples ('
' id INTEGER PRIMARY KEY AUTOINCREMENT,'
' timestamp TEXT NOT NULL,'
' hour REAL NOT NULL,'
' hour_sin REAL NOT NULL,'
' hour_cos REAL NOT NULL,'
' total_packets REAL NOT NULL,'
' total_bytes REAL NOT NULL,'
' tcp_syn_count REAL NOT NULL,'
' tcp_other_count REAL NOT NULL,'
' udp_count REAL NOT NULL,'
' icmp_count REAL NOT NULL,'
' other_proto_count REAL NOT NULL,'
' unique_ips_approx REAL NOT NULL,'
' small_pkt_count REAL NOT NULL,'
' large_pkt_count REAL NOT NULL,'
' syn_ratio REAL NOT NULL,'
' udp_ratio REAL NOT NULL,'
' icmp_ratio REAL NOT NULL,'
' small_pkt_ratio REAL NOT NULL,'
' avg_pkt_size REAL NOT NULL'
')'
)
self._traffic_db.execute(
'CREATE INDEX IF NOT EXISTS idx_timestamp ON traffic_samples(timestamp)'
)
self._traffic_db.commit()
log.info("Traffic log DB initialized: %s", db_path)
def _log_traffic(self, now, hour, features):
"""Insert one row into traffic_samples table."""
try:
self._traffic_db.execute(
'INSERT INTO traffic_samples ('
' timestamp, hour, hour_sin, hour_cos,'
' total_packets, total_bytes, tcp_syn_count, tcp_other_count,'
' udp_count, icmp_count, other_proto_count, unique_ips_approx,'
' small_pkt_count, large_pkt_count,'
' syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size'
') VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)',
(now.isoformat(), hour, *features)
)
self._traffic_db.commit()
except Exception as e:
log.error("Failed to write traffic log: %s", e)
def _cleanup_traffic_log(self):
"""Remove entries older than retention_days from traffic_samples."""
retention_days = self.cfg['ai'].get('traffic_log_retention_days', 7)
cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat()
try:
cur = self._traffic_db.execute(
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
)
deleted = cur.rowcount
self._traffic_db.commit()
if deleted > 1000:
self._traffic_db.execute('VACUUM')
log.info("Traffic log cleanup: deleted %d rows (retention=%dd)", deleted, retention_days)
except Exception as e:
log.error("Traffic log cleanup failed: %s", e)
def _get_model_mtime(self):
"""Get model file modification time, or current time if not found."""
model_file = self.cfg['ai'].get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
try:
return os.path.getmtime(model_file)
except OSError:
return time.time()
# ---- Worker Threads ----
def _ewma_thread(self):
"""Poll rate counters, compute EWMA, detect violations, escalate."""
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
prev_counters = {}
while not self._stop_event.is_set():
interval = self._ewma_interval
try:
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
active_ips = []
for ip_str, pkts, bts, last_seen in entries:
active_ips.append(ip_str)
prev = prev_counters.get(ip_str, 0)
delta = pkts - prev if pkts >= prev else pkts
prev_counters[ip_str] = pkts
if delta <= 0:
continue
pps = delta / max(interval, 0.1)
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
if is_anomalous:
# Skip whitelisted IPs
if is_whitelisted(ip_str):
log.debug("EWMA anomaly skipped (whitelisted): %s", ip_str)
continue
level = self.violation_tracker.record_violation(ip_str)
ew = self.ewma_analyzer.get_stats(ip_str)
log.warning(
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
ip_str, pps, ew['ewma'], ew['baseline'], level
)
if level == 'temp_block':
dur = self.cfg['escalation'].get('temp_block_duration', 300)
try:
block_ip(ip_str, dur)
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
except Exception as e:
log.error("Failed to temp-block %s: %s", ip_str, e)
elif level == 'perm_block':
try:
block_ip(ip_str, 0)
log.warning("PERM BLOCK: %s", ip_str)
except Exception as e:
log.error("Failed to perm-block %s: %s", ip_str, e)
self.ewma_analyzer.cleanup_stale(active_ips)
except Exception as e:
log.error("EWMA thread error: %s", e)
self._stop_event.wait(interval)
def _ai_thread(self):
"""Read traffic features, run AI inference or collect training data."""
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
prev_features = None
self._last_retrain_time = self._get_model_mtime()
self._last_log_cleanup = time.time()
while not self._stop_event.is_set():
interval = self._ai_interval
try:
if not self.ai_detector.enabled:
self._stop_event.wait(interval)
continue
features = read_percpu_features()
if not features:
self._stop_event.wait(interval)
continue
feature_names = [
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
'small_pkt_count', 'large_pkt_count'
]
if prev_features is not None:
deltas = []
for name in feature_names:
cur = features.get(name, 0)
prev = prev_features.get(name, 0)
deltas.append(max(0, cur - prev))
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
if deltas[0] < min_pkts:
prev_features = features
self._stop_event.wait(interval)
continue
total = deltas[0] + 1e-6
syn_ratio = deltas[2] / total
udp_ratio = deltas[4] / total
icmp_ratio = deltas[5] / total
small_pkt_ratio = deltas[8] / total
avg_pkt_size = deltas[1] / total
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
# Add time features (hour_sin, hour_cos) at the front
now = datetime.now()
hour = now.hour + now.minute / 60.0
hour_sin = math.sin(2 * math.pi * hour / 24)
hour_cos = math.cos(2 * math.pi * hour / 24)
deltas_with_time = [hour_sin, hour_cos] + deltas # 17 features
# Log to traffic DB
self._log_traffic(now, hour, deltas_with_time)
# Periodic log file cleanup (once per day)
if time.time() - self._last_log_cleanup > 86400:
self._cleanup_traffic_log()
self._last_log_cleanup = time.time()
if self.ai_detector.is_learning:
self.ai_detector.collect_sample(deltas_with_time, hour)
total_samples = sum(len(v) for v in self.ai_detector.training_data.values())
if total_samples % 100 == 0 and total_samples > 0:
log.debug("AI learning: %d samples collected", total_samples)
else:
# Auto-retrain check (background, no inference gap)
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
if time.time() - self._last_retrain_time >= retrain_interval:
log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval)
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
self.ai_detector.retrain_from_log(db_path, background=True)
self._last_retrain_time = time.time()
is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour)
if is_anomaly:
log.warning(
"AI ANOMALY detected: score=%.4f deltas=%s",
score, dict(zip(feature_names, deltas[:len(feature_names)]))
)
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
for ip_str, pkts, bts, _ in top_ips:
# Skip whitelisted IPs
if is_whitelisted(ip_str):
log.debug("AI escalation skipped (whitelisted): %s", ip_str)
continue
level = self.violation_tracker.record_violation(ip_str)
log.warning("AI escalation: %s -> %s", ip_str, level)
if level == 'temp_block':
dur = self.cfg['escalation'].get('temp_block_duration', 300)
try:
block_ip(ip_str, dur)
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
except Exception as e:
log.error("Failed to AI temp-block %s: %s", ip_str, e)
elif level == 'perm_block':
try:
block_ip(ip_str, 0)
log.warning("AI PERM BLOCK: %s", ip_str)
except Exception as e:
log.error("Failed to AI perm-block %s: %s", ip_str, e)
prev_features = features
except Exception as e:
log.error("AI thread error: %s", e)
self._stop_event.wait(interval)
def _profile_thread(self):
"""Check time-of-day and switch rate profiles."""
while not self._stop_event.is_set():
try:
self.profile_manager.check_and_apply()
except Exception as e:
log.error("Profile thread error: %s", e)
self._stop_event.wait(60)
def _cleanup_thread(self):
"""Periodically clean up expired blocked IPs and stale violations."""
from xdp_common import dump_blocked_ips, unblock_ip
while not self._stop_event.is_set():
try:
with open('/proc/uptime') as f:
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
entries = dump_blocked_ips(map_name)
for ip_str, expire_ns, blocked_at, drop_count in entries:
if expire_ns != 0 and now_ns > expire_ns:
try:
unblock_ip(ip_str)
self.violation_tracker.clear(ip_str)
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
except Exception as e:
log.error("Failed to remove expired block %s: %s", ip_str, e)
self.violation_tracker.cleanup_expired()
except Exception as e:
log.error("Cleanup thread error: %s", e)
self._stop_event.wait(60)
# ---- Main Loop ----
def run(self):
"""Start the daemon."""
log.info("XDP Defense Daemon starting...")
signal.signal(signal.SIGHUP, self._handle_sighup)
signal.signal(signal.SIGTERM, self._handle_sigterm)
signal.signal(signal.SIGINT, self._handle_sigterm)
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
self._ensure_single_instance()
self._write_pid()
threads = [
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
]
for t in threads:
t.start()
log.info("Started %s thread", t.name)
log.info("Daemon running (PID %d)", os.getpid())
try:
while not self._stop_event.is_set():
self._stop_event.wait(1)
except KeyboardInterrupt:
pass
log.info("Shutting down...")
self._stop_event.set()
for t in threads:
t.join(timeout=5)
self._remove_pid()
log.info("Daemon stopped")
# ==================== Entry Point ====================
def main():
config_path = CONFIG_PATH
if len(sys.argv) > 1:
config_path = sys.argv[1]
daemon = DDoSDaemon(config_path)
daemon.run()
if __name__ == '__main__':
main()