feat(dns): persist tunneling burst state across restarts

Switch burst deque from monotonic() to time.time() (wall-clock, serializable).
Add DNS_STATE_PATH env var: on startup _load_state() reads {src:[ts,...]} JSON
and prunes entries older than the burst window. _flush_state() write-then-renames
atomically; _state_flusher() coroutine flushes every 5s when dirty. Detection of
the 5th event also triggers an immediate flush. No-op when DNS_STATE_PATH is
unset, so the default deployment is unchanged.
This commit is contained in:
2026-05-21 22:10:10 -04:00
parent 457e2d990c
commit 757aff4671
2 changed files with 114 additions and 2 deletions

View File

@@ -20,6 +20,7 @@ event_type values emitted:
import asyncio import asyncio
import collections import collections
import hashlib import hashlib
import json
import math import math
import os import os
import socket import socket
@@ -41,6 +42,7 @@ _AUTHORS = os.environ.get("DNS_AUTHORS", "BIND9 Developers")
_NSID_RAW = os.environ.get("DNS_NSID", "") _NSID_RAW = os.environ.get("DNS_NSID", "")
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "") _EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes") REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
_STATE_PATH = os.environ.get("DNS_STATE_PATH", "")
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53") _upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
try: try:
@@ -379,8 +381,9 @@ _FORWARD_BUDGET_WIN = float(os.environ.get("DNS_FORWARD_WINDOW", "1.0"))
# ── Per-src state ───────────────────────────────────────────────────────────── # ── Per-src state ─────────────────────────────────────────────────────────────
# Tunneling: src_ip -> deque of recent timestamps (all _TUNNEL_QTYPES counted) # Tunneling: src_ip -> deque of recent wall-clock timestamps (all _TUNNEL_QTYPES counted)
_tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict() _tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
_state_dirty = False
# Flood: src_ip -> deque of recent query timestamps # Flood: src_ip -> deque of recent query timestamps
_qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict() _qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
@@ -429,6 +432,53 @@ def _track_lru(table: collections.OrderedDict, key: str, tracker_name: str) -> N
table.popitem(last=False) table.popitem(last=False)
_note_eviction(tracker_name) _note_eviction(tracker_name)
def _load_state() -> None:
"""Populate _tunnel_times from the state file on startup, pruning stale entries."""
if not _STATE_PATH:
return
try:
with open(_STATE_PATH) as fh:
data: dict[str, list[float]] = json.load(fh)
except FileNotFoundError:
return
except Exception:
_log("startup", severity=5, msg="dns state file unreadable, starting fresh")
return
now = time.time()
cutoff = now - _TXT_BURST_WINDOW
for src, timestamps in data.items():
recent = [t for t in timestamps if t > cutoff]
if recent:
_tunnel_times[src] = collections.deque(recent)
def _flush_state() -> None:
"""Write _tunnel_times atomically to the state file if _STATE_PATH is set."""
global _state_dirty
if not _STATE_PATH:
return
tmp = _STATE_PATH + ".tmp"
try:
serialized = {src: list(q) for src, q in _tunnel_times.items()}
with open(tmp, "w") as fh:
json.dump(serialized, fh)
os.replace(tmp, _STATE_PATH)
_state_dirty = False
except Exception:
pass
async def _state_flusher() -> None:
"""Periodically flush _tunnel_times to disk when dirty."""
while True:
await asyncio.sleep(5.0)
if _state_dirty:
_flush_state()
_load_state()
# ── Tunneling heuristic ─────────────────────────────────────────────────────── # ── Tunneling heuristic ───────────────────────────────────────────────────────
def _shannon_entropy(s: str) -> float: def _shannon_entropy(s: str) -> float:
@@ -459,7 +509,8 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
): ):
return "qname_entropy" return "qname_entropy"
if qtype in _TUNNEL_QTYPES: if qtype in _TUNNEL_QTYPES:
now = time.monotonic() global _state_dirty
now = time.time()
if src not in _tunnel_times: if src not in _tunnel_times:
_tunnel_times[src] = collections.deque() _tunnel_times[src] = collections.deque()
_track_lru(_tunnel_times, src, "tunnel_times") _track_lru(_tunnel_times, src, "tunnel_times")
@@ -467,7 +518,9 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
q.append(now) q.append(now)
while q and now - q[0] > _TXT_BURST_WINDOW: while q and now - q[0] > _TXT_BURST_WINDOW:
q.popleft() q.popleft()
_state_dirty = True
if len(q) >= _TXT_BURST_COUNT: if len(q) >= _TXT_BURST_COUNT:
_flush_state()
return "burst" return "burst"
return None return None
@@ -910,9 +963,11 @@ async def main() -> None:
tcp_server = await asyncio.start_server( tcp_server = await asyncio.start_server(
_tcp_session, "0.0.0.0", 53 # nosec B104 _tcp_session, "0.0.0.0", 53 # nosec B104
) )
flusher = asyncio.ensure_future(_state_flusher())
try: try:
await asyncio.sleep(float("inf")) await asyncio.sleep(float("inf"))
finally: finally:
flusher.cancel()
udp_transport.close() udp_transport.close()
tcp_server.close() tcp_server.close()

View File

@@ -6,6 +6,7 @@ import importlib.util
import socket import socket
import struct import struct
import sys import sys
import time
from types import ModuleType from types import ModuleType
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -76,6 +77,9 @@ def _load_dns(extra_env: dict | None = None):
mod._flood_cooldown.clear() mod._flood_cooldown.clear()
mod._recon_window.clear() mod._recon_window.clear()
mod._recon_cooldown.clear() mod._recon_cooldown.clear()
# Re-load tunnel state from file if path is set (clear above wiped it)
if env.get("DNS_STATE_PATH"):
mod._load_state()
return mod, bridge._events return mod, bridge._events
@@ -475,6 +479,59 @@ class TestTunnelingHeuristic:
mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp") mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp")
assert _events_of(events, "tunneling_suspect") assert _events_of(events, "tunneling_suspect")
# ── State persistence ─────────────────────────────────────────────────────────
class TestStatePersistence:
def test_state_path_unset_no_file_written(self, tmp_path):
"""With DNS_STATE_PATH unset, no file should be written after burst detection."""
mod, events = _load_dns()
src = "6.6.6.1"
for i in range(mod._TXT_BURST_COUNT):
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
assert _events_of(events, "tunneling_suspect")
# No state file should have appeared anywhere under tmp_path
assert list(tmp_path.iterdir()) == []
def test_state_persists_across_reload(self, tmp_path):
"""4 burst queries → flush → fresh module load → 5th query trips the burst."""
state_file = str(tmp_path / "dns_state.json")
mod, events = _load_dns({"DNS_STATE_PATH": state_file})
src = "6.6.6.2"
# Send _TXT_BURST_COUNT - 1 queries (not enough to trigger on their own)
for i in range(mod._TXT_BURST_COUNT - 1):
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
# No burst yet
assert not _events_of(events, "tunneling_suspect")
# Manually flush state to disk
mod._flush_state()
assert tmp_path.joinpath("dns_state.json").exists()
# Reload module (simulates container restart) — same env so it reads the file
mod2, events2 = _load_dns({"DNS_STATE_PATH": state_file})
# The reloaded module should have populated _tunnel_times from the file
assert src in mod2._tunnel_times
# 5th query → burst fires
mod2._handle(_build_query("b5.test.local", mod2.TYPE_TXT), src, 1234, "udp")
assert _events_of(events2, "tunneling_suspect"), "5th query after restart must fire"
def test_state_file_prunes_old_entries_on_load(self, tmp_path):
"""Entries older than _TXT_BURST_WINDOW are pruned on startup, not loaded."""
import json as _json
state_file = tmp_path / "dns_state.json"
old_ts = time.time() - 9999 # way outside window
state_file.write_text(_json.dumps({"1.2.3.4": [old_ts, old_ts]}))
mod, _ = _load_dns({"DNS_STATE_PATH": str(state_file)})
assert "1.2.3.4" not in mod._tunnel_times
def test_state_file_corrupt_json_tolerated(self, tmp_path):
"""A corrupt state file must not crash the module; state starts empty."""
state_file = tmp_path / "dns_state.json"
state_file.write_text("{garbage!!!")
mod, events = _load_dns({"DNS_STATE_PATH": str(state_file)})
# Module loads, handles a query without crashing
resp = mod._handle(_build_query("test.local", mod.TYPE_A), "7.7.7.7", 1234, "udp")
assert resp is not None
assert mod._tunnel_times == {}
# ── Flood detection ─────────────────────────────────────────────────────────── # ── Flood detection ───────────────────────────────────────────────────────────
class TestFloodDetection: class TestFloodDetection: