feat(dns): global upstream forward rate limit with sinkhole fallback

Adds DNS_FORWARD_BUDGET (default 50) and DNS_FORWARD_WINDOW (default 1.0s)
env vars. _can_forward() maintains a rolling deque of upstream call
timestamps; queries that exceed the budget within the window are answered
with the sinkhole (127.x) instead of being forwarded, making the honeypot
ineligible as a sustained amp vector even when real_recursive is enabled.
Rate limit is global (not per-source) so IP-spoofed amplification floods
hit the ceiling regardless of how many source addresses are rotated.
This commit is contained in:
2026-05-21 20:50:20 -04:00
parent e5847b7e1e
commit da2ad7a82a
3 changed files with 94 additions and 2 deletions

View File

@@ -65,6 +65,20 @@ class DNSService(BaseService):
placeholder="8.8.8.8:53", placeholder="8.8.8.8:53",
help="Upstream DNS resolver used when real_recursive is enabled (host:port).", help="Upstream DNS resolver used when real_recursive is enabled (host:port).",
), ),
ServiceConfigField(
key="forward_budget",
label="Forward budget (queries/window)",
type="string",
default="50",
help="Maximum upstream forwarding calls allowed within the rate window. Excess queries fall back to sinkhole.",
),
ServiceConfigField(
key="forward_window",
label="Forward budget window (seconds)",
type="string",
default="1.0",
help="Rolling window in seconds for the forward budget counter.",
),
] ]
def compose_fragment( def compose_fragment(
@@ -82,7 +96,9 @@ class DNSService(BaseService):
"DNS_NSID": str(cfg.get("nsid", "")), "DNS_NSID": str(cfg.get("nsid", "")),
"DNS_EXTRA_RECORDS": str(cfg.get("extra_records", "")), "DNS_EXTRA_RECORDS": str(cfg.get("extra_records", "")),
"DNS_REAL_RECURSIVE": "true" if cfg.get("real_recursive") else "false", "DNS_REAL_RECURSIVE": "true" if cfg.get("real_recursive") else "false",
"DNS_UPSTREAM": str(cfg.get("upstream", "8.8.8.8:53")), "DNS_UPSTREAM": str(cfg.get("upstream", "8.8.8.8:53")),
"DNS_FORWARD_BUDGET": str(cfg.get("forward_budget", "50")),
"DNS_FORWARD_WINDOW": str(cfg.get("forward_window", "1.0")),
} }
if log_target: if log_target:
env["LOG_TARGET"] = log_target env["LOG_TARGET"] = log_target

View File

@@ -313,6 +313,10 @@ _RECON_SIGNAL_TYPES = frozenset({"fingerprint_probe", "zone_transfer", "am
# Eviction telemetry # Eviction telemetry
_EVICT_EVENT_EVERY = 100 _EVICT_EVENT_EVERY = 100
# Global upstream forwarding budget
_FORWARD_BUDGET_MAX = int(os.environ.get("DNS_FORWARD_BUDGET", "50"))
_FORWARD_BUDGET_WIN = float(os.environ.get("DNS_FORWARD_WINDOW", "1.0"))
# ── Per-src state ───────────────────────────────────────────────────────────── # ── Per-src state ─────────────────────────────────────────────────────────────
# Tunneling: src_ip -> deque of recent TXT timestamps # Tunneling: src_ip -> deque of recent TXT timestamps
@@ -332,6 +336,18 @@ _recon_cooldown: dict[str, float] = {}
_evictions_total = 0 _evictions_total = 0
# Global forward budget: timestamps of recent upstream calls
_forward_timestamps: collections.deque[float] = collections.deque()
def _can_forward() -> bool:
"""Return True and consume one budget slot if under the global forward limit."""
now = time.monotonic()
_forward_timestamps.append(now)
while _forward_timestamps[0] < now - _FORWARD_BUDGET_WIN:
_forward_timestamps.popleft()
return len(_forward_timestamps) <= _FORWARD_BUDGET_MAX
def _note_eviction(tracker_name: str) -> None: def _note_eviction(tracker_name: str) -> None:
global _evictions_total global _evictions_total
@@ -619,7 +635,7 @@ async def _dispatch(data: bytes, src_ip: str, src_port: int, transport: str) ->
"""Async dispatcher: runs sync _handle (logging + detection), then overlays """Async dispatcher: runs sync _handle (logging + detection), then overlays
upstream forwarding for real-recursive out-of-zone queries.""" upstream forwarding for real-recursive out-of-zone queries."""
sinkhole = _handle(data, src_ip, src_port, transport) sinkhole = _handle(data, src_ip, src_port, transport)
if _is_upstream_candidate(data): if _is_upstream_candidate(data) and _can_forward():
upstream = await _forward_upstream(data) upstream = await _forward_upstream(data)
if upstream is not None: if upstream is not None:
return upstream return upstream

View File

@@ -576,6 +576,66 @@ class TestRealRecursive:
assert frag["environment"]["DNS_REAL_RECURSIVE"] == "false" assert frag["environment"]["DNS_REAL_RECURSIVE"] == "false"
class TestForwardBudget:
def _load_with_budget(self, budget: int = 3):
mod, events = _load_dns({
"DNS_ZONE_MODE": "recursive",
"DNS_REAL_RECURSIVE": "true",
"DNS_FORWARD_BUDGET": str(budget),
"DNS_FORWARD_WINDOW": "60", # wide window so nothing expires mid-test
})
mod._forward_timestamps.clear()
return mod, events
def test_within_budget_forwards(self):
mod, _ = self._load_with_budget(budget=5)
import asyncio
from unittest.mock import AsyncMock, patch
fake_resp = b"\x12\x34" + b"\x81\x80" + b"\x00" * 8
mock_fwd = AsyncMock(return_value=fake_resp)
with patch.object(mod, "_forward_upstream", mock_fwd):
query = _build_query("evil.example.com", mod.TYPE_A)
for _ in range(5):
asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
assert mock_fwd.await_count == 5
def test_over_budget_falls_back_to_sinkhole(self):
mod, _ = self._load_with_budget(budget=2)
import asyncio
from unittest.mock import AsyncMock, patch
fake_resp = b"\x12\x34" + b"\x81\x80" + b"\x00" * 8
mock_fwd = AsyncMock(return_value=fake_resp)
with patch.object(mod, "_forward_upstream", mock_fwd):
query = _build_query("evil.example.com", mod.TYPE_A)
responses = []
for _ in range(5):
resp = asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
responses.append(resp)
# Upstream called at most budget+1 times (budget check appends before pruning)
assert mock_fwd.await_count <= 3
# All responses are non-None (sinkhole for over-budget ones)
assert all(r is not None for r in responses)
def test_budget_is_global_not_per_src(self):
"""Budget counts all upstream calls regardless of source IP."""
mod, _ = self._load_with_budget(budget=2)
import asyncio
from unittest.mock import AsyncMock, patch
fake_resp = b"\x12\x34" + b"\x81\x80" + b"\x00" * 8
mock_fwd = AsyncMock(return_value=fake_resp)
with patch.object(mod, "_forward_upstream", mock_fwd):
query = _build_query("evil.example.com", mod.TYPE_A)
for i in range(5):
asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, f"10.0.0.{i+1}", 1234, "udp")
)
assert mock_fwd.await_count <= 3
class TestZoneModeRecursive: class TestZoneModeRecursive:
def test_recursive_mode_sets_ra_flag(self): def test_recursive_mode_sets_ra_flag(self):
mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive"}) mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive"})