Fix 12 code review issues (4 MEDIUM + 8 LOW)

MEDIUM:
- M1: Whitelist direct IP/CIDR additions now persist to direct.txt
- M2: get_map_id() uses 5s TTL cache (single bpftool call for all maps)
- M3: IPv6 extension header parsing in xdp_ddos.c (hop-by-hop/routing/frag/dst)
- M4: Shell injection prevention - sanitize_input() + sys.argv[] for all Python calls

LOW:
- L1: Remove redundant self.running (uses _stop_event only)
- L2: Remove unused config values (rate_limit_after, cooldown_multiplier, retrain_interval)
- L3: Thread poll intervals reloaded on SIGHUP
- L4: batch_map_operation counts only successfully written entries
- L5: Clarify unique_ips_approx comment (per-packet counter)
- L6: Document LRU_HASH multi-CPU race condition as acceptable
- L7: Download Cloudflare IPv6 ranges in whitelist preset
- L8: Fix file handle leak in xdp_country.py list_countries()

Also: SIGHUP now preserves EWMA/violation state, daemon skips whitelisted
IPs in EWMA/AI escalation, deep copy for default config, IHL validation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kaffa
2026-02-07 09:23:41 +09:00
parent dbfcb62cdf
commit 667c6eac81
7 changed files with 218 additions and 67 deletions

View File

@@ -32,8 +32,18 @@ if _syslog_handler:
# ==================== BPF Map Helpers ====================
_map_cache = {} # {map_name: (map_id, timestamp)}
_MAP_CACHE_TTL = 5.0 # seconds
def get_map_id(map_name):
"""Get BPF map ID by name."""
"""Get BPF map ID by name (cached with 5s TTL)."""
import time as _time
now = _time.monotonic()
cached = _map_cache.get(map_name)
if cached and (now - cached[1]) < _MAP_CACHE_TTL:
return cached[0]
result = subprocess.run(
["bpftool", "map", "show", "-j"],
capture_output=True, text=True
@@ -42,12 +52,15 @@ def get_map_id(map_name):
return None
try:
maps = json.loads(result.stdout)
# Update cache for all maps found
for m in maps:
if m.get("name") == map_name:
return m.get("id")
name = m.get("name")
if name:
_map_cache[name] = (m.get("id"), now)
except Exception:
pass
return None
cached = _map_cache.get(map_name)
return cached[0] if cached else None
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
@@ -130,10 +143,12 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
"""
total = len(cidrs)
processed = 0
written = 0
for i in range(0, total, batch_size):
batch = cidrs[i:i + batch_size]
batch_file = None
batch_written = 0
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
@@ -145,6 +160,7 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
else:
f.write(f"map delete id {map_id} key hex {key_hex}\n")
batch_written += 1
except (ValueError, Exception):
continue
@@ -157,13 +173,44 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
os.unlink(batch_file)
processed += len(batch)
written += batch_written
pct = processed * 100 // total if total > 0 else 100
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
print()
proto = "v6" if ipv6 else "v4"
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
return processed
audit_log.info(f"batch {operation} {proto}: {written} entries on map {map_id}")
return written
# ==================== Whitelist Check (for daemon) ====================
def is_whitelisted(ip_str):
"""Check if an IP is in the BPF whitelist maps.
Returns True if whitelisted, False otherwise.
"""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return False
if isinstance(addr, ipaddress.IPv6Address):
map_name = "whitelist_v6"
key_hex = cidr_to_key_v6(f"{ip_str}/128")
else:
map_name = "whitelist_v4"
key_hex = cidr_to_key(f"{ip_str}/32")
map_id = get_map_id(map_name)
if map_id is None:
return False
result = subprocess.run(
["bpftool", "map", "lookup", "id", str(map_id),
"key", "hex"] + key_hex.split(),
capture_output=True, text=True
)
return result.returncode == 0
# ==================== IP Encoding Helpers (from ddos) ====================
@@ -360,16 +407,17 @@ def block_ip(ip_str, duration_sec=0):
key_hex = ip_to_hex_key(ip_str)
# Use CLOCK_BOOTTIME (matches BPF ktime_get_ns)
with open('/proc/uptime', 'r') as f:
uptime_sec = float(f.read().split()[0])
now_ns = int(uptime_sec * 1_000_000_000)
if duration_sec > 0:
with open('/proc/uptime', 'r') as f:
uptime_sec = float(f.read().split()[0])
now_ns = int(uptime_sec * 1_000_000_000)
expire_ns = now_ns + (duration_sec * 1_000_000_000)
else:
expire_ns = 0
now_ns_val = 0
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
raw = struct.pack('<QQQ', expire_ns, now_ns, 0)
val_hex = ' '.join(f"{b:02x}" for b in raw)
result = subprocess.run(