feat(dns): real recursive forwarding with sinkhole fallback
When DNS_REAL_RECURSIVE=true and DNS_ZONE_MODE=recursive, out-of-zone queries are forwarded to DNS_UPSTREAM (default 8.8.8.8:53) via async UDP. Upstream response is relayed as-is; on timeout or error the already-computed sinkhole (127.x) is returned instead. _handle() always runs first so logging, tunneling detection, flood tracking, and recon-burst aggregation fire on every query regardless of whether the response ultimately comes from upstream. _dispatch() overlays forwarding on top of the sync handler. Protocol handlers (UDP datagram_received, TCP session) are now async via asyncio.ensure_future / await _dispatch(). Service class exposes real_recursive (bool) and upstream (string) config fields.
This commit is contained in:
@@ -18,7 +18,7 @@ class DNSService(BaseService):
|
||||
type="enum",
|
||||
enum=["auth", "recursive", "open"],
|
||||
default="auth",
|
||||
help="auth: authoritative only; recursive: RA flag set, NXDOMAIN for out-of-zone; open: responds to everything (amp bait)",
|
||||
help="auth: authoritative only; recursive: forwards out-of-zone queries upstream (real_recursive=true) or sinkholes them; open: responds to everything (amp bait)",
|
||||
),
|
||||
ServiceConfigField(
|
||||
key="domain",
|
||||
@@ -50,6 +50,21 @@ class DNSService(BaseService):
|
||||
placeholder="www A 10.0.0.5\nmail TXT v=spf1 ~all",
|
||||
help="Additional zone records, one per line: <name> <TYPE> <value>",
|
||||
),
|
||||
ServiceConfigField(
|
||||
key="real_recursive",
|
||||
label="Real recursive forwarding",
|
||||
type="bool",
|
||||
default=False,
|
||||
help="When zone_mode=recursive, forward out-of-zone queries to an upstream resolver instead of returning a sinkhole. Falls back to sinkhole on upstream timeout.",
|
||||
),
|
||||
ServiceConfigField(
|
||||
key="upstream",
|
||||
label="Upstream resolver",
|
||||
type="string",
|
||||
default="8.8.8.8:53",
|
||||
placeholder="8.8.8.8:53",
|
||||
help="Upstream DNS resolver used when real_recursive is enabled (host:port).",
|
||||
),
|
||||
]
|
||||
|
||||
def compose_fragment(
|
||||
@@ -65,7 +80,9 @@ class DNSService(BaseService):
|
||||
"DNS_DOMAIN": str(cfg.get("domain", "")),
|
||||
"DNS_BIND_VERSION": str(cfg.get("bind_version", _DEFAULT_VERSION)),
|
||||
"DNS_NSID": str(cfg.get("nsid", "")),
|
||||
"DNS_EXTRA_RECORDS": str(cfg.get("extra_records", "")),
|
||||
"DNS_EXTRA_RECORDS": str(cfg.get("extra_records", "")),
|
||||
"DNS_REAL_RECURSIVE": "true" if cfg.get("real_recursive") else "false",
|
||||
"DNS_UPSTREAM": str(cfg.get("upstream", "8.8.8.8:53")),
|
||||
}
|
||||
if log_target:
|
||||
env["LOG_TARGET"] = log_target
|
||||
|
||||
@@ -28,13 +28,21 @@ import instance_seed as seed
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
NODE_NAME = os.environ.get("NODE_NAME", "ns1")
|
||||
SERVICE_NAME = "dns"
|
||||
LOG_TARGET = os.environ.get("LOG_TARGET", "")
|
||||
ZONE_MODE = os.environ.get("DNS_ZONE_MODE", "auth")
|
||||
BIND_VERSION = os.environ.get("DNS_BIND_VERSION", "9.11.4-P2-RedHat-9.11.4-26.P2.el7_9.10")
|
||||
_NSID_RAW = os.environ.get("DNS_NSID", "")
|
||||
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
|
||||
NODE_NAME = os.environ.get("NODE_NAME", "ns1")
|
||||
SERVICE_NAME = "dns"
|
||||
LOG_TARGET = os.environ.get("LOG_TARGET", "")
|
||||
ZONE_MODE = os.environ.get("DNS_ZONE_MODE", "auth")
|
||||
BIND_VERSION = os.environ.get("DNS_BIND_VERSION", "9.11.4-P2-RedHat-9.11.4-26.P2.el7_9.10")
|
||||
_NSID_RAW = os.environ.get("DNS_NSID", "")
|
||||
_EXTRA_RAW = os.environ.get("DNS_EXTRA_RECORDS", "")
|
||||
REAL_RECURSIVE = os.environ.get("DNS_REAL_RECURSIVE", "").lower() in ("1", "true", "yes")
|
||||
|
||||
_upstream_raw = os.environ.get("DNS_UPSTREAM", "8.8.8.8:53")
|
||||
try:
|
||||
_up_host, _up_port_str = _upstream_raw.rsplit(":", 1)
|
||||
_UPSTREAM_ADDR: tuple[str, int] = (_up_host, int(_up_port_str))
|
||||
except (ValueError, AttributeError):
|
||||
_UPSTREAM_ADDR = ("8.8.8.8", 53)
|
||||
|
||||
# ── Zone generation ───────────────────────────────────────────────────────────
|
||||
|
||||
@@ -566,6 +574,57 @@ def _auth_response(qid: int, rd: bool, qname: str, qtype: int) -> bytes:
|
||||
+ q + answer_bytes + auth_bytes
|
||||
)
|
||||
|
||||
# ── Real recursive forwarding ─────────────────────────────────────────────────
|
||||
|
||||
def _is_upstream_candidate(data: bytes) -> bool:
|
||||
"""True when the query should be forwarded to the upstream resolver."""
|
||||
if not REAL_RECURSIVE or ZONE_MODE != "recursive":
|
||||
return False
|
||||
if len(data) < 12:
|
||||
return False
|
||||
try:
|
||||
qdcount = struct.unpack_from(">H", data, 4)[0]
|
||||
if qdcount == 0:
|
||||
return False
|
||||
qname, qtype, qclass, _ = _parse_question(data, 12)
|
||||
if qclass != CLASS_IN or qtype in (TYPE_AXFR, TYPE_IXFR):
|
||||
return False
|
||||
qname_bare = qname.rstrip(".")
|
||||
in_zone = qname_bare == DOMAIN_BARE or qname_bare.endswith("." + DOMAIN_BARE)
|
||||
return not in_zone
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _forward_upstream(data: bytes) -> bytes | None:
|
||||
"""Send raw query bytes to the upstream resolver; return raw response or None."""
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
try:
|
||||
await loop.sock_connect(sock, _UPSTREAM_ADDR)
|
||||
await loop.sock_sendall(sock, data)
|
||||
response = await asyncio.wait_for(loop.sock_recv(sock, 4096), timeout=3.0)
|
||||
return response if len(response) >= 12 else None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _dispatch(data: bytes, src_ip: str, src_port: int, transport: str) -> bytes | None:
|
||||
"""Async dispatcher: runs sync _handle (logging + detection), then overlays
|
||||
upstream forwarding for real-recursive out-of-zone queries."""
|
||||
sinkhole = _handle(data, src_ip, src_port, transport)
|
||||
if _is_upstream_candidate(data):
|
||||
upstream = await _forward_upstream(data)
|
||||
if upstream is not None:
|
||||
return upstream
|
||||
return sinkhole
|
||||
|
||||
# ── Request dispatcher ────────────────────────────────────────────────────────
|
||||
|
||||
def _handle(data: bytes, src_ip: str, src_port: int, transport: str) -> bytes | None:
|
||||
@@ -651,8 +710,11 @@ class _DNSUDPProtocol(asyncio.DatagramProtocol):
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple) -> None:
|
||||
asyncio.ensure_future(self._handle_datagram(data, addr))
|
||||
|
||||
async def _handle_datagram(self, data: bytes, addr: tuple) -> None:
|
||||
try:
|
||||
response = _handle(data, addr[0], addr[1], "udp")
|
||||
response = await _dispatch(data, addr[0], addr[1], "udp")
|
||||
if response and self._transport:
|
||||
self._transport.sendto(response, addr)
|
||||
except Exception:
|
||||
@@ -674,7 +736,7 @@ async def _tcp_session(reader: asyncio.StreamReader, writer: asyncio.StreamWrite
|
||||
if msg_len == 0:
|
||||
break
|
||||
data = await asyncio.wait_for(reader.readexactly(msg_len), timeout=10.0)
|
||||
response = _handle(data, src_ip, src_port, "tcp")
|
||||
response = await _dispatch(data, src_ip, src_port, "tcp")
|
||||
if response:
|
||||
writer.write(struct.pack(">H", len(response)) + response)
|
||||
await writer.drain()
|
||||
|
||||
Reference in New Issue
Block a user