feat(dns): persist tunneling burst state across restarts
Switch burst deque from monotonic() to time.time() (wall-clock, serializable).
Add DNS_STATE_PATH env var: on startup _load_state() reads {src:[ts,...]} JSON
and prunes entries older than the burst window. _flush_state() write-then-renames
atomically; _state_flusher() coroutine flushes every 5s when dirty. Detection of
the 5th event also triggers an immediate flush. No-op when DNS_STATE_PATH is
unset, so the default deployment is unchanged.
This commit is contained in:
@@ -20,6 +20,7 @@ event_type values emitted:
|
||||
import asyncio
|
||||
import collections
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import socket
|
||||
@@ -41,6 +42,7 @@ _AUTHORS = os.environ.get("DNS_AUTHORS", "BIND9 Developers")
|
||||
_NSID_RAW = os.environ.get("DNS_NSID", "")
|
||||
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
|
||||
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
|
||||
_STATE_PATH = os.environ.get("DNS_STATE_PATH", "")
|
||||
|
||||
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
|
||||
try:
|
||||
@@ -379,8 +381,9 @@ _FORWARD_BUDGET_WIN = float(os.environ.get("DNS_FORWARD_WINDOW", "1.0"))
|
||||
|
||||
# ── Per-src state ─────────────────────────────────────────────────────────────
|
||||
|
||||
# Tunneling: src_ip -> deque of recent timestamps (all _TUNNEL_QTYPES counted)
|
||||
# Tunneling: src_ip -> deque of recent wall-clock timestamps (all _TUNNEL_QTYPES counted)
|
||||
_tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
|
||||
_state_dirty = False
|
||||
|
||||
# Flood: src_ip -> deque of recent query timestamps
|
||||
_qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
|
||||
@@ -429,6 +432,53 @@ def _track_lru(table: collections.OrderedDict, key: str, tracker_name: str) -> N
|
||||
table.popitem(last=False)
|
||||
_note_eviction(tracker_name)
|
||||
|
||||
|
||||
def _load_state() -> None:
|
||||
"""Populate _tunnel_times from the state file on startup, pruning stale entries."""
|
||||
if not _STATE_PATH:
|
||||
return
|
||||
try:
|
||||
with open(_STATE_PATH) as fh:
|
||||
data: dict[str, list[float]] = json.load(fh)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except Exception:
|
||||
_log("startup", severity=5, msg="dns state file unreadable, starting fresh")
|
||||
return
|
||||
now = time.time()
|
||||
cutoff = now - _TXT_BURST_WINDOW
|
||||
for src, timestamps in data.items():
|
||||
recent = [t for t in timestamps if t > cutoff]
|
||||
if recent:
|
||||
_tunnel_times[src] = collections.deque(recent)
|
||||
|
||||
|
||||
def _flush_state() -> None:
|
||||
"""Write _tunnel_times atomically to the state file if _STATE_PATH is set."""
|
||||
global _state_dirty
|
||||
if not _STATE_PATH:
|
||||
return
|
||||
tmp = _STATE_PATH + ".tmp"
|
||||
try:
|
||||
serialized = {src: list(q) for src, q in _tunnel_times.items()}
|
||||
with open(tmp, "w") as fh:
|
||||
json.dump(serialized, fh)
|
||||
os.replace(tmp, _STATE_PATH)
|
||||
_state_dirty = False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _state_flusher() -> None:
|
||||
"""Periodically flush _tunnel_times to disk when dirty."""
|
||||
while True:
|
||||
await asyncio.sleep(5.0)
|
||||
if _state_dirty:
|
||||
_flush_state()
|
||||
|
||||
|
||||
_load_state()
|
||||
|
||||
# ── Tunneling heuristic ───────────────────────────────────────────────────────
|
||||
|
||||
def _shannon_entropy(s: str) -> float:
|
||||
@@ -459,7 +509,8 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
|
||||
):
|
||||
return "qname_entropy"
|
||||
if qtype in _TUNNEL_QTYPES:
|
||||
now = time.monotonic()
|
||||
global _state_dirty
|
||||
now = time.time()
|
||||
if src not in _tunnel_times:
|
||||
_tunnel_times[src] = collections.deque()
|
||||
_track_lru(_tunnel_times, src, "tunnel_times")
|
||||
@@ -467,7 +518,9 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
|
||||
q.append(now)
|
||||
while q and now - q[0] > _TXT_BURST_WINDOW:
|
||||
q.popleft()
|
||||
_state_dirty = True
|
||||
if len(q) >= _TXT_BURST_COUNT:
|
||||
_flush_state()
|
||||
return "burst"
|
||||
return None
|
||||
|
||||
@@ -910,9 +963,11 @@ async def main() -> None:
|
||||
tcp_server = await asyncio.start_server(
|
||||
_tcp_session, "0.0.0.0", 53 # nosec B104
|
||||
)
|
||||
flusher = asyncio.ensure_future(_state_flusher())
|
||||
try:
|
||||
await asyncio.sleep(float("inf"))
|
||||
finally:
|
||||
flusher.cancel()
|
||||
udp_transport.close()
|
||||
tcp_server.close()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import importlib.util
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -76,6 +77,9 @@ def _load_dns(extra_env: dict | None = None):
|
||||
mod._flood_cooldown.clear()
|
||||
mod._recon_window.clear()
|
||||
mod._recon_cooldown.clear()
|
||||
# Re-load tunnel state from file if path is set (clear above wiped it)
|
||||
if env.get("DNS_STATE_PATH"):
|
||||
mod._load_state()
|
||||
|
||||
return mod, bridge._events
|
||||
|
||||
@@ -475,6 +479,59 @@ class TestTunnelingHeuristic:
|
||||
mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp")
|
||||
assert _events_of(events, "tunneling_suspect")
|
||||
|
||||
# ── State persistence ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestStatePersistence:
|
||||
def test_state_path_unset_no_file_written(self, tmp_path):
|
||||
"""With DNS_STATE_PATH unset, no file should be written after burst detection."""
|
||||
mod, events = _load_dns()
|
||||
src = "6.6.6.1"
|
||||
for i in range(mod._TXT_BURST_COUNT):
|
||||
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
|
||||
assert _events_of(events, "tunneling_suspect")
|
||||
# No state file should have appeared anywhere under tmp_path
|
||||
assert list(tmp_path.iterdir()) == []
|
||||
|
||||
def test_state_persists_across_reload(self, tmp_path):
|
||||
"""4 burst queries → flush → fresh module load → 5th query trips the burst."""
|
||||
state_file = str(tmp_path / "dns_state.json")
|
||||
mod, events = _load_dns({"DNS_STATE_PATH": state_file})
|
||||
src = "6.6.6.2"
|
||||
# Send _TXT_BURST_COUNT - 1 queries (not enough to trigger on their own)
|
||||
for i in range(mod._TXT_BURST_COUNT - 1):
|
||||
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
|
||||
# No burst yet
|
||||
assert not _events_of(events, "tunneling_suspect")
|
||||
# Manually flush state to disk
|
||||
mod._flush_state()
|
||||
assert tmp_path.joinpath("dns_state.json").exists()
|
||||
# Reload module (simulates container restart) — same env so it reads the file
|
||||
mod2, events2 = _load_dns({"DNS_STATE_PATH": state_file})
|
||||
# The reloaded module should have populated _tunnel_times from the file
|
||||
assert src in mod2._tunnel_times
|
||||
# 5th query → burst fires
|
||||
mod2._handle(_build_query("b5.test.local", mod2.TYPE_TXT), src, 1234, "udp")
|
||||
assert _events_of(events2, "tunneling_suspect"), "5th query after restart must fire"
|
||||
|
||||
def test_state_file_prunes_old_entries_on_load(self, tmp_path):
|
||||
"""Entries older than _TXT_BURST_WINDOW are pruned on startup, not loaded."""
|
||||
import json as _json
|
||||
state_file = tmp_path / "dns_state.json"
|
||||
old_ts = time.time() - 9999 # way outside window
|
||||
state_file.write_text(_json.dumps({"1.2.3.4": [old_ts, old_ts]}))
|
||||
mod, _ = _load_dns({"DNS_STATE_PATH": str(state_file)})
|
||||
assert "1.2.3.4" not in mod._tunnel_times
|
||||
|
||||
def test_state_file_corrupt_json_tolerated(self, tmp_path):
|
||||
"""A corrupt state file must not crash the module; state starts empty."""
|
||||
state_file = tmp_path / "dns_state.json"
|
||||
state_file.write_text("{garbage!!!")
|
||||
mod, events = _load_dns({"DNS_STATE_PATH": str(state_file)})
|
||||
# Module loads, handles a query without crashing
|
||||
resp = mod._handle(_build_query("test.local", mod.TYPE_A), "7.7.7.7", 1234, "udp")
|
||||
assert resp is not None
|
||||
assert mod._tunnel_times == {}
|
||||
|
||||
# ── Flood detection ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestFloodDetection:
|
||||
|
||||
Reference in New Issue
Block a user