feat(dns): persist tunneling burst state across restarts
Switch burst deque from monotonic() to time.time() (wall-clock, serializable).
Add DNS_STATE_PATH env var: on startup _load_state() reads {src:[ts,...]} JSON
and prunes entries older than the burst window. _flush_state() write-then-renames
atomically; _state_flusher() coroutine flushes every 5s when dirty. Detection of
the 5th event also triggers an immediate flush. No-op when DNS_STATE_PATH is
unset, so the default deployment is unchanged.
This commit is contained in:
@@ -20,6 +20,7 @@ event_type values emitted:
|
|||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
@@ -41,6 +42,7 @@ _AUTHORS = os.environ.get("DNS_AUTHORS", "BIND9 Developers")
|
|||||||
_NSID_RAW = os.environ.get("DNS_NSID", "")
|
_NSID_RAW = os.environ.get("DNS_NSID", "")
|
||||||
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
|
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
|
||||||
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
|
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
|
||||||
|
_STATE_PATH = os.environ.get("DNS_STATE_PATH", "")
|
||||||
|
|
||||||
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
|
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
|
||||||
try:
|
try:
|
||||||
@@ -379,8 +381,9 @@ _FORWARD_BUDGET_WIN = float(os.environ.get("DNS_FORWARD_WINDOW", "1.0"))
|
|||||||
|
|
||||||
# ── Per-src state ─────────────────────────────────────────────────────────────
|
# ── Per-src state ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
# Tunneling: src_ip -> deque of recent timestamps (all _TUNNEL_QTYPES counted)
|
# Tunneling: src_ip -> deque of recent wall-clock timestamps (all _TUNNEL_QTYPES counted)
|
||||||
_tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
|
_tunnel_times: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
|
||||||
|
_state_dirty = False
|
||||||
|
|
||||||
# Flood: src_ip -> deque of recent query timestamps
|
# Flood: src_ip -> deque of recent query timestamps
|
||||||
_qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
|
_qps_window: collections.OrderedDict[str, collections.deque] = collections.OrderedDict()
|
||||||
@@ -429,6 +432,53 @@ def _track_lru(table: collections.OrderedDict, key: str, tracker_name: str) -> N
|
|||||||
table.popitem(last=False)
|
table.popitem(last=False)
|
||||||
_note_eviction(tracker_name)
|
_note_eviction(tracker_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state() -> None:
|
||||||
|
"""Populate _tunnel_times from the state file on startup, pruning stale entries."""
|
||||||
|
if not _STATE_PATH:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(_STATE_PATH) as fh:
|
||||||
|
data: dict[str, list[float]] = json.load(fh)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
_log("startup", severity=5, msg="dns state file unreadable, starting fresh")
|
||||||
|
return
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - _TXT_BURST_WINDOW
|
||||||
|
for src, timestamps in data.items():
|
||||||
|
recent = [t for t in timestamps if t > cutoff]
|
||||||
|
if recent:
|
||||||
|
_tunnel_times[src] = collections.deque(recent)
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_state() -> None:
|
||||||
|
"""Write _tunnel_times atomically to the state file if _STATE_PATH is set."""
|
||||||
|
global _state_dirty
|
||||||
|
if not _STATE_PATH:
|
||||||
|
return
|
||||||
|
tmp = _STATE_PATH + ".tmp"
|
||||||
|
try:
|
||||||
|
serialized = {src: list(q) for src, q in _tunnel_times.items()}
|
||||||
|
with open(tmp, "w") as fh:
|
||||||
|
json.dump(serialized, fh)
|
||||||
|
os.replace(tmp, _STATE_PATH)
|
||||||
|
_state_dirty = False
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _state_flusher() -> None:
|
||||||
|
"""Periodically flush _tunnel_times to disk when dirty."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(5.0)
|
||||||
|
if _state_dirty:
|
||||||
|
_flush_state()
|
||||||
|
|
||||||
|
|
||||||
|
_load_state()
|
||||||
|
|
||||||
# ── Tunneling heuristic ───────────────────────────────────────────────────────
|
# ── Tunneling heuristic ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _shannon_entropy(s: str) -> float:
|
def _shannon_entropy(s: str) -> float:
|
||||||
@@ -459,7 +509,8 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
|
|||||||
):
|
):
|
||||||
return "qname_entropy"
|
return "qname_entropy"
|
||||||
if qtype in _TUNNEL_QTYPES:
|
if qtype in _TUNNEL_QTYPES:
|
||||||
now = time.monotonic()
|
global _state_dirty
|
||||||
|
now = time.time()
|
||||||
if src not in _tunnel_times:
|
if src not in _tunnel_times:
|
||||||
_tunnel_times[src] = collections.deque()
|
_tunnel_times[src] = collections.deque()
|
||||||
_track_lru(_tunnel_times, src, "tunnel_times")
|
_track_lru(_tunnel_times, src, "tunnel_times")
|
||||||
@@ -467,7 +518,9 @@ def _is_tunneling(qname: str, qtype: int, src: str) -> str | None:
|
|||||||
q.append(now)
|
q.append(now)
|
||||||
while q and now - q[0] > _TXT_BURST_WINDOW:
|
while q and now - q[0] > _TXT_BURST_WINDOW:
|
||||||
q.popleft()
|
q.popleft()
|
||||||
|
_state_dirty = True
|
||||||
if len(q) >= _TXT_BURST_COUNT:
|
if len(q) >= _TXT_BURST_COUNT:
|
||||||
|
_flush_state()
|
||||||
return "burst"
|
return "burst"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -910,9 +963,11 @@ async def main() -> None:
|
|||||||
tcp_server = await asyncio.start_server(
|
tcp_server = await asyncio.start_server(
|
||||||
_tcp_session, "0.0.0.0", 53 # nosec B104
|
_tcp_session, "0.0.0.0", 53 # nosec B104
|
||||||
)
|
)
|
||||||
|
flusher = asyncio.ensure_future(_state_flusher())
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(float("inf"))
|
await asyncio.sleep(float("inf"))
|
||||||
finally:
|
finally:
|
||||||
|
flusher.cancel()
|
||||||
udp_transport.close()
|
udp_transport.close()
|
||||||
tcp_server.close()
|
tcp_server.close()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import importlib.util
|
|||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
@@ -76,6 +77,9 @@ def _load_dns(extra_env: dict | None = None):
|
|||||||
mod._flood_cooldown.clear()
|
mod._flood_cooldown.clear()
|
||||||
mod._recon_window.clear()
|
mod._recon_window.clear()
|
||||||
mod._recon_cooldown.clear()
|
mod._recon_cooldown.clear()
|
||||||
|
# Re-load tunnel state from file if path is set (clear above wiped it)
|
||||||
|
if env.get("DNS_STATE_PATH"):
|
||||||
|
mod._load_state()
|
||||||
|
|
||||||
return mod, bridge._events
|
return mod, bridge._events
|
||||||
|
|
||||||
@@ -475,6 +479,59 @@ class TestTunnelingHeuristic:
|
|||||||
mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp")
|
mod._handle(_build_query(f"n{i}.test.local", mod.TYPE_NULL), src, 1234, "udp")
|
||||||
assert _events_of(events, "tunneling_suspect")
|
assert _events_of(events, "tunneling_suspect")
|
||||||
|
|
||||||
|
# ── State persistence ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestStatePersistence:
|
||||||
|
def test_state_path_unset_no_file_written(self, tmp_path):
|
||||||
|
"""With DNS_STATE_PATH unset, no file should be written after burst detection."""
|
||||||
|
mod, events = _load_dns()
|
||||||
|
src = "6.6.6.1"
|
||||||
|
for i in range(mod._TXT_BURST_COUNT):
|
||||||
|
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
|
||||||
|
assert _events_of(events, "tunneling_suspect")
|
||||||
|
# No state file should have appeared anywhere under tmp_path
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
|
def test_state_persists_across_reload(self, tmp_path):
|
||||||
|
"""4 burst queries → flush → fresh module load → 5th query trips the burst."""
|
||||||
|
state_file = str(tmp_path / "dns_state.json")
|
||||||
|
mod, events = _load_dns({"DNS_STATE_PATH": state_file})
|
||||||
|
src = "6.6.6.2"
|
||||||
|
# Send _TXT_BURST_COUNT - 1 queries (not enough to trigger on their own)
|
||||||
|
for i in range(mod._TXT_BURST_COUNT - 1):
|
||||||
|
mod._handle(_build_query(f"b{i}.test.local", mod.TYPE_TXT), src, 1234, "udp")
|
||||||
|
# No burst yet
|
||||||
|
assert not _events_of(events, "tunneling_suspect")
|
||||||
|
# Manually flush state to disk
|
||||||
|
mod._flush_state()
|
||||||
|
assert tmp_path.joinpath("dns_state.json").exists()
|
||||||
|
# Reload module (simulates container restart) — same env so it reads the file
|
||||||
|
mod2, events2 = _load_dns({"DNS_STATE_PATH": state_file})
|
||||||
|
# The reloaded module should have populated _tunnel_times from the file
|
||||||
|
assert src in mod2._tunnel_times
|
||||||
|
# 5th query → burst fires
|
||||||
|
mod2._handle(_build_query("b5.test.local", mod2.TYPE_TXT), src, 1234, "udp")
|
||||||
|
assert _events_of(events2, "tunneling_suspect"), "5th query after restart must fire"
|
||||||
|
|
||||||
|
def test_state_file_prunes_old_entries_on_load(self, tmp_path):
|
||||||
|
"""Entries older than _TXT_BURST_WINDOW are pruned on startup, not loaded."""
|
||||||
|
import json as _json
|
||||||
|
state_file = tmp_path / "dns_state.json"
|
||||||
|
old_ts = time.time() - 9999 # way outside window
|
||||||
|
state_file.write_text(_json.dumps({"1.2.3.4": [old_ts, old_ts]}))
|
||||||
|
mod, _ = _load_dns({"DNS_STATE_PATH": str(state_file)})
|
||||||
|
assert "1.2.3.4" not in mod._tunnel_times
|
||||||
|
|
||||||
|
def test_state_file_corrupt_json_tolerated(self, tmp_path):
|
||||||
|
"""A corrupt state file must not crash the module; state starts empty."""
|
||||||
|
state_file = tmp_path / "dns_state.json"
|
||||||
|
state_file.write_text("{garbage!!!")
|
||||||
|
mod, events = _load_dns({"DNS_STATE_PATH": str(state_file)})
|
||||||
|
# Module loads, handles a query without crashing
|
||||||
|
resp = mod._handle(_build_query("test.local", mod.TYPE_A), "7.7.7.7", 1234, "udp")
|
||||||
|
assert resp is not None
|
||||||
|
assert mod._tunnel_times == {}
|
||||||
|
|
||||||
# ── Flood detection ───────────────────────────────────────────────────────────
|
# ── Flood detection ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
class TestFloodDetection:
|
class TestFloodDetection:
|
||||||
|
|||||||
Reference in New Issue
Block a user