Add time-period models, outlier filtering, and background retrain
- Split single IsolationForest into 4 period models (night/morning/afternoon/evening) - Each period trained independently on its time window data - Filter attack samples during retrain using existing model scores (threshold -0.5) - Retrain runs in background thread with lock, inference continues uninterrupted - New pickle format 'period_models' with automatic old format detection - SIGUSR1 and auto-retrain both use background mode Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -201,41 +201,63 @@ class EWMAAnalyzer:
|
||||
# ==================== AIDetector ====================
|
||||
|
||||
class AIDetector:
|
||||
"""Isolation Forest based anomaly detection on traffic features."""
|
||||
"""Time-period Isolation Forest anomaly detection.
|
||||
|
||||
Maintains 4 separate models for different time periods:
|
||||
night (00-06), morning (06-12), afternoon (12-18), evening (18-24).
|
||||
Each model learns the normal traffic pattern for its time window.
|
||||
"""
|
||||
|
||||
TIME_PERIODS = {
|
||||
'night': (0, 6),
|
||||
'morning': (6, 12),
|
||||
'afternoon': (12, 18),
|
||||
'evening': (18, 24),
|
||||
}
|
||||
|
||||
def __init__(self, ai_cfg):
|
||||
self.cfg = ai_cfg
|
||||
self.model = None
|
||||
self.scaler = None
|
||||
self.models = {} # {period_name: {'model': ..., 'scaler': ...}}
|
||||
self.started_at = time.time()
|
||||
self.training_data = []
|
||||
self.training_data = defaultdict(list) # {period_name: [samples]}
|
||||
self.is_learning = True
|
||||
self._retrain_requested = False
|
||||
self._retrain_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.cfg.get('enabled', False)
|
||||
|
||||
@staticmethod
|
||||
def get_period(hour):
|
||||
"""Map hour (0-24) to period name."""
|
||||
for name, (start, end) in AIDetector.TIME_PERIODS.items():
|
||||
if start <= hour < end:
|
||||
return name
|
||||
return 'night' # hour == 24 edge case
|
||||
|
||||
def request_retrain(self):
|
||||
self._retrain_requested = True
|
||||
|
||||
def collect_sample(self, features):
|
||||
"""Collect a feature sample during learning phase."""
|
||||
def collect_sample(self, features, hour):
|
||||
"""Collect a feature sample during learning phase, bucketed by period."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self.training_data.append(features)
|
||||
period = self.get_period(hour)
|
||||
self.training_data[period].append(features)
|
||||
|
||||
learning_dur = self.cfg.get('learning_duration', 259200)
|
||||
min_samples = self.cfg.get('min_samples', 1000)
|
||||
elapsed = time.time() - self.started_at
|
||||
total = sum(len(v) for v in self.training_data.values())
|
||||
|
||||
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
|
||||
if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested:
|
||||
self._train()
|
||||
self._retrain_requested = False
|
||||
|
||||
def _train(self):
|
||||
"""Train the Isolation Forest model."""
|
||||
"""Train per-period Isolation Forest models."""
|
||||
try:
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
@@ -245,32 +267,63 @@ class AIDetector:
|
||||
self.cfg['enabled'] = False
|
||||
return
|
||||
|
||||
if len(self.training_data) < 10:
|
||||
log.warning("Not enough training data (%d samples)", len(self.training_data))
|
||||
total = sum(len(v) for v in self.training_data.values())
|
||||
if total < 10:
|
||||
log.warning("Not enough training data (%d samples)", total)
|
||||
return
|
||||
|
||||
log.info("Training AI model with %d samples...", len(self.training_data))
|
||||
log.info("Training AI models: %s",
|
||||
{p: len(s) for p, s in self.training_data.items() if s})
|
||||
|
||||
try:
|
||||
X = np.array(self.training_data)
|
||||
new_models = {}
|
||||
all_samples = []
|
||||
|
||||
self.scaler = StandardScaler()
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
for period, samples in self.training_data.items():
|
||||
if len(samples) < 10:
|
||||
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
|
||||
continue
|
||||
|
||||
self.model = IsolationForest(
|
||||
X = np.array(samples)
|
||||
scaler = StandardScaler()
|
||||
X_scaled = scaler.fit_transform(X)
|
||||
|
||||
model = IsolationForest(
|
||||
n_estimators=self.cfg.get('n_estimators', 100),
|
||||
contamination=self.cfg.get('contamination', 'auto'),
|
||||
random_state=42,
|
||||
)
|
||||
self.model.fit(X_scaled)
|
||||
model.fit(X_scaled)
|
||||
new_models[period] = {'model': model, 'scaler': scaler}
|
||||
all_samples.extend(samples)
|
||||
log.info("Period %s: trained with %d samples", period, len(samples))
|
||||
|
||||
if not new_models:
|
||||
log.warning("No period had enough data to train")
|
||||
return
|
||||
|
||||
# Atomic swap
|
||||
self.models = new_models
|
||||
self.is_learning = False
|
||||
|
||||
# Save to disk (atomic)
|
||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||
with open(model_file, 'wb') as f:
|
||||
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
|
||||
tmp_model = model_file + '.tmp'
|
||||
with open(tmp_model, 'wb') as f:
|
||||
pickle.dump({
|
||||
'format': 'period_models',
|
||||
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
|
||||
for p, m in new_models.items()},
|
||||
'feature_count': 17,
|
||||
}, f)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.rename(tmp_model, model_file)
|
||||
|
||||
# Save training data CSV
|
||||
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
||||
with open(data_file, 'w', newline='') as f:
|
||||
tmp_data = data_file + '.tmp'
|
||||
with open(tmp_data, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
'hour_sin', 'hour_cos',
|
||||
@@ -279,63 +332,97 @@ class AIDetector:
|
||||
'small_pkt_count', 'large_pkt_count',
|
||||
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
||||
])
|
||||
writer.writerows(self.training_data)
|
||||
writer.writerows(all_samples)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.rename(tmp_data, data_file)
|
||||
|
||||
log.info("AI model trained and saved to %s", model_file)
|
||||
log.info("AI models trained and saved (%d periods)", len(new_models))
|
||||
|
||||
except Exception as e:
|
||||
log.error("AI training failed: %s", e)
|
||||
|
||||
def load_model(self):
|
||||
"""Load a previously trained model. Check feature dimension compatibility."""
|
||||
"""Load previously trained models. Handle old and new formats."""
|
||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||
if not os.path.exists(model_file):
|
||||
return False
|
||||
try:
|
||||
with open(model_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
model = data['model']
|
||||
scaler = data['scaler']
|
||||
|
||||
expected_features = 17
|
||||
if hasattr(scaler, 'n_features_in_') and scaler.n_features_in_ != expected_features:
|
||||
log.warning(
|
||||
"Model has %d features, expected %d. Switching to learning mode.",
|
||||
scaler.n_features_in_, expected_features
|
||||
)
|
||||
# New period_models format
|
||||
if data.get('format') == 'period_models':
|
||||
self.models = data['models']
|
||||
self.is_learning = False
|
||||
periods = list(self.models.keys())
|
||||
log.info("AI models loaded: %s (%d features)",
|
||||
periods, data.get('feature_count', '?'))
|
||||
return True
|
||||
|
||||
# Old single-model format → discard
|
||||
log.warning("Old single-model format detected. Switching to learning mode.")
|
||||
self.is_learning = True
|
||||
return False
|
||||
|
||||
self.model = model
|
||||
self.scaler = scaler
|
||||
self.is_learning = False
|
||||
log.info("AI model loaded from %s (%d features)",
|
||||
model_file, getattr(scaler, 'n_features_in_', '?'))
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error("Failed to load AI model: %s", e)
|
||||
return False
|
||||
|
||||
def predict(self, features):
|
||||
"""Run anomaly detection. Returns (is_anomaly, score)."""
|
||||
if not self.enabled or self.model is None:
|
||||
def predict(self, features, hour):
|
||||
"""Run anomaly detection using the period-specific model.
|
||||
Returns (is_anomaly, score).
|
||||
"""
|
||||
if not self.enabled or not self.models:
|
||||
return False, 0.0
|
||||
|
||||
period = self.get_period(hour)
|
||||
model_info = self.models.get(period)
|
||||
|
||||
# Fallback: try adjacent periods
|
||||
if model_info is None:
|
||||
for p in self.models:
|
||||
model_info = self.models[p]
|
||||
break
|
||||
if model_info is None:
|
||||
return False, 0.0
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
X = np.array([features])
|
||||
X_scaled = self.scaler.transform(X)
|
||||
score = self.model.decision_function(X_scaled)[0]
|
||||
X_scaled = model_info['scaler'].transform(X)
|
||||
score = model_info['model'].decision_function(X_scaled)[0]
|
||||
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
||||
return score < threshold, float(score)
|
||||
except Exception as e:
|
||||
log.error("AI prediction error: %s", e)
|
||||
return False, 0.0
|
||||
|
||||
def retrain_from_log(self, db_path=None):
|
||||
"""Retrain the model from traffic_log.db data."""
|
||||
def retrain_from_log(self, db_path=None, background=False):
|
||||
"""Retrain models from traffic_log.db. Optionally in background thread."""
|
||||
if db_path is None:
|
||||
db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||
|
||||
if background:
|
||||
t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True)
|
||||
t.start()
|
||||
return True # started, not yet finished
|
||||
|
||||
return self._do_retrain(db_path)
|
||||
|
||||
def _do_retrain(self, db_path):
|
||||
"""Actual retrain logic (can run in background thread)."""
|
||||
if not self._retrain_lock.acquire(blocking=False):
|
||||
log.info("Retrain already in progress, skipping")
|
||||
return False
|
||||
|
||||
try:
|
||||
return self._retrain_impl(db_path)
|
||||
finally:
|
||||
self._retrain_lock.release()
|
||||
|
||||
def _retrain_impl(self, db_path):
|
||||
"""Load data from DB, filter outliers, train per-period models."""
|
||||
if not os.path.exists(db_path):
|
||||
log.warning("Traffic log DB not found: %s", db_path)
|
||||
return False
|
||||
@@ -347,7 +434,7 @@ class AIDetector:
|
||||
try:
|
||||
conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
|
||||
cur = conn.execute(
|
||||
'SELECT hour_sin, hour_cos, total_packets, total_bytes, '
|
||||
'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, '
|
||||
'tcp_syn_count, tcp_other_count, udp_count, icmp_count, '
|
||||
'other_proto_count, unique_ips_approx, small_pkt_count, '
|
||||
'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, '
|
||||
@@ -355,25 +442,56 @@ class AIDetector:
|
||||
'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp',
|
||||
(cutoff,)
|
||||
)
|
||||
samples = [list(row) for row in cur.fetchall()]
|
||||
|
||||
if len(samples) < 10:
|
||||
log.warning("Not enough recent samples for retrain (%d)", len(samples))
|
||||
return False
|
||||
|
||||
log.info("Auto-retrain: loading %d samples from traffic log (window=%ds)",
|
||||
len(samples), retrain_window)
|
||||
self.training_data = samples
|
||||
self._train()
|
||||
return True
|
||||
|
||||
rows = cur.fetchall()
|
||||
except Exception as e:
|
||||
log.error("retrain_from_log failed: %s", e)
|
||||
log.error("retrain_from_log DB read failed: %s", e)
|
||||
return False
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
if len(rows) < 10:
|
||||
log.warning("Not enough recent samples for retrain (%d)", len(rows))
|
||||
return False
|
||||
|
||||
# Split by period: row[0]=hour, row[1:]=17 features
|
||||
period_data = defaultdict(list)
|
||||
for row in rows:
|
||||
hour = row[0]
|
||||
features = list(row[1:]) # 17 features
|
||||
period = self.get_period(hour)
|
||||
period_data[period].append(features)
|
||||
|
||||
# Filter outliers using existing models (item #2)
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
log.error("numpy not available for outlier filtering")
|
||||
return False
|
||||
|
||||
filtered_count = 0
|
||||
for period, samples in period_data.items():
|
||||
model_info = self.models.get(period)
|
||||
if model_info is None or len(samples) < 10:
|
||||
continue
|
||||
X = np.array(samples)
|
||||
X_scaled = model_info['scaler'].transform(X)
|
||||
scores = model_info['model'].decision_function(X_scaled)
|
||||
clean = [s for s, sc in zip(samples, scores) if sc >= -0.5]
|
||||
removed = len(samples) - len(clean)
|
||||
if removed > 0:
|
||||
filtered_count += removed
|
||||
period_data[period] = clean
|
||||
|
||||
total = sum(len(v) for v in period_data.values())
|
||||
log.info("Retrain: %d samples loaded, %d outliers filtered, %s",
|
||||
len(rows), filtered_count,
|
||||
{p: len(s) for p, s in period_data.items() if s})
|
||||
|
||||
self.training_data = period_data
|
||||
self._train()
|
||||
return not self.is_learning
|
||||
|
||||
|
||||
# ==================== ProfileManager ====================
|
||||
|
||||
@@ -549,14 +667,10 @@ class DDoSDaemon:
|
||||
self._stop_event.set()
|
||||
|
||||
def _handle_sigusr1(self, signum, frame):
|
||||
log.info("SIGUSR1 received, triggering retrain from traffic log...")
|
||||
log.info("SIGUSR1 received, triggering retrain from traffic log (background)...")
|
||||
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||
if self.ai_detector.retrain_from_log(db_path):
|
||||
self.ai_detector.retrain_from_log(db_path, background=True)
|
||||
self._last_retrain_time = time.time()
|
||||
log.info("SIGUSR1 retrain completed successfully")
|
||||
else:
|
||||
log.warning("SIGUSR1 retrain failed (falling back to collect mode)")
|
||||
self.ai_detector.request_retrain()
|
||||
|
||||
# ---- Traffic Logging (SQLite) ----
|
||||
|
||||
@@ -762,24 +876,20 @@ class DDoSDaemon:
|
||||
self._last_log_cleanup = time.time()
|
||||
|
||||
if self.ai_detector.is_learning:
|
||||
self.ai_detector.collect_sample(deltas_with_time)
|
||||
if len(self.ai_detector.training_data) % 100 == 0:
|
||||
log.debug("AI learning: %d samples collected",
|
||||
len(self.ai_detector.training_data))
|
||||
self.ai_detector.collect_sample(deltas_with_time, hour)
|
||||
total_samples = sum(len(v) for v in self.ai_detector.training_data.values())
|
||||
if total_samples % 100 == 0 and total_samples > 0:
|
||||
log.debug("AI learning: %d samples collected", total_samples)
|
||||
else:
|
||||
# Auto-retrain check
|
||||
# Auto-retrain check (background, no inference gap)
|
||||
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
|
||||
if time.time() - self._last_retrain_time >= retrain_interval:
|
||||
log.info("Auto-retrain triggered (interval=%ds)", retrain_interval)
|
||||
log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval)
|
||||
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||
if self.ai_detector.retrain_from_log(db_path):
|
||||
self._last_retrain_time = time.time()
|
||||
log.info("Auto-retrain completed successfully")
|
||||
else:
|
||||
log.warning("Auto-retrain failed, will retry next interval")
|
||||
self.ai_detector.retrain_from_log(db_path, background=True)
|
||||
self._last_retrain_time = time.time()
|
||||
|
||||
is_anomaly, score = self.ai_detector.predict(deltas_with_time)
|
||||
is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour)
|
||||
if is_anomaly:
|
||||
log.warning(
|
||||
"AI ANOMALY detected: score=%.4f deltas=%s",
|
||||
|
||||
Reference in New Issue
Block a user