feat: fleet-wide MACVLAN sniffer microservice

Replace per-decky sniffer containers with a single host-side sniffer
that monitors all traffic on the MACVLAN interface. Runs as a background
task in the FastAPI lifespan alongside the collector, fully fault-isolated
so failures never crash the API.

- Add fleet_singleton flag to BaseService; sniffer marked as singleton
- Composer skips fleet_singleton services in compose generation
- Fleet builder excludes singletons from random service assignment
- Extract TLS fingerprinting engine from templates/sniffer/server.py
  into decnet/sniffer/ package (parameterized for fleet-wide use)
- Sniffer worker maps packets to deckies via IP→name state mapping
- Original templates/sniffer/server.py preserved for future use
This commit is contained in:
2026-04-14 15:02:34 -04:00
parent 1d73957832
commit 5a7ff285cd
12 changed files with 1493 additions and 8 deletions

View File

@@ -0,0 +1,11 @@
"""
Fleet-wide MACVLAN sniffer microservice.
Runs as a single host-side background task (not per-decky) that sniffs
all TLS traffic on the MACVLAN interface, extracts fingerprints, and
feeds events into the existing log pipeline.
"""
from decnet.sniffer.worker import sniffer_worker
__all__ = ["sniffer_worker"]

View File

@@ -0,0 +1,884 @@
"""
TLS fingerprinting engine for the fleet-wide MACVLAN sniffer.
Extracted from templates/sniffer/server.py. All pure-Python TLS parsing,
JA3/JA3S/JA4/JA4S/JA4L computation, session tracking, and dedup logic
lives here. The packet callback is parameterized to accept an IP-to-decky
mapping and a write function, so it works for fleet-wide sniffing.
"""
from __future__ import annotations
import hashlib
import struct
import time
from pathlib import Path
from typing import Any, Callable
from decnet.sniffer.syslog import SEVERITY_INFO, SEVERITY_WARNING, syslog_line
# ─── Constants ───────────────────────────────────────────────────────────────
SERVICE_NAME: str = "sniffer"
_SESSION_TTL: float = 60.0
_DEDUP_TTL: float = 300.0
_GREASE: frozenset[int] = frozenset(0x0A0A + i * 0x1010 for i in range(16))
_TLS_RECORD_HANDSHAKE: int = 0x16
_TLS_HT_CLIENT_HELLO: int = 0x01
_TLS_HT_SERVER_HELLO: int = 0x02
_TLS_HT_CERTIFICATE: int = 0x0B
_EXT_SNI: int = 0x0000
_EXT_SUPPORTED_GROUPS: int = 0x000A
_EXT_EC_POINT_FORMATS: int = 0x000B
_EXT_SIGNATURE_ALGORITHMS: int = 0x000D
_EXT_ALPN: int = 0x0010
_EXT_SESSION_TICKET: int = 0x0023
_EXT_SUPPORTED_VERSIONS: int = 0x002B
_EXT_PRE_SHARED_KEY: int = 0x0029
_EXT_EARLY_DATA: int = 0x002A
_TCP_SYN: int = 0x02
_TCP_ACK: int = 0x10
# ─── GREASE helpers ──────────────────────────────────────────────────────────
def _is_grease(value: int) -> bool:
return value in _GREASE
def _filter_grease(values: list[int]) -> list[int]:
return [v for v in values if not _is_grease(v)]
# ─── TLS parsers ─────────────────────────────────────────────────────────────
def _parse_client_hello(data: bytes) -> dict[str, Any] | None:
try:
if len(data) < 6:
return None
if data[0] != _TLS_RECORD_HANDSHAKE:
return None
record_len = struct.unpack_from("!H", data, 3)[0]
if len(data) < 5 + record_len:
return None
hs = data[5:]
if hs[0] != _TLS_HT_CLIENT_HELLO:
return None
hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0]
body = hs[4: 4 + hs_len]
if len(body) < 34:
return None
pos = 0
tls_version = struct.unpack_from("!H", body, pos)[0]
pos += 2
pos += 32 # Random
session_id_len = body[pos]
session_id = body[pos + 1: pos + 1 + session_id_len]
pos += 1 + session_id_len
cs_len = struct.unpack_from("!H", body, pos)[0]
pos += 2
cipher_suites = [
struct.unpack_from("!H", body, pos + i * 2)[0]
for i in range(cs_len // 2)
]
pos += cs_len
comp_len = body[pos]
pos += 1 + comp_len
extensions: list[int] = []
supported_groups: list[int] = []
ec_point_formats: list[int] = []
signature_algorithms: list[int] = []
supported_versions: list[int] = []
sni: str = ""
alpn: list[str] = []
has_session_ticket_data: bool = False
has_pre_shared_key: bool = False
has_early_data: bool = False
if pos + 2 <= len(body):
ext_total = struct.unpack_from("!H", body, pos)[0]
pos += 2
ext_end = pos + ext_total
while pos + 4 <= ext_end:
ext_type = struct.unpack_from("!H", body, pos)[0]
ext_len = struct.unpack_from("!H", body, pos + 2)[0]
ext_data = body[pos + 4: pos + 4 + ext_len]
pos += 4 + ext_len
if not _is_grease(ext_type):
extensions.append(ext_type)
if ext_type == _EXT_SNI and len(ext_data) > 5:
sni = ext_data[5:].decode("ascii", errors="replace")
elif ext_type == _EXT_SUPPORTED_GROUPS and len(ext_data) >= 2:
grp_len = struct.unpack_from("!H", ext_data, 0)[0]
supported_groups = [
struct.unpack_from("!H", ext_data, 2 + i * 2)[0]
for i in range(grp_len // 2)
]
elif ext_type == _EXT_EC_POINT_FORMATS and len(ext_data) >= 1:
pf_len = ext_data[0]
ec_point_formats = list(ext_data[1: 1 + pf_len])
elif ext_type == _EXT_ALPN and len(ext_data) >= 2:
proto_list_len = struct.unpack_from("!H", ext_data, 0)[0]
ap = 2
while ap < 2 + proto_list_len:
plen = ext_data[ap]
alpn.append(ext_data[ap + 1: ap + 1 + plen].decode("ascii", errors="replace"))
ap += 1 + plen
elif ext_type == _EXT_SIGNATURE_ALGORITHMS and len(ext_data) >= 2:
sa_len = struct.unpack_from("!H", ext_data, 0)[0]
signature_algorithms = [
struct.unpack_from("!H", ext_data, 2 + i * 2)[0]
for i in range(sa_len // 2)
]
elif ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 1:
sv_len = ext_data[0]
supported_versions = [
struct.unpack_from("!H", ext_data, 1 + i * 2)[0]
for i in range(sv_len // 2)
]
elif ext_type == _EXT_SESSION_TICKET:
has_session_ticket_data = len(ext_data) > 0
elif ext_type == _EXT_PRE_SHARED_KEY:
has_pre_shared_key = True
elif ext_type == _EXT_EARLY_DATA:
has_early_data = True
filtered_ciphers = _filter_grease(cipher_suites)
filtered_groups = _filter_grease(supported_groups)
filtered_sig_algs = _filter_grease(signature_algorithms)
filtered_versions = _filter_grease(supported_versions)
return {
"tls_version": tls_version,
"cipher_suites": filtered_ciphers,
"extensions": extensions,
"supported_groups": filtered_groups,
"ec_point_formats": ec_point_formats,
"signature_algorithms": filtered_sig_algs,
"supported_versions": filtered_versions,
"sni": sni,
"alpn": alpn,
"session_id": session_id,
"has_session_ticket_data": has_session_ticket_data,
"has_pre_shared_key": has_pre_shared_key,
"has_early_data": has_early_data,
}
except Exception:
return None
def _parse_server_hello(data: bytes) -> dict[str, Any] | None:
try:
if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE:
return None
hs = data[5:]
if hs[0] != _TLS_HT_SERVER_HELLO:
return None
hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0]
body = hs[4: 4 + hs_len]
if len(body) < 35:
return None
pos = 0
tls_version = struct.unpack_from("!H", body, pos)[0]
pos += 2
pos += 32 # Random
session_id_len = body[pos]
pos += 1 + session_id_len
if pos + 2 > len(body):
return None
cipher_suite = struct.unpack_from("!H", body, pos)[0]
pos += 2
pos += 1 # Compression method
extensions: list[int] = []
selected_version: int | None = None
alpn: str = ""
if pos + 2 <= len(body):
ext_total = struct.unpack_from("!H", body, pos)[0]
pos += 2
ext_end = pos + ext_total
while pos + 4 <= ext_end:
ext_type = struct.unpack_from("!H", body, pos)[0]
ext_len = struct.unpack_from("!H", body, pos + 2)[0]
ext_data = body[pos + 4: pos + 4 + ext_len]
pos += 4 + ext_len
if not _is_grease(ext_type):
extensions.append(ext_type)
if ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2:
selected_version = struct.unpack_from("!H", ext_data, 0)[0]
elif ext_type == _EXT_ALPN and len(ext_data) >= 2:
proto_list_len = struct.unpack_from("!H", ext_data, 0)[0]
if proto_list_len > 0 and len(ext_data) >= 4:
plen = ext_data[2]
alpn = ext_data[3: 3 + plen].decode("ascii", errors="replace")
return {
"tls_version": tls_version,
"cipher_suite": cipher_suite,
"extensions": extensions,
"selected_version": selected_version,
"alpn": alpn,
}
except Exception:
return None
def _parse_certificate(data: bytes) -> dict[str, Any] | None:
try:
if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE:
return None
hs = data[5:]
if hs[0] != _TLS_HT_CERTIFICATE:
return None
hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0]
body = hs[4: 4 + hs_len]
if len(body) < 3:
return None
certs_len = struct.unpack_from("!I", b"\x00" + body[0:3])[0]
if certs_len == 0:
return None
pos = 3
if pos + 3 > len(body):
return None
cert_len = struct.unpack_from("!I", b"\x00" + body[pos:pos + 3])[0]
pos += 3
if pos + cert_len > len(body):
return None
cert_der = body[pos: pos + cert_len]
return _parse_x509_der(cert_der)
except Exception:
return None
# ─── Minimal DER/ASN.1 X.509 parser ─────────────────────────────────────────
def _der_read_tag_len(data: bytes, pos: int) -> tuple[int, int, int]:
tag = data[pos]
pos += 1
length_byte = data[pos]
pos += 1
if length_byte & 0x80:
num_bytes = length_byte & 0x7F
length = int.from_bytes(data[pos: pos + num_bytes], "big")
pos += num_bytes
else:
length = length_byte
return tag, pos, length
def _der_read_sequence(data: bytes, pos: int) -> tuple[int, int]:
tag, content_start, length = _der_read_tag_len(data, pos)
return content_start, length
def _der_read_oid(data: bytes, pos: int, length: int) -> str:
if length < 1:
return ""
first = data[pos]
oid_parts = [str(first // 40), str(first % 40)]
val = 0
for i in range(1, length):
b = data[pos + i]
val = (val << 7) | (b & 0x7F)
if not (b & 0x80):
oid_parts.append(str(val))
val = 0
return ".".join(oid_parts)
def _der_extract_cn(data: bytes, start: int, length: int) -> str:
pos = start
end = start + length
while pos < end:
set_tag, set_start, set_len = _der_read_tag_len(data, pos)
if set_tag != 0x31:
break
set_end = set_start + set_len
attr_pos = set_start
while attr_pos < set_end:
seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos)
if seq_tag != 0x30:
break
oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start)
if oid_tag == 0x06:
oid = _der_read_oid(data, oid_start, oid_len)
if oid == "2.5.4.3":
val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len)
return data[val_start: val_start + val_len].decode("utf-8", errors="replace")
attr_pos = seq_start + seq_len
pos = set_end
return ""
def _der_extract_name_str(data: bytes, start: int, length: int) -> str:
parts: list[str] = []
pos = start
end = start + length
oid_names = {
"2.5.4.3": "CN",
"2.5.4.6": "C",
"2.5.4.7": "L",
"2.5.4.8": "ST",
"2.5.4.10": "O",
"2.5.4.11": "OU",
}
while pos < end:
set_tag, set_start, set_len = _der_read_tag_len(data, pos)
if set_tag != 0x31:
break
set_end = set_start + set_len
attr_pos = set_start
while attr_pos < set_end:
seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos)
if seq_tag != 0x30:
break
oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start)
if oid_tag == 0x06:
oid = _der_read_oid(data, oid_start, oid_len)
val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len)
val = data[val_start: val_start + val_len].decode("utf-8", errors="replace")
name = oid_names.get(oid, oid)
parts.append(f"{name}={val}")
attr_pos = seq_start + seq_len
pos = set_end
return ", ".join(parts)
def _parse_x509_der(cert_der: bytes) -> dict[str, Any] | None:
try:
outer_start, outer_len = _der_read_sequence(cert_der, 0)
tbs_tag, tbs_start, tbs_len = _der_read_tag_len(cert_der, outer_start)
tbs_end = tbs_start + tbs_len
pos = tbs_start
if cert_der[pos] == 0xA0:
_, v_start, v_len = _der_read_tag_len(cert_der, pos)
pos = v_start + v_len
_, sn_start, sn_len = _der_read_tag_len(cert_der, pos)
pos = sn_start + sn_len
_, sa_start, sa_len = _der_read_tag_len(cert_der, pos)
pos = sa_start + sa_len
issuer_tag, issuer_start, issuer_len = _der_read_tag_len(cert_der, pos)
issuer_str = _der_extract_name_str(cert_der, issuer_start, issuer_len)
issuer_cn = _der_extract_cn(cert_der, issuer_start, issuer_len)
pos = issuer_start + issuer_len
val_tag, val_start, val_len = _der_read_tag_len(cert_der, pos)
nb_tag, nb_start, nb_len = _der_read_tag_len(cert_der, val_start)
not_before = cert_der[nb_start: nb_start + nb_len].decode("ascii", errors="replace")
na_tag, na_start, na_len = _der_read_tag_len(cert_der, nb_start + nb_len)
not_after = cert_der[na_start: na_start + na_len].decode("ascii", errors="replace")
pos = val_start + val_len
subj_tag, subj_start, subj_len = _der_read_tag_len(cert_der, pos)
subject_cn = _der_extract_cn(cert_der, subj_start, subj_len)
subject_str = _der_extract_name_str(cert_der, subj_start, subj_len)
self_signed = (issuer_cn == subject_cn) and subject_cn != ""
pos = subj_start + subj_len
sans: list[str] = _extract_sans(cert_der, pos, tbs_end)
return {
"subject_cn": subject_cn,
"subject": subject_str,
"issuer": issuer_str,
"issuer_cn": issuer_cn,
"not_before": not_before,
"not_after": not_after,
"self_signed": self_signed,
"sans": sans,
}
except Exception:
return None
def _extract_sans(cert_der: bytes, pos: int, end: int) -> list[str]:
sans: list[str] = []
try:
if pos >= end:
return sans
spki_tag, spki_start, spki_len = _der_read_tag_len(cert_der, pos)
pos = spki_start + spki_len
while pos < end:
tag = cert_der[pos]
if tag == 0xA3:
_, ext_wrap_start, ext_wrap_len = _der_read_tag_len(cert_der, pos)
_, exts_start, exts_len = _der_read_tag_len(cert_der, ext_wrap_start)
epos = exts_start
eend = exts_start + exts_len
while epos < eend:
ext_tag, ext_start, ext_len = _der_read_tag_len(cert_der, epos)
ext_end = ext_start + ext_len
oid_tag, oid_start, oid_len = _der_read_tag_len(cert_der, ext_start)
if oid_tag == 0x06:
oid = _der_read_oid(cert_der, oid_start, oid_len)
if oid == "2.5.29.17":
vpos = oid_start + oid_len
if vpos < ext_end and cert_der[vpos] == 0x01:
_, bs, bl = _der_read_tag_len(cert_der, vpos)
vpos = bs + bl
if vpos < ext_end:
os_tag, os_start, os_len = _der_read_tag_len(cert_der, vpos)
if os_tag == 0x04:
sans = _parse_san_sequence(cert_der, os_start, os_len)
epos = ext_end
break
else:
_, skip_start, skip_len = _der_read_tag_len(cert_der, pos)
pos = skip_start + skip_len
except Exception:
pass
return sans
def _parse_san_sequence(data: bytes, start: int, length: int) -> list[str]:
names: list[str] = []
try:
seq_tag, seq_start, seq_len = _der_read_tag_len(data, start)
pos = seq_start
end = seq_start + seq_len
while pos < end:
tag = data[pos]
_, val_start, val_len = _der_read_tag_len(data, pos)
context_tag = tag & 0x1F
if context_tag == 2:
names.append(data[val_start: val_start + val_len].decode("ascii", errors="replace"))
elif context_tag == 7 and val_len == 4:
names.append(".".join(str(b) for b in data[val_start: val_start + val_len]))
pos = val_start + val_len
except Exception:
pass
return names
# ─── JA3 / JA3S ─────────────────────────────────────────────────────────────
def _tls_version_str(version: int) -> str:
return {
0x0301: "TLS 1.0",
0x0302: "TLS 1.1",
0x0303: "TLS 1.2",
0x0304: "TLS 1.3",
0x0200: "SSL 2.0",
0x0300: "SSL 3.0",
}.get(version, f"0x{version:04x}")
def _ja3(ch: dict[str, Any]) -> tuple[str, str]:
parts = [
str(ch["tls_version"]),
"-".join(str(c) for c in ch["cipher_suites"]),
"-".join(str(e) for e in ch["extensions"]),
"-".join(str(g) for g in ch["supported_groups"]),
"-".join(str(p) for p in ch["ec_point_formats"]),
]
ja3_str = ",".join(parts)
return ja3_str, hashlib.md5(ja3_str.encode()).hexdigest()
def _ja3s(sh: dict[str, Any]) -> tuple[str, str]:
parts = [
str(sh["tls_version"]),
str(sh["cipher_suite"]),
"-".join(str(e) for e in sh["extensions"]),
]
ja3s_str = ",".join(parts)
return ja3s_str, hashlib.md5(ja3s_str.encode()).hexdigest()
# ─── JA4 / JA4S ─────────────────────────────────────────────────────────────
def _ja4_version(ch: dict[str, Any]) -> str:
versions = ch.get("supported_versions", [])
if versions:
best = max(versions)
else:
best = ch["tls_version"]
return {
0x0304: "13",
0x0303: "12",
0x0302: "11",
0x0301: "10",
0x0300: "s3",
0x0200: "s2",
}.get(best, "00")
def _ja4_alpn_tag(alpn_list: list[str] | str) -> str:
if isinstance(alpn_list, str):
proto = alpn_list
elif alpn_list:
proto = alpn_list[0]
else:
return "00"
if not proto:
return "00"
if len(proto) == 1:
return proto[0] + proto[0]
return proto[0] + proto[-1]
def _sha256_12(text: str) -> str:
return hashlib.sha256(text.encode()).hexdigest()[:12]
def _ja4(ch: dict[str, Any]) -> str:
proto = "t"
ver = _ja4_version(ch)
sni_flag = "d" if ch.get("sni") else "i"
cs_count = min(len(ch["cipher_suites"]), 99)
ext_count = min(len(ch["extensions"]), 99)
alpn_tag = _ja4_alpn_tag(ch.get("alpn", []))
section_a = f"{proto}{ver}{sni_flag}{cs_count:02d}{ext_count:02d}{alpn_tag}"
sorted_cs = sorted(ch["cipher_suites"])
section_b = _sha256_12(",".join(str(c) for c in sorted_cs))
sorted_ext = sorted(ch["extensions"])
sorted_sa = sorted(ch.get("signature_algorithms", []))
ext_str = ",".join(str(e) for e in sorted_ext)
sa_str = ",".join(str(s) for s in sorted_sa)
combined = f"{ext_str}_{sa_str}" if sa_str else ext_str
section_c = _sha256_12(combined)
return f"{section_a}_{section_b}_{section_c}"
def _ja4s(sh: dict[str, Any]) -> str:
proto = "t"
selected = sh.get("selected_version")
if selected:
ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10",
0x0300: "s3", 0x0200: "s2"}.get(selected, "00")
else:
ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10",
0x0300: "s3", 0x0200: "s2"}.get(sh["tls_version"], "00")
ext_count = min(len(sh["extensions"]), 99)
alpn_tag = _ja4_alpn_tag(sh.get("alpn", ""))
section_a = f"{proto}{ver}{ext_count:02d}{alpn_tag}"
sorted_ext = sorted(sh["extensions"])
inner = f"{sh['cipher_suite']},{','.join(str(e) for e in sorted_ext)}"
section_b = _sha256_12(inner)
return f"{section_a}_{section_b}"
# ─── JA4L (latency) ─────────────────────────────────────────────────────────
def _ja4l(
key: tuple[str, int, str, int],
tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]],
) -> dict[str, Any] | None:
return tcp_rtt.get(key)
# ─── Session resumption ─────────────────────────────────────────────────────
def _session_resumption_info(ch: dict[str, Any]) -> dict[str, Any]:
mechanisms: list[str] = []
if ch.get("has_session_ticket_data"):
mechanisms.append("session_ticket")
if ch.get("has_pre_shared_key"):
mechanisms.append("psk")
if ch.get("has_early_data"):
mechanisms.append("early_data_0rtt")
if ch.get("session_id") and len(ch["session_id"]) > 0:
mechanisms.append("session_id")
return {
"resumption_attempted": len(mechanisms) > 0,
"mechanisms": mechanisms,
}
# ─── Sniffer engine (stateful, one instance per worker) ─────────────────────
class SnifferEngine:
"""
Stateful TLS fingerprinting engine. Tracks sessions, TCP RTTs,
and dedup state. Thread-safe only when called from a single thread
(the scapy sniff thread).
"""
def __init__(
self,
ip_to_decky: dict[str, str],
write_fn: Callable[[str], None],
dedup_ttl: float = 300.0,
):
self._ip_to_decky = ip_to_decky
self._write_fn = write_fn
self._dedup_ttl = dedup_ttl
self._sessions: dict[tuple[str, int, str, int], dict[str, Any]] = {}
self._session_ts: dict[tuple[str, int, str, int], float] = {}
self._tcp_syn: dict[tuple[str, int, str, int], dict[str, Any]] = {}
self._tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]] = {}
self._dedup_cache: dict[tuple[str, str, str], float] = {}
self._dedup_last_cleanup: float = 0.0
self._DEDUP_CLEANUP_INTERVAL: float = 60.0
def update_ip_map(self, ip_to_decky: dict[str, str]) -> None:
self._ip_to_decky = ip_to_decky
def _resolve_decky(self, src_ip: str, dst_ip: str) -> str | None:
"""Map a packet to a decky name. Returns None if neither IP is a known decky."""
if dst_ip in self._ip_to_decky:
return self._ip_to_decky[dst_ip]
if src_ip in self._ip_to_decky:
return self._ip_to_decky[src_ip]
return None
def _cleanup_sessions(self) -> None:
now = time.monotonic()
stale = [k for k, ts in self._session_ts.items() if now - ts > _SESSION_TTL]
for k in stale:
self._sessions.pop(k, None)
self._session_ts.pop(k, None)
stale_syn = [k for k, v in self._tcp_syn.items()
if now - v.get("time", 0) > _SESSION_TTL]
for k in stale_syn:
self._tcp_syn.pop(k, None)
stale_rtt = [k for k, _ in self._tcp_rtt.items()
if k not in self._sessions and k not in self._session_ts]
for k in stale_rtt:
self._tcp_rtt.pop(k, None)
def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str:
if event_type == "tls_client_hello":
return fields.get("ja3", "") + "|" + fields.get("ja4", "")
if event_type == "tls_session":
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
if event_type == "tls_certificate":
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "")
return fields.get("mechanisms", fields.get("resumption", ""))
def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool:
if self._dedup_ttl <= 0:
return False
now = time.monotonic()
if now - self._dedup_last_cleanup > self._DEDUP_CLEANUP_INTERVAL:
stale = [k for k, ts in self._dedup_cache.items() if now - ts > self._dedup_ttl]
for k in stale:
del self._dedup_cache[k]
self._dedup_last_cleanup = now
src_ip = fields.get("src_ip", "")
fp = self._dedup_key_for(event_type, fields)
cache_key = (src_ip, event_type, fp)
last_seen = self._dedup_cache.get(cache_key)
if last_seen is not None and now - last_seen < self._dedup_ttl:
return True
self._dedup_cache[cache_key] = now
return False
def _log(self, node_name: str, event_type: str, severity: int = SEVERITY_INFO, **fields: Any) -> None:
if self._is_duplicate(event_type, fields):
return
line = syslog_line(SERVICE_NAME, node_name, event_type, severity=severity, **fields)
self._write_fn(line)
def on_packet(self, pkt: Any) -> None:
"""Process a single scapy packet. Called from the sniff thread."""
try:
from scapy.layers.inet import IP, TCP
except ImportError:
return
if not (pkt.haslayer(IP) and pkt.haslayer(TCP)):
return
ip = pkt[IP]
tcp = pkt[TCP]
src_ip: str = ip.src
dst_ip: str = ip.dst
src_port: int = tcp.sport
dst_port: int = tcp.dport
flags: int = tcp.flags.value if hasattr(tcp.flags, 'value') else int(tcp.flags)
# Skip traffic not involving any decky
node_name = self._resolve_decky(src_ip, dst_ip)
if node_name is None:
return
# TCP SYN tracking for JA4L
if flags & _TCP_SYN and not (flags & _TCP_ACK):
key = (src_ip, src_port, dst_ip, dst_port)
self._tcp_syn[key] = {"time": time.monotonic(), "ttl": ip.ttl}
elif flags & _TCP_SYN and flags & _TCP_ACK:
rev_key = (dst_ip, dst_port, src_ip, src_port)
syn_data = self._tcp_syn.pop(rev_key, None)
if syn_data:
rtt_ms = round((time.monotonic() - syn_data["time"]) * 1000, 2)
self._tcp_rtt[rev_key] = {
"rtt_ms": rtt_ms,
"client_ttl": syn_data["ttl"],
}
payload = bytes(tcp.payload)
if not payload:
return
if payload[0] != _TLS_RECORD_HANDSHAKE:
return
# ClientHello
ch = _parse_client_hello(payload)
if ch is not None:
self._cleanup_sessions()
key = (src_ip, src_port, dst_ip, dst_port)
ja3_str, ja3_hash = _ja3(ch)
ja4_hash = _ja4(ch)
resumption = _session_resumption_info(ch)
rtt_data = _ja4l(key, self._tcp_rtt)
self._sessions[key] = {
"ja3": ja3_hash,
"ja3_str": ja3_str,
"ja4": ja4_hash,
"tls_version": ch["tls_version"],
"cipher_suites": ch["cipher_suites"],
"extensions": ch["extensions"],
"signature_algorithms": ch.get("signature_algorithms", []),
"supported_versions": ch.get("supported_versions", []),
"sni": ch["sni"],
"alpn": ch["alpn"],
"resumption": resumption,
}
self._session_ts[key] = time.monotonic()
log_fields: dict[str, Any] = {
"src_ip": src_ip,
"src_port": str(src_port),
"dst_ip": dst_ip,
"dst_port": str(dst_port),
"ja3": ja3_hash,
"ja4": ja4_hash,
"tls_version": _tls_version_str(ch["tls_version"]),
"sni": ch["sni"] or "",
"alpn": ",".join(ch["alpn"]),
"raw_ciphers": "-".join(str(c) for c in ch["cipher_suites"]),
"raw_extensions": "-".join(str(e) for e in ch["extensions"]),
}
if resumption["resumption_attempted"]:
log_fields["resumption"] = ",".join(resumption["mechanisms"])
if rtt_data:
log_fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"])
log_fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"])
# Resolve node for the *destination* (the decky being attacked)
target_node = self._ip_to_decky.get(dst_ip, node_name)
self._log(target_node, "tls_client_hello", **log_fields)
return
# ServerHello
sh = _parse_server_hello(payload)
if sh is not None:
rev_key = (dst_ip, dst_port, src_ip, src_port)
ch_data = self._sessions.pop(rev_key, None)
self._session_ts.pop(rev_key, None)
ja3s_str, ja3s_hash = _ja3s(sh)
ja4s_hash = _ja4s(sh)
fields: dict[str, Any] = {
"src_ip": dst_ip,
"src_port": str(dst_port),
"dst_ip": src_ip,
"dst_port": str(src_port),
"ja3s": ja3s_hash,
"ja4s": ja4s_hash,
"tls_version": _tls_version_str(sh["tls_version"]),
}
if ch_data:
fields["ja3"] = ch_data["ja3"]
fields["ja4"] = ch_data.get("ja4", "")
fields["sni"] = ch_data["sni"] or ""
fields["alpn"] = ",".join(ch_data["alpn"])
fields["raw_ciphers"] = "-".join(str(c) for c in ch_data["cipher_suites"])
fields["raw_extensions"] = "-".join(str(e) for e in ch_data["extensions"])
if ch_data.get("resumption", {}).get("resumption_attempted"):
fields["resumption"] = ",".join(ch_data["resumption"]["mechanisms"])
rtt_data = self._tcp_rtt.pop(rev_key, None)
if rtt_data:
fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"])
fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"])
# Server response — resolve by src_ip (the decky responding)
target_node = self._ip_to_decky.get(src_ip, node_name)
self._log(target_node, "tls_session", severity=SEVERITY_WARNING, **fields)
return
# Certificate (TLS 1.2 only)
cert = _parse_certificate(payload)
if cert is not None:
rev_key = (dst_ip, dst_port, src_ip, src_port)
ch_data = self._sessions.get(rev_key)
cert_fields: dict[str, Any] = {
"src_ip": dst_ip,
"src_port": str(dst_port),
"dst_ip": src_ip,
"dst_port": str(src_port),
"subject_cn": cert["subject_cn"],
"issuer": cert["issuer"],
"self_signed": str(cert["self_signed"]).lower(),
"not_before": cert["not_before"],
"not_after": cert["not_after"],
}
if cert["sans"]:
cert_fields["sans"] = ",".join(cert["sans"])
if ch_data:
cert_fields["sni"] = ch_data.get("sni", "")
target_node = self._ip_to_decky.get(src_ip, node_name)
self._log(target_node, "tls_certificate", **cert_fields)

69
decnet/sniffer/syslog.py Normal file
View File

@@ -0,0 +1,69 @@
"""
RFC 5424 syslog formatting and log-file writing for the fleet sniffer.
Reuses the same wire format as templates/sniffer/decnet_logging.py so the
existing collector parser and ingester can consume events without changes.
"""
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from decnet.collector.worker import parse_rfc5424
# ─── Constants (must match templates/sniffer/decnet_logging.py) ──────────────
_FACILITY_LOCAL0 = 16
_SD_ID = "decnet@55555"
_NILVALUE = "-"
SEVERITY_INFO = 6
SEVERITY_WARNING = 4
_MAX_HOSTNAME = 255
_MAX_APPNAME = 48
_MAX_MSGID = 32
# ─── Formatter ───────────────────────────────────────────────────────────────
def _sd_escape(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"').replace("]", "\\]")
def _sd_element(fields: dict[str, Any]) -> str:
if not fields:
return _NILVALUE
params = " ".join(f'{k}="{_sd_escape(str(v))}"' for k, v in fields.items())
return f"[{_SD_ID} {params}]"
def syslog_line(
service: str,
hostname: str,
event_type: str,
severity: int = SEVERITY_INFO,
msg: str | None = None,
**fields: Any,
) -> str:
pri = f"<{_FACILITY_LOCAL0 * 8 + severity}>"
ts = datetime.now(timezone.utc).isoformat()
host = (hostname or _NILVALUE)[:_MAX_HOSTNAME]
appname = (service or _NILVALUE)[:_MAX_APPNAME]
msgid = (event_type or _NILVALUE)[:_MAX_MSGID]
sd = _sd_element(fields)
message = f" {msg}" if msg else ""
return f"{pri}1 {ts} {host} {appname} {_NILVALUE} {msgid} {sd}{message}"
def write_event(line: str, log_path: Path, json_path: Path) -> None:
"""Append a syslog line to the raw log and its parsed JSON to the json log."""
with open(log_path, "a", encoding="utf-8") as lf:
lf.write(line + "\n")
lf.flush()
parsed = parse_rfc5424(line)
if parsed:
with open(json_path, "a", encoding="utf-8") as jf:
jf.write(json.dumps(parsed) + "\n")
jf.flush()

145
decnet/sniffer/worker.py Normal file
View File

@@ -0,0 +1,145 @@
"""
Fleet-wide MACVLAN sniffer worker.
Runs as a single host-side async background task that sniffs all TLS
traffic on the MACVLAN host interface. Maps packets to deckies by IP
and feeds fingerprint events into the existing log pipeline.
Modeled on decnet.collector.worker — same lifecycle pattern.
Fault-isolated: any exception is logged and the worker exits cleanly.
The API never depends on this worker being alive.
"""
import asyncio
import os
import subprocess
import threading
import time
from pathlib import Path
from typing import Any
from decnet.logging import get_logger
from decnet.network import HOST_MACVLAN_IFACE
from decnet.sniffer.fingerprint import SnifferEngine
from decnet.sniffer.syslog import write_event
logger = get_logger("sniffer")
_IP_MAP_REFRESH_INTERVAL: float = 60.0
def _load_ip_to_decky() -> dict[str, str]:
"""Build IP → decky-name mapping from decnet-state.json."""
from decnet.config import load_state
state = load_state()
if state is None:
return {}
config, _ = state
mapping: dict[str, str] = {}
for decky in config.deckies:
mapping[decky.ip] = decky.name
return mapping
def _interface_exists(iface: str) -> bool:
"""Check if a network interface exists on this host."""
try:
result = subprocess.run(
["ip", "link", "show", iface],
capture_output=True, text=True, check=False,
)
return result.returncode == 0
except Exception:
return False
def _sniff_loop(
interface: str,
log_path: Path,
json_path: Path,
stop_event: threading.Event,
) -> None:
"""Blocking sniff loop. Runs in a dedicated thread via asyncio.to_thread."""
try:
from scapy.sendrecv import sniff
except ImportError:
logger.error("scapy not installed — sniffer cannot start")
return
ip_map = _load_ip_to_decky()
if not ip_map:
logger.warning("sniffer: no deckies in state — nothing to sniff")
return
def _write_fn(line: str) -> None:
write_event(line, log_path, json_path)
engine = SnifferEngine(ip_to_decky=ip_map, write_fn=_write_fn)
# Periodically refresh IP map in a background daemon thread
def _refresh_loop() -> None:
while not stop_event.is_set():
stop_event.wait(_IP_MAP_REFRESH_INTERVAL)
if stop_event.is_set():
break
try:
new_map = _load_ip_to_decky()
if new_map:
engine.update_ip_map(new_map)
except Exception as exc:
logger.debug("sniffer: ip map refresh failed: %s", exc)
refresh_thread = threading.Thread(target=_refresh_loop, daemon=True)
refresh_thread.start()
logger.info("sniffer: sniffing on interface=%s deckies=%d", interface, len(ip_map))
try:
sniff(
iface=interface,
filter="tcp",
prn=engine.on_packet,
store=False,
stop_filter=lambda pkt: stop_event.is_set(),
)
except Exception as exc:
logger.error("sniffer: scapy sniff exited: %s", exc)
finally:
stop_event.set()
logger.info("sniffer: sniff loop ended")
async def sniffer_worker(log_file: str) -> None:
"""
Async entry point — started as asyncio.create_task in the API lifespan.
Fully fault-isolated: catches all exceptions, logs them, and returns
cleanly. The API continues running regardless of sniffer state.
"""
try:
interface = os.environ.get("DECNET_SNIFFER_IFACE", HOST_MACVLAN_IFACE)
if not _interface_exists(interface):
logger.warning(
"sniffer: interface %s not found — sniffer disabled "
"(fleet may not be deployed yet)", interface,
)
return
log_path = Path(log_file)
json_path = log_path.with_suffix(".json")
log_path.parent.mkdir(parents=True, exist_ok=True)
stop_event = threading.Event()
try:
await asyncio.to_thread(_sniff_loop, interface, log_path, json_path, stop_event)
except asyncio.CancelledError:
logger.info("sniffer: shutdown requested")
stop_event.set()
raise
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error("sniffer: worker failed — API continues without sniffing: %s", exc)