Fix HIGH severity security and thread-safety issues
Daemon fixes: - Add _db_lock for thread-safe SQLite access - Atomic SIGHUP config swap (build all values before applying) - Check world-writable permission before loading pickle model - Write model files with 0o600 permissions via os.open - Module-level xdp_common import with fatal exit on failure - Close traffic DB on shutdown - Add period_data parameter to _train() to avoid race condition CLI fixes: - Replace $COMMON_PY variable with hardcoded 'xdp_common' - Pass CONFIG_FILE via sys.argv instead of string interpolation - Add key_hex regex validation before all bpftool commands - Switch sanitize_input from denylist to strict allowlist Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,6 @@ COUNTRY_DIR="/etc/xdp-blocker/countries"
|
|||||||
GEOIP_DB="/usr/share/GeoIP/GeoLite2-Country.mmdb"
|
GEOIP_DB="/usr/share/GeoIP/GeoLite2-Country.mmdb"
|
||||||
CITY_DB="/usr/share/GeoIP/GeoLite2-City.mmdb"
|
CITY_DB="/usr/share/GeoIP/GeoLite2-City.mmdb"
|
||||||
ASN_DB="/usr/share/GeoIP/GeoLite2-ASN.mmdb"
|
ASN_DB="/usr/share/GeoIP/GeoLite2-ASN.mmdb"
|
||||||
COMMON_PY="xdp_common"
|
|
||||||
|
|
||||||
# Ensure Python can find xdp_common.py (installed to /usr/local/bin)
|
# Ensure Python can find xdp_common.py (installed to /usr/local/bin)
|
||||||
export PYTHONPATH="/usr/local/bin:${PYTHONPATH:-}"
|
export PYTHONPATH="/usr/local/bin:${PYTHONPATH:-}"
|
||||||
@@ -30,10 +29,11 @@ log_ok() { echo -e "${GREEN}[OK]${NC} $1"; logger -t xdp-defense "OK: $1" 2>/d
|
|||||||
log_err() { echo -e "${RED}[ERROR]${NC} $1" >&2; logger -t xdp-defense -p user.err "ERROR: $1" 2>/dev/null || true; }
|
log_err() { echo -e "${RED}[ERROR]${NC} $1" >&2; logger -t xdp-defense -p user.err "ERROR: $1" 2>/dev/null || true; }
|
||||||
log_info() { echo -e "${CYAN}[INFO]${NC} $1"; logger -t xdp-defense "INFO: $1" 2>/dev/null || true; }
|
log_info() { echo -e "${CYAN}[INFO]${NC} $1"; logger -t xdp-defense "INFO: $1" 2>/dev/null || true; }
|
||||||
|
|
||||||
# Sanitize input for safe embedding in Python strings (reject dangerous chars)
|
# Sanitize input - strict allowlist for IP addresses, CIDR notation, preset names
|
||||||
sanitize_input() {
|
sanitize_input() {
|
||||||
local val="$1"
|
local val="$1"
|
||||||
if [[ "$val" =~ [\'\"\;\$\`\\] ]]; then
|
# Allow only alphanumeric, dots, colons, slashes, hyphens, underscores
|
||||||
|
if [[ ! "$val" =~ ^[a-zA-Z0-9._:/\ -]+$ ]]; then
|
||||||
log_err "Invalid characters in input: $val"
|
log_err "Invalid characters in input: $val"
|
||||||
return 1
|
return 1
|
||||||
fi
|
fi
|
||||||
@@ -44,10 +44,10 @@ get_iface() {
|
|||||||
if [ -f "$CONFIG_FILE" ]; then
|
if [ -f "$CONFIG_FILE" ]; then
|
||||||
local iface
|
local iface
|
||||||
iface=$(python3 -c "
|
iface=$(python3 -c "
|
||||||
import yaml
|
import yaml, sys
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[1]) as f:
|
||||||
print(yaml.safe_load(f).get('general',{}).get('interface','eth0'))
|
print(yaml.safe_load(f).get('general',{}).get('interface','eth0'))
|
||||||
" 2>/dev/null)
|
" "$CONFIG_FILE" 2>/dev/null)
|
||||||
echo "${iface:-eth0}"
|
echo "${iface:-eth0}"
|
||||||
else
|
else
|
||||||
echo "eth0"
|
echo "eth0"
|
||||||
@@ -249,7 +249,7 @@ cmd_status() {
|
|||||||
|
|
||||||
# Rate config
|
# Rate config
|
||||||
python3 -c "
|
python3 -c "
|
||||||
from ${COMMON_PY} import read_rate_config
|
from xdp_common import read_rate_config
|
||||||
cfg = read_rate_config()
|
cfg = read_rate_config()
|
||||||
if cfg:
|
if cfg:
|
||||||
print(f'Rate limit: {cfg[\"pps_threshold\"]} pps (window: {cfg[\"window_ns\"] // 1_000_000_000}s)')
|
print(f'Rate limit: {cfg[\"pps_threshold\"]} pps (window: {cfg[\"window_ns\"] // 1_000_000_000}s)')
|
||||||
@@ -259,7 +259,7 @@ else:
|
|||||||
|
|
||||||
# Blocked IPs
|
# Blocked IPs
|
||||||
python3 -c "
|
python3 -c "
|
||||||
from ${COMMON_PY} import dump_blocked_ips
|
from xdp_common import dump_blocked_ips
|
||||||
v4 = dump_blocked_ips('blocked_ips_v4')
|
v4 = dump_blocked_ips('blocked_ips_v4')
|
||||||
v6 = dump_blocked_ips('blocked_ips_v6')
|
v6 = dump_blocked_ips('blocked_ips_v6')
|
||||||
print(f'Blocked IPs: {len(v4) + len(v6)}')
|
print(f'Blocked IPs: {len(v4) + len(v6)}')
|
||||||
@@ -283,8 +283,9 @@ cmd_blocker_add() {
|
|||||||
[ -z "$map_id" ] && { log_err "IPv6 map not found. Is XDP loaded?"; exit 1; }
|
[ -z "$map_id" ] && { log_err "IPv6 map not found. Is XDP loaded?"; exit 1; }
|
||||||
|
|
||||||
local key_hex
|
local key_hex
|
||||||
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null)
|
key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null)
|
||||||
[ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; }
|
[ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; }
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
|
|
||||||
bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null
|
bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null
|
||||||
[ "$2" != "quiet" ] && log_ok "Added (v6): $cidr" || true
|
[ "$2" != "quiet" ] && log_ok "Added (v6): $cidr" || true
|
||||||
@@ -312,6 +313,7 @@ cmd_blocker_add() {
|
|||||||
IFS='.' read -r a b c d <<< "$ip"
|
IFS='.' read -r a b c d <<< "$ip"
|
||||||
local key_hex
|
local key_hex
|
||||||
key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d")
|
key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d")
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
|
|
||||||
bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null
|
bpftool map update id "$map_id" key hex $key_hex value hex 01 00 00 00 00 00 00 00 2>/dev/null
|
||||||
[ "$2" != "quiet" ] && log_ok "Added: $cidr" || true
|
[ "$2" != "quiet" ] && log_ok "Added: $cidr" || true
|
||||||
@@ -331,8 +333,9 @@ cmd_blocker_del() {
|
|||||||
[ -z "$map_id" ] && { log_err "IPv6 map not found"; exit 1; }
|
[ -z "$map_id" ] && { log_err "IPv6 map not found"; exit 1; }
|
||||||
|
|
||||||
local key_hex
|
local key_hex
|
||||||
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null)
|
key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$cidr" 2>/dev/null)
|
||||||
[ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; }
|
[ -z "$key_hex" ] && { log_err "Invalid IPv6 CIDR: $cidr"; exit 1; }
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
|
|
||||||
bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed (v6): $cidr"
|
bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed (v6): $cidr"
|
||||||
local tmpfile="${BLOCKLIST_FILE}.tmp.$$"
|
local tmpfile="${BLOCKLIST_FILE}.tmp.$$"
|
||||||
@@ -358,6 +361,7 @@ cmd_blocker_del() {
|
|||||||
IFS='.' read -r a b c d <<< "$ip"
|
IFS='.' read -r a b c d <<< "$ip"
|
||||||
local key_hex
|
local key_hex
|
||||||
key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d")
|
key_hex=$(printf '%02x 00 00 00 %02x %02x %02x %02x' "$prefix" "$a" "$b" "$c" "$d")
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
|
|
||||||
bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed: $cidr"
|
bpftool map delete id "$map_id" key hex $key_hex 2>/dev/null && log_ok "Removed: $cidr"
|
||||||
local tmpfile="${BLOCKLIST_FILE}.tmp.$$"
|
local tmpfile="${BLOCKLIST_FILE}.tmp.$$"
|
||||||
@@ -461,12 +465,13 @@ cmd_whitelist_add() {
|
|||||||
local map_name key_hex
|
local map_name key_hex
|
||||||
if [[ "$name" == *":"* ]]; then
|
if [[ "$name" == *":"* ]]; then
|
||||||
map_name="whitelist_v6"
|
map_name="whitelist_v6"
|
||||||
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null)
|
key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null)
|
||||||
else
|
else
|
||||||
map_name="whitelist_v4"
|
map_name="whitelist_v4"
|
||||||
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null)
|
key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null)
|
||||||
fi
|
fi
|
||||||
[ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; }
|
[ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; }
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
|
|
||||||
local map_id
|
local map_id
|
||||||
map_id=$(get_map_id "$map_name")
|
map_id=$(get_map_id "$map_name")
|
||||||
@@ -496,12 +501,13 @@ cmd_whitelist_del() {
|
|||||||
local map_name key_hex
|
local map_name key_hex
|
||||||
if [[ "$name" == *":"* ]]; then
|
if [[ "$name" == *":"* ]]; then
|
||||||
map_name="whitelist_v6"
|
map_name="whitelist_v6"
|
||||||
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null)
|
key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key_v6; print(cidr_to_key_v6(sys.argv[1]))" "$name" 2>/dev/null)
|
||||||
else
|
else
|
||||||
map_name="whitelist_v4"
|
map_name="whitelist_v4"
|
||||||
key_hex=$(python3 -c "import sys; from ${COMMON_PY} import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null)
|
key_hex=$(python3 -c "import sys; from xdp_common import cidr_to_key; print(cidr_to_key(sys.argv[1]))" "$name" 2>/dev/null)
|
||||||
fi
|
fi
|
||||||
[ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; }
|
[ -z "$key_hex" ] && { log_err "Invalid CIDR: $name"; exit 1; }
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
|
|
||||||
local map_id
|
local map_id
|
||||||
map_id=$(get_map_id "$map_name")
|
map_id=$(get_map_id "$map_name")
|
||||||
@@ -532,7 +538,7 @@ cmd_whitelist_list() {
|
|||||||
cmd_ddos_stats() {
|
cmd_ddos_stats() {
|
||||||
echo -e "${BOLD}=== DDoS Statistics ===${NC}"
|
echo -e "${BOLD}=== DDoS Statistics ===${NC}"
|
||||||
python3 -c "
|
python3 -c "
|
||||||
from ${COMMON_PY} import read_percpu_stats
|
from xdp_common import read_percpu_stats
|
||||||
stats = read_percpu_stats('global_stats', 5)
|
stats = read_percpu_stats('global_stats', 5)
|
||||||
labels = ['Passed', 'Dropped (blocked)', 'Dropped (rate)', 'Total', 'Errors']
|
labels = ['Passed', 'Dropped (blocked)', 'Dropped (rate)', 'Total', 'Errors']
|
||||||
for i, label in enumerate(labels):
|
for i, label in enumerate(labels):
|
||||||
@@ -547,7 +553,7 @@ cmd_ddos_top() {
|
|||||||
echo -e "${BOLD}=== Top $n IPs by Packet Count ===${NC}"
|
echo -e "${BOLD}=== Top $n IPs by Packet Count ===${NC}"
|
||||||
python3 -c "
|
python3 -c "
|
||||||
import sys
|
import sys
|
||||||
from ${COMMON_PY} import dump_rate_counters
|
from xdp_common import dump_rate_counters
|
||||||
entries = dump_rate_counters('rate_counter_v4', int(sys.argv[1]))
|
entries = dump_rate_counters('rate_counter_v4', int(sys.argv[1]))
|
||||||
if not entries:
|
if not entries:
|
||||||
print(' (empty)')
|
print(' (empty)')
|
||||||
@@ -569,7 +575,7 @@ if entries6:
|
|||||||
cmd_ddos_blocked() {
|
cmd_ddos_blocked() {
|
||||||
echo -e "${BOLD}=== Blocked IPs ===${NC}"
|
echo -e "${BOLD}=== Blocked IPs ===${NC}"
|
||||||
python3 -c "
|
python3 -c "
|
||||||
from ${COMMON_PY} import dump_blocked_ips
|
from xdp_common import dump_blocked_ips
|
||||||
with open('/proc/uptime') as f:
|
with open('/proc/uptime') as f:
|
||||||
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
now_ns = int(float(f.read().split()[0]) * 1_000_000_000)
|
||||||
|
|
||||||
@@ -604,7 +610,7 @@ cmd_ddos_block() {
|
|||||||
[ -z "$ip" ] && { log_err "Usage: xdp-defense ddos block <ip> [duration_sec]"; exit 1; }
|
[ -z "$ip" ] && { log_err "Usage: xdp-defense ddos block <ip> [duration_sec]"; exit 1; }
|
||||||
[[ "$duration" =~ ^[0-9]+$ ]] || { log_err "Invalid duration: $duration"; exit 1; }
|
[[ "$duration" =~ ^[0-9]+$ ]] || { log_err "Invalid duration: $duration"; exit 1; }
|
||||||
|
|
||||||
python3 -c "import sys; from ${COMMON_PY} import block_ip; block_ip(sys.argv[1], int(sys.argv[2]))" "$ip" "$duration" 2>/dev/null || \
|
python3 -c "import sys; from xdp_common import block_ip; block_ip(sys.argv[1], int(sys.argv[2]))" "$ip" "$duration" 2>/dev/null || \
|
||||||
{ log_err "Failed to block $ip"; exit 1; }
|
{ log_err "Failed to block $ip"; exit 1; }
|
||||||
|
|
||||||
if [ "$duration" -gt 0 ] 2>/dev/null; then
|
if [ "$duration" -gt 0 ] 2>/dev/null; then
|
||||||
@@ -619,7 +625,7 @@ cmd_ddos_unblock() {
|
|||||||
ip=$(sanitize_input "$1") || exit 1
|
ip=$(sanitize_input "$1") || exit 1
|
||||||
[ -z "$ip" ] && { log_err "Usage: xdp-defense ddos unblock <ip>"; exit 1; }
|
[ -z "$ip" ] && { log_err "Usage: xdp-defense ddos unblock <ip>"; exit 1; }
|
||||||
|
|
||||||
python3 -c "import sys; from ${COMMON_PY} import unblock_ip; unblock_ip(sys.argv[1])" "$ip" 2>/dev/null || \
|
python3 -c "import sys; from xdp_common import unblock_ip; unblock_ip(sys.argv[1])" "$ip" 2>/dev/null || \
|
||||||
{ log_err "Failed to unblock $ip"; exit 1; }
|
{ log_err "Failed to unblock $ip"; exit 1; }
|
||||||
log_ok "Unblocked $ip"
|
log_ok "Unblocked $ip"
|
||||||
}
|
}
|
||||||
@@ -631,7 +637,7 @@ cmd_ddos_config() {
|
|||||||
echo -e "${BOLD}=== Rate Configuration ===${NC}"
|
echo -e "${BOLD}=== Rate Configuration ===${NC}"
|
||||||
echo -e "\n${CYAN}Active (BPF map):${NC}"
|
echo -e "\n${CYAN}Active (BPF map):${NC}"
|
||||||
python3 -c "
|
python3 -c "
|
||||||
from ${COMMON_PY} import read_rate_config
|
from xdp_common import read_rate_config
|
||||||
cfg = read_rate_config()
|
cfg = read_rate_config()
|
||||||
if cfg:
|
if cfg:
|
||||||
pps = cfg['pps_threshold']
|
pps = cfg['pps_threshold']
|
||||||
@@ -647,8 +653,8 @@ else:
|
|||||||
if [ -f "$CONFIG_FILE" ]; then
|
if [ -f "$CONFIG_FILE" ]; then
|
||||||
echo -e "\n${CYAN}Config file ($CONFIG_FILE):${NC}"
|
echo -e "\n${CYAN}Config file ($CONFIG_FILE):${NC}"
|
||||||
python3 -c "
|
python3 -c "
|
||||||
import yaml
|
import yaml, sys
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[1]) as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
rl = cfg.get('rate_limits', {})
|
rl = cfg.get('rate_limits', {})
|
||||||
print(f' Default PPS: {rl.get(\"default_pps\", \"N/A\")}')
|
print(f' Default PPS: {rl.get(\"default_pps\", \"N/A\")}')
|
||||||
@@ -662,7 +668,7 @@ if profiles:
|
|||||||
hours = p.get('hours', '')
|
hours = p.get('hours', '')
|
||||||
pps = p.get('pps', 'N/A')
|
pps = p.get('pps', 'N/A')
|
||||||
print(f' {name}: pps={pps}, hours={hours}')
|
print(f' {name}: pps={pps}, hours={hours}')
|
||||||
" 2>/dev/null
|
" "$CONFIG_FILE" 2>/dev/null
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
set)
|
set)
|
||||||
@@ -674,7 +680,7 @@ if profiles:
|
|||||||
|
|
||||||
python3 -c "
|
python3 -c "
|
||||||
import sys
|
import sys
|
||||||
from ${COMMON_PY} import read_rate_config, write_rate_config
|
from xdp_common import read_rate_config, write_rate_config
|
||||||
cfg = read_rate_config()
|
cfg = read_rate_config()
|
||||||
if not cfg:
|
if not cfg:
|
||||||
cfg = {'pps_threshold': 1000, 'bps_threshold': 0, 'window_ns': 1000000000}
|
cfg = {'pps_threshold': 1000, 'bps_threshold': 0, 'window_ns': 1000000000}
|
||||||
@@ -700,16 +706,16 @@ cmd_ddos_config_apply() {
|
|||||||
[ ! -f "$CONFIG_FILE" ] && return
|
[ ! -f "$CONFIG_FILE" ] && return
|
||||||
|
|
||||||
python3 -c "
|
python3 -c "
|
||||||
import yaml
|
import yaml, sys
|
||||||
from ${COMMON_PY} import write_rate_config
|
from xdp_common import write_rate_config
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[1]) as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
rl = cfg.get('rate_limits', {})
|
rl = cfg.get('rate_limits', {})
|
||||||
pps = rl.get('default_pps', 1000)
|
pps = rl.get('default_pps', 1000)
|
||||||
bps = rl.get('default_bps', 0)
|
bps = rl.get('default_bps', 0)
|
||||||
win = rl.get('window_sec', 1)
|
win = rl.get('window_sec', 1)
|
||||||
write_rate_config(pps, bps, win * 1000000000)
|
write_rate_config(pps, bps, win * 1000000000)
|
||||||
" 2>/dev/null || return 0
|
" "$CONFIG_FILE" 2>/dev/null || return 0
|
||||||
|
|
||||||
[ "$quiet" != "quiet" ] && log_ok "Config applied from $CONFIG_FILE" || true
|
[ "$quiet" != "quiet" ] && log_ok "Config applied from $CONFIG_FILE" || true
|
||||||
}
|
}
|
||||||
@@ -777,8 +783,8 @@ cmd_ai_status() {
|
|||||||
|
|
||||||
if [ -f "$CONFIG_FILE" ]; then
|
if [ -f "$CONFIG_FILE" ]; then
|
||||||
python3 -c "
|
python3 -c "
|
||||||
import yaml
|
import yaml, sys
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[1]) as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
ai = cfg.get('ai', {})
|
ai = cfg.get('ai', {})
|
||||||
enabled = ai.get('enabled', False)
|
enabled = ai.get('enabled', False)
|
||||||
@@ -786,7 +792,7 @@ if enabled:
|
|||||||
print(f'AI Detection: enabled ({ai.get(\"model_type\", \"IsolationForest\")})')
|
print(f'AI Detection: enabled ({ai.get(\"model_type\", \"IsolationForest\")})')
|
||||||
else:
|
else:
|
||||||
print('AI Detection: disabled')
|
print('AI Detection: disabled')
|
||||||
" 2>/dev/null
|
" "$CONFIG_FILE" 2>/dev/null
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -804,11 +810,11 @@ cmd_ai_retrain() {
|
|||||||
cmd_ai_traffic() {
|
cmd_ai_traffic() {
|
||||||
local db_file
|
local db_file
|
||||||
db_file=$(python3 -c "
|
db_file=$(python3 -c "
|
||||||
import yaml
|
import yaml, sys
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[1]) as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db'))
|
print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db'))
|
||||||
" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db")
|
" "$CONFIG_FILE" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db")
|
||||||
|
|
||||||
[ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; }
|
[ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; }
|
||||||
|
|
||||||
@@ -872,7 +878,7 @@ conn.close()
|
|||||||
# Show next retrain time
|
# Show next retrain time
|
||||||
import yaml, os, time
|
import yaml, os, time
|
||||||
try:
|
try:
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[2]) as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
retrain_interval = cfg.get('ai',{}).get('retrain_interval', 86400)
|
retrain_interval = cfg.get('ai',{}).get('retrain_interval', 86400)
|
||||||
model_file = cfg.get('ai',{}).get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
model_file = cfg.get('ai',{}).get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||||
@@ -890,7 +896,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
print()
|
print()
|
||||||
" "$db_file"
|
" "$db_file" "$CONFIG_FILE"
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd_ai_log() {
|
cmd_ai_log() {
|
||||||
@@ -899,11 +905,11 @@ cmd_ai_log() {
|
|||||||
|
|
||||||
local db_file
|
local db_file
|
||||||
db_file=$(python3 -c "
|
db_file=$(python3 -c "
|
||||||
import yaml
|
import yaml, sys
|
||||||
with open('$CONFIG_FILE') as f:
|
with open(sys.argv[1]) as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db'))
|
print(cfg.get('ai',{}).get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db'))
|
||||||
" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db")
|
" "$CONFIG_FILE" 2>/dev/null || echo "/var/lib/xdp-defense/traffic_log.db")
|
||||||
|
|
||||||
[ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; }
|
[ ! -f "$db_file" ] && { log_err "Traffic log not found: $db_file"; exit 1; }
|
||||||
|
|
||||||
@@ -1001,6 +1007,7 @@ cmd_geoip() {
|
|||||||
IFS='.' read -r a b c d <<< "$ip"
|
IFS='.' read -r a b c d <<< "$ip"
|
||||||
local key_hex
|
local key_hex
|
||||||
key_hex=$(printf '20 00 00 00 %02x %02x %02x %02x' "$a" "$b" "$c" "$d")
|
key_hex=$(printf '20 00 00 00 %02x %02x %02x %02x' "$a" "$b" "$c" "$d")
|
||||||
|
[[ "$key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
if bpftool map lookup id "$map_id" key hex $key_hex 2>/dev/null | grep -q "value"; then
|
if bpftool map lookup id "$map_id" key hex $key_hex 2>/dev/null | grep -q "value"; then
|
||||||
echo -e "Blocker: ${RED}BLOCKED${NC}"
|
echo -e "Blocker: ${RED}BLOCKED${NC}"
|
||||||
else
|
else
|
||||||
@@ -1013,7 +1020,8 @@ cmd_geoip() {
|
|||||||
ddos_map_id=$(get_map_id blocked_ips_v4)
|
ddos_map_id=$(get_map_id blocked_ips_v4)
|
||||||
if [ -n "$ddos_map_id" ]; then
|
if [ -n "$ddos_map_id" ]; then
|
||||||
local ddos_key_hex
|
local ddos_key_hex
|
||||||
ddos_key_hex=$(python3 -c "import sys; from ${COMMON_PY} import ip_to_hex_key; print(ip_to_hex_key(sys.argv[1]))" "$ip" 2>/dev/null)
|
ddos_key_hex=$(python3 -c "import sys; from xdp_common import ip_to_hex_key; print(ip_to_hex_key(sys.argv[1]))" "$ip" 2>/dev/null)
|
||||||
|
[[ "$ddos_key_hex" =~ ^[0-9a-f\ ]+$ ]] || { log_err "Invalid key hex"; exit 1; }
|
||||||
if [ -n "$ddos_key_hex" ] && bpftool map lookup id "$ddos_map_id" key hex $ddos_key_hex 2>/dev/null | grep -q "value"; then
|
if [ -n "$ddos_key_hex" ] && bpftool map lookup id "$ddos_map_id" key hex $ddos_key_hex 2>/dev/null | grep -q "value"; then
|
||||||
echo -e "DDoS: ${RED}BLOCKED${NC}"
|
echo -e "DDoS: ${RED}BLOCKED${NC}"
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ time-profile switching, and automatic escalation.
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import stat
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import signal
|
import signal
|
||||||
@@ -28,6 +29,16 @@ from datetime import datetime, timedelta
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xdp_common import (
|
||||||
|
dump_rate_counters, block_ip, is_whitelisted,
|
||||||
|
read_percpu_features, dump_blocked_ips, unblock_ip,
|
||||||
|
write_rate_config, read_rate_config,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"FATAL: Cannot import xdp_common: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# ==================== Logging ====================
|
# ==================== Logging ====================
|
||||||
|
|
||||||
log = logging.getLogger('xdp-defense-daemon')
|
log = logging.getLogger('xdp-defense-daemon')
|
||||||
@@ -256,8 +267,13 @@ class AIDetector:
|
|||||||
self._train()
|
self._train()
|
||||||
self._retrain_requested = False
|
self._retrain_requested = False
|
||||||
|
|
||||||
def _train(self):
|
def _train(self, period_data=None):
|
||||||
"""Train per-period Isolation Forest models."""
|
"""Train per-period Isolation Forest models.
|
||||||
|
If period_data provided, use it instead of self.training_data.
|
||||||
|
"""
|
||||||
|
if period_data is None:
|
||||||
|
period_data = self.training_data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sklearn.ensemble import IsolationForest
|
from sklearn.ensemble import IsolationForest
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
@@ -267,19 +283,19 @@ class AIDetector:
|
|||||||
self.cfg['enabled'] = False
|
self.cfg['enabled'] = False
|
||||||
return
|
return
|
||||||
|
|
||||||
total = sum(len(v) for v in self.training_data.values())
|
total = sum(len(v) for v in period_data.values())
|
||||||
if total < 10:
|
if total < 10:
|
||||||
log.warning("Not enough training data (%d samples)", total)
|
log.warning("Not enough training data (%d samples)", total)
|
||||||
return
|
return
|
||||||
|
|
||||||
log.info("Training AI models: %s",
|
log.info("Training AI models: %s",
|
||||||
{p: len(s) for p, s in self.training_data.items() if s})
|
{p: len(s) for p, s in period_data.items() if s})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_models = {}
|
new_models = {}
|
||||||
all_samples = []
|
all_samples = []
|
||||||
|
|
||||||
for period, samples in self.training_data.items():
|
for period, samples in period_data.items():
|
||||||
if len(samples) < 10:
|
if len(samples) < 10:
|
||||||
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
|
log.info("Period %s: %d samples (too few, skip)", period, len(samples))
|
||||||
continue
|
continue
|
||||||
@@ -309,7 +325,8 @@ class AIDetector:
|
|||||||
# Save to disk (atomic)
|
# Save to disk (atomic)
|
||||||
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
model_file = self.cfg.get('model_file', '/var/lib/xdp-defense/ai_model.pkl')
|
||||||
tmp_model = model_file + '.tmp'
|
tmp_model = model_file + '.tmp'
|
||||||
with open(tmp_model, 'wb') as f:
|
fd = os.open(tmp_model, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||||
|
with os.fdopen(fd, 'wb') as f:
|
||||||
pickle.dump({
|
pickle.dump({
|
||||||
'format': 'period_models',
|
'format': 'period_models',
|
||||||
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
|
'models': {p: {'model': m['model'], 'scaler': m['scaler']}
|
||||||
@@ -323,7 +340,8 @@ class AIDetector:
|
|||||||
# Save training data CSV
|
# Save training data CSV
|
||||||
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
data_file = self.cfg.get('training_data_file', '/var/lib/xdp-defense/training_data.csv')
|
||||||
tmp_data = data_file + '.tmp'
|
tmp_data = data_file + '.tmp'
|
||||||
with open(tmp_data, 'w', newline='') as f:
|
fd = os.open(tmp_data, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
|
||||||
|
with os.fdopen(fd, 'w', newline='') as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
writer.writerow([
|
writer.writerow([
|
||||||
'hour_sin', 'hour_cos',
|
'hour_sin', 'hour_cos',
|
||||||
@@ -348,6 +366,12 @@ class AIDetector:
|
|||||||
if not os.path.exists(model_file):
|
if not os.path.exists(model_file):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
|
st = os.stat(model_file)
|
||||||
|
# Warn if file is world-writable
|
||||||
|
if st.st_mode & stat.S_IWOTH:
|
||||||
|
log.warning("Model file %s is world-writable! Refusing to load.", model_file)
|
||||||
|
return False
|
||||||
|
|
||||||
with open(model_file, 'rb') as f:
|
with open(model_file, 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
|
|
||||||
@@ -488,8 +512,7 @@ class AIDetector:
|
|||||||
len(rows), filtered_count,
|
len(rows), filtered_count,
|
||||||
{p: len(s) for p, s in period_data.items() if s})
|
{p: len(s) for p, s in period_data.items() if s})
|
||||||
|
|
||||||
self.training_data = period_data
|
self._train(period_data=period_data)
|
||||||
self._train()
|
|
||||||
return not self.is_learning
|
return not self.is_learning
|
||||||
|
|
||||||
|
|
||||||
@@ -504,8 +527,6 @@ class ProfileManager:
|
|||||||
|
|
||||||
def check_and_apply(self):
|
def check_and_apply(self):
|
||||||
"""Check current time and apply matching profile."""
|
"""Check current time and apply matching profile."""
|
||||||
from xdp_common import write_rate_config
|
|
||||||
|
|
||||||
profiles = self.cfg.get('profiles', {})
|
profiles = self.cfg.get('profiles', {})
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
current_hour = now.hour
|
current_hour = now.hour
|
||||||
@@ -648,17 +669,25 @@ class DDoSDaemon:
|
|||||||
|
|
||||||
def _handle_sighup(self, signum, frame):
|
def _handle_sighup(self, signum, frame):
|
||||||
log.info("SIGHUP received, reloading config...")
|
log.info("SIGHUP received, reloading config...")
|
||||||
self.cfg = load_config(self.config_path)
|
new_cfg = load_config(self.config_path)
|
||||||
# Update existing components without rebuilding (preserves EWMA/violation state)
|
# Build all new values before swapping anything
|
||||||
self.violation_tracker.cfg = self.cfg['escalation']
|
new_escalation = new_cfg['escalation']
|
||||||
self.ewma_analyzer.alpha = self.cfg['ewma'].get('alpha', 0.3)
|
new_alpha = new_cfg['ewma'].get('alpha', 0.3)
|
||||||
self.ewma_analyzer.threshold_multiplier = self.cfg['ewma'].get('threshold_multiplier', 3.0)
|
new_threshold = new_cfg['ewma'].get('threshold_multiplier', 3.0)
|
||||||
self.ai_detector.cfg = self.cfg['ai']
|
new_ai_cfg = new_cfg['ai']
|
||||||
self.profile_manager.cfg = self.cfg['rate_limits']
|
new_rate_cfg = new_cfg['rate_limits']
|
||||||
# Update poll intervals (used by threads on next iteration)
|
new_ewma_interval = new_cfg['ewma'].get('poll_interval', 1)
|
||||||
self._ewma_interval = self.cfg['ewma'].get('poll_interval', 1)
|
new_ai_interval = new_cfg['ai'].get('poll_interval', 5)
|
||||||
self._ai_interval = self.cfg['ai'].get('poll_interval', 5)
|
level = new_cfg['general'].get('log_level', 'info').upper()
|
||||||
level = self.cfg['general'].get('log_level', 'info').upper()
|
# Now apply all at once
|
||||||
|
self.cfg = new_cfg
|
||||||
|
self.violation_tracker.cfg = new_escalation
|
||||||
|
self.ewma_analyzer.alpha = new_alpha
|
||||||
|
self.ewma_analyzer.threshold_multiplier = new_threshold
|
||||||
|
self.ai_detector.cfg = new_ai_cfg
|
||||||
|
self.profile_manager.cfg = new_rate_cfg
|
||||||
|
self._ewma_interval = new_ewma_interval
|
||||||
|
self._ai_interval = new_ai_interval
|
||||||
log.setLevel(getattr(logging, level, logging.INFO))
|
log.setLevel(getattr(logging, level, logging.INFO))
|
||||||
log.info("Config reloaded (state preserved)")
|
log.info("Config reloaded (state preserved)")
|
||||||
|
|
||||||
@@ -678,6 +707,7 @@ class DDoSDaemon:
|
|||||||
"""Initialize SQLite database for traffic logging."""
|
"""Initialize SQLite database for traffic logging."""
|
||||||
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
db_path = self.cfg['ai'].get('traffic_log_db', '/var/lib/xdp-defense/traffic_log.db')
|
||||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||||
|
self._db_lock = threading.Lock()
|
||||||
self._traffic_db = sqlite3.connect(db_path, check_same_thread=False)
|
self._traffic_db = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
self._traffic_db.execute(
|
self._traffic_db.execute(
|
||||||
'CREATE TABLE IF NOT EXISTS traffic_samples ('
|
'CREATE TABLE IF NOT EXISTS traffic_samples ('
|
||||||
@@ -712,6 +742,7 @@ class DDoSDaemon:
|
|||||||
def _log_traffic(self, now, hour, features):
|
def _log_traffic(self, now, hour, features):
|
||||||
"""Insert one row into traffic_samples table."""
|
"""Insert one row into traffic_samples table."""
|
||||||
try:
|
try:
|
||||||
|
with self._db_lock:
|
||||||
self._traffic_db.execute(
|
self._traffic_db.execute(
|
||||||
'INSERT INTO traffic_samples ('
|
'INSERT INTO traffic_samples ('
|
||||||
' timestamp, hour, hour_sin, hour_cos,'
|
' timestamp, hour, hour_sin, hour_cos,'
|
||||||
@@ -732,6 +763,7 @@ class DDoSDaemon:
|
|||||||
cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat()
|
cutoff = (datetime.now() - timedelta(days=retention_days)).isoformat()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
with self._db_lock:
|
||||||
cur = self._traffic_db.execute(
|
cur = self._traffic_db.execute(
|
||||||
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
|
'DELETE FROM traffic_samples WHERE timestamp < ?', (cutoff,)
|
||||||
)
|
)
|
||||||
@@ -755,8 +787,6 @@ class DDoSDaemon:
|
|||||||
|
|
||||||
def _ewma_thread(self):
|
def _ewma_thread(self):
|
||||||
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
"""Poll rate counters, compute EWMA, detect violations, escalate."""
|
||||||
from xdp_common import dump_rate_counters, block_ip, is_whitelisted
|
|
||||||
|
|
||||||
prev_counters = {}
|
prev_counters = {}
|
||||||
|
|
||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
@@ -815,8 +845,6 @@ class DDoSDaemon:
|
|||||||
|
|
||||||
def _ai_thread(self):
|
def _ai_thread(self):
|
||||||
"""Read traffic features, run AI inference or collect training data."""
|
"""Read traffic features, run AI inference or collect training data."""
|
||||||
from xdp_common import read_percpu_features, dump_rate_counters, block_ip, is_whitelisted
|
|
||||||
|
|
||||||
prev_features = None
|
prev_features = None
|
||||||
self._last_retrain_time = self._get_model_mtime()
|
self._last_retrain_time = self._get_model_mtime()
|
||||||
self._last_log_cleanup = time.time()
|
self._last_log_cleanup = time.time()
|
||||||
@@ -938,8 +966,6 @@ class DDoSDaemon:
|
|||||||
|
|
||||||
def _cleanup_thread(self):
|
def _cleanup_thread(self):
|
||||||
"""Periodically clean up expired blocked IPs and stale violations."""
|
"""Periodically clean up expired blocked IPs and stale violations."""
|
||||||
from xdp_common import dump_blocked_ips, unblock_ip
|
|
||||||
|
|
||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
with open('/proc/uptime') as f:
|
with open('/proc/uptime') as f:
|
||||||
@@ -1003,6 +1029,11 @@ class DDoSDaemon:
|
|||||||
t.join(timeout=5)
|
t.join(timeout=5)
|
||||||
|
|
||||||
self._remove_pid()
|
self._remove_pid()
|
||||||
|
if hasattr(self, '_traffic_db') and self._traffic_db:
|
||||||
|
try:
|
||||||
|
self._traffic_db.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
log.info("Daemon stopped")
|
log.info("Daemon stopped")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user