Add time-period models, outlier filtering, and background retrain

- Split single IsolationForest into 4 period models (night/morning/afternoon/evening)
- Each period trained independently on its time window data
- Filter attack samples during retrain using existing model scores (threshold -0.5)
- Retrain runs in background thread with lock, inference continues uninterrupted
- New pickle format 'period_models' with automatic old format detection
- SIGUSR1 and auto-retrain both use background mode

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-07 10:43:38 +09:00
parent 3d1e353b1a
commit a28d7fd646

View File

@@ -201,41 +201,63 @@ class EWMAAnalyzer:
# ==================== AIDetector ====================
class AIDetector:
"""Isolation Forest based anomaly detection on traffic features."""
"""Time-period Isolation Forest anomaly detection.
Maintains 4 separate models for different time periods:
night (00-06), morning (06-12), afternoon (12-18), evening (18-24).
Each model learns the normal traffic pattern for its time window.
"""
TIME_PERIODS = {
'night': (0, 6),
'morning': (6, 12),
'afternoon': (12, 18),
'evening': (18, 24),
}
def __init__(self, ai_cfg):
self.cfg = ai_cfg
self.model = None
self.scaler = None
self.models = {} # {period_name: {'model': ..., 'scaler': ...}}
self.started_at = time.time()
self.training_data = []
self.training_data = defaultdict(list) # {period_name: [samples]}
self.is_learning = True
self._retrain_requested = False
self._retrain_lock = threading.Lock()
@property
def enabled(self):
return self.cfg.get('enabled', False)
@staticmethod
def get_period(hour):
"""Map hour (0-24) to period name."""
for name, (start, end) in AIDetector.TIME_PERIODS.items():
if start <= hour < end:
return name
return 'night' # hour == 24 edge case
def request_retrain(self):
self._retrain_requested = True
def collect_sample(self, features):
"""Collect a feature sample during learning phase."""
def collect_sample(self, features, hour):
"""Collect a feature sample during learning phase, bucketed by period."""
if not self.enabled:
return
self.training_data.append(features)
period = self.get_period(hour)
self.training_data[period].append(features)
learning_dur = self.cfg.get('learning_duration', 259200)
min_samples = self.cfg.get('min_samples', 1000)
elapsed = time.time() - self.started_at
total = sum(len(v) for v in self.training_data.values())
if (elapsed >= learning_dur and len(self.training_data) >= min_samples) or self._retrain_requested:
if (elapsed >= learning_dur and total >= min_samples) or self._retrain_requested:
self._train()
self._retrain_requested = False
def _train(self):
"""Train the Isolation Forest model."""
"""Train per-period Isolation Forest models."""
try:
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
@@ -245,32 +267,63 @@ class AIDetector:
self.cfg['enabled'] = False
return
if len(self.training_data) < 10:
log.warning("Not enough training data (%d samples)", len(self.training_data))
total = sum(len(v) for v in self.training_data.values())
if total < 10:
log.warning("Not enough training data (%d samples)", total)
return
log.info("Training AI model with %d samples...", len(self.training_data))
log.info("Training AI models: %s",
{p: len(s) for p, s in self.training_data.items() if s})
try:
X = np.array(self.training_data)
new_models = {}
all_samples = []
self.scaler = StandardScaler()
X_scaled = self.scaler.fit_transform(X)
for period, samples in self.training_data.items():
if len(samples) < 10:
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
continue
self.model = IsolationForest(
n_estimators=self.cfg.get('n_estimators', 100),
contamination=self.cfg.get('contamination', 'auto'),
random_state=42,
)
self.model.fit(X_scaled)
X = np.array(samples)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
model = IsolationForest(
n_estimators=self.cfg.get('n_estimators', 100),
contamination=self.cfg.get('contamination', 'auto'),
random_state=42,
)
model.fit(X_scaled)
new_models[period] = {'model': model, 'scaler': scaler}
all_samples.extend(samples)
log.info("Period %s: trained with %d samples", period, len(samples))
if not new_models:
log.warning("No period had enough data to train")
return
# Atomic swap
self.models = new_models
self.is_learning = False
# Save to disk (atomic)
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
with open(model_file, 'wb') as f:
pickle.dump({'model': self.model, 'scaler': self.scaler}, f)
tmp_model = model_file + '.tmp'
with open(tmp_model, 'wb') as f:
pickle.dump({
'format': 'period_models',
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
for p, m in new_models.items()},
'feature_count': 17,
}, f)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_model, model_file)
# Save training data CSV
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
with open(data_file, 'w', newline='') as f:
tmp_data = data_file + '.tmp'
with open(tmp_data, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow([
'hour_sin', 'hour_cos',
@@ -279,63 +332,97 @@ class AIDetector:
'small_pkt_count', 'large_pkt_count',
'syn_ratio', 'udp_ratio', 'icmp_ratio', 'small_pkt_ratio', 'avg_pkt_size'
])
writer.writerows(self.training_data)
writer.writerows(all_samples)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_data, data_file)
log.info("AI model trained and saved to %s", model_file)
log.info("AI models trained and saved (%d periods)", len(new_models))
except Exception as e:
log.error("AI training failed: %s", e)
def load_model(self):
"""Load a previously trained model. Check feature dimension compatibility."""
"""Load previously trained models. Handle old and new formats."""
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
if not os.path.exists(model_file):
return False
try:
with open(model_file, 'rb') as f:
data = pickle.load(f)
model = data['model']
scaler = data['scaler']
expected_features = 17
if hasattr(scaler, 'n_features_in_') and scaler.n_features_in_ != expected_features:
log.warning(
"Model has %d features, expected %d. Switching to learning mode.",
scaler.n_features_in_, expected_features
)
self.is_learning = True
return False
# New period_models format
if data.get('format') == 'period_models':
self.models = data['models']
self.is_learning = False
periods = list(self.models.keys())
log.info("AI models loaded: %s (%d features)",
periods, data.get('feature_count', '?'))
return True
# Old single-model format → discard
log.warning("Old single-model format detected. Switching to learning mode.")
self.is_learning = True
return False
self.model = model
self.scaler = scaler
self.is_learning = False
log.info("AI model loaded from %s (%d features)",
model_file, getattr(scaler, 'n_features_in_', '?'))
return True
except Exception as e:
log.error("Failed to load AI model: %s", e)
return False
def predict(self, features):
"""Run anomaly detection. Returns (is_anomaly, score)."""
if not self.enabled or self.model is None:
def predict(self, features, hour):
"""Run anomaly detection using the period-specific model.
Returns (is_anomaly, score).
"""
if not self.enabled or not self.models:
return False, 0.0
period = self.get_period(hour)
model_info = self.models.get(period)
# Fallback: try adjacent periods
if model_info is None:
for p in self.models:
model_info = self.models[p]
break
if model_info is None:
return False, 0.0
try:
import numpy as np
X = np.array([features])
X_scaled = self.scaler.transform(X)
score = self.model.decision_function(X_scaled)[0]
X_scaled = model_info['scaler'].transform(X)
score = model_info['model'].decision_function(X_scaled)[0]
threshold = self.cfg.get('anomaly_threshold', -0.3)
return score < threshold, float(score)
except Exception as e:
log.error("AI prediction error: %s", e)
return False, 0.0
def retrain_from_log(self, db_path=None):
"""Retrain the model from traffic_log.db data."""
def retrain_from_log(self, db_path=None, background=False):
"""Retrain models from traffic_log.db. Optionally in background thread."""
if db_path is None:
db_path = self.cfg.get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
if background:
t = threading.Thread(target=self._do_retrain, args=(db_path,), daemon=True)
t.start()
return True # started, not yet finished
return self._do_retrain(db_path)
def _do_retrain(self, db_path):
"""Actual retrain logic (can run in background thread)."""
if not self._retrain_lock.acquire(blocking=False):
log.info("Retrain already in progress, skipping")
return False
try:
return self._retrain_impl(db_path)
finally:
self._retrain_lock.release()
def _retrain_impl(self, db_path):
"""Load data from DB, filter outliers, train per-period models."""
if not os.path.exists(db_path):
log.warning("Traffic log DB not found: %s", db_path)
return False
@@ -347,7 +434,7 @@ class AIDetector:
try:
conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
cur = conn.execute(
'SELECT hour_sin, hour_cos, total_packets, total_bytes, '
'SELECT hour, hour_sin, hour_cos, total_packets, total_bytes, '
'tcp_syn_count, tcp_other_count, udp_count, icmp_count, '
'other_proto_count, unique_ips_approx, small_pkt_count, '
'large_pkt_count, syn_ratio, udp_ratio, icmp_ratio, '
@@ -355,25 +442,56 @@ class AIDetector:
'FROM traffic_samples WHERE timestamp >= ? ORDER BY timestamp',
(cutoff,)
)
samples = [list(row) for row in cur.fetchall()]
if len(samples) < 10:
log.warning("Not enough recent samples for retrain (%d)", len(samples))
return False
log.info("Auto-retrain: loading %d samples from traffic log (window=%ds)",
len(samples), retrain_window)
self.training_data = samples
self._train()
return True
rows = cur.fetchall()
except Exception as e:
log.error("retrain_from_log failed: %s", e)
log.error("retrain_from_log DB read failed: %s", e)
return False
finally:
if conn:
conn.close()
if len(rows) < 10:
log.warning("Not enough recent samples for retrain (%d)", len(rows))
return False
# Split by period: row[0]=hour, row[1:]=17 features
period_data = defaultdict(list)
for row in rows:
hour = row[0]
features = list(row[1:]) # 17 features
period = self.get_period(hour)
period_data[period].append(features)
# Filter outliers using existing models (item #2)
try:
import numpy as np
except ImportError:
log.error("numpy not available for outlier filtering")
return False
filtered_count = 0
for period, samples in period_data.items():
model_info = self.models.get(period)
if model_info is None or len(samples) < 10:
continue
X = np.array(samples)
X_scaled = model_info['scaler'].transform(X)
scores = model_info['model'].decision_function(X_scaled)
clean = [s for s, sc in zip(samples, scores) if sc >= -0.5]
removed = len(samples) - len(clean)
if removed > 0:
filtered_count += removed
period_data[period] = clean
total = sum(len(v) for v in period_data.values())
log.info("Retrain: %d samples loaded, %d outliers filtered, %s",
len(rows), filtered_count,
{p: len(s) for p, s in period_data.items() if s})
self.training_data = period_data
self._train()
return not self.is_learning
# ==================== ProfileManager ====================
@@ -549,14 +667,10 @@ class DDoSDaemon:
self._stop_event.set()
def _handle_sigusr1(self, signum, frame):
log.info("SIGUSR1 received, triggering retrain from traffic log...")
log.info("SIGUSR1 received, triggering retrain from traffic log (background)...")
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
if self.ai_detector.retrain_from_log(db_path):
self._last_retrain_time = time.time()
log.info("SIGUSR1 retrain completed successfully")
else:
log.warning("SIGUSR1 retrain failed (falling back to collect mode)")
self.ai_detector.request_retrain()
self.ai_detector.retrain_from_log(db_path, background=True)
self._last_retrain_time = time.time()
# ---- Traffic Logging (SQLite) ----
@@ -762,24 +876,20 @@ class DDoSDaemon:
self._last_log_cleanup = time.time()
if self.ai_detector.is_learning:
self.ai_detector.collect_sample(deltas_with_time)
if len(self.ai_detector.training_data) % 100 == 0:
log.debug("AI learning: %d samples collected",
len(self.ai_detector.training_data))
self.ai_detector.collect_sample(deltas_with_time, hour)
total_samples = sum(len(v) for v in self.ai_detector.training_data.values())
if total_samples % 100 == 0 and total_samples > 0:
log.debug("AI learning: %d samples collected", total_samples)
else:
# Auto-retrain check
# Auto-retrain check (background, no inference gap)
retrain_interval = self.cfg['ai'].get('retrain_interval', 86400)
if time.time() - self._last_retrain_time >= retrain_interval:
log.info("Auto-retrain triggered (interval=%ds)", retrain_interval)
log.info("Auto-retrain triggered (interval=%ds, background)", retrain_interval)
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
if self.ai_detector.retrain_from_log(db_path):
self._last_retrain_time = time.time()
log.info("Auto-retrain completed successfully")
else:
log.warning("Auto-retrain failed, will retry next interval")
self._last_retrain_time = time.time()
self.ai_detector.retrain_from_log(db_path, background=True)
self._last_retrain_time = time.time()
is_anomaly, score = self.ai_detector.predict(deltas_with_time)
is_anomaly, score = self.ai_detector.predict(deltas_with_time, hour)
if is_anomaly:
log.warning(
"AI ANOMALY detected: score=%.4f deltas=%s",