Daemon fixes: - Add _db_lock for thread-safe SQLite access - Atomic SIGHUP config swap (build all values before applying) - Check world-writable permission before loading pickle model - Write model files with 0o600 permissions via os.open - Module-level xdp_common import with fatal exit on failure - Close traffic DB on shutdown - Add period_data parameter to _train() to avoid race condition CLI fixes: - Replace $COMMON_PY variable with hardcoded 'xdp_common' - Pass CONFIG_FILE via sys.argv instead of string interpolation - Add key_hex regex validation before all bpftool commands - Switch sanitize_input from denylist to strict allowlist Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1053 lines
40 KiB
Python
Executable File
1053 lines
40 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
XDP Defense Daemon
|
|
Userspace component: EWMA-based rate analysis, AI anomaly detection,
|
|
time-profile switching, and automatic escalation.
|
|
|
|
4 worker threads + main thread (signal handling):
|
|
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
|
|
- AI Thread: reads traffic features, runs Isolation Forest inference
|
|
- Profile Thread: checks time-of-day, switches rate_config profiles
|
|
- Cleanup Thread: removes expired entries from blocked_ips maps
|
|
"""
|
|
|
|
import copy
|
|
import math
|
|
import os
|
|
import stat
|
|
import sys
|
|
import time
|
|
import signal
|
|
import threading
|
|
import logging
|
|
import logging.handlers
|
|
import csv
|
|
import pickle
|
|
import sqlite3
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
|
|
import yaml
|
|
|
|
try:
|
|
from xdp_common import (
|
|
dump_rate_counters, block_ip, is_whitelisted,
|
|
read_percpu_features, dump_blocked_ips, unblock_ip,
|
|
write_rate_config, read_rate_config,
|
|
)
|
|
except ImportError as e:
|
|
print(f"FATAL: Cannot import xdp_common: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# ==================== Logging ====================
|
|
|
|
log = logging.getLogger('xdp-defense-daemon')
|
|
log.setLevel(logging.INFO)
|
|
|
|
_console = logging.StreamHandler()
|
|
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
|
|
log.addHandler(_console)
|
|
|
|
try:
|
|
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
|
|
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
|
|
log.addHandler(_syslog)
|
|
except Exception:
|
|
pass
|
|
|
|
# ==================== Configuration ====================
|
|
|
|
DEFAULT_CONFIG = {
|
|
'general': {
|
|
'interface': 'eth0',
|
|
'log_level': 'info',
|
|
'pid_file': '/var/lib/xdp-defense/daemon.pid',
|
|
'data_dir': '/var/lib/xdp-defense',
|
|
},
|
|
'rate_limits': {
|
|
'default_pps': 1000,
|
|
'default_bps': 0,
|
|
'window_sec': 1,
|
|
'profiles': {},
|
|
},
|
|
'escalation': {
|
|
'temp_block_after': 5,
|
|
'perm_block_after': 20,
|
|
'temp_block_duration': 300,
|
|
'violation_window': 600,
|
|
},
|
|
'ewma': {
|
|
'alpha': 0.3,
|
|
'poll_interval': 1,
|
|
'threshold_multiplier': 3.0,
|
|
},
|
|
'ai': {
|
|
'enabled': True,
|
|
'model_type': 'IsolationForest',
|
|
'contamination': 0.05,
|
|
'n_estimators': 100,
|
|
'learning_duration': 259200,
|
|
'min_samples': 1000,
|
|
'poll_interval': 5,
|
|
'anomaly_threshold': -0.3,
|
|
'min_packets_for_sample': 20,
|
|
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
|
|
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
|
|
'traffic_log_db': '/var/lib/xdp-defense/traffic_log.db',
|
|
'traffic_log_retention_days': 7,
|
|
'retrain_interval': 86400,
|
|
'retrain_window': 604800,
|
|
},
|
|
}
|
|
|
|
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
|
|
|
|
|
|
def load_config(path=CONFIG_PATH):
|
|
"""Load config with defaults."""
|
|
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
|
try:
|
|
with open(path) as f:
|
|
user = yaml.safe_load(f) or {}
|
|
for section in cfg:
|
|
if section in user and isinstance(user[section], dict):
|
|
cfg[section].update(user[section])
|
|
except FileNotFoundError:
|
|
log.warning("Config not found at %s, using defaults", path)
|
|
except Exception as e:
|
|
log.error("Failed to load config: %s", e)
|
|
return cfg
|
|
|
|
|
|
# ==================== ViolationTracker ====================
|
|
|
|
class ViolationTracker:
|
|
"""Track per-IP violation counts and manage escalation."""
|
|
|
|
def __init__(self, escalation_cfg):
|
|
self.cfg = escalation_cfg
|
|
self.violations = defaultdict(list)
|
|
self.lock = threading.Lock()
|
|
|
|
def record_violation(self, ip):
|
|
"""Record a violation and return escalation level.
|
|
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
|
|
"""
|
|
now = time.time()
|
|
window = self.cfg.get('violation_window', 600)
|
|
|
|
with self.lock:
|
|
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
|
|
self.violations[ip].append(now)
|
|
count = len(self.violations[ip])
|
|
|
|
perm_after = self.cfg.get('perm_block_after', 20)
|
|
temp_after = self.cfg.get('temp_block_after', 5)
|
|
|
|
if count >= perm_after:
|
|
return 'perm_block'
|
|
elif count >= temp_after:
|
|
return 'temp_block'
|
|
return 'rate_limit'
|
|
|
|
def clear(self, ip):
|
|
with self.lock:
|
|
self.violations.pop(ip, None)
|
|
|
|
def cleanup_expired(self):
|
|
"""Remove entries with no recent violations."""
|
|
now = time.time()
|
|
window = self.cfg.get('violation_window', 600)
|
|
with self.lock:
|
|
expired = [ip for ip, times in self.violations.items()
|
|
if all(now - t >= window for t in times)]
|
|
for ip in expired:
|
|
del self.violations[ip]
|
|
|
|
|
|
# ==================== EWMAAnalyzer ====================
|
|
|
|
class EWMAAnalyzer:
|
|
"""Per-IP EWMA calculation for rate anomaly detection."""
|
|
|
|
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
|
|
self.alpha = alpha
|
|
self.threshold_multiplier = threshold_multiplier
|
|
self.ewma = {}
|
|
self.baseline = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def update(self, ip, current_pps):
|
|
"""Update EWMA for an IP. Returns True if anomalous."""
|
|
with self.lock:
|
|
if ip not in self.ewma:
|
|
self.ewma[ip] = current_pps
|
|
self.baseline[ip] = current_pps
|
|
return False
|
|
|
|
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
|
|
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
|
|
|
|
base = max(self.baseline[ip], 1)
|
|
if self.ewma[ip] > base * self.threshold_multiplier:
|
|
return True
|
|
return False
|
|
|
|
def get_stats(self, ip):
|
|
with self.lock:
|
|
return {
|
|
'ewma': self.ewma.get(ip, 0),
|
|
'baseline': self.baseline.get(ip, 0),
|
|
}
|
|
|
|
def cleanup_stale(self, active_ips):
|
|
"""Remove tracking for IPs no longer in rate counters."""
|
|
with self.lock:
|
|
stale = set(self.ewma.keys()) - set(active_ips)
|
|
for ip in stale:
|
|
self.ewma.pop(ip, None)
|
|
self.baseline.pop(ip, None)
|
|
|
|
|
|
# ==================== AIDetector ====================
|
|
|
|
class AIDetector:
|
|
"""Time-period Isolation Forest anomaly detection.
|
|
|
|
Maintains 4 separate models for different time periods:
|
|
night (00-06), morning (06-12), afternoon (12-18), evening (18-24).
|
|
Each model learns the normal traffic pattern for its time window.
|
|
"""
|
|
|
|
TIME_PERIODS = {
|
|
'night': (0, 6),
|
|
'morning': (6, 12),
|
|
'afternoon': (12, 18),
|
|
'evening': (18, 24),
|
|
}
|
|
|
|
def __init__(self, ai_cfg):
|
|
self.cfg = ai_cfg
|
|
self.models = {} # {period_name: {'model': ..., 'scaler': ...}}
|
|
self.started_at = time.time()
|
|
self.training_data = defaultdict(list) # {period_name: [samples]}
|
|
self.is_learning = True
|
|
self._retrain_requested = False
|
|
self._retrain_lock = threading.Lock()
|
|
|
|
@property
|
|
def enabled(self):
|
|
return self.cfg.get('enabled', False)
|
|
|
|
@staticmethod
|
|
def get_period(hour):
|
|
"""Map hour (0-24) to period name."""
|
|
for name, (start, end) in AIDetector.TIME_PERIODS.items():
|
|
if start <= hour < end:
|
|
return name
|
|
return 'night' # hour == 24 edge case
|
|
|
|
def request_retrain(self):
|
|
self._retrain_requested = True
|
|
|
|
def collect_sample(self, features, hour):
|
|
"""Collect a feature sample during learning phase, bucketed by period."""
|
|
if not self.enabled:
|
|
return
|
|
|
|
period = self.get_period(hour)
|
|
self.training_data[period].append(features)
|
|
|
|
learning_dur = self.cfg.get('learning_duration', 259200)
|
|
min_samples = self.cfg.get('min_samples', 1000)
|
|
elapsed = time.time() - self.started_at
|
|
total = sum(len(v) for v in self.training_data.values())
|
|
|
|
if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested:
|
|
self._train()
|
|
self._retrain_requested = False
|
|
|
|
def _train(self, period_data=None):
|
|
"""Train per-period Isolation Forest models.
|
|
If period_data provided, use it instead of self.training_data.
|
|
"""
|
|
if period_data is None:
|
|
period_data = self.training_data
|
|
|
|
try:
|
|
from sklearn.ensemble import IsolationForest
|
|
from sklearn.preprocessing import StandardScaler
|
|
import numpy as np
|
|
except ImportError:
|
|
log.error("scikit-learn not installed. AI detection disabled.")
|
|
self.cfg['enabled'] = False
|
|
return
|
|
|
|
total = sum(len(v) for v in period_data.values())
|
|
if total < 10:
|
|
log.warning("Not enough training data (%d samples)", total)
|
|
return
|
|
|
|
log.info("Training AI models: %s",
|
|
{p: len(s) for p, s in period_data.items() if s})
|
|
|
|
try:
|
|
new_models = {}
|
|
all_samples = []
|
|
|
|
for period, samples in period_data.items():
|
|
if len(samples) < 10:
|
|
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
|
|
continue
|
|
|
|
X = np.array(samples)
|
|
scaler = StandardScaler()
|
|
X_scaled = scaler.fit_transform(X)
|
|
|
|
model = IsolationForest(
|
|
n_estimators=self.cfg.get('n_estimators', 100),
|
|
contamination=self.cfg.get('contamination', 'auto'),
|
|
random_state=42,
|
|
)
|
|
model.fit(X_scaled)
|
|
new_models[period] = {'model': model, 'scaler': scaler}
|
|
all_samples.extend(samples)
|
|
log.info("Period %s: trained with %d samples", period, len(samples))
|
|
|
|
if not new_models:
|
|
log.warning("No period had enough data to train")
|
|
return
|
|
|
|
# Atomic swap
|
|
self.models = new_models
|
|
self.is_learning = False
|
|
|
|
# Save to disk (atomic)
|
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
tmp_model = model_file + '.tmp'
|
|
fd = os.open(tmp_model, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
with os.fdopen(fd, 'wb') as f:
|
|
pickle.dump({
|
|
'format': 'period_models',
|
|
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
|
|
for p, m in new_models.items()},
|
|
'feature_count': 17,
|
|
}, f)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.rename(tmp_model, model_file)
|
|
|
|
# Save training data CSV
|
|
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
|
tmp_data = data_file + '.tmp'
|
|
fd = os.open(tmp_data, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
|
|
with os.fdopen(fd, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([
|
|
'hour_sin', 'hour_cos',
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count',
|
|
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
|
])
|
|
writer.writerows(all_samples)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.rename(tmp_data, data_file)
|
|
|
|
log.info("AI models trained and saved (%d periods)", len(new_models))
|
|
|
|
except Exception as e:
|
|
log.error("AI training failed: %s", e)
|
|
|
|
def load_model(self):
|
|
"""Load previously trained models. Handle old and new formats."""
|
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
if not os.path.exists(model_file):
|
|
return False
|
|
try:
|
|
st = os.stat(model_file)
|
|
# Warn if file is world-writable
|
|
if st.st_mode & stat.S_IWOTH:
|
|
log.warning("Model file %s is world-writable! Refusing to load.", model_file)
|
|
return False
|
|
|
|
with open(model_file, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
# New period_models format
|
|
if data.get('format') == 'period_models':
|
|
self.models = data['models']
|
|
self.is_learning = False
|
|
periods = list(self.models.keys())
|
|
log.info("AI models loaded: %s (%d features)",
|
|
periods, data.get('feature_count', '?'))
|
|
return True
|
|
|
|
# Old single-model format → discard
|
|
log.warning("Old single-model format detected. Switching to learning mode.")
|
|
self.is_learning = True
|
|
return False
|
|
|
|
except Exception as e:
|
|
log.error("Failed to load AI model: %s", e)
|
|
return False
|
|
|
|
def predict(self, features, hour):
|
|
"""Run anomaly detection using the period-specific model.
|
|
Returns (is_anomaly, score).
|
|
"""
|
|
if not self.enabled or not self.models:
|
|
return False, 0.0
|
|
|
|
period = self.get_period(hour)
|
|
model_info = self.models.get(period)
|
|
|
|
# Fallback: try adjacent periods
|
|
if model_info is None:
|
|
for p in self.models:
|
|
model_info = self.models[p]
|
|
break
|
|
if model_info is None:
|
|
return False, 0.0
|
|
|
|
try:
|
|
import numpy as np
|
|
X = np.array([features])
|
|
X_scaled = model_info['scaler'].transform(X)
|
|
score = model_info['model'].decision_function(X_scaled)[0]
|
|
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
|
return score < threshold, float(score)
|
|
except Exception as e:
|
|
log.error("AI prediction error: %s", e)
|
|
return False, 0.0
|
|
|
|
def retrain_from_log(self, db_path=None, background=False):
|
|
"""Retrain models from traffic_log.db. Optionally in background thread."""
|
|
if db_path is None:
|
|
db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
|
|
|
if background:
|
|
t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True)
|
|
t.start()
|
|
return True # started, not yet finished
|
|
|
|
return self._do_retrain(db_path)
|
|
|
|
def _do_retrain(self, db_path):
|
|
"""Actual retrain logic (can run in background thread)."""
|
|
if not self._retrain_lock.acquire(blocking=False):
|
|
log.info("Retrain already in progress, skipping")
|
|
return False
|
|
|
|
try:
|
|
return self._retrain_impl(db_path)
|
|
finally:
|
|
self._retrain_lock.release()
|
|
|
|
def _retrain_impl(self, db_path):
|
|
"""Load data from DB, filter outliers, train per-period models."""
|
|
if not os.path.exists(db_path):
|
|
log.warning("Traffic log DB not found: %s", db_path)
|
|
return False
|
|
|
|
retrain_window = self.cfg.get('retrain_window', 86400)
|
|
cutoff = (datetime.now() - timedelta(seconds=retrain_window)).isoformat()
|
|
|
|
conn = None
|
|
try:
|
|
conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
|
|
cur = conn.execute(
|
|
'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, '
|
|
'tcp_syn_count, tcp_other_count, udp_count, icmp_count, '
|
|
'other_proto_count, unique_ips_approx, small_pkt_count, '
|
|
'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, '
|
|
'small_pkt_ratio, avg_pkt_size '
|
|
'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp',
|
|
(cutoff,)
|
|
)
|
|
rows = cur.fetchall()
|
|
except Exception as e:
|
|
log.error("retrain_from_log DB read failed: %s", e)
|
|
return False
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
if len(rows) < 10:
|
|
log.warning("Not enough recent samples for retrain (%d)", len(rows))
|
|
return False
|
|
|
|
# Split by period: row[0]=hour, row[1:]=17 features
|
|
period_data = defaultdict(list)
|
|
for row in rows:
|
|
hour = row[0]
|
|
features = list(row[1:]) # 17 features
|
|
period = self.get_period(hour)
|
|
period_data[period].append(features)
|
|
|
|
# Filter outliers using existing models (item #2)
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
log.error("numpy not available for outlier filtering")
|
|
return False
|
|
|
|
filtered_count = 0
|
|
for period, samples in period_data.items():
|
|
model_info = self.models.get(period)
|
|
if model_info is None or len(samples) < 10:
|
|
continue
|
|
X = np.array(samples)
|
|
X_scaled = model_info['scaler'].transform(X)
|
|
scores = model_info['model'].decision_function(X_scaled)
|
|
clean = [s for s, sc in zip(samples, scores) if sc >= -0.5]
|
|
removed = len(samples) - len(clean)
|
|
if removed > 0:
|
|
filtered_count += removed
|
|
period_data[period] = clean
|
|
|
|
total = sum(len(v) for v in period_data.values())
|
|
log.info("Retrain: %d samples loaded, %d outliers filtered, %s",
|
|
len(rows), filtered_count,
|
|
{p: len(s) for p, s in period_data.items() if s})
|
|
|
|
self._train(period_data=period_data)
|
|
return not self.is_learning
|
|
|
|
|
|
# ==================== ProfileManager ====================
|
|
|
|
class ProfileManager:
|
|
"""Manage time-based rate limit profiles."""
|
|
|
|
def __init__(self, rate_cfg):
|
|
self.cfg = rate_cfg
|
|
self.current_profile = 'default'
|
|
|
|
def check_and_apply(self):
|
|
"""Check current time and apply matching profile."""
|
|
profiles = self.cfg.get('profiles', {})
|
|
now = datetime.now()
|
|
current_hour = now.hour
|
|
current_min = now.minute
|
|
current_time = current_hour * 60 + current_min
|
|
weekday = now.strftime('%a').lower()
|
|
|
|
matched_profile = None
|
|
matched_name = 'default'
|
|
|
|
for name, profile in profiles.items():
|
|
hours = profile.get('hours', '')
|
|
weekdays = profile.get('weekdays', '')
|
|
|
|
if weekdays:
|
|
day_range = weekdays.lower().split('-')
|
|
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
|
|
if len(day_range) == 2:
|
|
try:
|
|
start_idx = day_names.index(day_range[0])
|
|
end_idx = day_names.index(day_range[1])
|
|
current_idx = day_names.index(weekday)
|
|
if start_idx <= end_idx:
|
|
if not (start_idx <= current_idx <= end_idx):
|
|
continue
|
|
else:
|
|
if not (current_idx >= start_idx or current_idx <= end_idx):
|
|
continue
|
|
except ValueError:
|
|
continue
|
|
|
|
if hours:
|
|
try:
|
|
start_str, end_str = hours.split('-')
|
|
sh, sm = map(int, start_str.split(':'))
|
|
eh, em = map(int, end_str.split(':'))
|
|
start_min = sh * 60 + sm
|
|
end_min = eh * 60 + em
|
|
|
|
if start_min <= end_min:
|
|
if not (start_min <= current_time < end_min):
|
|
continue
|
|
else:
|
|
if not (current_time >= start_min or current_time < end_min):
|
|
continue
|
|
except (ValueError, AttributeError):
|
|
continue
|
|
|
|
matched_profile = profile
|
|
matched_name = name
|
|
break
|
|
|
|
if matched_name != self.current_profile:
|
|
if matched_profile:
|
|
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
|
|
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
|
|
else:
|
|
pps = self.cfg.get('default_pps', 1000)
|
|
bps = self.cfg.get('default_bps', 0)
|
|
|
|
window = self.cfg.get('window_sec', 1)
|
|
|
|
try:
|
|
write_rate_config(pps, bps, window * 1_000_000_000)
|
|
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
|
|
self.current_profile = matched_name
|
|
except Exception as e:
|
|
log.error("Failed to apply profile %s: %s", matched_name, e)
|
|
|
|
|
|
# ==================== DDoSDaemon ====================
|
|
|
|
class DDoSDaemon:
|
|
"""Main daemon orchestrator."""
|
|
|
|
def __init__(self, config_path=CONFIG_PATH):
|
|
self.config_path = config_path
|
|
self.cfg = load_config(config_path)
|
|
self._stop_event = threading.Event()
|
|
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
|
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
|
self._setup_components()
|
|
|
|
def _setup_components(self):
|
|
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
|
|
self.ewma_analyzer = EWMAAnalyzer(
|
|
alpha=self.cfg['ewma'].get('alpha', 0.3),
|
|
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
|
|
)
|
|
self.ai_detector = AIDetector(self.cfg['ai'])
|
|
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
|
|
|
|
if self.ai_detector.enabled:
|
|
self.ai_detector.load_model()
|
|
|
|
self._last_retrain_time = self._get_model_mtime()
|
|
self._last_log_cleanup = time.time()
|
|
|
|
self._init_traffic_db()
|
|
|
|
level = self.cfg['general'].get('log_level', 'info').upper()
|
|
log.setLevel(getattr(logging, level, logging.INFO))
|
|
|
|
def _write_pid(self):
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
|
|
with open(pid_file, 'w') as f:
|
|
f.write(str(os.getpid()))
|
|
|
|
def _remove_pid(self):
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
try:
|
|
os.unlink(pid_file)
|
|
except OSError:
|
|
pass
|
|
|
|
def _ensure_single_instance(self):
|
|
"""Stop any existing daemon before starting."""
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
if not os.path.exists(pid_file):
|
|
return
|
|
try:
|
|
with open(pid_file) as f:
|
|
old_pid = int(f.read().strip())
|
|
os.kill(old_pid, 0)
|
|
log.info("Stopping existing daemon (PID %d)...", old_pid)
|
|
os.kill(old_pid, signal.SIGTERM)
|
|
for _ in range(30):
|
|
time.sleep(1)
|
|
try:
|
|
os.kill(old_pid, 0)
|
|
except OSError:
|
|
log.info("Old daemon stopped")
|
|
return
|
|
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
|
|
os.kill(old_pid, signal.SIGKILL)
|
|
time.sleep(1)
|
|
except (ValueError, OSError):
|
|
pass
|
|
|
|
def _handle_sighup(self, signum, frame):
|
|
log.info("SIGHUP received, reloading config...")
|
|
new_cfg = load_config(self.config_path)
|
|
# Build all new values before swapping anything
|
|
new_escalation = new_cfg['escalation']
|
|
new_alpha = new_cfg['ewma'].get('alpha', 0.3)
|
|
new_threshold = new_cfg['ewma'].get('threshold_multiplier', 3.0)
|
|
new_ai_cfg = new_cfg['ai']
|
|
new_rate_cfg = new_cfg['rate_limits']
|
|
new_ewma_interval = new_cfg['ewma'].get('poll_interval', 1)
|
|
new_ai_interval = new_cfg['ai'].get('poll_interval', 5)
|
|
level = new_cfg['general'].get('log_level', 'info').upper()
|
|
# Now apply all at once
|
|
self.cfg = new_cfg
|
|
self.violation_tracker.cfg = new_escalation
|
|
self.ewma_analyzer.alpha = new_alpha
|
|
self.ewma_analyzer.threshold_multiplier = new_threshold
|
|
self.ai_detector.cfg = new_ai_cfg
|
|
self.profile_manager.cfg = new_rate_cfg
|
|
self._ewma_interval = new_ewma_interval
|
|
self._ai_interval = new_ai_interval
|
|
log.setLevel(getattr(logging, level, logging.INFO))
|
|
log.info("Config reloaded (state preserved)")
|
|
|
|
def _handle_sigterm(self, signum, frame):
|
|
log.info("SIGTERM received, shutting down...")
|
|
self._stop_event.set()
|
|
|
|
def _handle_sigusr1(self, signum, frame):
|
|
log.info("SIGUSR1 received, triggering retrain from traffic log (background)...")
|
|
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
|
self.ai_detector.retrain_from_log(db_path, background=True)
|
|
self._last_retrain_time = time.time()
|
|
|
|
# ---- Traffic Logging (SQLite) ----
|
|
|
|
def _init_traffic_db(self):
|
|
"""Initialize SQLite database for traffic logging."""
|
|
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
|
self._db_lock = threading.Lock()
|
|
self._traffic_db = sqlite3.connect(db_path, check_same_thread=False)
|
|
self._traffic_db.execute(
|
|
'CREATE TABLE IF NOT EXISTS traffic_samples ('
|
|
' id INTEGER PRIMARY KEY AUTOINCREMENT,'
|
|
' timestamp TEXT NOT NULL,'
|
|
' hour REAL NOT NULL,'
|
|
' hour_sin REAL NOT NULL,'
|
|
' hour_cos REAL NOT NULL,'
|
|
' total_packets REAL NOT NULL,'
|
|
' total_bytes REAL NOT NULL,'
|
|
' tcp_syn_count REAL NOT NULL,'
|
|
' tcp_other_count REAL NOT NULL,'
|
|
' udp_count REAL NOT NULL,'
|
|
' icmp_count REAL NOT NULL,'
|
|
' other_proto_count REAL NOT NULL,'
|
|
' unique_ips_approx REAL NOT NULL,'
|
|
' small_pkt_count REAL NOT NULL,'
|
|
' large_pkt_count REAL NOT NULL,'
|
|
' syn_ratio REAL NOT NULL,'
|
|
' udp_ratio REAL NOT NULL,'
|
|
' icmp_ratio REAL NOT NULL,'
|
|
' small_pkt_ratio REAL NOT NULL,'
|
|
' avg_pkt_size REAL NOT NULL'
|
|
')'
|
|
)
|
|
self._traffic_db.execute(
|
|
'CREATE INDEX IF NOT EXISTS idx_timestamp ON traffic_samples(timestamp)'
|
|
)
|
|
self._traffic_db.commit()
|
|
log.info("Traffic log DB initialized: %s", db_path)
|
|
|
|
def _log_traffic(self, now, hour, features):
|
|
"""Insert one row into traffic_samples table."""
|
|
try:
|
|
with self._db_lock:
|
|
self._traffic_db.execute(
|
|
'INSERT INTO traffic_samples ('
|
|
' timestamp, hour, hour_sin, hour_cos,'
|
|
' total_packets, total_bytes, tcp_syn_count, tcp_other_count,'
|
|
' udp_count, icmp_count, other_proto_count, unique_ips_approx,'
|
|
' small_pkt_count, large_pkt_count,'
|
|
' syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size'
|
|
') VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)',
|
|
(now.isoformat(), hour, *features)
|
|
)
|
|
self._traffic_db.commit()
|
|
except Exception as e:
|
|
log.error("Failed to write traffic log: %s", e)
|
|
|
|
def _cleanup_traffic_log(self):
|
|
"""Remove entries older than retention_days from traffic_samples."""
|
|
retention_days = self.cfg['ai'].get('traffic_log_retention_days', 7)
|
|
cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat()
|
|
|
|
try:
|
|
with self._db_lock:
|
|
cur = self._traffic_db.execute(
|
|
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
|
|
)
|
|
deleted = cur.rowcount
|
|
self._traffic_db.commit()
|
|
if deleted > 1000:
|
|
self._traffic_db.execute('VACUUM')
|
|
log.info("Traffic log cleanup: deleted %d rows (retention=%dd)", deleted, retention_days)
|
|
except Exception as e:
|
|
log.error("Traffic log cleanup failed: %s", e)
|
|
|
|
def _get_model_mtime(self):
|
|
"""Get model file modification time, or current time if not found."""
|
|
model_file = self.cfg['ai'].get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
try:
|
|
return os.path.getmtime(model_file)
|
|
except OSError:
|
|
return time.time()
|
|
|
|
# ---- Worker Threads ----
|
|
|
|
def _ewma_thread(self):
|
|
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
|
prev_counters = {}
|
|
|
|
while not self._stop_event.is_set():
|
|
interval = self._ewma_interval
|
|
try:
|
|
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
|
|
active_ips = []
|
|
|
|
for ip_str, pkts, bts, last_seen in entries:
|
|
active_ips.append(ip_str)
|
|
|
|
prev = prev_counters.get(ip_str, 0)
|
|
delta = pkts - prev if pkts >= prev else pkts
|
|
prev_counters[ip_str] = pkts
|
|
|
|
if delta <= 0:
|
|
continue
|
|
|
|
pps = delta / max(interval, 0.1)
|
|
|
|
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
|
|
if is_anomalous:
|
|
# Skip whitelisted IPs
|
|
if is_whitelisted(ip_str):
|
|
log.debug("EWMA anomaly skipped (whitelisted): %s", ip_str)
|
|
continue
|
|
|
|
level = self.violation_tracker.record_violation(ip_str)
|
|
ew = self.ewma_analyzer.get_stats(ip_str)
|
|
log.warning(
|
|
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
|
|
ip_str, pps, ew['ewma'], ew['baseline'], level
|
|
)
|
|
|
|
if level == 'temp_block':
|
|
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
|
try:
|
|
block_ip(ip_str, dur)
|
|
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
|
|
except Exception as e:
|
|
log.error("Failed to temp-block %s: %s", ip_str, e)
|
|
|
|
elif level == 'perm_block':
|
|
try:
|
|
block_ip(ip_str, 0)
|
|
log.warning("PERM BLOCK: %s", ip_str)
|
|
except Exception as e:
|
|
log.error("Failed to perm-block %s: %s", ip_str, e)
|
|
|
|
self.ewma_analyzer.cleanup_stale(active_ips)
|
|
|
|
except Exception as e:
|
|
log.error("EWMA thread error: %s", e)
|
|
|
|
self._stop_event.wait(interval)
|
|
|
|
def _ai_thread(self):
|
|
"""Read traffic features, run AI inference or collect training data."""
|
|
prev_features = None
|
|
self._last_retrain_time = self._get_model_mtime()
|
|
self._last_log_cleanup = time.time()
|
|
|
|
while not self._stop_event.is_set():
|
|
interval = self._ai_interval
|
|
try:
|
|
if not self.ai_detector.enabled:
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
features = read_percpu_features()
|
|
if not features:
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
feature_names = [
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count'
|
|
]
|
|
|
|
if prev_features is not None:
|
|
deltas = []
|
|
for name in feature_names:
|
|
cur = features.get(name, 0)
|
|
prev = prev_features.get(name, 0)
|
|
deltas.append(max(0, cur - prev))
|
|
|
|
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
|
|
if deltas[0] < min_pkts:
|
|
prev_features = features
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
total = deltas[0] + 1e-6
|
|
syn_ratio = deltas[2] / total
|
|
udp_ratio = deltas[4] / total
|
|
icmp_ratio = deltas[5] / total
|
|
small_pkt_ratio = deltas[8] / total
|
|
avg_pkt_size = deltas[1] / total
|
|
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
|
|
|
|
# Add time features (hour_sin, hour_cos) at the front
|
|
now = datetime.now()
|
|
hour = now.hour + now.minute / 60.0
|
|
hour_sin = math.sin(2 * math.pi * hour / 24)
|
|
hour_cos = math.cos(2 * math.pi * hour / 24)
|
|
deltas_with_time = [hour_sin, hour_cos] + deltas # 17 features
|
|
|
|
# Log to traffic DB
|
|
self._log_traffic(now, hour, deltas_with_time)
|
|
|
|
# Periodic log file cleanup (once per day)
|
|
if time.time() - self._last_log_cleanup > 86400:
|
|
self._cleanup_traffic_log()
|
|
self._last_log_cleanup = time.time()
|
|
|
|
if self.ai_detector.is_learning:
|
|
self.ai_detector.collect_sample(deltas_with_time, hour)
|
|
total_samples = sum(len(v) for v in self.ai_detector.training_data.values())
|
|
if total_samples % 100 == 0 and total_samples > 0:
|
|
log.debug("AI learning: %d samples collected", total_samples)
|
|
else:
|
|
# Auto-retrain check (background, no inference gap)
|
|
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
|
|
if time.time() - self._last_retrain_time >= retrain_interval:
|
|
log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval)
|
|
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
|
self.ai_detector.retrain_from_log(db_path, background=True)
|
|
self._last_retrain_time = time.time()
|
|
|
|
is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour)
|
|
if is_anomaly:
|
|
log.warning(
|
|
"AI ANOMALY detected: score=%.4f deltas=%s",
|
|
score, dict(zip(feature_names, deltas[:len(feature_names)]))
|
|
)
|
|
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
|
|
for ip_str, pkts, bts, _ in top_ips:
|
|
# Skip whitelisted IPs
|
|
if is_whitelisted(ip_str):
|
|
log.debug("AI escalation skipped (whitelisted): %s", ip_str)
|
|
continue
|
|
|
|
level = self.violation_tracker.record_violation(ip_str)
|
|
log.warning("AI escalation: %s -> %s", ip_str, level)
|
|
|
|
if level == 'temp_block':
|
|
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
|
try:
|
|
block_ip(ip_str, dur)
|
|
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
|
|
except Exception as e:
|
|
log.error("Failed to AI temp-block %s: %s", ip_str, e)
|
|
|
|
elif level == 'perm_block':
|
|
try:
|
|
block_ip(ip_str, 0)
|
|
log.warning("AI PERM BLOCK: %s", ip_str)
|
|
except Exception as e:
|
|
log.error("Failed to AI perm-block %s: %s", ip_str, e)
|
|
|
|
prev_features = features
|
|
|
|
except Exception as e:
|
|
log.error("AI thread error: %s", e)
|
|
|
|
self._stop_event.wait(interval)
|
|
|
|
def _profile_thread(self):
|
|
"""Check time-of-day and switch rate profiles."""
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
self.profile_manager.check_and_apply()
|
|
except Exception as e:
|
|
log.error("Profile thread error: %s", e)
|
|
self._stop_event.wait(60)
|
|
|
|
def _cleanup_thread(self):
|
|
"""Periodically clean up expired blocked IPs and stale violations."""
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
with open('/proc/uptime') as f:
|
|
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
|
|
|
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
|
|
entries = dump_blocked_ips(map_name)
|
|
for ip_str, expire_ns, blocked_at, drop_count in entries:
|
|
if expire_ns != 0 and now_ns > expire_ns:
|
|
try:
|
|
unblock_ip(ip_str)
|
|
self.violation_tracker.clear(ip_str)
|
|
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
|
|
except Exception as e:
|
|
log.error("Failed to remove expired block %s: %s", ip_str, e)
|
|
|
|
self.violation_tracker.cleanup_expired()
|
|
|
|
except Exception as e:
|
|
log.error("Cleanup thread error: %s", e)
|
|
|
|
self._stop_event.wait(60)
|
|
|
|
# ---- Main Loop ----
|
|
|
|
def run(self):
|
|
"""Start the daemon."""
|
|
log.info("XDP Defense Daemon starting...")
|
|
|
|
signal.signal(signal.SIGHUP, self._handle_sighup)
|
|
signal.signal(signal.SIGTERM, self._handle_sigterm)
|
|
signal.signal(signal.SIGINT, self._handle_sigterm)
|
|
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
|
|
|
|
self._ensure_single_instance()
|
|
self._write_pid()
|
|
|
|
threads = [
|
|
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
|
|
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
|
|
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
|
|
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
|
|
]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
log.info("Started %s thread", t.name)
|
|
|
|
log.info("Daemon running (PID %d)", os.getpid())
|
|
|
|
try:
|
|
while not self._stop_event.is_set():
|
|
self._stop_event.wait(1)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
log.info("Shutting down...")
|
|
self._stop_event.set()
|
|
|
|
for t in threads:
|
|
t.join(timeout=5)
|
|
|
|
self._remove_pid()
|
|
if hasattr(self, '_traffic_db') and self._traffic_db:
|
|
try:
|
|
self._traffic_db.close()
|
|
except Exception:
|
|
pass
|
|
log.info("Daemon stopped")
|
|
|
|
|
|
# ==================== Entry Point ====================
|
|
|
|
def main():
|
|
config_path = CONFIG_PATH
|
|
if len(sys.argv) > 1:
|
|
config_path = sys.argv[1]
|
|
|
|
daemon = DDoSDaemon(config_path)
|
|
daemon.run()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|