feat(dns): real recursive forwarding with sinkhole fallback

When DNS_REAL_RECURSIVE=true and DNS_ZONE_MODE=recursive, out-of-zone
queries are forwarded to DNS_UPSTREAM (default 8.8.8.8:53) via async
UDP. Upstream response is relayed as-is; on timeout or error the
already-computed sinkhole (127.x) is returned instead.

_handle() always runs first so logging, tunneling detection, flood
tracking, and recon-burst aggregation fire on every query regardless
of whether the response ultimately comes from upstream. _dispatch()
overlays forwarding on top of the sync handler.

Protocol handlers (UDP datagram_received, TCP session) are now async
via asyncio.ensure_future / await _dispatch(). Service class exposes
real_recursive (bool) and upstream (string) config fields.
This commit is contained in:
2026-05-21 20:49:19 -04:00
parent 8f33f1b849
commit e5847b7e1e
3 changed files with 186 additions and 11 deletions

View File

@@ -18,7 +18,7 @@ class DNSService(BaseService):
type="enum", type="enum",
enum=["auth", "recursive", "open"], enum=["auth", "recursive", "open"],
default="auth", default="auth",
help="auth: authoritative only; recursive: RA flag set, NXDOMAIN for out-of-zone; open: responds to everything (amp bait)", help="auth: authoritative only; recursive: forwards out-of-zone queries upstream (real_recursive=true) or sinkholes them; open: responds to everything (amp bait)",
), ),
ServiceConfigField( ServiceConfigField(
key="domain", key="domain",
@@ -50,6 +50,21 @@ class DNSService(BaseService):
placeholder="www A 10.0.0.5\nmail TXT v=spf1 ~all", placeholder="www A 10.0.0.5\nmail TXT v=spf1 ~all",
help="Additional zone records, one per line: <name> <TYPE> <value>", help="Additional zone records, one per line: <name> <TYPE> <value>",
), ),
ServiceConfigField(
key="real_recursive",
label="Real recursive forwarding",
type="bool",
default=False,
help="When zone_mode=recursive, forward out-of-zone queries to an upstream resolver instead of returning a sinkhole. Falls back to sinkhole on upstream timeout.",
),
ServiceConfigField(
key="upstream",
label="Upstream resolver",
type="string",
default="8.8.8.8:53",
placeholder="8.8.8.8:53",
help="Upstream DNS resolver used when real_recursive is enabled (host:port).",
),
] ]
def compose_fragment( def compose_fragment(
@@ -65,7 +80,9 @@ class DNSService(BaseService):
"DNS_DOMAIN": str(cfg.get("domain", "")), "DNS_DOMAIN": str(cfg.get("domain", "")),
"DNS_BIND_VERSION": str(cfg.get("bind_version", _DEFAULT_VERSION)), "DNS_BIND_VERSION": str(cfg.get("bind_version", _DEFAULT_VERSION)),
"DNS_NSID": str(cfg.get("nsid", "")), "DNS_NSID": str(cfg.get("nsid", "")),
"DNS_EXTRA_RECORDS": str(cfg.get("extra_records", "")), "DNS_EXTRA_RECORDS": str(cfg.get("extra_records", "")),
"DNS_REAL_RECURSIVE": "true" if cfg.get("real_recursive") else "false",
"DNS_UPSTREAM": str(cfg.get("upstream", "8.8.8.8:53")),
} }
if log_target: if log_target:
env["LOG_TARGET"] = log_target env["LOG_TARGET"] = log_target

View File

@@ -28,13 +28,21 @@ import instance_seed as seed
# ── Config ──────────────────────────────────────────────────────────────────── # ── Config ────────────────────────────────────────────────────────────────────
NODE_NAME = os.environ.get("NODE_NAME", "ns1") NODE_NAME = os.environ.get("NODE_NAME", "ns1")
SERVICE_NAME = "dns" SERVICE_NAME = "dns"
LOG_TARGET = os.environ.get("LOG_TARGET", "") LOG_TARGET = os.environ.get("LOG_TARGET", "")
ZONE_MODE = os.environ.get("DNS_ZONE_MODE", "auth") ZONE_MODE = os.environ.get("DNS_ZONE_MODE", "auth")
BIND_VERSION = os.environ.get("DNS_BIND_VERSION", "9.11.4-P2-RedHat-9.11.4-26.P2.el7_9.10") BIND_VERSION = os.environ.get("DNS_BIND_VERSION", "9.11.4-P2-RedHat-9.11.4-26.P2.el7_9.10")
_NSID_RAW = os.environ.get("DNS_NSID", "") _NSID_RAW = os.environ.get("DNS_NSID", "")
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "") _EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
try:
_up_host, _up_port_str = _upstream_raw.rsplit(":", 1)
_UPSTREAM_ADDR: tuple[str, int] = (_up_host, int(_up_port_str))
except (ValueError, AttributeError):
_UPSTREAM_ADDR = ("8.8.8.8", 53)
# ── Zone generation ─────────────────────────────────────────────────────────── # ── Zone generation ───────────────────────────────────────────────────────────
@@ -566,6 +574,57 @@ def _auth_response(qid: int, rd: bool, qname: str, qtype: int) -> bytes:
+ q + answer_bytes + auth_bytes + q + answer_bytes + auth_bytes
) )
# ── Real recursive forwarding ─────────────────────────────────────────────────
def _is_upstream_candidate(data: bytes) -> bool:
"""True when the query should be forwarded to the upstream resolver."""
if not REAL_RECURSIVE or ZONE_MODE != "recursive":
return False
if len(data) < 12:
return False
try:
qdcount = struct.unpack_from(">H", data, 4)[0]
if qdcount == 0:
return False
qname, qtype, qclass, _ = _parse_question(data, 12)
if qclass != CLASS_IN or qtype in (TYPE_AXFR, TYPE_IXFR):
return False
qname_bare = qname.rstrip(".")
in_zone = qname_bare == DOMAIN_BARE or qname_bare.endswith("." + DOMAIN_BARE)
return not in_zone
except Exception:
return False
async def _forward_upstream(data: bytes) -> bytes | None:
"""Send raw query bytes to the upstream resolver; return raw response or None."""
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
try:
await loop.sock_connect(sock, _UPSTREAM_ADDR)
await loop.sock_sendall(sock, data)
response = await asyncio.wait_for(loop.sock_recv(sock, 4096), timeout=3.0)
return response if len(response) >= 12 else None
except Exception:
return None
finally:
try:
sock.close()
except Exception:
pass
async def _dispatch(data: bytes, src_ip: str, src_port: int, transport: str) -> bytes | None:
"""Async dispatcher: runs sync _handle (logging + detection), then overlays
upstream forwarding for real-recursive out-of-zone queries."""
sinkhole = _handle(data, src_ip, src_port, transport)
if _is_upstream_candidate(data):
upstream = await _forward_upstream(data)
if upstream is not None:
return upstream
return sinkhole
# ── Request dispatcher ──────────────────────────────────────────────────────── # ── Request dispatcher ────────────────────────────────────────────────────────
def _handle(data: bytes, src_ip: str, src_port: int, transport: str) -> bytes | None: def _handle(data: bytes, src_ip: str, src_port: int, transport: str) -> bytes | None:
@@ -651,8 +710,11 @@ class _DNSUDPProtocol(asyncio.DatagramProtocol):
self._transport = cast(asyncio.DatagramTransport, transport) self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes, addr: tuple) -> None: def datagram_received(self, data: bytes, addr: tuple) -> None:
asyncio.ensure_future(self._handle_datagram(data, addr))
async def _handle_datagram(self, data: bytes, addr: tuple) -> None:
try: try:
response = _handle(data, addr[0], addr[1], "udp") response = await _dispatch(data, addr[0], addr[1], "udp")
if response and self._transport: if response and self._transport:
self._transport.sendto(response, addr) self._transport.sendto(response, addr)
except Exception: except Exception:
@@ -674,7 +736,7 @@ async def _tcp_session(reader: asyncio.StreamReader, writer: asyncio.StreamWrite
if msg_len == 0: if msg_len == 0:
break break
data = await asyncio.wait_for(reader.readexactly(msg_len), timeout=10.0) data = await asyncio.wait_for(reader.readexactly(msg_len), timeout=10.0)
response = _handle(data, src_ip, src_port, "tcp") response = await _dispatch(data, src_ip, src_port, "tcp")
if response: if response:
writer.write(struct.pack(">H", len(response)) + response) writer.write(struct.pack(">H", len(response)) + response)
await writer.drain() await writer.drain()

View File

@@ -480,6 +480,102 @@ class TestZoneModeOpen:
# ── Zone mode: recursive ────────────────────────────────────────────────────── # ── Zone mode: recursive ──────────────────────────────────────────────────────
class TestRealRecursive:
def test_upstream_response_relayed_when_available(self):
"""Upstream response is returned instead of sinkhole when forwarding succeeds."""
mod, events = _load_dns({"DNS_ZONE_MODE": "recursive", "DNS_REAL_RECURSIVE": "true"})
# Build a realistic upstream response: NOERROR, 1 A answer for evil.example.com
fake_upstream = _build_query("evil.example.com", mod.TYPE_A, qid=0x1234)
# Craft a minimal answer: header with QR=1, ANCOUNT=1 + question + A RR
flags = struct.pack(">H", 0x8180) # QR=1 AA=0 RA=1 RCODE=0
answer_hdr = struct.pack(">HHHHHH", 0x1234, 0x8180, 1, 1, 0, 0)
qname_wire = b"\x04evil\x07example\x03com\x00"
question = qname_wire + struct.pack(">HH", mod.TYPE_A, mod.CLASS_IN)
rdata = bytes([1, 2, 3, 4])
rr = qname_wire + struct.pack(">HHIH", mod.TYPE_A, mod.CLASS_IN, 60, 4) + rdata
fake_response = answer_hdr + question + rr
import asyncio
from unittest.mock import AsyncMock, patch
mock_forward = AsyncMock(return_value=fake_response)
with patch.object(mod, "_forward_upstream", mock_forward):
query = _build_query("evil.example.com", mod.TYPE_A, qid=0x1234)
resp = asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
assert resp == fake_response
mock_forward.assert_awaited_once()
def test_sinkhole_fallback_when_upstream_fails(self):
"""Sinkhole is returned when upstream times out."""
mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive", "DNS_REAL_RECURSIVE": "true"})
import asyncio
from unittest.mock import AsyncMock, patch
with patch.object(mod, "_forward_upstream", AsyncMock(return_value=None)):
query = _build_query("evil.example.com", mod.TYPE_A)
resp = asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
assert resp is not None
assert _rcode(resp) == mod.RCODE_NOERROR
assert b"\x7f" in resp # sinkhole
def test_in_zone_query_not_forwarded(self):
"""In-zone queries never hit upstream even with real_recursive=true."""
mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive", "DNS_REAL_RECURSIVE": "true"})
import asyncio
from unittest.mock import AsyncMock, patch
mock_forward = AsyncMock(return_value=None)
with patch.object(mod, "_forward_upstream", mock_forward):
query = _build_query("test.local", mod.TYPE_A)
asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
mock_forward.assert_not_awaited()
def test_real_recursive_false_never_forwards(self):
"""_forward_upstream is never called when REAL_RECURSIVE is off."""
mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive", "DNS_REAL_RECURSIVE": "false"})
import asyncio
from unittest.mock import AsyncMock, patch
mock_forward = AsyncMock(return_value=None)
with patch.object(mod, "_forward_upstream", mock_forward):
query = _build_query("evil.example.com", mod.TYPE_A)
asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
mock_forward.assert_not_awaited()
def test_logging_fires_even_when_forwarding(self):
"""query event is still emitted for forwarded queries (via _handle)."""
mod, events = _load_dns({"DNS_ZONE_MODE": "recursive", "DNS_REAL_RECURSIVE": "true"})
import asyncio
from unittest.mock import AsyncMock, patch
fake_resp = b"\x12\x34\x81\x80" + b"\x00" * 8 # minimal valid header
with patch.object(mod, "_forward_upstream", AsyncMock(return_value=fake_resp)):
query = _build_query("evil.example.com", mod.TYPE_A)
asyncio.get_event_loop().run_until_complete(
mod._dispatch(query, "1.1.1.1", 1234, "udp")
)
assert _events_of(events, "query")
def test_compose_fragment_includes_real_recursive_vars(self):
from decnet.services.registry import get_service
svc = get_service("dns")
frag = svc.compose_fragment(
"decky-01",
service_cfg={"real_recursive": True, "upstream": "1.1.1.1:53"},
)
assert frag["environment"]["DNS_REAL_RECURSIVE"] == "true"
assert frag["environment"]["DNS_UPSTREAM"] == "1.1.1.1:53"
def test_compose_fragment_real_recursive_default_false(self):
from decnet.services.registry import get_service
svc = get_service("dns")
frag = svc.compose_fragment("decky-01")
assert frag["environment"]["DNS_REAL_RECURSIVE"] == "false"
class TestZoneModeRecursive: class TestZoneModeRecursive:
def test_recursive_mode_sets_ra_flag(self): def test_recursive_mode_sets_ra_flag(self):
mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive"}) mod, _ = _load_dns({"DNS_ZONE_MODE": "recursive"})