From 757aff4671b65cf121e1e8b7ab54d886807004ff Mon Sep 17 00:00:00 2001 From: anti Date: Thu, 21 May 2026 22:10:10 -0400 Subject: [PATCH] feat(dns): persist tunneling burst state across restarts Switch burst deque from monotonic() to time.time() (wall-clock, serializable). Add DNS_STATE_PATH env var: on startup _load_state() reads {src:[ts,...]} JSON and prunes entries older than the burst window. _flush_state() write-then-renames atomically; _state_flusher() coroutine flushes every 5s when dirty. Detection of the 5th event also triggers an immediate flush. No-op when DNS_STATE_PATH is unset, so the default deployment is unchanged. --- decnet/templates/dns/server.py | 59 +++++++++++++++++++++++++++++-- tests/service_testing/test_dns.py | 57 +++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/decnet/templates/dns/server.py b/decnet/templates/dns/server.py index 61f32981..223d08de 100644 --- a/decnet/templates/dns/server.py +++ b/decnet/templates/dns/server.py @@ -20,6 +20,7 @@ event_type values emitted: import asyncio import collections import hashlib +import json import math import os import socket @@ -41,6 +42,7 @@ _AUTHORS = os.environ.get("DNS_AUTHORS", "BIND9 Developers") _NSID_RAW = os.environ.get("DNS_NSID", "") _EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "") REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes") +_STATE_PATH = os.environ.get("DNS_STATE_PATH", "") _upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53") try: @@ -379,8 +381,9 @@ _FORWARD_BUDGET_WIN = float(os.environ.get("DNS_FORWARD_WINDOW", "1.0")) # ── Per-src state ───────────────────────────────────────────────────────────── -# Tunneling: src_ip -> deque of recent timestamps (all _TUNNEL_QTYPES counted) +# Tunneling: src_ip -> deque of recent wall-clock timestamps (all _TUNNEL_QTYPES counted) _tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict() +_state_dirty = False # Flood: src_ip -> deque of recent query timestamps _qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict() @@ -429,6 +432,53 @@ def _track_lru(table: collections.OrderedDict, key: str, tracker_name: str) -> N table.popitem(last=False) _note_eviction(tracker_name) + +def _load_state() -> None: + """Populate _tunnel_times from the state file on startup, pruning stale entries.""" + if not _STATE_PATH: + return + try: + with open(_STATE_PATH) as fh: + data: dict[str, list[float]] = json.load(fh) + except FileNotFoundError: + return + except Exception: + _log("startup", severity=5, msg="dns state file unreadable, starting fresh") + return + now = time.time() + cutoff = now - _TXT_BURST_WINDOW + for src, timestamps in data.items(): + recent = [t for t in timestamps if t > cutoff] + if recent: + _tunnel_times[src] = collections.deque(recent) + + +def _flush_state() -> None: + """Write _tunnel_times atomically to the state file if _STATE_PATH is set.""" + global _state_dirty + if not _STATE_PATH: + return + tmp = _STATE_PATH + ".tmp" + try: + serialized = {src: list(q) for src, q in _tunnel_times.items()} + with open(tmp, "w") as fh: + json.dump(serialized, fh) + os.replace(tmp, _STATE_PATH) + _state_dirty = False + except Exception: + pass + + +async def _state_flusher() -> None: + """Periodically flush _tunnel_times to disk when dirty.""" + while True: + await asyncio.sleep(5.0) + if _state_dirty: + _flush_state() + + +_load_state() + # ── Tunneling heuristic ─────────────────────────────────────────────────────── def _shannon_entropy(s: str) -> float: @@ -459,7 +509,8 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None: ): return "qname_entropy" if qtype in _TUNNEL_QTYPES: - now = time.monotonic() + global _state_dirty + now = time.time() if src not in _tunnel_times: _tunnel_times[src] = collections.deque() _track_lru(_tunnel_times, src, "tunnel_times") @@ -467,7 +518,9 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None: q.append(now) while q and now - q[0] > _TXT_BURST_WINDOW: q.popleft() + _state_dirty = True if len(q) >= _TXT_BURST_COUNT: + _flush_state() return "burst" return None @@ -910,9 +963,11 @@ async def main() -> None: tcp_server = await asyncio.start_server( _tcp_session, "0.0.0.0", 53 # nosec B104 ) + flusher = asyncio.ensure_future(_state_flusher()) try: await asyncio.sleep(float("inf")) finally: + flusher.cancel() udp_transport.close() tcp_server.close() diff --git a/tests/service_testing/test_dns.py b/tests/service_testing/test_dns.py index 18eb500c..4c4523cb 100644 --- a/tests/service_testing/test_dns.py +++ b/tests/service_testing/test_dns.py @@ -6,6 +6,7 @@ import importlib.util import socket import struct import sys +import time from types import ModuleType from unittest.mock import MagicMock, patch @@ -76,6 +77,9 @@ def _load_dns(extra_env: dict | None = None): mod._flood_cooldown.clear() mod._recon_window.clear() mod._recon_cooldown.clear() + # Re-load tunnel state from file if path is set (clear above wiped it) + if env.get("DNS_STATE_PATH"): + mod._load_state() return mod, bridge._events @@ -475,6 +479,59 @@ class TestTunnelingHeuristic: mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp") assert _events_of(events, "tunneling_suspect") +# ── State persistence ───────────────────────────────────────────────────────── + +class TestStatePersistence: + def test_state_path_unset_no_file_written(self, tmp_path): + """With DNS_STATE_PATH unset, no file should be written after burst detection.""" + mod, events = _load_dns() + src = "6.6.6.1" + for i in range(mod._TXT_BURST_COUNT): + mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp") + assert _events_of(events, "tunneling_suspect") + # No state file should have appeared anywhere under tmp_path + assert list(tmp_path.iterdir()) == [] + + def test_state_persists_across_reload(self, tmp_path): + """4 burst queries → flush → fresh module load → 5th query trips the burst.""" + state_file = str(tmp_path / "dns_state.json") + mod, events = _load_dns({"DNS_STATE_PATH": state_file}) + src = "6.6.6.2" + # Send _TXT_BURST_COUNT - 1 queries (not enough to trigger on their own) + for i in range(mod._TXT_BURST_COUNT - 1): + mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp") + # No burst yet + assert not _events_of(events, "tunneling_suspect") + # Manually flush state to disk + mod._flush_state() + assert tmp_path.joinpath("dns_state.json").exists() + # Reload module (simulates container restart) — same env so it reads the file + mod2, events2 = _load_dns({"DNS_STATE_PATH": state_file}) + # The reloaded module should have populated _tunnel_times from the file + assert src in mod2._tunnel_times + # 5th query → burst fires + mod2._handle(_build_query("b5.test.local", mod2.TYPE_TXT), src, 1234, "udp") + assert _events_of(events2, "tunneling_suspect"), "5th query after restart must fire" + + def test_state_file_prunes_old_entries_on_load(self, tmp_path): + """Entries older than _TXT_BURST_WINDOW are pruned on startup, not loaded.""" + import json as _json + state_file = tmp_path / "dns_state.json" + old_ts = time.time() - 9999 # way outside window + state_file.write_text(_json.dumps({"1.2.3.4": [old_ts, old_ts]})) + mod, _ = _load_dns({"DNS_STATE_PATH": str(state_file)}) + assert "1.2.3.4" not in mod._tunnel_times + + def test_state_file_corrupt_json_tolerated(self, tmp_path): + """A corrupt state file must not crash the module; state starts empty.""" + state_file = tmp_path / "dns_state.json" + state_file.write_text("{garbage!!!") + mod, events = _load_dns({"DNS_STATE_PATH": str(state_file)}) + # Module loads, handles a query without crashing + resp = mod._handle(_build_query("test.local", mod.TYPE_A), "7.7.7.7", 1234, "udp") + assert resp is not None + assert mod._tunnel_times == {} + # ── Flood detection ─────────────────────────────────────────────────────────── class TestFloodDetection: