Fix 12 code review issues (4 MEDIUM + 8 LOW)
MEDIUM: - M1: Whitelist direct IP/CIDR additions now persist to direct.txt - M2: get_map_id() uses 5s TTL cache (single bpftool call for all maps) - M3: IPv6 extension header parsing in xdp_ddos.c (hop-by-hop/routing/frag/dst) - M4: Shell injection prevention - sanitize_input() + sys.argv[] for all Python calls LOW: - L1: Remove redundant self.running (uses _stop_event only) - L2: Remove unused config values (rate_limit_after, cooldown_multiplier, retrain_interval) - L3: Thread poll intervals reloaded on SIGHUP - L4: batch_map_operation counts only successfully written entries - L5: Clarify unique_ips_approx comment (per-packet counter) - L6: Document LRU_HASH multi-CPU race condition as acceptable - L7: Download Cloudflare IPv6 ranges in whitelist preset - L8: Fix file handle leak in xdp_country.py list_countries() Also: SIGHUP now preserves EWMA/violation state, daemon skips whitelisted IPs in EWMA/AI escalation, deep copy for default config, IHL validation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -32,8 +32,18 @@ if _syslog_handler:
|
||||
|
||||
# ==================== BPF Map Helpers ====================
|
||||
|
||||
_map_cache = {} # {map_name: (map_id, timestamp)}
|
||||
_MAP_CACHE_TTL = 5.0 # seconds
|
||||
|
||||
def get_map_id(map_name):
|
||||
"""Get BPF map ID by name."""
|
||||
"""Get BPF map ID by name (cached with 5s TTL)."""
|
||||
import time as _time
|
||||
now = _time.monotonic()
|
||||
|
||||
cached = _map_cache.get(map_name)
|
||||
if cached and (now - cached[1]) < _MAP_CACHE_TTL:
|
||||
return cached[0]
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "show", "-j"],
|
||||
capture_output=True, text=True
|
||||
@@ -42,12 +52,15 @@ def get_map_id(map_name):
|
||||
return None
|
||||
try:
|
||||
maps = json.loads(result.stdout)
|
||||
# Update cache for all maps found
|
||||
for m in maps:
|
||||
if m.get("name") == map_name:
|
||||
return m.get("id")
|
||||
name = m.get("name")
|
||||
if name:
|
||||
_map_cache[name] = (m.get("id"), now)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
cached = _map_cache.get(map_name)
|
||||
return cached[0] if cached else None
|
||||
|
||||
|
||||
# ==================== CIDR / IPv4 Helpers (from blocker) ====================
|
||||
@@ -130,10 +143,12 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
|
||||
"""
|
||||
total = len(cidrs)
|
||||
processed = 0
|
||||
written = 0
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = cidrs[i:i + batch_size]
|
||||
batch_file = None
|
||||
batch_written = 0
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.batch', delete=False) as f:
|
||||
@@ -145,6 +160,7 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
|
||||
f.write(f"map update id {map_id} key hex {key_hex} value hex {value_hex}\n")
|
||||
else:
|
||||
f.write(f"map delete id {map_id} key hex {key_hex}\n")
|
||||
batch_written += 1
|
||||
except (ValueError, Exception):
|
||||
continue
|
||||
|
||||
@@ -157,13 +173,44 @@ def batch_map_operation(map_id, cidrs, operation="update", value_hex="01 00 00 0
|
||||
os.unlink(batch_file)
|
||||
|
||||
processed += len(batch)
|
||||
written += batch_written
|
||||
pct = processed * 100 // total if total > 0 else 100
|
||||
print(f"\r Progress: {processed}/{total} ({pct}%)", end="", flush=True)
|
||||
|
||||
print()
|
||||
proto = "v6" if ipv6 else "v4"
|
||||
audit_log.info(f"batch {operation} {proto}: {processed} entries on map {map_id}")
|
||||
return processed
|
||||
audit_log.info(f"batch {operation} {proto}: {written} entries on map {map_id}")
|
||||
return written
|
||||
|
||||
|
||||
# ==================== Whitelist Check (for daemon) ====================
|
||||
|
||||
def is_whitelisted(ip_str):
|
||||
"""Check if an IP is in the BPF whitelist maps.
|
||||
Returns True if whitelisted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if isinstance(addr, ipaddress.IPv6Address):
|
||||
map_name = "whitelist_v6"
|
||||
key_hex = cidr_to_key_v6(f"{ip_str}/128")
|
||||
else:
|
||||
map_name = "whitelist_v4"
|
||||
key_hex = cidr_to_key(f"{ip_str}/32")
|
||||
|
||||
map_id = get_map_id(map_name)
|
||||
if map_id is None:
|
||||
return False
|
||||
|
||||
result = subprocess.run(
|
||||
["bpftool", "map", "lookup", "id", str(map_id),
|
||||
"key", "hex"] + key_hex.split(),
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
# ==================== IP Encoding Helpers (from ddos) ====================
|
||||
@@ -360,16 +407,17 @@ def block_ip(ip_str, duration_sec=0):
|
||||
|
||||
key_hex = ip_to_hex_key(ip_str)
|
||||
|
||||
# Use CLOCK_BOOTTIME (matches BPF ktime_get_ns)
|
||||
with open('/proc/uptime', 'r') as f:
|
||||
uptime_sec = float(f.read().split()[0])
|
||||
now_ns = int(uptime_sec * 1_000_000_000)
|
||||
|
||||
if duration_sec > 0:
|
||||
with open('/proc/uptime', 'r') as f:
|
||||
uptime_sec = float(f.read().split()[0])
|
||||
now_ns = int(uptime_sec * 1_000_000_000)
|
||||
expire_ns = now_ns + (duration_sec * 1_000_000_000)
|
||||
else:
|
||||
expire_ns = 0
|
||||
|
||||
now_ns_val = 0
|
||||
raw = struct.pack('<QQQ', expire_ns, now_ns_val, 0)
|
||||
raw = struct.pack('<QQQ', expire_ns, now_ns, 0)
|
||||
val_hex = ' '.join(f"{b:02x}" for b in raw)
|
||||
|
||||
result = subprocess.run(
|
||||
|
||||
Reference in New Issue
Block a user