- Log traffic features with timestamps to CSV every 5s - Add hour_sin/hour_cos time features (15 → 17 feature vector) - Auto-retrain from traffic log at configurable interval (default 24h) - Detect old 15-feature models and switch to learning mode - SIGUSR1 now retrains from traffic log first, falls back to collect mode - Add CLI: `ai traffic` (time-bucketed summary), `ai log` (recent entries) - Add config keys: traffic_log_file, retention_days, retrain_window Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
898 lines
34 KiB
Python
Executable File
898 lines
34 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
XDP Defense Daemon
|
|
Userspace component: EWMA-based rate analysis, AI anomaly detection,
|
|
time-profile switching, and automatic escalation.
|
|
|
|
4 worker threads + main thread (signal handling):
|
|
- EWMA Thread: polls rate counters, calculates EWMA, detects violations
|
|
- AI Thread: reads traffic features, runs Isolation Forest inference
|
|
- Profile Thread: checks time-of-day, switches rate_config profiles
|
|
- Cleanup Thread: removes expired entries from blocked_ips maps
|
|
"""
|
|
|
|
import copy
|
|
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
import signal
|
|
import threading
|
|
import logging
|
|
import logging.handlers
|
|
import csv
|
|
import pickle
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
|
|
import yaml
|
|
|
|
# ==================== Logging ====================
|
|
|
|
log = logging.getLogger('xdp-defense-daemon')
|
|
log.setLevel(logging.INFO)
|
|
|
|
_console = logging.StreamHandler()
|
|
_console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))
|
|
log.addHandler(_console)
|
|
|
|
try:
|
|
_syslog = logging.handlers.SysLogHandler(address='/dev/log')
|
|
_syslog.setFormatter(logging.Formatter('xdp-defense-daemon: %(message)s'))
|
|
log.addHandler(_syslog)
|
|
except Exception:
|
|
pass
|
|
|
|
# ==================== Configuration ====================
|
|
|
|
DEFAULT_CONFIG = {
|
|
'general': {
|
|
'interface': 'eth0',
|
|
'log_level': 'info',
|
|
'pid_file': '/var/lib/xdp-defense/daemon.pid',
|
|
'data_dir': '/var/lib/xdp-defense',
|
|
},
|
|
'rate_limits': {
|
|
'default_pps': 1000,
|
|
'default_bps': 0,
|
|
'window_sec': 1,
|
|
'profiles': {},
|
|
},
|
|
'escalation': {
|
|
'temp_block_after': 5,
|
|
'perm_block_after': 20,
|
|
'temp_block_duration': 300,
|
|
'violation_window': 600,
|
|
},
|
|
'ewma': {
|
|
'alpha': 0.3,
|
|
'poll_interval': 1,
|
|
'threshold_multiplier': 3.0,
|
|
},
|
|
'ai': {
|
|
'enabled': True,
|
|
'model_type': 'IsolationForest',
|
|
'contamination': 0.05,
|
|
'n_estimators': 100,
|
|
'learning_duration': 259200,
|
|
'min_samples': 1000,
|
|
'poll_interval': 5,
|
|
'anomaly_threshold': -0.3,
|
|
'min_packets_for_sample': 20,
|
|
'model_file': '/var/lib/xdp-defense/ai_model.pkl',
|
|
'training_data_file': '/var/lib/xdp-defense/training_data.csv',
|
|
'traffic_log_file': '/var/lib/xdp-defense/traffic_log.csv',
|
|
'traffic_log_retention_days': 7,
|
|
'retrain_interval': 86400,
|
|
'retrain_window': 86400,
|
|
},
|
|
}
|
|
|
|
CONFIG_PATH = '/etc/xdp-defense/config.yaml'
|
|
|
|
|
|
def load_config(path=CONFIG_PATH):
|
|
"""Load config with defaults."""
|
|
cfg = copy.deepcopy(DEFAULT_CONFIG)
|
|
try:
|
|
with open(path) as f:
|
|
user = yaml.safe_load(f) or {}
|
|
for section in cfg:
|
|
if section in user and isinstance(user[section], dict):
|
|
cfg[section].update(user[section])
|
|
except FileNotFoundError:
|
|
log.warning("Config not found at %s, using defaults", path)
|
|
except Exception as e:
|
|
log.error("Failed to load config: %s", e)
|
|
return cfg
|
|
|
|
|
|
# ==================== ViolationTracker ====================
|
|
|
|
class ViolationTracker:
|
|
"""Track per-IP violation counts and manage escalation."""
|
|
|
|
def __init__(self, escalation_cfg):
|
|
self.cfg = escalation_cfg
|
|
self.violations = defaultdict(list)
|
|
self.lock = threading.Lock()
|
|
|
|
def record_violation(self, ip):
|
|
"""Record a violation and return escalation level.
|
|
Returns: 'rate_limit', 'temp_block', 'perm_block', or None
|
|
"""
|
|
now = time.time()
|
|
window = self.cfg.get('violation_window', 600)
|
|
|
|
with self.lock:
|
|
self.violations[ip] = [t for t in self.violations[ip] if now - t < window]
|
|
self.violations[ip].append(now)
|
|
count = len(self.violations[ip])
|
|
|
|
perm_after = self.cfg.get('perm_block_after', 20)
|
|
temp_after = self.cfg.get('temp_block_after', 5)
|
|
|
|
if count >= perm_after:
|
|
return 'perm_block'
|
|
elif count >= temp_after:
|
|
return 'temp_block'
|
|
return 'rate_limit'
|
|
|
|
def clear(self, ip):
|
|
with self.lock:
|
|
self.violations.pop(ip, None)
|
|
|
|
def cleanup_expired(self):
|
|
"""Remove entries with no recent violations."""
|
|
now = time.time()
|
|
window = self.cfg.get('violation_window', 600)
|
|
with self.lock:
|
|
expired = [ip for ip, times in self.violations.items()
|
|
if all(now - t >= window for t in times)]
|
|
for ip in expired:
|
|
del self.violations[ip]
|
|
|
|
|
|
# ==================== EWMAAnalyzer ====================
|
|
|
|
class EWMAAnalyzer:
|
|
"""Per-IP EWMA calculation for rate anomaly detection."""
|
|
|
|
def __init__(self, alpha=0.3, threshold_multiplier=3.0):
|
|
self.alpha = alpha
|
|
self.threshold_multiplier = threshold_multiplier
|
|
self.ewma = {}
|
|
self.baseline = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def update(self, ip, current_pps):
|
|
"""Update EWMA for an IP. Returns True if anomalous."""
|
|
with self.lock:
|
|
if ip not in self.ewma:
|
|
self.ewma[ip] = current_pps
|
|
self.baseline[ip] = current_pps
|
|
return False
|
|
|
|
self.ewma[ip] = self.alpha * current_pps + (1 - self.alpha) * self.ewma[ip]
|
|
self.baseline[ip] = 0.01 * current_pps + 0.99 * self.baseline[ip]
|
|
|
|
base = max(self.baseline[ip], 1)
|
|
if self.ewma[ip] > base * self.threshold_multiplier:
|
|
return True
|
|
return False
|
|
|
|
def get_stats(self, ip):
|
|
with self.lock:
|
|
return {
|
|
'ewma': self.ewma.get(ip, 0),
|
|
'baseline': self.baseline.get(ip, 0),
|
|
}
|
|
|
|
def cleanup_stale(self, active_ips):
|
|
"""Remove tracking for IPs no longer in rate counters."""
|
|
with self.lock:
|
|
stale = set(self.ewma.keys()) - set(active_ips)
|
|
for ip in stale:
|
|
self.ewma.pop(ip, None)
|
|
self.baseline.pop(ip, None)
|
|
|
|
|
|
# ==================== AIDetector ====================
|
|
|
|
class AIDetector:
|
|
"""Isolation Forest based anomaly detection on traffic features."""
|
|
|
|
def __init__(self, ai_cfg):
|
|
self.cfg = ai_cfg
|
|
self.model = None
|
|
self.scaler = None
|
|
self.started_at = time.time()
|
|
self.training_data = []
|
|
self.is_learning = True
|
|
self._retrain_requested = False
|
|
|
|
@property
|
|
def enabled(self):
|
|
return self.cfg.get('enabled', False)
|
|
|
|
def request_retrain(self):
|
|
self._retrain_requested = True
|
|
|
|
def collect_sample(self, features):
|
|
"""Collect a feature sample during learning phase."""
|
|
if not self.enabled:
|
|
return
|
|
|
|
self.training_data.append(features)
|
|
|
|
learning_dur = self.cfg.get('learning_duration', 259200)
|
|
min_samples = self.cfg.get('min_samples', 1000)
|
|
elapsed = time.time() - self.started_at
|
|
|
|
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
|
|
self._train()
|
|
self._retrain_requested = False
|
|
|
|
def _train(self):
|
|
"""Train the Isolation Forest model."""
|
|
try:
|
|
from sklearn.ensemble import IsolationForest
|
|
from sklearn.preprocessing import StandardScaler
|
|
import numpy as np
|
|
except ImportError:
|
|
log.error("scikit-learn not installed. AI detection disabled.")
|
|
self.cfg['enabled'] = False
|
|
return
|
|
|
|
if len(self.training_data) < 10:
|
|
log.warning("Not enough training data (%d samples)", len(self.training_data))
|
|
return
|
|
|
|
log.info("Training AI model with %d samples...", len(self.training_data))
|
|
|
|
try:
|
|
X = np.array(self.training_data)
|
|
|
|
self.scaler = StandardScaler()
|
|
X_scaled = self.scaler.fit_transform(X)
|
|
|
|
self.model = IsolationForest(
|
|
n_estimators=self.cfg.get('n_estimators', 100),
|
|
contamination=self.cfg.get('contamination', 'auto'),
|
|
random_state=42,
|
|
)
|
|
self.model.fit(X_scaled)
|
|
self.is_learning = False
|
|
|
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
with open(model_file, 'wb') as f:
|
|
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
|
|
|
|
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
|
with open(data_file, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([
|
|
'hour_sin', 'hour_cos',
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count',
|
|
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
|
])
|
|
writer.writerows(self.training_data)
|
|
|
|
log.info("AI model trained and saved to %s", model_file)
|
|
|
|
except Exception as e:
|
|
log.error("AI training failed: %s", e)
|
|
|
|
def load_model(self):
|
|
"""Load a previously trained model. Check feature dimension compatibility."""
|
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
if not os.path.exists(model_file):
|
|
return False
|
|
try:
|
|
with open(model_file, 'rb') as f:
|
|
data = pickle.load(f)
|
|
model = data['model']
|
|
scaler = data['scaler']
|
|
|
|
expected_features = 17
|
|
if hasattr(scaler, 'n_features_in_') and scaler.n_features_in_ != expected_features:
|
|
log.warning(
|
|
"Model has %d features, expected %d. Switching to learning mode.",
|
|
scaler.n_features_in_, expected_features
|
|
)
|
|
self.is_learning = True
|
|
return False
|
|
|
|
self.model = model
|
|
self.scaler = scaler
|
|
self.is_learning = False
|
|
log.info("AI model loaded from %s (%d features)",
|
|
model_file, getattr(scaler, 'n_features_in_', '?'))
|
|
return True
|
|
except Exception as e:
|
|
log.error("Failed to load AI model: %s", e)
|
|
return False
|
|
|
|
def predict(self, features):
|
|
"""Run anomaly detection. Returns (is_anomaly, score)."""
|
|
if not self.enabled or self.model is None:
|
|
return False, 0.0
|
|
|
|
try:
|
|
import numpy as np
|
|
X = np.array([features])
|
|
X_scaled = self.scaler.transform(X)
|
|
score = self.model.decision_function(X_scaled)[0]
|
|
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
|
return score < threshold, float(score)
|
|
except Exception as e:
|
|
log.error("AI prediction error: %s", e)
|
|
return False, 0.0
|
|
|
|
def retrain_from_log(self):
|
|
"""Retrain the model from traffic_log.csv data."""
|
|
log_file = self.cfg.get('traffic_log_file', '/var/lib/xdp-defense/traffic_log.csv')
|
|
if not os.path.exists(log_file):
|
|
log.warning("Traffic log not found: %s", log_file)
|
|
return False
|
|
|
|
retrain_window = self.cfg.get('retrain_window', 86400)
|
|
cutoff = datetime.now() - timedelta(seconds=retrain_window)
|
|
|
|
try:
|
|
samples = []
|
|
with open(log_file, 'r', newline='') as f:
|
|
reader = csv.reader(f)
|
|
header = next(reader, None)
|
|
if header is None:
|
|
log.warning("Traffic log is empty")
|
|
return False
|
|
|
|
# Feature columns: skip timestamp and hour (first 2), take remaining 17
|
|
for row in reader:
|
|
try:
|
|
ts = datetime.fromisoformat(row[0])
|
|
if ts < cutoff:
|
|
continue
|
|
features = [float(v) for v in row[2:]] # skip timestamp, hour
|
|
if len(features) == 17:
|
|
samples.append(features)
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
if len(samples) < 10:
|
|
log.warning("Not enough recent samples for retrain (%d)", len(samples))
|
|
return False
|
|
|
|
log.info("Auto-retrain: loading %d samples from traffic log (window=%ds)",
|
|
len(samples), retrain_window)
|
|
self.training_data = samples
|
|
self._train()
|
|
return True
|
|
|
|
except Exception as e:
|
|
log.error("retrain_from_log failed: %s", e)
|
|
return False
|
|
|
|
|
|
# ==================== ProfileManager ====================
|
|
|
|
class ProfileManager:
|
|
"""Manage time-based rate limit profiles."""
|
|
|
|
def __init__(self, rate_cfg):
|
|
self.cfg = rate_cfg
|
|
self.current_profile = 'default'
|
|
|
|
def check_and_apply(self):
|
|
"""Check current time and apply matching profile."""
|
|
from xdp_common import write_rate_config
|
|
|
|
profiles = self.cfg.get('profiles', {})
|
|
now = datetime.now()
|
|
current_hour = now.hour
|
|
current_min = now.minute
|
|
current_time = current_hour * 60 + current_min
|
|
weekday = now.strftime('%a').lower()
|
|
|
|
matched_profile = None
|
|
matched_name = 'default'
|
|
|
|
for name, profile in profiles.items():
|
|
hours = profile.get('hours', '')
|
|
weekdays = profile.get('weekdays', '')
|
|
|
|
if weekdays:
|
|
day_range = weekdays.lower().split('-')
|
|
day_names = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
|
|
if len(day_range) == 2:
|
|
try:
|
|
start_idx = day_names.index(day_range[0])
|
|
end_idx = day_names.index(day_range[1])
|
|
current_idx = day_names.index(weekday)
|
|
if start_idx <= end_idx:
|
|
if not (start_idx <= current_idx <= end_idx):
|
|
continue
|
|
else:
|
|
if not (current_idx >= start_idx or current_idx <= end_idx):
|
|
continue
|
|
except ValueError:
|
|
continue
|
|
|
|
if hours:
|
|
try:
|
|
start_str, end_str = hours.split('-')
|
|
sh, sm = map(int, start_str.split(':'))
|
|
eh, em = map(int, end_str.split(':'))
|
|
start_min = sh * 60 + sm
|
|
end_min = eh * 60 + em
|
|
|
|
if start_min <= end_min:
|
|
if not (start_min <= current_time < end_min):
|
|
continue
|
|
else:
|
|
if not (current_time >= start_min or current_time < end_min):
|
|
continue
|
|
except (ValueError, AttributeError):
|
|
continue
|
|
|
|
matched_profile = profile
|
|
matched_name = name
|
|
break
|
|
|
|
if matched_name != self.current_profile:
|
|
if matched_profile:
|
|
pps = matched_profile.get('pps', self.cfg.get('default_pps', 1000))
|
|
bps = matched_profile.get('bps', self.cfg.get('default_bps', 0))
|
|
else:
|
|
pps = self.cfg.get('default_pps', 1000)
|
|
bps = self.cfg.get('default_bps', 0)
|
|
|
|
window = self.cfg.get('window_sec', 1)
|
|
|
|
try:
|
|
write_rate_config(pps, bps, window * 1_000_000_000)
|
|
log.info("Profile switched: %s -> %s (pps=%d)", self.current_profile, matched_name, pps)
|
|
self.current_profile = matched_name
|
|
except Exception as e:
|
|
log.error("Failed to apply profile %s: %s", matched_name, e)
|
|
|
|
|
|
# ==================== DDoSDaemon ====================
|
|
|
|
class DDoSDaemon:
|
|
"""Main daemon orchestrator."""
|
|
|
|
def __init__(self, config_path=CONFIG_PATH):
|
|
self.config_path = config_path
|
|
self.cfg = load_config(config_path)
|
|
self._stop_event = threading.Event()
|
|
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
|
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
|
self._setup_components()
|
|
|
|
def _setup_components(self):
|
|
self.violation_tracker = ViolationTracker(self.cfg['escalation'])
|
|
self.ewma_analyzer = EWMAAnalyzer(
|
|
alpha=self.cfg['ewma'].get('alpha', 0.3),
|
|
threshold_multiplier=self.cfg['ewma'].get('threshold_multiplier', 3.0),
|
|
)
|
|
self.ai_detector = AIDetector(self.cfg['ai'])
|
|
self.profile_manager = ProfileManager(self.cfg['rate_limits'])
|
|
|
|
if self.ai_detector.enabled:
|
|
self.ai_detector.load_model()
|
|
|
|
self._last_retrain_time = self._get_model_mtime()
|
|
self._last_log_cleanup = time.time()
|
|
|
|
level = self.cfg['general'].get('log_level', 'info').upper()
|
|
log.setLevel(getattr(logging, level, logging.INFO))
|
|
|
|
def _write_pid(self):
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
os.makedirs(os.path.dirname(pid_file), exist_ok=True)
|
|
with open(pid_file, 'w') as f:
|
|
f.write(str(os.getpid()))
|
|
|
|
def _remove_pid(self):
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
try:
|
|
os.unlink(pid_file)
|
|
except OSError:
|
|
pass
|
|
|
|
def _ensure_single_instance(self):
|
|
"""Stop any existing daemon before starting."""
|
|
pid_file = self.cfg['general'].get('pid_file', '/var/lib/xdp-defense/daemon.pid')
|
|
if not os.path.exists(pid_file):
|
|
return
|
|
try:
|
|
with open(pid_file) as f:
|
|
old_pid = int(f.read().strip())
|
|
os.kill(old_pid, 0)
|
|
log.info("Stopping existing daemon (PID %d)...", old_pid)
|
|
os.kill(old_pid, signal.SIGTERM)
|
|
for _ in range(30):
|
|
time.sleep(1)
|
|
try:
|
|
os.kill(old_pid, 0)
|
|
except OSError:
|
|
log.info("Old daemon stopped")
|
|
return
|
|
log.warning("Daemon PID %d did not stop, sending SIGKILL", old_pid)
|
|
os.kill(old_pid, signal.SIGKILL)
|
|
time.sleep(1)
|
|
except (ValueError, OSError):
|
|
pass
|
|
|
|
def _handle_sighup(self, signum, frame):
|
|
log.info("SIGHUP received, reloading config...")
|
|
self.cfg = load_config(self.config_path)
|
|
# Update existing components without rebuilding (preserves EWMA/violation state)
|
|
self.violation_tracker.cfg = self.cfg['escalation']
|
|
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3)
|
|
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0)
|
|
self.ai_detector.cfg = self.cfg['ai']
|
|
self.profile_manager.cfg = self.cfg['rate_limits']
|
|
# Update poll intervals (used by threads on next iteration)
|
|
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
|
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
|
level = self.cfg['general'].get('log_level', 'info').upper()
|
|
log.setLevel(getattr(logging, level, logging.INFO))
|
|
log.info("Config reloaded (state preserved)")
|
|
|
|
def _handle_sigterm(self, signum, frame):
|
|
log.info("SIGTERM received, shutting down...")
|
|
self._stop_event.set()
|
|
|
|
def _handle_sigusr1(self, signum, frame):
|
|
log.info("SIGUSR1 received, triggering retrain from traffic log...")
|
|
if self.ai_detector.retrain_from_log():
|
|
self._last_retrain_time = time.time()
|
|
log.info("SIGUSR1 retrain completed successfully")
|
|
else:
|
|
log.warning("SIGUSR1 retrain failed (falling back to collect mode)")
|
|
self.ai_detector.request_retrain()
|
|
|
|
# ---- Traffic Logging ----
|
|
|
|
TRAFFIC_CSV_HEADER = [
|
|
'timestamp', 'hour',
|
|
'hour_sin', 'hour_cos',
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count',
|
|
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
|
]
|
|
|
|
def _log_traffic(self, now, hour, features):
|
|
"""Append one row to traffic_log.csv."""
|
|
log_file = self.cfg['ai'].get('traffic_log_file', '/var/lib/xdp-defense/traffic_log.csv')
|
|
try:
|
|
write_header = not os.path.exists(log_file) or os.path.getsize(log_file) == 0
|
|
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
|
with open(log_file, 'a', newline='') as f:
|
|
writer = csv.writer(f)
|
|
if write_header:
|
|
writer.writerow(self.TRAFFIC_CSV_HEADER)
|
|
row = [now.isoformat(), f'{hour:.4f}'] + [f'{v:.6f}' for v in features]
|
|
writer.writerow(row)
|
|
except Exception as e:
|
|
log.error("Failed to write traffic log: %s", e)
|
|
|
|
def _cleanup_traffic_log(self):
|
|
"""Remove entries older than retention_days from traffic_log.csv."""
|
|
log_file = self.cfg['ai'].get('traffic_log_file', '/var/lib/xdp-defense/traffic_log.csv')
|
|
retention_days = self.cfg['ai'].get('traffic_log_retention_days', 7)
|
|
cutoff = datetime.now() - timedelta(days=retention_days)
|
|
|
|
if not os.path.exists(log_file):
|
|
return
|
|
|
|
try:
|
|
kept = []
|
|
header = None
|
|
with open(log_file, 'r', newline='') as f:
|
|
reader = csv.reader(f)
|
|
header = next(reader, None)
|
|
for row in reader:
|
|
try:
|
|
ts = datetime.fromisoformat(row[0])
|
|
if ts >= cutoff:
|
|
kept.append(row)
|
|
except (ValueError, IndexError):
|
|
kept.append(row) # keep unparseable rows
|
|
|
|
with open(log_file, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
if header:
|
|
writer.writerow(header)
|
|
writer.writerows(kept)
|
|
|
|
log.info("Traffic log cleanup: kept %d rows (retention=%dd)", len(kept), retention_days)
|
|
except Exception as e:
|
|
log.error("Traffic log cleanup failed: %s", e)
|
|
|
|
def _get_model_mtime(self):
|
|
"""Get model file modification time, or current time if not found."""
|
|
model_file = self.cfg['ai'].get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
|
try:
|
|
return os.path.getmtime(model_file)
|
|
except OSError:
|
|
return time.time()
|
|
|
|
# ---- Worker Threads ----
|
|
|
|
def _ewma_thread(self):
|
|
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
|
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
|
|
|
|
prev_counters = {}
|
|
|
|
while not self._stop_event.is_set():
|
|
interval = self._ewma_interval
|
|
try:
|
|
entries = dump_rate_counters('rate_counter_v4', top_n=1000)
|
|
active_ips = []
|
|
|
|
for ip_str, pkts, bts, last_seen in entries:
|
|
active_ips.append(ip_str)
|
|
|
|
prev = prev_counters.get(ip_str, 0)
|
|
delta = pkts - prev if pkts >= prev else pkts
|
|
prev_counters[ip_str] = pkts
|
|
|
|
if delta <= 0:
|
|
continue
|
|
|
|
pps = delta / max(interval, 0.1)
|
|
|
|
is_anomalous = self.ewma_analyzer.update(ip_str, pps)
|
|
if is_anomalous:
|
|
# Skip whitelisted IPs
|
|
if is_whitelisted(ip_str):
|
|
log.debug("EWMA anomaly skipped (whitelisted): %s", ip_str)
|
|
continue
|
|
|
|
level = self.violation_tracker.record_violation(ip_str)
|
|
ew = self.ewma_analyzer.get_stats(ip_str)
|
|
log.warning(
|
|
"EWMA anomaly: %s pps=%.1f ewma=%.1f baseline=%.1f -> %s",
|
|
ip_str, pps, ew['ewma'], ew['baseline'], level
|
|
)
|
|
|
|
if level == 'temp_block':
|
|
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
|
try:
|
|
block_ip(ip_str, dur)
|
|
log.warning("TEMP BLOCK: %s for %ds", ip_str, dur)
|
|
except Exception as e:
|
|
log.error("Failed to temp-block %s: %s", ip_str, e)
|
|
|
|
elif level == 'perm_block':
|
|
try:
|
|
block_ip(ip_str, 0)
|
|
log.warning("PERM BLOCK: %s", ip_str)
|
|
except Exception as e:
|
|
log.error("Failed to perm-block %s: %s", ip_str, e)
|
|
|
|
self.ewma_analyzer.cleanup_stale(active_ips)
|
|
|
|
except Exception as e:
|
|
log.error("EWMA thread error: %s", e)
|
|
|
|
self._stop_event.wait(interval)
|
|
|
|
def _ai_thread(self):
|
|
"""Read traffic features, run AI inference or collect training data."""
|
|
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
|
|
|
|
prev_features = None
|
|
self._last_retrain_time = self._get_model_mtime()
|
|
self._last_log_cleanup = time.time()
|
|
|
|
while not self._stop_event.is_set():
|
|
interval = self._ai_interval
|
|
try:
|
|
if not self.ai_detector.enabled:
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
features = read_percpu_features()
|
|
if not features:
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
feature_names = [
|
|
'total_packets', 'total_bytes', 'tcp_syn_count', 'tcp_other_count',
|
|
'udp_count', 'icmp_count', 'other_proto_count', 'unique_ips_approx',
|
|
'small_pkt_count', 'large_pkt_count'
|
|
]
|
|
|
|
if prev_features is not None:
|
|
deltas = []
|
|
for name in feature_names:
|
|
cur = features.get(name, 0)
|
|
prev = prev_features.get(name, 0)
|
|
deltas.append(max(0, cur - prev))
|
|
|
|
min_pkts = self.cfg['ai'].get('min_packets_for_sample', 20)
|
|
if deltas[0] < min_pkts:
|
|
prev_features = features
|
|
self._stop_event.wait(interval)
|
|
continue
|
|
|
|
total = deltas[0] + 1e-6
|
|
syn_ratio = deltas[2] / total
|
|
udp_ratio = deltas[4] / total
|
|
icmp_ratio = deltas[5] / total
|
|
small_pkt_ratio = deltas[8] / total
|
|
avg_pkt_size = deltas[1] / total
|
|
deltas.extend([syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size])
|
|
|
|
# Add time features (hour_sin, hour_cos) at the front
|
|
now = datetime.now()
|
|
hour = now.hour + now.minute / 60.0
|
|
hour_sin = math.sin(2 * math.pi * hour / 24)
|
|
hour_cos = math.cos(2 * math.pi * hour / 24)
|
|
deltas_with_time = [hour_sin, hour_cos] + deltas # 17 features
|
|
|
|
# Log to traffic CSV
|
|
self._log_traffic(now, hour, deltas_with_time)
|
|
|
|
# Periodic log file cleanup (once per day)
|
|
if time.time() - self._last_log_cleanup > 86400:
|
|
self._cleanup_traffic_log()
|
|
self._last_log_cleanup = time.time()
|
|
|
|
if self.ai_detector.is_learning:
|
|
self.ai_detector.collect_sample(deltas_with_time)
|
|
if len(self.ai_detector.training_data) % 100 == 0:
|
|
log.debug("AI learning: %d samples collected",
|
|
len(self.ai_detector.training_data))
|
|
else:
|
|
# Auto-retrain check
|
|
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
|
|
if time.time() - self._last_retrain_time >= retrain_interval:
|
|
log.info("Auto-retrain triggered (interval=%ds)", retrain_interval)
|
|
if self.ai_detector.retrain_from_log():
|
|
self._last_retrain_time = time.time()
|
|
log.info("Auto-retrain completed successfully")
|
|
else:
|
|
log.warning("Auto-retrain failed, will retry next interval")
|
|
self._last_retrain_time = time.time()
|
|
|
|
is_anomaly, score = self.ai_detector.predict(deltas_with_time)
|
|
if is_anomaly:
|
|
log.warning(
|
|
"AI ANOMALY detected: score=%.4f deltas=%s",
|
|
score, dict(zip(feature_names, deltas[:len(feature_names)]))
|
|
)
|
|
top_ips = dump_rate_counters('rate_counter_v4', top_n=5)
|
|
for ip_str, pkts, bts, _ in top_ips:
|
|
# Skip whitelisted IPs
|
|
if is_whitelisted(ip_str):
|
|
log.debug("AI escalation skipped (whitelisted): %s", ip_str)
|
|
continue
|
|
|
|
level = self.violation_tracker.record_violation(ip_str)
|
|
log.warning("AI escalation: %s -> %s", ip_str, level)
|
|
|
|
if level == 'temp_block':
|
|
dur = self.cfg['escalation'].get('temp_block_duration', 300)
|
|
try:
|
|
block_ip(ip_str, dur)
|
|
log.warning("AI TEMP BLOCK: %s for %ds", ip_str, dur)
|
|
except Exception as e:
|
|
log.error("Failed to AI temp-block %s: %s", ip_str, e)
|
|
|
|
elif level == 'perm_block':
|
|
try:
|
|
block_ip(ip_str, 0)
|
|
log.warning("AI PERM BLOCK: %s", ip_str)
|
|
except Exception as e:
|
|
log.error("Failed to AI perm-block %s: %s", ip_str, e)
|
|
|
|
prev_features = features
|
|
|
|
except Exception as e:
|
|
log.error("AI thread error: %s", e)
|
|
|
|
self._stop_event.wait(interval)
|
|
|
|
def _profile_thread(self):
|
|
"""Check time-of-day and switch rate profiles."""
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
self.profile_manager.check_and_apply()
|
|
except Exception as e:
|
|
log.error("Profile thread error: %s", e)
|
|
self._stop_event.wait(60)
|
|
|
|
def _cleanup_thread(self):
|
|
"""Periodically clean up expired blocked IPs and stale violations."""
|
|
from xdp_common import dump_blocked_ips, unblock_ip
|
|
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
with open('/proc/uptime') as f:
|
|
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
|
|
|
for map_name in ['blocked_ips_v4', 'blocked_ips_v6']:
|
|
entries = dump_blocked_ips(map_name)
|
|
for ip_str, expire_ns, blocked_at, drop_count in entries:
|
|
if expire_ns != 0 and now_ns > expire_ns:
|
|
try:
|
|
unblock_ip(ip_str)
|
|
self.violation_tracker.clear(ip_str)
|
|
log.info("Expired block removed: %s (dropped %d pkts)", ip_str, drop_count)
|
|
except Exception as e:
|
|
log.error("Failed to remove expired block %s: %s", ip_str, e)
|
|
|
|
self.violation_tracker.cleanup_expired()
|
|
|
|
except Exception as e:
|
|
log.error("Cleanup thread error: %s", e)
|
|
|
|
self._stop_event.wait(60)
|
|
|
|
# ---- Main Loop ----
|
|
|
|
def run(self):
|
|
"""Start the daemon."""
|
|
log.info("XDP Defense Daemon starting...")
|
|
|
|
signal.signal(signal.SIGHUP, self._handle_sighup)
|
|
signal.signal(signal.SIGTERM, self._handle_sigterm)
|
|
signal.signal(signal.SIGINT, self._handle_sigterm)
|
|
signal.signal(signal.SIGUSR1, self._handle_sigusr1)
|
|
|
|
self._ensure_single_instance()
|
|
self._write_pid()
|
|
|
|
threads = [
|
|
threading.Thread(target=self._ewma_thread, name='ewma', daemon=True),
|
|
threading.Thread(target=self._ai_thread, name='ai', daemon=True),
|
|
threading.Thread(target=self._profile_thread, name='profile', daemon=True),
|
|
threading.Thread(target=self._cleanup_thread, name='cleanup', daemon=True),
|
|
]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
log.info("Started %s thread", t.name)
|
|
|
|
log.info("Daemon running (PID %d)", os.getpid())
|
|
|
|
try:
|
|
while not self._stop_event.is_set():
|
|
self._stop_event.wait(1)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
log.info("Shutting down...")
|
|
self._stop_event.set()
|
|
|
|
for t in threads:
|
|
t.join(timeout=5)
|
|
|
|
self._remove_pid()
|
|
log.info("Daemon stopped")
|
|
|
|
|
|
# ==================== Entry Point ====================
|
|
|
|
def main():
|
|
config_path = CONFIG_PATH
|
|
if len(sys.argv) > 1:
|
|
config_path = sys.argv[1]
|
|
|
|
daemon = DDoSDaemon(config_path)
|
|
daemon.run()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|