diff --git a/lib/xdp_defense_daemon.py b/lib/xdp_defense_daemon.py index e12d6dc..cec6024 100755 --- a/lib/xdp_defense_daemon.py +++ b/lib/xdp_defense_daemon.py @@ -201,41 +201,63 @@ class EWMAAnalyzer: # ==================== AIDetector ==================== class AIDetector: - """Isolation Forest based anomaly detection on traffic features.""" + """Time-period Isolation Forest anomaly detection. + + Maintains 4 separate models for different time periods: + night (00-06), morning (06-12), afternoon (12-18), evening (18-24). + Each model learns the normal traffic pattern for its time window. + """ + + TIME_PERIODS = { + 'night': (0, 6), + 'morning': (6, 12), + 'afternoon': (12, 18), + 'evening': (18, 24), + } def __init__(self, ai_cfg): self.cfg = ai_cfg - self.model = None - self.scaler = None + self.models = {} # {period_name: {'model': ..., 'scaler': ...}} self.started_at = time.time() - self.training_data = [] + self.training_data = defaultdict(list) # {period_name: [samples]} self.is_learning = True self._retrain_requested = False + self._retrain_lock = threading.Lock() @property def enabled(self): return self.cfg.get('enabled', False) + @staticmethod + def get_period(hour): + """Map hour (0-24) to period name.""" + for name, (start, end) in AIDetector.TIME_PERIODS.items(): + if start <= hour < end: + return name + return 'night' # hour == 24 edge case + def request_retrain(self): self._retrain_requested = True - def collect_sample(self, features): - """Collect a feature sample during learning phase.""" + def collect_sample(self, features, hour): + """Collect a feature sample during learning phase, bucketed by period.""" if not self.enabled: return - self.training_data.append(features) + period = self.get_period(hour) + self.training_data[period].append(features) learning_dur = self.cfg.get('learning_duration', 259200) min_samples = self.cfg.get('min_samples', 1000) elapsed = time.time() - self.started_at + total = sum(len(v) for v in self.training_data.values()) - if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested: + if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested: self._train() self._retrain_requested = False def _train(self): - """Train the Isolation Forest model.""" + """Train per-period Isolation Forest models.""" try: from sklearn.ensemble import IsolationForest from sklearn.preprocessing import StandardScaler @@ -245,32 +267,63 @@ class AIDetector: self.cfg['enabled'] = False return - if len(self.training_data) < 10: - log.warning("Not enough training data (%d samples)", len(self.training_data)) + total = sum(len(v) for v in self.training_data.values()) + if total < 10: + log.warning("Not enough training data (%d samples)", total) return - log.info("Training AI model with %d samples...", len(self.training_data)) + log.info("Training AI models: %s", + {p: len(s) for p, s in self.training_data.items() if s}) try: - X = np.array(self.training_data) + new_models = {} + all_samples = [] - self.scaler = StandardScaler() - X_scaled = self.scaler.fit_transform(X) + for period, samples in self.training_data.items(): + if len(samples) < 10: + log.info("Period %s: %d samples (too few, skip)", period, len(samples)) + continue - self.model = IsolationForest( - n_estimators=self.cfg.get('n_estimators', 100), - contamination=self.cfg.get('contamination', 'auto'), - random_state=42, - ) - self.model.fit(X_scaled) + X = np.array(samples) + scaler = StandardScaler() + X_scaled = scaler.fit_transform(X) + + model = IsolationForest( + n_estimators=self.cfg.get('n_estimators', 100), + contamination=self.cfg.get('contamination', 'auto'), + random_state=42, + ) + model.fit(X_scaled) + new_models[period] = {'model': model, 'scaler': scaler} + all_samples.extend(samples) + log.info("Period %s: trained with %d samples", period, len(samples)) + + if not new_models: + log.warning("No period had enough data to train") + return + + # Atomic swap + self.models = new_models self.is_learning = False + # Save to disk (atomic) model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl') - with open(model_file, 'wb') as f: - pickle.dump({'model': self.model, 'scaler': self.scaler}, f) + tmp_model = model_file + '.tmp' + with open(tmp_model, 'wb') as f: + pickle.dump({ + 'format': 'period_models', + 'models': {p: {'model': m['model'], 'scaler': m['scaler']} + for p, m in new_models.items()}, + 'feature_count': 17, + }, f) + f.flush() + os.fsync(f.fileno()) + os.rename(tmp_model, model_file) + # Save training data CSV data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv') - with open(data_file, 'w', newline='') as f: + tmp_data = data_file + '.tmp' + with open(tmp_data, 'w', newline='') as f: writer = csv.writer(f) writer.writerow([ 'hour_sin', 'hour_cos', @@ -279,63 +332,97 @@ class AIDetector: 'small_pkt_count', 'large_pkt_count', 'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size' ]) - writer.writerows(self.training_data) + writer.writerows(all_samples) + f.flush() + os.fsync(f.fileno()) + os.rename(tmp_data, data_file) - log.info("AI model trained and saved to %s", model_file) + log.info("AI models trained and saved (%d periods)", len(new_models)) except Exception as e: log.error("AI training failed: %s", e) def load_model(self): - """Load a previously trained model. Check feature dimension compatibility.""" + """Load previously trained models. Handle old and new formats.""" model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl') if not os.path.exists(model_file): return False try: with open(model_file, 'rb') as f: data = pickle.load(f) - model = data['model'] - scaler = data['scaler'] - expected_features = 17 - if hasattr(scaler, 'n_features_in_') and scaler.n_features_in_ != expected_features: - log.warning( - "Model has %d features, expected %d. Switching to learning mode.", - scaler.n_features_in_, expected_features - ) - self.is_learning = True - return False + # New period_models format + if data.get('format') == 'period_models': + self.models = data['models'] + self.is_learning = False + periods = list(self.models.keys()) + log.info("AI models loaded: %s (%d features)", + periods, data.get('feature_count', '?')) + return True + + # Old single-model format → discard + log.warning("Old single-model format detected. Switching to learning mode.") + self.is_learning = True + return False - self.model = model - self.scaler = scaler - self.is_learning = False - log.info("AI model loaded from %s (%d features)", - model_file, getattr(scaler, 'n_features_in_', '?')) - return True except Exception as e: log.error("Failed to load AI model: %s", e) return False - def predict(self, features): - """Run anomaly detection. Returns (is_anomaly, score).""" - if not self.enabled or self.model is None: + def predict(self, features, hour): + """Run anomaly detection using the period-specific model. + Returns (is_anomaly, score). + """ + if not self.enabled or not self.models: + return False, 0.0 + + period = self.get_period(hour) + model_info = self.models.get(period) + + # Fallback: try adjacent periods + if model_info is None: + for p in self.models: + model_info = self.models[p] + break + if model_info is None: return False, 0.0 try: import numpy as np X = np.array([features]) - X_scaled = self.scaler.transform(X) - score = self.model.decision_function(X_scaled)[0] + X_scaled = model_info['scaler'].transform(X) + score = model_info['model'].decision_function(X_scaled)[0] threshold = self.cfg.get('anomaly_threshold', -0.3) return score < threshold, float(score) except Exception as e: log.error("AI prediction error: %s", e) return False, 0.0 - def retrain_from_log(self, db_path=None): - """Retrain the model from traffic_log.db data.""" + def retrain_from_log(self, db_path=None, background=False): + """Retrain models from traffic_log.db. Optionally in background thread.""" if db_path is None: db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') + + if background: + t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True) + t.start() + return True # started, not yet finished + + return self._do_retrain(db_path) + + def _do_retrain(self, db_path): + """Actual retrain logic (can run in background thread).""" + if not self._retrain_lock.acquire(blocking=False): + log.info("Retrain already in progress, skipping") + return False + + try: + return self._retrain_impl(db_path) + finally: + self._retrain_lock.release() + + def _retrain_impl(self, db_path): + """Load data from DB, filter outliers, train per-period models.""" if not os.path.exists(db_path): log.warning("Traffic log DB not found: %s", db_path) return False @@ -347,7 +434,7 @@ class AIDetector: try: conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True) cur = conn.execute( - 'SELECT hour_sin, hour_cos, total_packets, total_bytes, ' + 'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, ' 'tcp_syn_count, tcp_other_count, udp_count, icmp_count, ' 'other_proto_count, unique_ips_approx, small_pkt_count, ' 'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, ' @@ -355,25 +442,56 @@ class AIDetector: 'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp', (cutoff,) ) - samples = [list(row) for row in cur.fetchall()] - - if len(samples) < 10: - log.warning("Not enough recent samples for retrain (%d)", len(samples)) - return False - - log.info("Auto-retrain: loading %d samples from traffic log (window=%ds)", - len(samples), retrain_window) - self.training_data = samples - self._train() - return True - + rows = cur.fetchall() except Exception as e: - log.error("retrain_from_log failed: %s", e) + log.error("retrain_from_log DB read failed: %s", e) return False finally: if conn: conn.close() + if len(rows) < 10: + log.warning("Not enough recent samples for retrain (%d)", len(rows)) + return False + + # Split by period: row[0]=hour, row[1:]=17 features + period_data = defaultdict(list) + for row in rows: + hour = row[0] + features = list(row[1:]) # 17 features + period = self.get_period(hour) + period_data[period].append(features) + + # Filter outliers using existing models (item #2) + try: + import numpy as np + except ImportError: + log.error("numpy not available for outlier filtering") + return False + + filtered_count = 0 + for period, samples in period_data.items(): + model_info = self.models.get(period) + if model_info is None or len(samples) < 10: + continue + X = np.array(samples) + X_scaled = model_info['scaler'].transform(X) + scores = model_info['model'].decision_function(X_scaled) + clean = [s for s, sc in zip(samples, scores) if sc >= -0.5] + removed = len(samples) - len(clean) + if removed > 0: + filtered_count += removed + period_data[period] = clean + + total = sum(len(v) for v in period_data.values()) + log.info("Retrain: %d samples loaded, %d outliers filtered, %s", + len(rows), filtered_count, + {p: len(s) for p, s in period_data.items() if s}) + + self.training_data = period_data + self._train() + return not self.is_learning + # ==================== ProfileManager ==================== @@ -549,14 +667,10 @@ class DDoSDaemon: self._stop_event.set() def _handle_sigusr1(self, signum, frame): - log.info("SIGUSR1 received, triggering retrain from traffic log...") + log.info("SIGUSR1 received, triggering retrain from traffic log (background)...") db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') - if self.ai_detector.retrain_from_log(db_path): - self._last_retrain_time = time.time() - log.info("SIGUSR1 retrain completed successfully") - else: - log.warning("SIGUSR1 retrain failed (falling back to collect mode)") - self.ai_detector.request_retrain() + self.ai_detector.retrain_from_log(db_path, background=True) + self._last_retrain_time = time.time() # ---- Traffic Logging (SQLite) ---- @@ -762,24 +876,20 @@ class DDoSDaemon: self._last_log_cleanup = time.time() if self.ai_detector.is_learning: - self.ai_detector.collect_sample(deltas_with_time) - if len(self.ai_detector.training_data) % 100 == 0: - log.debug("AI learning: %d samples collected", - len(self.ai_detector.training_data)) + self.ai_detector.collect_sample(deltas_with_time, hour) + total_samples = sum(len(v) for v in self.ai_detector.training_data.values()) + if total_samples % 100 == 0 and total_samples > 0: + log.debug("AI learning: %d samples collected", total_samples) else: - # Auto-retrain check + # Auto-retrain check (background, no inference gap) retrain_interval = self.cfg['ai'].get('retrain_interval', 86400) if time.time() - self._last_retrain_time >= retrain_interval: - log.info("Auto-retrain triggered (interval=%ds)", retrain_interval) + log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval) db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') - if self.ai_detector.retrain_from_log(db_path): - self._last_retrain_time = time.time() - log.info("Auto-retrain completed successfully") - else: - log.warning("Auto-retrain failed, will retry next interval") - self._last_retrain_time = time.time() + self.ai_detector.retrain_from_log(db_path, background=True) + self._last_retrain_time = time.time() - is_anomaly, score = self.ai_detector.predict(deltas_with_time) + is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour) if is_anomaly: log.warning( "AI ANOMALY detected: score=%.4f deltas=%s",