Fix HIGH severity security and thread-safety issues
Daemon fixes: - Add _db_lock for thread-safe SQLite access - Atomic SIGHUP config swap (build all values before applying) - Check world-writable permission before loading pickle model - Write model files with 0o600 permissions via os.open - Module-level xdp_common import with fatal exit on failure - Close traffic DB on shutdown - Add period_data parameter to _train() to avoid race condition CLI fixes: - Replace $COMMON_PY variable with hardcoded 'xdp_common' - Pass CONFIG_FILE via sys.argv instead of string interpolation - Add key_hex regex validation before all bpftool commands - Switch sanitize_input from denylist to strict allowlist Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -14,6 +14,7 @@ time-profile switching, and automatic escalation.
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import time
|
||||
import signal
|
||||
@@ -28,6 +29,16 @@ from datetime import datetime, timedelta
|
||||
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from xdp_common import (
|
||||
dump_rate_counters, block_ip, is_whitelisted,
|
||||
read_percpu_features, dump_blocked_ips, unblock_ip,
|
||||
write_rate_config, read_rate_config,
|
||||
)
|
||||
except ImportError as e:
|
||||
print(f"FATAL: Cannot import xdp_common: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# ==================== Logging ====================
|
||||
|
||||
log = logging.getLogger('xdp-defense-daemon')
|
||||
@@ -256,8 +267,13 @@ class AIDetector:
|
||||
self._train()
|
||||
self._retrain_requested = False
|
||||
|
||||
def _train(self):
|
||||
"""Train per-period Isolation Forest models."""
|
||||
def _train(self, period_data=None):
|
||||
"""Train per-period Isolation Forest models.
|
||||
If period_data provided, use it instead of self.training_data.
|
||||
"""
|
||||
if period_data is None:
|
||||
period_data = self.training_data
|
||||
|
||||
try:
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
@@ -267,19 +283,19 @@ class AIDetector:
|
||||
self.cfg['enabled'] = False
|
||||
return
|
||||
|
||||
total = sum(len(v) for v in self.training_data.values())
|
||||
total = sum(len(v) for v in period_data.values())
|
||||
if total < 10:
|
||||
log.warning("Not enough training data (%d samples)", total)
|
||||
return
|
||||
|
||||
log.info("Training AI models: %s",
|
||||
{p: len(s) for p, s in self.training_data.items() if s})
|
||||
{p: len(s) for p, s in period_data.items() if s})
|
||||
|
||||
try:
|
||||
new_models = {}
|
||||
all_samples = []
|
||||
|
||||
for period, samples in self.training_data.items():
|
||||
for period, samples in period_data.items():
|
||||
if len(samples) < 10:
|
||||
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
|
||||
continue
|
||||
@@ -309,7 +325,8 @@ class AIDetector:
|
||||
# Save to disk (atomic)
|
||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||
tmp_model = model_file + '.tmp'
|
||||
with open(tmp_model, 'wb') as f:
|
||||
fd = os.open(tmp_model, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, 'wb') as f:
|
||||
pickle.dump({
|
||||
'format': 'period_models',
|
||||
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
|
||||
@@ -323,7 +340,8 @@ class AIDetector:
|
||||
# Save training data CSV
|
||||
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
||||
tmp_data = data_file + '.tmp'
|
||||
with open(tmp_data, 'w', newline='') as f:
|
||||
fd = os.open(tmp_data, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
|
||||
with os.fdopen(fd, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
'hour_sin', 'hour_cos',
|
||||
@@ -348,6 +366,12 @@ class AIDetector:
|
||||
if not os.path.exists(model_file):
|
||||
return False
|
||||
try:
|
||||
st = os.stat(model_file)
|
||||
# Warn if file is world-writable
|
||||
if st.st_mode & stat.S_IWOTH:
|
||||
log.warning("Model file %s is world-writable! Refusing to load.", model_file)
|
||||
return False
|
||||
|
||||
with open(model_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
@@ -488,8 +512,7 @@ class AIDetector:
|
||||
len(rows), filtered_count,
|
||||
{p: len(s) for p, s in period_data.items() if s})
|
||||
|
||||
self.training_data = period_data
|
||||
self._train()
|
||||
self._train(period_data=period_data)
|
||||
return not self.is_learning
|
||||
|
||||
|
||||
@@ -504,8 +527,6 @@ class ProfileManager:
|
||||
|
||||
def check_and_apply(self):
|
||||
"""Check current time and apply matching profile."""
|
||||
from xdp_common import write_rate_config
|
||||
|
||||
profiles = self.cfg.get('profiles', {})
|
||||
now = datetime.now()
|
||||
current_hour = now.hour
|
||||
@@ -648,17 +669,25 @@ class DDoSDaemon:
|
||||
|
||||
def _handle_sighup(self, signum, frame):
|
||||
log.info("SIGHUP received, reloading config...")
|
||||
self.cfg = load_config(self.config_path)
|
||||
# Update existing components without rebuilding (preserves EWMA/violation state)
|
||||
self.violation_tracker.cfg = self.cfg['escalation']
|
||||
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3)
|
||||
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0)
|
||||
self.ai_detector.cfg = self.cfg['ai']
|
||||
self.profile_manager.cfg = self.cfg['rate_limits']
|
||||
# Update poll intervals (used by threads on next iteration)
|
||||
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
||||
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
||||
level = self.cfg['general'].get('log_level', 'info').upper()
|
||||
new_cfg = load_config(self.config_path)
|
||||
# Build all new values before swapping anything
|
||||
new_escalation = new_cfg['escalation']
|
||||
new_alpha = new_cfg['ewma'].get('alpha', 0.3)
|
||||
new_threshold = new_cfg['ewma'].get('threshold_multiplier', 3.0)
|
||||
new_ai_cfg = new_cfg['ai']
|
||||
new_rate_cfg = new_cfg['rate_limits']
|
||||
new_ewma_interval = new_cfg['ewma'].get('poll_interval', 1)
|
||||
new_ai_interval = new_cfg['ai'].get('poll_interval', 5)
|
||||
level = new_cfg['general'].get('log_level', 'info').upper()
|
||||
# Now apply all at once
|
||||
self.cfg = new_cfg
|
||||
self.violation_tracker.cfg = new_escalation
|
||||
self.ewma_analyzer.alpha = new_alpha
|
||||
self.ewma_analyzer.threshold_multiplier = new_threshold
|
||||
self.ai_detector.cfg = new_ai_cfg
|
||||
self.profile_manager.cfg = new_rate_cfg
|
||||
self._ewma_interval = new_ewma_interval
|
||||
self._ai_interval = new_ai_interval
|
||||
log.setLevel(getattr(logging, level, logging.INFO))
|
||||
log.info("Config reloaded (state preserved)")
|
||||
|
||||
@@ -678,6 +707,7 @@ class DDoSDaemon:
|
||||
"""Initialize SQLite database for traffic logging."""
|
||||
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self._db_lock = threading.Lock()
|
||||
self._traffic_db = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._traffic_db.execute(
|
||||
'CREATE TABLE IF NOT EXISTS traffic_samples ('
|
||||
@@ -712,17 +742,18 @@ class DDoSDaemon:
|
||||
def _log_traffic(self, now, hour, features):
|
||||
"""Insert one row into traffic_samples table."""
|
||||
try:
|
||||
self._traffic_db.execute(
|
||||
'INSERT INTO traffic_samples ('
|
||||
' timestamp, hour, hour_sin, hour_cos,'
|
||||
' total_packets, total_bytes, tcp_syn_count, tcp_other_count,'
|
||||
' udp_count, icmp_count, other_proto_count, unique_ips_approx,'
|
||||
' small_pkt_count, large_pkt_count,'
|
||||
' syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size'
|
||||
') VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)',
|
||||
(now.isoformat(), hour, *features)
|
||||
)
|
||||
self._traffic_db.commit()
|
||||
with self._db_lock:
|
||||
self._traffic_db.execute(
|
||||
'INSERT INTO traffic_samples ('
|
||||
' timestamp, hour, hour_sin, hour_cos,'
|
||||
' total_packets, total_bytes, tcp_syn_count, tcp_other_count,'
|
||||
' udp_count, icmp_count, other_proto_count, unique_ips_approx,'
|
||||
' small_pkt_count, large_pkt_count,'
|
||||
' syn_ratio, udp_ratio, icmp_ratio, small_pkt_ratio, avg_pkt_size'
|
||||
') VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)',
|
||||
(now.isoformat(), hour, *features)
|
||||
)
|
||||
self._traffic_db.commit()
|
||||
except Exception as e:
|
||||
log.error("Failed to write traffic log: %s", e)
|
||||
|
||||
@@ -732,13 +763,14 @@ class DDoSDaemon:
|
||||
cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat()
|
||||
|
||||
try:
|
||||
cur = self._traffic_db.execute(
|
||||
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
self._traffic_db.commit()
|
||||
if deleted > 1000:
|
||||
self._traffic_db.execute('VACUUM')
|
||||
with self._db_lock:
|
||||
cur = self._traffic_db.execute(
|
||||
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
self._traffic_db.commit()
|
||||
if deleted > 1000:
|
||||
self._traffic_db.execute('VACUUM')
|
||||
log.info("Traffic log cleanup: deleted %d rows (retention=%dd)", deleted, retention_days)
|
||||
except Exception as e:
|
||||
log.error("Traffic log cleanup failed: %s", e)
|
||||
@@ -755,8 +787,6 @@ class DDoSDaemon:
|
||||
|
||||
def _ewma_thread(self):
|
||||
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
||||
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
|
||||
|
||||
prev_counters = {}
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
@@ -815,8 +845,6 @@ class DDoSDaemon:
|
||||
|
||||
def _ai_thread(self):
|
||||
"""Read traffic features, run AI inference or collect training data."""
|
||||
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
|
||||
|
||||
prev_features = None
|
||||
self._last_retrain_time = self._get_model_mtime()
|
||||
self._last_log_cleanup = time.time()
|
||||
@@ -938,8 +966,6 @@ class DDoSDaemon:
|
||||
|
||||
def _cleanup_thread(self):
|
||||
"""Periodically clean up expired blocked IPs and stale violations."""
|
||||
from xdp_common import dump_blocked_ips, unblock_ip
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
with open('/proc/uptime') as f:
|
||||
@@ -1003,6 +1029,11 @@ class DDoSDaemon:
|
||||
t.join(timeout=5)
|
||||
|
||||
self._remove_pid()
|
||||
if hasattr(self, '_traffic_db') and self._traffic_db:
|
||||
try:
|
||||
self._traffic_db.close()
|
||||
except Exception:
|
||||
pass
|
||||
log.info("Daemon stopped")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user