feat(dns): persist tunneling burst state across restarts

Switch burst deque from monotonic() to time.time() (wall-clock, serializable).
Add DNS_STATE_PATH env var: on startup _load_state() reads {src:[ts,...]} JSON
and prunes entries older than the burst window. _flush_state() write-then-renames
atomically; _state_flusher() coroutine flushes every 5s when dirty. Detection of
the 5th event also triggers an immediate flush. No-op when DNS_STATE_PATH is
unset, so the default deployment is unchanged.
This commit is contained in:
2026-05-21 22:10:10 -04:00
parent 457e2d990c
commit 757aff4671
2 changed files with 114 additions and 2 deletions

View File

@@ -20,6 +20,7 @@ event_type values emitted:
import asyncio
import collections
import hashlib
import json
import math
import os
import socket
@@ -41,6 +42,7 @@ _AUTHORS = os.environ.get("DNS_AUTHORS", "BIND9 Developers")
_NSID_RAW = os.environ.get("DNS_NSID", "")
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
_STATE_PATH = os.environ.get("DNS_STATE_PATH", "")
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
try:
@@ -379,8 +381,9 @@ _FORWARD_BUDGET_WIN = float(os.environ.get("DNS_FORWARD_WINDOW", "1.0"))
# ── Per-src state ─────────────────────────────────────────────────────────────
# Tunneling: src_ip -> deque of recent timestamps (all _TUNNEL_QTYPES counted)
# Tunneling: src_ip -> deque of recent wall-clock timestamps (all _TUNNEL_QTYPES counted)
_tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
_state_dirty = False
# Flood: src_ip -> deque of recent query timestamps
_qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
@@ -429,6 +432,53 @@ def _track_lru(table: collections.OrderedDict, key: str, tracker_name: str) -> N
table.popitem(last=False)
_note_eviction(tracker_name)
def _load_state() -> None:
"""Populate _tunnel_times from the state file on startup, pruning stale entries."""
if not _STATE_PATH:
return
try:
with open(_STATE_PATH) as fh:
data: dict[str, list[float]] = json.load(fh)
except FileNotFoundError:
return
except Exception:
_log("startup", severity=5, msg="dns state file unreadable, starting fresh")
return
now = time.time()
cutoff = now - _TXT_BURST_WINDOW
for src, timestamps in data.items():
recent = [t for t in timestamps if t > cutoff]
if recent:
_tunnel_times[src] = collections.deque(recent)
def _flush_state() -> None:
"""Write _tunnel_times atomically to the state file if _STATE_PATH is set."""
global _state_dirty
if not _STATE_PATH:
return
tmp = _STATE_PATH + ".tmp"
try:
serialized = {src: list(q) for src, q in _tunnel_times.items()}
with open(tmp, "w") as fh:
json.dump(serialized, fh)
os.replace(tmp, _STATE_PATH)
_state_dirty = False
except Exception:
pass
async def _state_flusher() -> None:
"""Periodically flush _tunnel_times to disk when dirty."""
while True:
await asyncio.sleep(5.0)
if _state_dirty:
_flush_state()
_load_state()
# ── Tunneling heuristic ───────────────────────────────────────────────────────
def _shannon_entropy(s: str) -> float:
@@ -459,7 +509,8 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
):
return "qname_entropy"
if qtype in _TUNNEL_QTYPES:
now = time.monotonic()
global _state_dirty
now = time.time()
if src not in _tunnel_times:
_tunnel_times[src] = collections.deque()
_track_lru(_tunnel_times, src, "tunnel_times")
@@ -467,7 +518,9 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
q.append(now)
while q and now - q[0] > _TXT_BURST_WINDOW:
q.popleft()
_state_dirty = True
if len(q) >= _TXT_BURST_COUNT:
_flush_state()
return "burst"
return None
@@ -910,9 +963,11 @@ async def main() -> None:
tcp_server = await asyncio.start_server(
_tcp_session, "0.0.0.0", 53 # nosec B104
)
flusher = asyncio.ensure_future(_state_flusher())
try:
await asyncio.sleep(float("inf"))
finally:
flusher.cancel()
udp_transport.close()
tcp_server.close()

View File

@@ -6,6 +6,7 @@ import importlib.util
import socket
import struct
import sys
import time
from types import ModuleType
from unittest.mock import MagicMock, patch
@@ -76,6 +77,9 @@ def _load_dns(extra_env: dict | None = None):
mod._flood_cooldown.clear()
mod._recon_window.clear()
mod._recon_cooldown.clear()
# Re-load tunnel state from file if path is set (clear above wiped it)
if env.get("DNS_STATE_PATH"):
mod._load_state()
return mod, bridge._events
@@ -475,6 +479,59 @@ class TestTunnelingHeuristic:
mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp")
assert _events_of(events, "tunneling_suspect")
# ── State persistence ─────────────────────────────────────────────────────────
class TestStatePersistence:
def test_state_path_unset_no_file_written(self, tmp_path):
"""With DNS_STATE_PATH unset, no file should be written after burst detection."""
mod, events = _load_dns()
src = "6.6.6.1"
for i in range(mod._TXT_BURST_COUNT):
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
assert _events_of(events, "tunneling_suspect")
# No state file should have appeared anywhere under tmp_path
assert list(tmp_path.iterdir()) == []
def test_state_persists_across_reload(self, tmp_path):
"""4 burst queries → flush → fresh module load → 5th query trips the burst."""
state_file = str(tmp_path / "dns_state.json")
mod, events = _load_dns({"DNS_STATE_PATH": state_file})
src = "6.6.6.2"
# Send _TXT_BURST_COUNT - 1 queries (not enough to trigger on their own)
for i in range(mod._TXT_BURST_COUNT - 1):
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
# No burst yet
assert not _events_of(events, "tunneling_suspect")
# Manually flush state to disk
mod._flush_state()
assert tmp_path.joinpath("dns_state.json").exists()
# Reload module (simulates container restart) — same env so it reads the file
mod2, events2 = _load_dns({"DNS_STATE_PATH": state_file})
# The reloaded module should have populated _tunnel_times from the file
assert src in mod2._tunnel_times
# 5th query → burst fires
mod2._handle(_build_query("b5.test.local", mod2.TYPE_TXT), src, 1234, "udp")
assert _events_of(events2, "tunneling_suspect"), "5th query after restart must fire"
def test_state_file_prunes_old_entries_on_load(self, tmp_path):
"""Entries older than _TXT_BURST_WINDOW are pruned on startup, not loaded."""
import json as _json
state_file = tmp_path / "dns_state.json"
old_ts = time.time() - 9999 # way outside window
state_file.write_text(_json.dumps({"1.2.3.4": [old_ts, old_ts]}))
mod, _ = _load_dns({"DNS_STATE_PATH": str(state_file)})
assert "1.2.3.4" not in mod._tunnel_times
def test_state_file_corrupt_json_tolerated(self, tmp_path):
"""A corrupt state file must not crash the module; state starts empty."""
state_file = tmp_path / "dns_state.json"
state_file.write_text("{garbage!!!")
mod, events = _load_dns({"DNS_STATE_PATH": str(state_file)})
# Module loads, handles a query without crashing
resp = mod._handle(_build_query("test.local", mod.TYPE_A), "7.7.7.7", 1234, "udp")
assert resp is not None
assert mod._tunnel_times == {}
# ── Flood detection ───────────────────────────────────────────────────────────
class TestFloodDetection: