Add time-period models, outlier filtering, and background retrain
- Split single IsolationForest into 4 period models (night/morning/afternoon/evening) - Each period trained independently on its time window data - Filter attack samples during retrain using existing model scores (threshold -0.5) - Retrain runs in background thread with lock, inference continues uninterrupted - New pickle format 'period_models' with automatic old format detection - SIGUSR1 and auto-retrain both use background mode Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -201,41 +201,63 @@ class EWMAAnalyzer:
|
|||||||
# ==================== AIDetector ====================
|
# ==================== AIDetector ====================
|
||||||
|
|
||||||
class AIDetector:
|
class AIDetector:
|
||||||
"""Isolation Forest based anomaly detection on traffic features."""
|
"""Time-period Isolation Forest anomaly detection.
|
||||||
|
|
||||||
|
Maintains 4 separate models for different time periods:
|
||||||
|
night (00-06), morning (06-12), afternoon (12-18), evening (18-24).
|
||||||
|
Each model learns the normal traffic pattern for its time window.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TIME_PERIODS = {
|
||||||
|
'night': (0, 6),
|
||||||
|
'morning': (6, 12),
|
||||||
|
'afternoon': (12, 18),
|
||||||
|
'evening': (18, 24),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, ai_cfg):
|
def __init__(self, ai_cfg):
|
||||||
self.cfg = ai_cfg
|
self.cfg = ai_cfg
|
||||||
self.model = None
|
self.models = {} # {period_name: {'model': ..., 'scaler': ...}}
|
||||||
self.scaler = None
|
|
||||||
self.started_at = time.time()
|
self.started_at = time.time()
|
||||||
self.training_data = []
|
self.training_data = defaultdict(list) # {period_name: [samples]}
|
||||||
self.is_learning = True
|
self.is_learning = True
|
||||||
self._retrain_requested = False
|
self._retrain_requested = False
|
||||||
|
self._retrain_lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self):
|
def enabled(self):
|
||||||
return self.cfg.get('enabled', False)
|
return self.cfg.get('enabled', False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_period(hour):
|
||||||
|
"""Map hour (0-24) to period name."""
|
||||||
|
for name, (start, end) in AIDetector.TIME_PERIODS.items():
|
||||||
|
if start <= hour < end:
|
||||||
|
return name
|
||||||
|
return 'night' # hour == 24 edge case
|
||||||
|
|
||||||
def request_retrain(self):
|
def request_retrain(self):
|
||||||
self._retrain_requested = True
|
self._retrain_requested = True
|
||||||
|
|
||||||
def collect_sample(self, features):
|
def collect_sample(self, features, hour):
|
||||||
"""Collect a feature sample during learning phase."""
|
"""Collect a feature sample during learning phase, bucketed by period."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.training_data.append(features)
|
period = self.get_period(hour)
|
||||||
|
self.training_data[period].append(features)
|
||||||
|
|
||||||
learning_dur = self.cfg.get('learning_duration', 259200)
|
learning_dur = self.cfg.get('learning_duration', 259200)
|
||||||
min_samples = self.cfg.get('min_samples', 1000)
|
min_samples = self.cfg.get('min_samples', 1000)
|
||||||
elapsed = time.time() - self.started_at
|
elapsed = time.time() - self.started_at
|
||||||
|
total = sum(len(v) for v in self.training_data.values())
|
||||||
|
|
||||||
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
|
if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested:
|
||||||
self._train()
|
self._train()
|
||||||
self._retrain_requested = False
|
self._retrain_requested = False
|
||||||
|
|
||||||
def _train(self):
|
def _train(self):
|
||||||
"""Train the Isolation Forest model."""
|
"""Train per-period Isolation Forest models."""
|
||||||
try:
|
try:
|
||||||
from sklearn.ensemble import IsolationForest
|
from sklearn.ensemble import IsolationForest
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
@@ -245,32 +267,63 @@ class AIDetector:
|
|||||||
self.cfg['enabled'] = False
|
self.cfg['enabled'] = False
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self.training_data) < 10:
|
total = sum(len(v) for v in self.training_data.values())
|
||||||
log.warning("Not enough training data (%d samples)", len(self.training_data))
|
if total < 10:
|
||||||
|
log.warning("Not enough training data (%d samples)", total)
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info("Training AI model with %d samples...", len(self.training_data))
|
log.info("Training AI models: %s",
|
||||||
|
{p: len(s) for p, s in self.training_data.items() if s})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
X = np.array(self.training_data)
|
new_models = {}
|
||||||
|
all_samples = []
|
||||||
|
|
||||||
self.scaler = StandardScaler()
|
for period, samples in self.training_data.items():
|
||||||
X_scaled = self.scaler.fit_transform(X)
|
if len(samples) < 10:
|
||||||
|
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
|
||||||
|
continue
|
||||||
|
|
||||||
self.model = IsolationForest(
|
X = np.array(samples)
|
||||||
|
scaler = StandardScaler()
|
||||||
|
X_scaled = scaler.fit_transform(X)
|
||||||
|
|
||||||
|
model = IsolationForest(
|
||||||
n_estimators=self.cfg.get('n_estimators', 100),
|
n_estimators=self.cfg.get('n_estimators', 100),
|
||||||
contamination=self.cfg.get('contamination', 'auto'),
|
contamination=self.cfg.get('contamination', 'auto'),
|
||||||
random_state=42,
|
random_state=42,
|
||||||
)
|
)
|
||||||
self.model.fit(X_scaled)
|
model.fit(X_scaled)
|
||||||
|
new_models[period] = {'model': model, 'scaler': scaler}
|
||||||
|
all_samples.extend(samples)
|
||||||
|
log.info("Period %s: trained with %d samples", period, len(samples))
|
||||||
|
|
||||||
|
if not new_models:
|
||||||
|
log.warning("No period had enough data to train")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Atomic swap
|
||||||
|
self.models = new_models
|
||||||
self.is_learning = False
|
self.is_learning = False
|
||||||
|
|
||||||
|
# Save to disk (atomic)
|
||||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||||
with open(model_file, 'wb') as f:
|
tmp_model = model_file + '.tmp'
|
||||||
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
|
with open(tmp_model, 'wb') as f:
|
||||||
|
pickle.dump({
|
||||||
|
'format': 'period_models',
|
||||||
|
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
|
||||||
|
for p, m in new_models.items()},
|
||||||
|
'feature_count': 17,
|
||||||
|
}, f)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
os.rename(tmp_model, model_file)
|
||||||
|
|
||||||
|
# Save training data CSV
|
||||||
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
||||||
with open(data_file, 'w', newline='') as f:
|
tmp_data = data_file + '.tmp'
|
||||||
|
with open(tmp_data, 'w', newline='') as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
writer.writerow([
|
writer.writerow([
|
||||||
'hour_sin', 'hour_cos',
|
'hour_sin', 'hour_cos',
|
||||||
@@ -279,63 +332,97 @@ class AIDetector:
|
|||||||
'small_pkt_count', 'large_pkt_count',
|
'small_pkt_count', 'large_pkt_count',
|
||||||
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
|
||||||
])
|
])
|
||||||
writer.writerows(self.training_data)
|
writer.writerows(all_samples)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
os.rename(tmp_data, data_file)
|
||||||
|
|
||||||
log.info("AI model trained and saved to %s", model_file)
|
log.info("AI models trained and saved (%d periods)", len(new_models))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("AI training failed: %s", e)
|
log.error("AI training failed: %s", e)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load a previously trained model. Check feature dimension compatibility."""
|
"""Load previously trained models. Handle old and new formats."""
|
||||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||||
if not os.path.exists(model_file):
|
if not os.path.exists(model_file):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
with open(model_file, 'rb') as f:
|
with open(model_file, 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
model = data['model']
|
|
||||||
scaler = data['scaler']
|
|
||||||
|
|
||||||
expected_features = 17
|
# New period_models format
|
||||||
if hasattr(scaler, 'n_features_in_') and scaler.n_features_in_ != expected_features:
|
if data.get('format') == 'period_models':
|
||||||
log.warning(
|
self.models = data['models']
|
||||||
"Model has %d features, expected %d. Switching to learning mode.",
|
self.is_learning = False
|
||||||
scaler.n_features_in_, expected_features
|
periods = list(self.models.keys())
|
||||||
)
|
log.info("AI models loaded: %s (%d features)",
|
||||||
|
periods, data.get('feature_count', '?'))
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Old single-model format → discard
|
||||||
|
log.warning("Old single-model format detected. Switching to learning mode.")
|
||||||
self.is_learning = True
|
self.is_learning = True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.scaler = scaler
|
|
||||||
self.is_learning = False
|
|
||||||
log.info("AI model loaded from %s (%d features)",
|
|
||||||
model_file, getattr(scaler, 'n_features_in_', '?'))
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("Failed to load AI model: %s", e)
|
log.error("Failed to load AI model: %s", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def predict(self, features):
|
def predict(self, features, hour):
|
||||||
"""Run anomaly detection. Returns (is_anomaly, score)."""
|
"""Run anomaly detection using the period-specific model.
|
||||||
if not self.enabled or self.model is None:
|
Returns (is_anomaly, score).
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.models:
|
||||||
|
return False, 0.0
|
||||||
|
|
||||||
|
period = self.get_period(hour)
|
||||||
|
model_info = self.models.get(period)
|
||||||
|
|
||||||
|
# Fallback: try adjacent periods
|
||||||
|
if model_info is None:
|
||||||
|
for p in self.models:
|
||||||
|
model_info = self.models[p]
|
||||||
|
break
|
||||||
|
if model_info is None:
|
||||||
return False, 0.0
|
return False, 0.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
X = np.array([features])
|
X = np.array([features])
|
||||||
X_scaled = self.scaler.transform(X)
|
X_scaled = model_info['scaler'].transform(X)
|
||||||
score = self.model.decision_function(X_scaled)[0]
|
score = model_info['model'].decision_function(X_scaled)[0]
|
||||||
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
threshold = self.cfg.get('anomaly_threshold', -0.3)
|
||||||
return score < threshold, float(score)
|
return score < threshold, float(score)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("AI prediction error: %s", e)
|
log.error("AI prediction error: %s", e)
|
||||||
return False, 0.0
|
return False, 0.0
|
||||||
|
|
||||||
def retrain_from_log(self, db_path=None):
|
def retrain_from_log(self, db_path=None, background=False):
|
||||||
"""Retrain the model from traffic_log.db data."""
|
"""Retrain models from traffic_log.db. Optionally in background thread."""
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||||
|
|
||||||
|
if background:
|
||||||
|
t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True)
|
||||||
|
t.start()
|
||||||
|
return True # started, not yet finished
|
||||||
|
|
||||||
|
return self._do_retrain(db_path)
|
||||||
|
|
||||||
|
def _do_retrain(self, db_path):
|
||||||
|
"""Actual retrain logic (can run in background thread)."""
|
||||||
|
if not self._retrain_lock.acquire(blocking=False):
|
||||||
|
log.info("Retrain already in progress, skipping")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._retrain_impl(db_path)
|
||||||
|
finally:
|
||||||
|
self._retrain_lock.release()
|
||||||
|
|
||||||
|
def _retrain_impl(self, db_path):
|
||||||
|
"""Load data from DB, filter outliers, train per-period models."""
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
log.warning("Traffic log DB not found: %s", db_path)
|
log.warning("Traffic log DB not found: %s", db_path)
|
||||||
return False
|
return False
|
||||||
@@ -347,7 +434,7 @@ class AIDetector:
|
|||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
|
conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
|
||||||
cur = conn.execute(
|
cur = conn.execute(
|
||||||
'SELECT hour_sin, hour_cos, total_packets, total_bytes, '
|
'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, '
|
||||||
'tcp_syn_count, tcp_other_count, udp_count, icmp_count, '
|
'tcp_syn_count, tcp_other_count, udp_count, icmp_count, '
|
||||||
'other_proto_count, unique_ips_approx, small_pkt_count, '
|
'other_proto_count, unique_ips_approx, small_pkt_count, '
|
||||||
'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, '
|
'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, '
|
||||||
@@ -355,25 +442,56 @@ class AIDetector:
|
|||||||
'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp',
|
'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp',
|
||||||
(cutoff,)
|
(cutoff,)
|
||||||
)
|
)
|
||||||
samples = [list(row) for row in cur.fetchall()]
|
rows = cur.fetchall()
|
||||||
|
|
||||||
if len(samples) < 10:
|
|
||||||
log.warning("Not enough recent samples for retrain (%d)", len(samples))
|
|
||||||
return False
|
|
||||||
|
|
||||||
log.info("Auto-retrain: loading %d samples from traffic log (window=%ds)",
|
|
||||||
len(samples), retrain_window)
|
|
||||||
self.training_data = samples
|
|
||||||
self._train()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("retrain_from_log failed: %s", e)
|
log.error("retrain_from_log DB read failed: %s", e)
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if conn:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
if len(rows) < 10:
|
||||||
|
log.warning("Not enough recent samples for retrain (%d)", len(rows))
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Split by period: row[0]=hour, row[1:]=17 features
|
||||||
|
period_data = defaultdict(list)
|
||||||
|
for row in rows:
|
||||||
|
hour = row[0]
|
||||||
|
features = list(row[1:]) # 17 features
|
||||||
|
period = self.get_period(hour)
|
||||||
|
period_data[period].append(features)
|
||||||
|
|
||||||
|
# Filter outliers using existing models (item #2)
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
log.error("numpy not available for outlier filtering")
|
||||||
|
return False
|
||||||
|
|
||||||
|
filtered_count = 0
|
||||||
|
for period, samples in period_data.items():
|
||||||
|
model_info = self.models.get(period)
|
||||||
|
if model_info is None or len(samples) < 10:
|
||||||
|
continue
|
||||||
|
X = np.array(samples)
|
||||||
|
X_scaled = model_info['scaler'].transform(X)
|
||||||
|
scores = model_info['model'].decision_function(X_scaled)
|
||||||
|
clean = [s for s, sc in zip(samples, scores) if sc >= -0.5]
|
||||||
|
removed = len(samples) - len(clean)
|
||||||
|
if removed > 0:
|
||||||
|
filtered_count += removed
|
||||||
|
period_data[period] = clean
|
||||||
|
|
||||||
|
total = sum(len(v) for v in period_data.values())
|
||||||
|
log.info("Retrain: %d samples loaded, %d outliers filtered, %s",
|
||||||
|
len(rows), filtered_count,
|
||||||
|
{p: len(s) for p, s in period_data.items() if s})
|
||||||
|
|
||||||
|
self.training_data = period_data
|
||||||
|
self._train()
|
||||||
|
return not self.is_learning
|
||||||
|
|
||||||
|
|
||||||
# ==================== ProfileManager ====================
|
# ==================== ProfileManager ====================
|
||||||
|
|
||||||
@@ -549,14 +667,10 @@ class DDoSDaemon:
|
|||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
|
|
||||||
def _handle_sigusr1(self, signum, frame):
|
def _handle_sigusr1(self, signum, frame):
|
||||||
log.info("SIGUSR1 received, triggering retrain from traffic log...")
|
log.info("SIGUSR1 received, triggering retrain from traffic log (background)...")
|
||||||
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||||
if self.ai_detector.retrain_from_log(db_path):
|
self.ai_detector.retrain_from_log(db_path, background=True)
|
||||||
self._last_retrain_time = time.time()
|
self._last_retrain_time = time.time()
|
||||||
log.info("SIGUSR1 retrain completed successfully")
|
|
||||||
else:
|
|
||||||
log.warning("SIGUSR1 retrain failed (falling back to collect mode)")
|
|
||||||
self.ai_detector.request_retrain()
|
|
||||||
|
|
||||||
# ---- Traffic Logging (SQLite) ----
|
# ---- Traffic Logging (SQLite) ----
|
||||||
|
|
||||||
@@ -762,24 +876,20 @@ class DDoSDaemon:
|
|||||||
self._last_log_cleanup = time.time()
|
self._last_log_cleanup = time.time()
|
||||||
|
|
||||||
if self.ai_detector.is_learning:
|
if self.ai_detector.is_learning:
|
||||||
self.ai_detector.collect_sample(deltas_with_time)
|
self.ai_detector.collect_sample(deltas_with_time, hour)
|
||||||
if len(self.ai_detector.training_data) % 100 == 0:
|
total_samples = sum(len(v) for v in self.ai_detector.training_data.values())
|
||||||
log.debug("AI learning: %d samples collected",
|
if total_samples % 100 == 0 and total_samples > 0:
|
||||||
len(self.ai_detector.training_data))
|
log.debug("AI learning: %d samples collected", total_samples)
|
||||||
else:
|
else:
|
||||||
# Auto-retrain check
|
# Auto-retrain check (background, no inference gap)
|
||||||
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
|
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
|
||||||
if time.time() - self._last_retrain_time >= retrain_interval:
|
if time.time() - self._last_retrain_time >= retrain_interval:
|
||||||
log.info("Auto-retrain triggered (interval=%ds)", retrain_interval)
|
log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval)
|
||||||
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||||
if self.ai_detector.retrain_from_log(db_path):
|
self.ai_detector.retrain_from_log(db_path, background=True)
|
||||||
self._last_retrain_time = time.time()
|
|
||||||
log.info("Auto-retrain completed successfully")
|
|
||||||
else:
|
|
||||||
log.warning("Auto-retrain failed, will retry next interval")
|
|
||||||
self._last_retrain_time = time.time()
|
self._last_retrain_time = time.time()
|
||||||
|
|
||||||
is_anomaly, score = self.ai_detector.predict(deltas_with_time)
|
is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour)
|
||||||
if is_anomaly:
|
if is_anomaly:
|
||||||
log.warning(
|
log.warning(
|
||||||
"AI ANOMALY detected: score=%.4f deltas=%s",
|
"AI ANOMALY detected: score=%.4f deltas=%s",
|
||||||
|
|||||||
Reference in New Issue
Block a user