Fix HIGH severity security and thread-safety issues

Daemon fixes:
- Add _db_lock for thread-safe SQLite access
- Atomic SIGHUP config swap (build all values before applying)
- Check world-writable permission before loading pickle model
- Write model files with 0o600 permissions via os.open
- Module-level xdp_common import with fatal exit on failure
- Close traffic DB on shutdown
- Add period_data parameter to _train() to avoid race condition

CLI fixes:
- Replace $COMMON_PY variable with hardcoded 'xdp_common'
- Pass CONFIG_FILE via sys.argv instead of string interpolation
- Add key_hex regex validation before all bpftool commands
- Switch sanitize_input from denylist to strict allowlist

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-07 11:22:55 +09:00
parent a6519fd664
commit 2c29eab991
2 changed files with 125 additions and 86 deletions

View File

@@ -14,7 +14,6 @@ COUNTRY_DIR="/etc/xdp-blocker/countries"
GEOIP_DB="/usr/share/GeoIP/GeoLite2-Country.mmdb" GEOIP_DB="/usr/share/GeoIP/GeoLite2-Country.mmdb"
CITY_DB="/usr/share/GeoIP/GeoLite2-City.mmdb" CITY_DB="/usr/share/GeoIP/GeoLite2-City.mmdb"
ASN_DB="/usr/share/GeoIP/GeoLite2-ASN.mmdb" ASN_DB="/usr/share/GeoIP/GeoLite2-ASN.mmdb"
COMMON_PY="xdp_common"
# Ensure Python can find xdp_common.py (installed to /usr/local/bin) # Ensure Python can find xdp_common.py (installed to /usr/local/bin)
export PYTHONPATH="/usr/local/bin:${PYTHONPATH:-}" export PYTHONPATH="/usr/local/bin:${PYTHONPATH:-}"
@@ -30,10 +29,11 @@ log_ok() { echo -e "${GREEN}[OK]${NC} $1"; logger -t xdp-defense "OK: $1" 2>/d
log_err() { echo -e "${RED}[ERROR]${NC} $1" >&2; logger -t xdp-defense -p user.err "ERROR: $1" 2>/dev/null || true; } log_err() { echo -e "${RED}[ERROR]${NC} $1" >&2; logger -t xdp-defense -p user.err "ERROR: $1" 2>/dev/null || true; }
log_info() { echo -e "${CYAN}[INFO]${NC} $1"; logger -t xdp-defense "INFO: $1" 2>/dev/null || true; } log_info() { echo -e "${CYAN}[INFO]${NC} $1"; logger -t xdp-defense "INFO: $1" 2>/dev/null || true; }
# Sanitize input for safe embedding in Python strings (reject dangerous chars) # Sanitize input - strict allowlist for IP addresses, CIDR notation, preset names
sanitize_input() { sanitize_input() {
local val="$1" local val="$1"
if [[ "$val" =~ [\'\"\;\$\`\\] ]]; then # Allow only alphanumeric, dots, colons, slashes, hyphens, underscores
if [[ ! "$val" =~ ^[a-zA-Z0-9._:/\ -]+$ ]]; then
log_err "Invalid characters in input: $val" log_err "Invalid characters in input: $val"
return 1 return 1
fi fi
@@ -44,10 +44,10 @@ get_iface() {
if [ -f "$CONFIG_FILE" ]; then if [ -f "$CONFIG_FILE" ]; then
local iface local iface
iface=$(python3 -c " iface=$(python3 -c "
import yaml import yaml, sys
with open('$CONFIG_FILE') as f: with open(sys.argv[1]) as f:
print(yaml.safe_load(f).get('general',{}).get('interface','eth0')) print(yaml.safe_load(f).get('general',{}).get('interface','eth0'))
" 2>/dev/null) " "$CONFIG_FILE" 2>/dev/null)
echo "${iface:-eth0}" echo "${iface:-eth0}"
else else
echo "eth0" echo "eth0"
@@ -249,7 +249,7 @@ cmd_status() {
# Rate config # Rate config
python3 -c " python3 -c "
from ${COMMON_PY} import read_rate_config from xdp_common import read_rate_config
cfg = read_rate_config() cfg = read_rate_config()
if cfg: if cfg:
print(f'Rate limit: {cfg[\"pps_threshold\"]} pps (window: {cfg[\"window_ns\"] // 1_000_000_000}s)') print(f'Rate limit: {cfg[\"pps_threshold\"]} pps (window: {cfg[\"window_ns\"] // 1_000_000_000}s)')
@@ -259,7 +259,7 @@ else:
# Blocked IPs # Blocked IPs
python3 -c " python3 -c "
from ${COMMON_PY} import dump_blocked_ips from xdp_common import dump_blocked_ips
v4 = dump_blocked_ips('blocked_ips_v4') v4 = dump_blocked_ips('blocked_ips_v4')
v6 = dump_blocked_ips('blocked_ips_v6') v6 = dump_blocked_ips('blocked_ips_v6')
print(f'Blocked IPs: {len(v4) + len(v6)}') print(f'Blocked IPs: {len(v4) + len(v6)}')
@@ -283,8 +283,9 @@ cmd_blocker_add() {
[ -z "$map_id" ] && { log_err "IPv6 map not found. Is XDP loaded?"; exit 1; } [ -z "$map_id" ] && { log_err "IPv6 map not found. Is XDP loaded?"; exit 1; }
local key_hex local key_hex
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null) key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null)
[ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; } [ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; }
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null
[ "$2" != "quiet" ] && log_ok "Added (v6): $cidr" || true [ "$2" != "quiet" ] && log_ok "Added (v6): $cidr" || true
@@ -312,6 +313,7 @@ cmd_blocker_add() {
IFS='.' read -r a b c d <<< "$ip" IFS='.' read -r a b c d <<< "$ip"
local key_hex local key_hex
key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d") key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d")
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null
[ "$2" != "quiet" ] && log_ok "Added: $cidr" || true [ "$2" != "quiet" ] && log_ok "Added: $cidr" || true
@@ -331,8 +333,9 @@ cmd_blocker_del() {
[ -z "$map_id" ] && { log_err "IPv6 map not found"; exit 1; } [ -z "$map_id" ] && { log_err "IPv6 map not found"; exit 1; }
local key_hex local key_hex
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null) key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null)
[ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; } [ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; }
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed (v6): $cidr" bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed (v6): $cidr"
local tmpfile="${BLOCKLIST_FILE}.tmp.$$" local tmpfile="${BLOCKLIST_FILE}.tmp.$$"
@@ -358,6 +361,7 @@ cmd_blocker_del() {
IFS='.' read -r a b c d <<< "$ip" IFS='.' read -r a b c d <<< "$ip"
local key_hex local key_hex
key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d") key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d")
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed: $cidr" bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed: $cidr"
local tmpfile="${BLOCKLIST_FILE}.tmp.$$" local tmpfile="${BLOCKLIST_FILE}.tmp.$$"
@@ -461,12 +465,13 @@ cmd_whitelist_add() {
local map_name key_hex local map_name key_hex
if [[ "$name" == *":"* ]]; then if [[ "$name" == *":"* ]]; then
map_name="whitelist_v6" map_name="whitelist_v6"
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null) key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null)
else else
map_name="whitelist_v4" map_name="whitelist_v4"
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null) key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null)
fi fi
[ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; } [ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; }
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
local map_id local map_id
map_id=$(get_map_id "$map_name") map_id=$(get_map_id "$map_name")
@@ -496,12 +501,13 @@ cmd_whitelist_del() {
local map_name key_hex local map_name key_hex
if [[ "$name" == *":"* ]]; then if [[ "$name" == *":"* ]]; then
map_name="whitelist_v6" map_name="whitelist_v6"
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null) key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null)
else else
map_name="whitelist_v4" map_name="whitelist_v4"
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null) key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null)
fi fi
[ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; } [ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; }
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
local map_id local map_id
map_id=$(get_map_id "$map_name") map_id=$(get_map_id "$map_name")
@@ -532,7 +538,7 @@ cmd_whitelist_list() {
cmd_ddos_stats() { cmd_ddos_stats() {
echo -e "${BOLD}=== DDoS Statistics ===${NC}" echo -e "${BOLD}=== DDoS Statistics ===${NC}"
python3 -c " python3 -c "
from ${COMMON_PY} import read_percpu_stats from xdp_common import read_percpu_stats
stats = read_percpu_stats('global_stats', 5) stats = read_percpu_stats('global_stats', 5)
labels = ['Passed', 'Dropped (blocked)', 'Dropped (rate)', 'Total', 'Errors'] labels = ['Passed', 'Dropped (blocked)', 'Dropped (rate)', 'Total', 'Errors']
for i, label in enumerate(labels): for i, label in enumerate(labels):
@@ -547,7 +553,7 @@ cmd_ddos_top() {
echo -e "${BOLD}=== Top $n IPs by Packet Count ===${NC}" echo -e "${BOLD}=== Top $n IPs by Packet Count ===${NC}"
python3 -c " python3 -c "
import sys import sys
from ${COMMON_PY} import dump_rate_counters from xdp_common import dump_rate_counters
entries = dump_rate_counters('rate_counter_v4', int(sys.argv[1])) entries = dump_rate_counters('rate_counter_v4', int(sys.argv[1]))
if not entries: if not entries:
print(' (empty)') print(' (empty)')
@@ -569,7 +575,7 @@ if entries6:
cmd_ddos_blocked() { cmd_ddos_blocked() {
echo -e "${BOLD}=== Blocked IPs ===${NC}" echo -e "${BOLD}=== Blocked IPs ===${NC}"
python3 -c " python3 -c "
from ${COMMON_PY} import dump_blocked_ips from xdp_common import dump_blocked_ips
with open('/proc/uptime') as f: with open('/proc/uptime') as f:
now_ns = int(float(f.read().split()[0]) * 1_000_000_000) now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
@@ -604,7 +610,7 @@ cmd_ddos_block() {
[ -z "$ip" ] && { log_err "Usage: xdp-defense ddos block <ip> [duration_sec]"; exit 1; } [ -z "$ip" ] && { log_err "Usage: xdp-defense ddos block <ip> [duration_sec]"; exit 1; }
[[ "$duration" =~ ^[0-9]+$ ]] || { log_err "Invalid duration: $duration"; exit 1; } [[ "$duration" =~ ^[0-9]+$ ]] || { log_err "Invalid duration: $duration"; exit 1; }
python3 -c "import sys; from ${COMMON_PY} import block_ip; block_ip(sys.argv[1], int(sys.argv[2]))" "$ip" "$duration" 2>/dev/null || \ python3 -c "import sys; from xdp_common import block_ip; block_ip(sys.argv[1], int(sys.argv[2]))" "$ip" "$duration" 2>/dev/null || \
{ log_err "Failed to block $ip"; exit 1; } { log_err "Failed to block $ip"; exit 1; }
if [ "$duration" -gt 0 ] 2>/dev/null; then if [ "$duration" -gt 0 ] 2>/dev/null; then
@@ -619,7 +625,7 @@ cmd_ddos_unblock() {
ip=$(sanitize_input "$1") || exit 1 ip=$(sanitize_input "$1") || exit 1
[ -z "$ip" ] && { log_err "Usage: xdp-defense ddos unblock <ip>"; exit 1; } [ -z "$ip" ] && { log_err "Usage: xdp-defense ddos unblock <ip>"; exit 1; }
python3 -c "import sys; from ${COMMON_PY} import unblock_ip; unblock_ip(sys.argv[1])" "$ip" 2>/dev/null || \ python3 -c "import sys; from xdp_common import unblock_ip; unblock_ip(sys.argv[1])" "$ip" 2>/dev/null || \
{ log_err "Failed to unblock $ip"; exit 1; } { log_err "Failed to unblock $ip"; exit 1; }
log_ok "Unblocked $ip" log_ok "Unblocked $ip"
} }
@@ -631,7 +637,7 @@ cmd_ddos_config() {
echo -e "${BOLD}=== Rate Configuration ===${NC}" echo -e "${BOLD}=== Rate Configuration ===${NC}"
echo -e "\n${CYAN}Active (BPF map):${NC}" echo -e "\n${CYAN}Active (BPF map):${NC}"
python3 -c " python3 -c "
from ${COMMON_PY} import read_rate_config from xdp_common import read_rate_config
cfg = read_rate_config() cfg = read_rate_config()
if cfg: if cfg:
pps = cfg['pps_threshold'] pps = cfg['pps_threshold']
@@ -647,8 +653,8 @@ else:
if [ -f "$CONFIG_FILE" ]; then if [ -f "$CONFIG_FILE" ]; then
echo -e "\n${CYAN}Config file ($CONFIG_FILE):${NC}" echo -e "\n${CYAN}Config file ($CONFIG_FILE):${NC}"
python3 -c " python3 -c "
import yaml import yaml, sys
with open('$CONFIG_FILE') as f: with open(sys.argv[1]) as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
rl = cfg.get('rate_limits', {}) rl = cfg.get('rate_limits', {})
print(f' Default PPS: {rl.get(\"default_pps\", \"N/A\")}') print(f' Default PPS: {rl.get(\"default_pps\", \"N/A\")}')
@@ -662,7 +668,7 @@ if profiles:
hours = p.get('hours', '') hours = p.get('hours', '')
pps = p.get('pps', 'N/A') pps = p.get('pps', 'N/A')
print(f' {name}: pps={pps}, hours={hours}') print(f' {name}: pps={pps}, hours={hours}')
" 2>/dev/null " "$CONFIG_FILE" 2>/dev/null
fi fi
;; ;;
set) set)
@@ -674,7 +680,7 @@ if profiles:
python3 -c " python3 -c "
import sys import sys
from ${COMMON_PY} import read_rate_config, write_rate_config from xdp_common import read_rate_config, write_rate_config
cfg = read_rate_config() cfg = read_rate_config()
if not cfg: if not cfg:
cfg = {'pps_threshold': 1000, 'bps_threshold': 0, 'window_ns': 1000000000} cfg = {'pps_threshold': 1000, 'bps_threshold': 0, 'window_ns': 1000000000}
@@ -700,16 +706,16 @@ cmd_ddos_config_apply() {
[ ! -f "$CONFIG_FILE" ] && return [ ! -f "$CONFIG_FILE" ] && return
python3 -c " python3 -c "
import yaml import yaml, sys
from ${COMMON_PY} import write_rate_config from xdp_common import write_rate_config
with open('$CONFIG_FILE') as f: with open(sys.argv[1]) as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
rl = cfg.get('rate_limits', {}) rl = cfg.get('rate_limits', {})
pps = rl.get('default_pps', 1000) pps = rl.get('default_pps', 1000)
bps = rl.get('default_bps', 0) bps = rl.get('default_bps', 0)
win = rl.get('window_sec', 1) win = rl.get('window_sec', 1)
write_rate_config(pps, bps, win * 1000000000) write_rate_config(pps, bps, win * 1000000000)
" 2>/dev/null || return 0 " "$CONFIG_FILE" 2>/dev/null || return 0
[ "$quiet" != "quiet" ] && log_ok "Config applied from $CONFIG_FILE" || true [ "$quiet" != "quiet" ] && log_ok "Config applied from $CONFIG_FILE" || true
} }
@@ -777,8 +783,8 @@ cmd_ai_status() {
if [ -f "$CONFIG_FILE" ]; then if [ -f "$CONFIG_FILE" ]; then
python3 -c " python3 -c "
import yaml import yaml, sys
with open('$CONFIG_FILE') as f: with open(sys.argv[1]) as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
ai = cfg.get('ai', {}) ai = cfg.get('ai', {})
enabled = ai.get('enabled', False) enabled = ai.get('enabled', False)
@@ -786,7 +792,7 @@ if enabled:
print(f'AI Detection: enabled ({ai.get(\"model_type\", \"IsolationForest\")})') print(f'AI Detection: enabled ({ai.get(\"model_type\", \"IsolationForest\")})')
else: else:
print('AI Detection: disabled') print('AI Detection: disabled')
" 2>/dev/null " "$CONFIG_FILE" 2>/dev/null
fi fi
} }
@@ -804,11 +810,11 @@ cmd_ai_retrain() {
cmd_ai_traffic() { cmd_ai_traffic() {
local db_file local db_file
db_file=$(python3 -c " db_file=$(python3 -c "
import yaml import yaml, sys
with open('$CONFIG_FILE') as f: with open(sys.argv[1]) as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')) print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db'))
" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db") " "$CONFIG_FILE" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db")
[ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; } [ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; }
@@ -872,7 +878,7 @@ conn.close()
# Show next retrain time # Show next retrain time
import yaml, os, time import yaml, os, time
try: try:
with open('$CONFIG_FILE') as f: with open(sys.argv[2]) as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
retrain_interval = cfg.get('ai',{}).get('retrain_interval', 86400) retrain_interval = cfg.get('ai',{}).get('retrain_interval', 86400)
model_file = cfg.get('ai',{}).get('model_file', '/var/lib/xdp-defense/ai_model.pkl') model_file = cfg.get('ai',{}).get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
@@ -890,7 +896,7 @@ try:
except: except:
pass pass
print() print()
" "$db_file" " "$db_file" "$CONFIG_FILE"
} }
cmd_ai_log() { cmd_ai_log() {
@@ -899,11 +905,11 @@ cmd_ai_log() {
local db_file local db_file
db_file=$(python3 -c " db_file=$(python3 -c "
import yaml import yaml, sys
with open('$CONFIG_FILE') as f: with open(sys.argv[1]) as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')) print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db'))
" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db") " "$CONFIG_FILE" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db")
[ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; } [ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; }
@@ -1001,6 +1007,7 @@ cmd_geoip() {
IFS='.' read -r a b c d <<< "$ip" IFS='.' read -r a b c d <<< "$ip"
local key_hex local key_hex
key_hex=$(printf '20 00 00 00 %02x %02x %02x %02x' "$a" "$b" "$c" "$d") key_hex=$(printf '20 00 00 00 %02x %02x %02x %02x' "$a" "$b" "$c" "$d")
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
if bpftool map lookup id "$map_id" key hex $key_hex 2>/dev/null | grep -q "value"; then if bpftool map lookup id "$map_id" key hex $key_hex 2>/dev/null | grep -q "value"; then
echo -e "Blocker: ${RED}BLOCKED${NC}" echo -e "Blocker: ${RED}BLOCKED${NC}"
else else
@@ -1013,7 +1020,8 @@ cmd_geoip() {
ddos_map_id=$(get_map_id blocked_ips_v4) ddos_map_id=$(get_map_id blocked_ips_v4)
if [ -n "$ddos_map_id" ]; then if [ -n "$ddos_map_id" ]; then
local ddos_key_hex local ddos_key_hex
ddos_key_hex=$(python3 -c "import sys; from ${COMMON_PY} import ip_to_hex_key; print(ip_to_hex_key(sys.argv[1]))" "$ip" 2>/dev/null) ddos_key_hex=$(python3 -c "import sys; from xdp_common import ip_to_hex_key; print(ip_to_hex_key(sys.argv[1]))" "$ip" 2>/dev/null)
[[ "$ddos_key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
if [ -n "$ddos_key_hex" ] && bpftool map lookup id "$ddos_map_id" key hex $ddos_key_hex 2>/dev/null | grep -q "value"; then if [ -n "$ddos_key_hex" ] && bpftool map lookup id "$ddos_map_id" key hex $ddos_key_hex 2>/dev/null | grep -q "value"; then
echo -e "DDoS: ${RED}BLOCKED${NC}" echo -e "DDoS: ${RED}BLOCKED${NC}"
else else

View File

@@ -14,6 +14,7 @@ time-profile switching, and automatic escalation.
import copy import copy
import math import math
import os import os
import stat
import sys import sys
import time import time
import signal import signal
@@ -28,6 +29,16 @@ from datetime import datetime, timedelta
import yaml import yaml
try:
from xdp_common import (
dump_rate_counters, block_ip, is_whitelisted,
read_percpu_features, dump_blocked_ips, unblock_ip,
write_rate_config, read_rate_config,
)
except ImportError as e:
print(f"FATAL: Cannot import xdp_common: {e}", file=sys.stderr)
sys.exit(1)
# ==================== Logging ==================== # ==================== Logging ====================
log = logging.getLogger('xdp-defense-daemon') log = logging.getLogger('xdp-defense-daemon')
@@ -256,8 +267,13 @@ class AIDetector:
self._train() self._train()
self._retrain_requested = False self._retrain_requested = False
def _train(self): def _train(self, period_data=None):
"""Train per-period Isolation Forest models.""" """Train per-period Isolation Forest models.
If period_data provided, use it instead of self.training_data.
"""
if period_data is None:
period_data = self.training_data
try: try:
from sklearn.ensemble import IsolationForest from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
@@ -267,19 +283,19 @@ class AIDetector:
self.cfg['enabled'] = False self.cfg['enabled'] = False
return return
total = sum(len(v) for v in self.training_data.values()) total = sum(len(v) for v in period_data.values())
if total < 10: if total < 10:
log.warning("Not enough training data (%d samples)", total) log.warning("Not enough training data (%d samples)", total)
return return
log.info("Training AI models: %s", log.info("Training AI models: %s",
{p: len(s) for p, s in self.training_data.items() if s}) {p: len(s) for p, s in period_data.items() if s})
try: try:
new_models = {} new_models = {}
all_samples = [] all_samples = []
for period, samples in self.training_data.items(): for period, samples in period_data.items():
if len(samples) < 10: if len(samples) < 10:
log.info("Period %s: %d samples (too few, skip)", period, len(samples)) log.info("Period %s: %d samples (too few, skip)", period, len(samples))
continue continue
@@ -309,7 +325,8 @@ class AIDetector:
# Save to disk (atomic) # Save to disk (atomic)
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl') model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
tmp_model = model_file + '.tmp' tmp_model = model_file + '.tmp'
with open(tmp_model, 'wb') as f: fd = os.open(tmp_model, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, 'wb') as f:
pickle.dump({ pickle.dump({
'format': 'period_models', 'format': 'period_models',
'models': {p: {'model': m['model'], 'scaler': m['scaler']} 'models': {p: {'model': m['model'], 'scaler': m['scaler']}
@@ -323,7 +340,8 @@ class AIDetector:
# Save training data CSV # Save training data CSV
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv') data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
tmp_data = data_file + '.tmp' tmp_data = data_file + '.tmp'
with open(tmp_data, 'w', newline='') as f: fd = os.open(tmp_data, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
with os.fdopen(fd, 'w', newline='') as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow([ writer.writerow([
'hour_sin', 'hour_cos', 'hour_sin', 'hour_cos',
@@ -348,6 +366,12 @@ class AIDetector:
if not os.path.exists(model_file): if not os.path.exists(model_file):
return False return False
try: try:
st = os.stat(model_file)
# Warn if file is world-writable
if st.st_mode & stat.S_IWOTH:
log.warning("Model file %s is world-writable! Refusing to load.", model_file)
return False
with open(model_file, 'rb') as f: with open(model_file, 'rb') as f:
data = pickle.load(f) data = pickle.load(f)
@@ -488,8 +512,7 @@ class AIDetector:
len(rows), filtered_count, len(rows), filtered_count,
{p: len(s) for p, s in period_data.items() if s}) {p: len(s) for p, s in period_data.items() if s})
self.training_data = period_data self._train(period_data=period_data)
self._train()
return not self.is_learning return not self.is_learning
@@ -504,8 +527,6 @@ class ProfileManager:
def check_and_apply(self): def check_and_apply(self):
"""Check current time and apply matching profile.""" """Check current time and apply matching profile."""
from xdp_common import write_rate_config
profiles = self.cfg.get('profiles', {}) profiles = self.cfg.get('profiles', {})
now = datetime.now() now = datetime.now()
current_hour = now.hour current_hour = now.hour
@@ -648,17 +669,25 @@ class DDoSDaemon:
def _handle_sighup(self, signum, frame): def _handle_sighup(self, signum, frame):
log.info("SIGHUP received, reloading config...") log.info("SIGHUP received, reloading config...")
self.cfg = load_config(self.config_path) new_cfg = load_config(self.config_path)
# Update existing components without rebuilding (preserves EWMA/violation state) # Build all new values before swapping anything
self.violation_tracker.cfg = self.cfg['escalation'] new_escalation = new_cfg['escalation']
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3) new_alpha = new_cfg['ewma'].get('alpha', 0.3)
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0) new_threshold = new_cfg['ewma'].get('threshold_multiplier', 3.0)
self.ai_detector.cfg = self.cfg['ai'] new_ai_cfg = new_cfg['ai']
self.profile_manager.cfg = self.cfg['rate_limits'] new_rate_cfg = new_cfg['rate_limits']
# Update poll intervals (used by threads on next iteration) new_ewma_interval = new_cfg['ewma'].get('poll_interval', 1)
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1) new_ai_interval = new_cfg['ai'].get('poll_interval', 5)
self._ai_interval = self.cfg['ai'].get('poll_interval', 5) level = new_cfg['general'].get('log_level', 'info').upper()
level = self.cfg['general'].get('log_level', 'info').upper() # Now apply all at once
self.cfg = new_cfg
self.violation_tracker.cfg = new_escalation
self.ewma_analyzer.alpha = new_alpha
self.ewma_analyzer.threshold_multiplier = new_threshold
self.ai_detector.cfg = new_ai_cfg
self.profile_manager.cfg = new_rate_cfg
self._ewma_interval = new_ewma_interval
self._ai_interval = new_ai_interval
log.setLevel(getattr(logging, level, logging.INFO)) log.setLevel(getattr(logging, level, logging.INFO))
log.info("Config reloaded (state preserved)") log.info("Config reloaded (state preserved)")
@@ -678,6 +707,7 @@ class DDoSDaemon:
"""Initialize SQLite database for traffic logging.""" """Initialize SQLite database for traffic logging."""
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db') db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
os.makedirs(os.path.dirname(db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._db_lock = threading.Lock()
self._traffic_db = sqlite3.connect(db_path, check_same_thread=False) self._traffic_db = sqlite3.connect(db_path, check_same_thread=False)
self._traffic_db.execute( self._traffic_db.execute(
'CREATE TABLE IF NOT EXISTS traffic_samples (' 'CREATE TABLE IF NOT EXISTS traffic_samples ('
@@ -712,6 +742,7 @@ class DDoSDaemon:
def _log_traffic(self, now, hour, features): def _log_traffic(self, now, hour, features):
"""Insert one row into traffic_samples table.""" """Insert one row into traffic_samples table."""
try: try:
with self._db_lock:
self._traffic_db.execute( self._traffic_db.execute(
'INSERT INTO traffic_samples (' 'INSERT INTO traffic_samples ('
' timestamp, hour, hour_sin, hour_cos,' ' timestamp, hour, hour_sin, hour_cos,'
@@ -732,6 +763,7 @@ class DDoSDaemon:
cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat() cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat()
try: try:
with self._db_lock:
cur = self._traffic_db.execute( cur = self._traffic_db.execute(
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,) 'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
) )
@@ -755,8 +787,6 @@ class DDoSDaemon:
def _ewma_thread(self): def _ewma_thread(self):
"""Poll rate counters, compute EWMA, detect violations, escalate.""" """Poll rate counters, compute EWMA, detect violations, escalate."""
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
prev_counters = {} prev_counters = {}
while not self._stop_event.is_set(): while not self._stop_event.is_set():
@@ -815,8 +845,6 @@ class DDoSDaemon:
def _ai_thread(self): def _ai_thread(self):
"""Read traffic features, run AI inference or collect training data.""" """Read traffic features, run AI inference or collect training data."""
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
prev_features = None prev_features = None
self._last_retrain_time = self._get_model_mtime() self._last_retrain_time = self._get_model_mtime()
self._last_log_cleanup = time.time() self._last_log_cleanup = time.time()
@@ -938,8 +966,6 @@ class DDoSDaemon:
def _cleanup_thread(self): def _cleanup_thread(self):
"""Periodically clean up expired blocked IPs and stale violations.""" """Periodically clean up expired blocked IPs and stale violations."""
from xdp_common import dump_blocked_ips, unblock_ip
while not self._stop_event.is_set(): while not self._stop_event.is_set():
try: try:
with open('/proc/uptime') as f: with open('/proc/uptime') as f:
@@ -1003,6 +1029,11 @@ class DDoSDaemon:
t.join(timeout=5) t.join(timeout=5)
self._remove_pid() self._remove_pid()
if hasattr(self, '_traffic_db') and self._traffic_db:
try:
self._traffic_db.close()
except Exception:
pass
log.info("Daemon stopped") log.info("Daemon stopped")