diff --git a/decnet/composer.py b/decnet/composer.py index 973762e..d789615 100644 --- a/decnet/composer.py +++ b/decnet/composer.py @@ -64,6 +64,8 @@ def generate_compose(config: DecnetConfig) -> dict: # --- Service containers: share base network namespace --- for svc_name in decky.services: svc = get_service(svc_name) + if svc.fleet_singleton: + continue svc_cfg = decky.service_config.get(svc_name, {}) fragment = svc.compose_fragment(decky.name, service_cfg=svc_cfg) diff --git a/decnet/fleet.py b/decnet/fleet.py index 01a38c4..f41dbee 100644 --- a/decnet/fleet.py +++ b/decnet/fleet.py @@ -17,8 +17,11 @@ from decnet.services.registry import all_services def all_service_names() -> list[str]: - """Return all registered service names from the live plugin registry.""" - return sorted(all_services().keys()) + """Return all registered per-decky service names (excludes fleet singletons).""" + return sorted( + name for name, svc in all_services().items() + if not svc.fleet_singleton + ) def resolve_distros( diff --git a/decnet/services/base.py b/decnet/services/base.py index 17c2e20..2f7936f 100644 --- a/decnet/services/base.py +++ b/decnet/services/base.py @@ -13,6 +13,7 @@ class BaseService(ABC): name: str # unique slug, e.g. "ssh", "smb" ports: list[int] # ports this service listens on inside the container default_image: str # Docker image tag, or "build" if a Dockerfile is needed + fleet_singleton: bool = False # True = runs once fleet-wide, not per-decky @abstractmethod def compose_fragment( diff --git a/decnet/services/sniffer.py b/decnet/services/sniffer.py index 6bf9c44..2f0dd2e 100644 --- a/decnet/services/sniffer.py +++ b/decnet/services/sniffer.py @@ -16,6 +16,7 @@ class SnifferService(BaseService): name = "sniffer" ports: list[int] = [] default_image = "build" + fleet_singleton = True def compose_fragment( self, diff --git a/decnet/sniffer/__init__.py b/decnet/sniffer/__init__.py new file mode 100644 index 0000000..4428ea1 --- /dev/null +++ b/decnet/sniffer/__init__.py @@ -0,0 +1,11 @@ +""" +Fleet-wide MACVLAN sniffer microservice. + +Runs as a single host-side background task (not per-decky) that sniffs +all TLS traffic on the MACVLAN interface, extracts fingerprints, and +feeds events into the existing log pipeline. +""" + +from decnet.sniffer.worker import sniffer_worker + +__all__ = ["sniffer_worker"] diff --git a/decnet/sniffer/fingerprint.py b/decnet/sniffer/fingerprint.py new file mode 100644 index 0000000..487db32 --- /dev/null +++ b/decnet/sniffer/fingerprint.py @@ -0,0 +1,884 @@ +""" +TLS fingerprinting engine for the fleet-wide MACVLAN sniffer. + +Extracted from templates/sniffer/server.py. All pure-Python TLS parsing, +JA3/JA3S/JA4/JA4S/JA4L computation, session tracking, and dedup logic +lives here. The packet callback is parameterized to accept an IP-to-decky +mapping and a write function, so it works for fleet-wide sniffing. +""" + +from __future__ import annotations + +import hashlib +import struct +import time +from pathlib import Path +from typing import Any, Callable + +from decnet.sniffer.syslog import SEVERITY_INFO, SEVERITY_WARNING, syslog_line + +# ─── Constants ─────────────────────────────────────────────────────────────── + +SERVICE_NAME: str = "sniffer" + +_SESSION_TTL: float = 60.0 +_DEDUP_TTL: float = 300.0 + +_GREASE: frozenset[int] = frozenset(0x0A0A + i * 0x1010 for i in range(16)) + +_TLS_RECORD_HANDSHAKE: int = 0x16 +_TLS_HT_CLIENT_HELLO: int = 0x01 +_TLS_HT_SERVER_HELLO: int = 0x02 +_TLS_HT_CERTIFICATE: int = 0x0B + +_EXT_SNI: int = 0x0000 +_EXT_SUPPORTED_GROUPS: int = 0x000A +_EXT_EC_POINT_FORMATS: int = 0x000B +_EXT_SIGNATURE_ALGORITHMS: int = 0x000D +_EXT_ALPN: int = 0x0010 +_EXT_SESSION_TICKET: int = 0x0023 +_EXT_SUPPORTED_VERSIONS: int = 0x002B +_EXT_PRE_SHARED_KEY: int = 0x0029 +_EXT_EARLY_DATA: int = 0x002A + +_TCP_SYN: int = 0x02 +_TCP_ACK: int = 0x10 + + +# ─── GREASE helpers ────────────────────────────────────────────────────────── + +def _is_grease(value: int) -> bool: + return value in _GREASE + + +def _filter_grease(values: list[int]) -> list[int]: + return [v for v in values if not _is_grease(v)] + + +# ─── TLS parsers ───────────────────────────────────────────────────────────── + +def _parse_client_hello(data: bytes) -> dict[str, Any] | None: + try: + if len(data) < 6: + return None + if data[0] != _TLS_RECORD_HANDSHAKE: + return None + record_len = struct.unpack_from("!H", data, 3)[0] + if len(data) < 5 + record_len: + return None + + hs = data[5:] + if hs[0] != _TLS_HT_CLIENT_HELLO: + return None + + hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] + body = hs[4: 4 + hs_len] + if len(body) < 34: + return None + + pos = 0 + tls_version = struct.unpack_from("!H", body, pos)[0] + pos += 2 + pos += 32 # Random + + session_id_len = body[pos] + session_id = body[pos + 1: pos + 1 + session_id_len] + pos += 1 + session_id_len + + cs_len = struct.unpack_from("!H", body, pos)[0] + pos += 2 + cipher_suites = [ + struct.unpack_from("!H", body, pos + i * 2)[0] + for i in range(cs_len // 2) + ] + pos += cs_len + + comp_len = body[pos] + pos += 1 + comp_len + + extensions: list[int] = [] + supported_groups: list[int] = [] + ec_point_formats: list[int] = [] + signature_algorithms: list[int] = [] + supported_versions: list[int] = [] + sni: str = "" + alpn: list[str] = [] + has_session_ticket_data: bool = False + has_pre_shared_key: bool = False + has_early_data: bool = False + + if pos + 2 <= len(body): + ext_total = struct.unpack_from("!H", body, pos)[0] + pos += 2 + ext_end = pos + ext_total + + while pos + 4 <= ext_end: + ext_type = struct.unpack_from("!H", body, pos)[0] + ext_len = struct.unpack_from("!H", body, pos + 2)[0] + ext_data = body[pos + 4: pos + 4 + ext_len] + pos += 4 + ext_len + + if not _is_grease(ext_type): + extensions.append(ext_type) + + if ext_type == _EXT_SNI and len(ext_data) > 5: + sni = ext_data[5:].decode("ascii", errors="replace") + + elif ext_type == _EXT_SUPPORTED_GROUPS and len(ext_data) >= 2: + grp_len = struct.unpack_from("!H", ext_data, 0)[0] + supported_groups = [ + struct.unpack_from("!H", ext_data, 2 + i * 2)[0] + for i in range(grp_len // 2) + ] + + elif ext_type == _EXT_EC_POINT_FORMATS and len(ext_data) >= 1: + pf_len = ext_data[0] + ec_point_formats = list(ext_data[1: 1 + pf_len]) + + elif ext_type == _EXT_ALPN and len(ext_data) >= 2: + proto_list_len = struct.unpack_from("!H", ext_data, 0)[0] + ap = 2 + while ap < 2 + proto_list_len: + plen = ext_data[ap] + alpn.append(ext_data[ap + 1: ap + 1 + plen].decode("ascii", errors="replace")) + ap += 1 + plen + + elif ext_type == _EXT_SIGNATURE_ALGORITHMS and len(ext_data) >= 2: + sa_len = struct.unpack_from("!H", ext_data, 0)[0] + signature_algorithms = [ + struct.unpack_from("!H", ext_data, 2 + i * 2)[0] + for i in range(sa_len // 2) + ] + + elif ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 1: + sv_len = ext_data[0] + supported_versions = [ + struct.unpack_from("!H", ext_data, 1 + i * 2)[0] + for i in range(sv_len // 2) + ] + + elif ext_type == _EXT_SESSION_TICKET: + has_session_ticket_data = len(ext_data) > 0 + + elif ext_type == _EXT_PRE_SHARED_KEY: + has_pre_shared_key = True + + elif ext_type == _EXT_EARLY_DATA: + has_early_data = True + + filtered_ciphers = _filter_grease(cipher_suites) + filtered_groups = _filter_grease(supported_groups) + filtered_sig_algs = _filter_grease(signature_algorithms) + filtered_versions = _filter_grease(supported_versions) + + return { + "tls_version": tls_version, + "cipher_suites": filtered_ciphers, + "extensions": extensions, + "supported_groups": filtered_groups, + "ec_point_formats": ec_point_formats, + "signature_algorithms": filtered_sig_algs, + "supported_versions": filtered_versions, + "sni": sni, + "alpn": alpn, + "session_id": session_id, + "has_session_ticket_data": has_session_ticket_data, + "has_pre_shared_key": has_pre_shared_key, + "has_early_data": has_early_data, + } + + except Exception: + return None + + +def _parse_server_hello(data: bytes) -> dict[str, Any] | None: + try: + if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE: + return None + + hs = data[5:] + if hs[0] != _TLS_HT_SERVER_HELLO: + return None + + hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] + body = hs[4: 4 + hs_len] + if len(body) < 35: + return None + + pos = 0 + tls_version = struct.unpack_from("!H", body, pos)[0] + pos += 2 + pos += 32 # Random + + session_id_len = body[pos] + pos += 1 + session_id_len + + if pos + 2 > len(body): + return None + + cipher_suite = struct.unpack_from("!H", body, pos)[0] + pos += 2 + pos += 1 # Compression method + + extensions: list[int] = [] + selected_version: int | None = None + alpn: str = "" + + if pos + 2 <= len(body): + ext_total = struct.unpack_from("!H", body, pos)[0] + pos += 2 + ext_end = pos + ext_total + while pos + 4 <= ext_end: + ext_type = struct.unpack_from("!H", body, pos)[0] + ext_len = struct.unpack_from("!H", body, pos + 2)[0] + ext_data = body[pos + 4: pos + 4 + ext_len] + pos += 4 + ext_len + if not _is_grease(ext_type): + extensions.append(ext_type) + + if ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2: + selected_version = struct.unpack_from("!H", ext_data, 0)[0] + + elif ext_type == _EXT_ALPN and len(ext_data) >= 2: + proto_list_len = struct.unpack_from("!H", ext_data, 0)[0] + if proto_list_len > 0 and len(ext_data) >= 4: + plen = ext_data[2] + alpn = ext_data[3: 3 + plen].decode("ascii", errors="replace") + + return { + "tls_version": tls_version, + "cipher_suite": cipher_suite, + "extensions": extensions, + "selected_version": selected_version, + "alpn": alpn, + } + + except Exception: + return None + + +def _parse_certificate(data: bytes) -> dict[str, Any] | None: + try: + if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE: + return None + + hs = data[5:] + if hs[0] != _TLS_HT_CERTIFICATE: + return None + + hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] + body = hs[4: 4 + hs_len] + if len(body) < 3: + return None + + certs_len = struct.unpack_from("!I", b"\x00" + body[0:3])[0] + if certs_len == 0: + return None + + pos = 3 + if pos + 3 > len(body): + return None + cert_len = struct.unpack_from("!I", b"\x00" + body[pos:pos + 3])[0] + pos += 3 + if pos + cert_len > len(body): + return None + + cert_der = body[pos: pos + cert_len] + return _parse_x509_der(cert_der) + + except Exception: + return None + + +# ─── Minimal DER/ASN.1 X.509 parser ───────────────────────────────────────── + +def _der_read_tag_len(data: bytes, pos: int) -> tuple[int, int, int]: + tag = data[pos] + pos += 1 + length_byte = data[pos] + pos += 1 + if length_byte & 0x80: + num_bytes = length_byte & 0x7F + length = int.from_bytes(data[pos: pos + num_bytes], "big") + pos += num_bytes + else: + length = length_byte + return tag, pos, length + + +def _der_read_sequence(data: bytes, pos: int) -> tuple[int, int]: + tag, content_start, length = _der_read_tag_len(data, pos) + return content_start, length + + +def _der_read_oid(data: bytes, pos: int, length: int) -> str: + if length < 1: + return "" + first = data[pos] + oid_parts = [str(first // 40), str(first % 40)] + val = 0 + for i in range(1, length): + b = data[pos + i] + val = (val << 7) | (b & 0x7F) + if not (b & 0x80): + oid_parts.append(str(val)) + val = 0 + return ".".join(oid_parts) + + +def _der_extract_cn(data: bytes, start: int, length: int) -> str: + pos = start + end = start + length + while pos < end: + set_tag, set_start, set_len = _der_read_tag_len(data, pos) + if set_tag != 0x31: + break + set_end = set_start + set_len + attr_pos = set_start + while attr_pos < set_end: + seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos) + if seq_tag != 0x30: + break + oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start) + if oid_tag == 0x06: + oid = _der_read_oid(data, oid_start, oid_len) + if oid == "2.5.4.3": + val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len) + return data[val_start: val_start + val_len].decode("utf-8", errors="replace") + attr_pos = seq_start + seq_len + pos = set_end + return "" + + +def _der_extract_name_str(data: bytes, start: int, length: int) -> str: + parts: list[str] = [] + pos = start + end = start + length + oid_names = { + "2.5.4.3": "CN", + "2.5.4.6": "C", + "2.5.4.7": "L", + "2.5.4.8": "ST", + "2.5.4.10": "O", + "2.5.4.11": "OU", + } + while pos < end: + set_tag, set_start, set_len = _der_read_tag_len(data, pos) + if set_tag != 0x31: + break + set_end = set_start + set_len + attr_pos = set_start + while attr_pos < set_end: + seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos) + if seq_tag != 0x30: + break + oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start) + if oid_tag == 0x06: + oid = _der_read_oid(data, oid_start, oid_len) + val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len) + val = data[val_start: val_start + val_len].decode("utf-8", errors="replace") + name = oid_names.get(oid, oid) + parts.append(f"{name}={val}") + attr_pos = seq_start + seq_len + pos = set_end + return ", ".join(parts) + + +def _parse_x509_der(cert_der: bytes) -> dict[str, Any] | None: + try: + outer_start, outer_len = _der_read_sequence(cert_der, 0) + tbs_tag, tbs_start, tbs_len = _der_read_tag_len(cert_der, outer_start) + tbs_end = tbs_start + tbs_len + pos = tbs_start + + if cert_der[pos] == 0xA0: + _, v_start, v_len = _der_read_tag_len(cert_der, pos) + pos = v_start + v_len + + _, sn_start, sn_len = _der_read_tag_len(cert_der, pos) + pos = sn_start + sn_len + + _, sa_start, sa_len = _der_read_tag_len(cert_der, pos) + pos = sa_start + sa_len + + issuer_tag, issuer_start, issuer_len = _der_read_tag_len(cert_der, pos) + issuer_str = _der_extract_name_str(cert_der, issuer_start, issuer_len) + issuer_cn = _der_extract_cn(cert_der, issuer_start, issuer_len) + pos = issuer_start + issuer_len + + val_tag, val_start, val_len = _der_read_tag_len(cert_der, pos) + nb_tag, nb_start, nb_len = _der_read_tag_len(cert_der, val_start) + not_before = cert_der[nb_start: nb_start + nb_len].decode("ascii", errors="replace") + na_tag, na_start, na_len = _der_read_tag_len(cert_der, nb_start + nb_len) + not_after = cert_der[na_start: na_start + na_len].decode("ascii", errors="replace") + pos = val_start + val_len + + subj_tag, subj_start, subj_len = _der_read_tag_len(cert_der, pos) + subject_cn = _der_extract_cn(cert_der, subj_start, subj_len) + subject_str = _der_extract_name_str(cert_der, subj_start, subj_len) + + self_signed = (issuer_cn == subject_cn) and subject_cn != "" + + pos = subj_start + subj_len + sans: list[str] = _extract_sans(cert_der, pos, tbs_end) + + return { + "subject_cn": subject_cn, + "subject": subject_str, + "issuer": issuer_str, + "issuer_cn": issuer_cn, + "not_before": not_before, + "not_after": not_after, + "self_signed": self_signed, + "sans": sans, + } + + except Exception: + return None + + +def _extract_sans(cert_der: bytes, pos: int, end: int) -> list[str]: + sans: list[str] = [] + try: + if pos >= end: + return sans + spki_tag, spki_start, spki_len = _der_read_tag_len(cert_der, pos) + pos = spki_start + spki_len + + while pos < end: + tag = cert_der[pos] + if tag == 0xA3: + _, ext_wrap_start, ext_wrap_len = _der_read_tag_len(cert_der, pos) + _, exts_start, exts_len = _der_read_tag_len(cert_der, ext_wrap_start) + epos = exts_start + eend = exts_start + exts_len + while epos < eend: + ext_tag, ext_start, ext_len = _der_read_tag_len(cert_der, epos) + ext_end = ext_start + ext_len + oid_tag, oid_start, oid_len = _der_read_tag_len(cert_der, ext_start) + if oid_tag == 0x06: + oid = _der_read_oid(cert_der, oid_start, oid_len) + if oid == "2.5.29.17": + vpos = oid_start + oid_len + if vpos < ext_end and cert_der[vpos] == 0x01: + _, bs, bl = _der_read_tag_len(cert_der, vpos) + vpos = bs + bl + if vpos < ext_end: + os_tag, os_start, os_len = _der_read_tag_len(cert_der, vpos) + if os_tag == 0x04: + sans = _parse_san_sequence(cert_der, os_start, os_len) + epos = ext_end + break + else: + _, skip_start, skip_len = _der_read_tag_len(cert_der, pos) + pos = skip_start + skip_len + except Exception: + pass + return sans + + +def _parse_san_sequence(data: bytes, start: int, length: int) -> list[str]: + names: list[str] = [] + try: + seq_tag, seq_start, seq_len = _der_read_tag_len(data, start) + pos = seq_start + end = seq_start + seq_len + while pos < end: + tag = data[pos] + _, val_start, val_len = _der_read_tag_len(data, pos) + context_tag = tag & 0x1F + if context_tag == 2: + names.append(data[val_start: val_start + val_len].decode("ascii", errors="replace")) + elif context_tag == 7 and val_len == 4: + names.append(".".join(str(b) for b in data[val_start: val_start + val_len])) + pos = val_start + val_len + except Exception: + pass + return names + + +# ─── JA3 / JA3S ───────────────────────────────────────────────────────────── + +def _tls_version_str(version: int) -> str: + return { + 0x0301: "TLS 1.0", + 0x0302: "TLS 1.1", + 0x0303: "TLS 1.2", + 0x0304: "TLS 1.3", + 0x0200: "SSL 2.0", + 0x0300: "SSL 3.0", + }.get(version, f"0x{version:04x}") + + +def _ja3(ch: dict[str, Any]) -> tuple[str, str]: + parts = [ + str(ch["tls_version"]), + "-".join(str(c) for c in ch["cipher_suites"]), + "-".join(str(e) for e in ch["extensions"]), + "-".join(str(g) for g in ch["supported_groups"]), + "-".join(str(p) for p in ch["ec_point_formats"]), + ] + ja3_str = ",".join(parts) + return ja3_str, hashlib.md5(ja3_str.encode()).hexdigest() + + +def _ja3s(sh: dict[str, Any]) -> tuple[str, str]: + parts = [ + str(sh["tls_version"]), + str(sh["cipher_suite"]), + "-".join(str(e) for e in sh["extensions"]), + ] + ja3s_str = ",".join(parts) + return ja3s_str, hashlib.md5(ja3s_str.encode()).hexdigest() + + +# ─── JA4 / JA4S ───────────────────────────────────────────────────────────── + +def _ja4_version(ch: dict[str, Any]) -> str: + versions = ch.get("supported_versions", []) + if versions: + best = max(versions) + else: + best = ch["tls_version"] + return { + 0x0304: "13", + 0x0303: "12", + 0x0302: "11", + 0x0301: "10", + 0x0300: "s3", + 0x0200: "s2", + }.get(best, "00") + + +def _ja4_alpn_tag(alpn_list: list[str] | str) -> str: + if isinstance(alpn_list, str): + proto = alpn_list + elif alpn_list: + proto = alpn_list[0] + else: + return "00" + if not proto: + return "00" + if len(proto) == 1: + return proto[0] + proto[0] + return proto[0] + proto[-1] + + +def _sha256_12(text: str) -> str: + return hashlib.sha256(text.encode()).hexdigest()[:12] + + +def _ja4(ch: dict[str, Any]) -> str: + proto = "t" + ver = _ja4_version(ch) + sni_flag = "d" if ch.get("sni") else "i" + cs_count = min(len(ch["cipher_suites"]), 99) + ext_count = min(len(ch["extensions"]), 99) + alpn_tag = _ja4_alpn_tag(ch.get("alpn", [])) + section_a = f"{proto}{ver}{sni_flag}{cs_count:02d}{ext_count:02d}{alpn_tag}" + sorted_cs = sorted(ch["cipher_suites"]) + section_b = _sha256_12(",".join(str(c) for c in sorted_cs)) + sorted_ext = sorted(ch["extensions"]) + sorted_sa = sorted(ch.get("signature_algorithms", [])) + ext_str = ",".join(str(e) for e in sorted_ext) + sa_str = ",".join(str(s) for s in sorted_sa) + combined = f"{ext_str}_{sa_str}" if sa_str else ext_str + section_c = _sha256_12(combined) + return f"{section_a}_{section_b}_{section_c}" + + +def _ja4s(sh: dict[str, Any]) -> str: + proto = "t" + selected = sh.get("selected_version") + if selected: + ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10", + 0x0300: "s3", 0x0200: "s2"}.get(selected, "00") + else: + ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10", + 0x0300: "s3", 0x0200: "s2"}.get(sh["tls_version"], "00") + ext_count = min(len(sh["extensions"]), 99) + alpn_tag = _ja4_alpn_tag(sh.get("alpn", "")) + section_a = f"{proto}{ver}{ext_count:02d}{alpn_tag}" + sorted_ext = sorted(sh["extensions"]) + inner = f"{sh['cipher_suite']},{','.join(str(e) for e in sorted_ext)}" + section_b = _sha256_12(inner) + return f"{section_a}_{section_b}" + + +# ─── JA4L (latency) ───────────────────────────────────────────────────────── + +def _ja4l( + key: tuple[str, int, str, int], + tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]], +) -> dict[str, Any] | None: + return tcp_rtt.get(key) + + +# ─── Session resumption ───────────────────────────────────────────────────── + +def _session_resumption_info(ch: dict[str, Any]) -> dict[str, Any]: + mechanisms: list[str] = [] + if ch.get("has_session_ticket_data"): + mechanisms.append("session_ticket") + if ch.get("has_pre_shared_key"): + mechanisms.append("psk") + if ch.get("has_early_data"): + mechanisms.append("early_data_0rtt") + if ch.get("session_id") and len(ch["session_id"]) > 0: + mechanisms.append("session_id") + return { + "resumption_attempted": len(mechanisms) > 0, + "mechanisms": mechanisms, + } + + +# ─── Sniffer engine (stateful, one instance per worker) ───────────────────── + +class SnifferEngine: + """ + Stateful TLS fingerprinting engine. Tracks sessions, TCP RTTs, + and dedup state. Thread-safe only when called from a single thread + (the scapy sniff thread). + """ + + def __init__( + self, + ip_to_decky: dict[str, str], + write_fn: Callable[[str], None], + dedup_ttl: float = 300.0, + ): + self._ip_to_decky = ip_to_decky + self._write_fn = write_fn + self._dedup_ttl = dedup_ttl + + self._sessions: dict[tuple[str, int, str, int], dict[str, Any]] = {} + self._session_ts: dict[tuple[str, int, str, int], float] = {} + self._tcp_syn: dict[tuple[str, int, str, int], dict[str, Any]] = {} + self._tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]] = {} + + self._dedup_cache: dict[tuple[str, str, str], float] = {} + self._dedup_last_cleanup: float = 0.0 + self._DEDUP_CLEANUP_INTERVAL: float = 60.0 + + def update_ip_map(self, ip_to_decky: dict[str, str]) -> None: + self._ip_to_decky = ip_to_decky + + def _resolve_decky(self, src_ip: str, dst_ip: str) -> str | None: + """Map a packet to a decky name. Returns None if neither IP is a known decky.""" + if dst_ip in self._ip_to_decky: + return self._ip_to_decky[dst_ip] + if src_ip in self._ip_to_decky: + return self._ip_to_decky[src_ip] + return None + + def _cleanup_sessions(self) -> None: + now = time.monotonic() + stale = [k for k, ts in self._session_ts.items() if now - ts > _SESSION_TTL] + for k in stale: + self._sessions.pop(k, None) + self._session_ts.pop(k, None) + stale_syn = [k for k, v in self._tcp_syn.items() + if now - v.get("time", 0) > _SESSION_TTL] + for k in stale_syn: + self._tcp_syn.pop(k, None) + stale_rtt = [k for k, _ in self._tcp_rtt.items() + if k not in self._sessions and k not in self._session_ts] + for k in stale_rtt: + self._tcp_rtt.pop(k, None) + + def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str: + if event_type == "tls_client_hello": + return fields.get("ja3", "") + "|" + fields.get("ja4", "") + if event_type == "tls_session": + return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") + + "|" + fields.get("ja4", "") + "|" + fields.get("ja4s", "")) + if event_type == "tls_certificate": + return fields.get("subject_cn", "") + "|" + fields.get("issuer", "") + return fields.get("mechanisms", fields.get("resumption", "")) + + def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool: + if self._dedup_ttl <= 0: + return False + now = time.monotonic() + if now - self._dedup_last_cleanup > self._DEDUP_CLEANUP_INTERVAL: + stale = [k for k, ts in self._dedup_cache.items() if now - ts > self._dedup_ttl] + for k in stale: + del self._dedup_cache[k] + self._dedup_last_cleanup = now + src_ip = fields.get("src_ip", "") + fp = self._dedup_key_for(event_type, fields) + cache_key = (src_ip, event_type, fp) + last_seen = self._dedup_cache.get(cache_key) + if last_seen is not None and now - last_seen < self._dedup_ttl: + return True + self._dedup_cache[cache_key] = now + return False + + def _log(self, node_name: str, event_type: str, severity: int = SEVERITY_INFO, **fields: Any) -> None: + if self._is_duplicate(event_type, fields): + return + line = syslog_line(SERVICE_NAME, node_name, event_type, severity=severity, **fields) + self._write_fn(line) + + def on_packet(self, pkt: Any) -> None: + """Process a single scapy packet. Called from the sniff thread.""" + try: + from scapy.layers.inet import IP, TCP + except ImportError: + return + + if not (pkt.haslayer(IP) and pkt.haslayer(TCP)): + return + + ip = pkt[IP] + tcp = pkt[TCP] + + src_ip: str = ip.src + dst_ip: str = ip.dst + src_port: int = tcp.sport + dst_port: int = tcp.dport + flags: int = tcp.flags.value if hasattr(tcp.flags, 'value') else int(tcp.flags) + + # Skip traffic not involving any decky + node_name = self._resolve_decky(src_ip, dst_ip) + if node_name is None: + return + + # TCP SYN tracking for JA4L + if flags & _TCP_SYN and not (flags & _TCP_ACK): + key = (src_ip, src_port, dst_ip, dst_port) + self._tcp_syn[key] = {"time": time.monotonic(), "ttl": ip.ttl} + + elif flags & _TCP_SYN and flags & _TCP_ACK: + rev_key = (dst_ip, dst_port, src_ip, src_port) + syn_data = self._tcp_syn.pop(rev_key, None) + if syn_data: + rtt_ms = round((time.monotonic() - syn_data["time"]) * 1000, 2) + self._tcp_rtt[rev_key] = { + "rtt_ms": rtt_ms, + "client_ttl": syn_data["ttl"], + } + + payload = bytes(tcp.payload) + if not payload: + return + + if payload[0] != _TLS_RECORD_HANDSHAKE: + return + + # ClientHello + ch = _parse_client_hello(payload) + if ch is not None: + self._cleanup_sessions() + + key = (src_ip, src_port, dst_ip, dst_port) + ja3_str, ja3_hash = _ja3(ch) + ja4_hash = _ja4(ch) + resumption = _session_resumption_info(ch) + rtt_data = _ja4l(key, self._tcp_rtt) + + self._sessions[key] = { + "ja3": ja3_hash, + "ja3_str": ja3_str, + "ja4": ja4_hash, + "tls_version": ch["tls_version"], + "cipher_suites": ch["cipher_suites"], + "extensions": ch["extensions"], + "signature_algorithms": ch.get("signature_algorithms", []), + "supported_versions": ch.get("supported_versions", []), + "sni": ch["sni"], + "alpn": ch["alpn"], + "resumption": resumption, + } + self._session_ts[key] = time.monotonic() + + log_fields: dict[str, Any] = { + "src_ip": src_ip, + "src_port": str(src_port), + "dst_ip": dst_ip, + "dst_port": str(dst_port), + "ja3": ja3_hash, + "ja4": ja4_hash, + "tls_version": _tls_version_str(ch["tls_version"]), + "sni": ch["sni"] or "", + "alpn": ",".join(ch["alpn"]), + "raw_ciphers": "-".join(str(c) for c in ch["cipher_suites"]), + "raw_extensions": "-".join(str(e) for e in ch["extensions"]), + } + + if resumption["resumption_attempted"]: + log_fields["resumption"] = ",".join(resumption["mechanisms"]) + + if rtt_data: + log_fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"]) + log_fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"]) + + # Resolve node for the *destination* (the decky being attacked) + target_node = self._ip_to_decky.get(dst_ip, node_name) + self._log(target_node, "tls_client_hello", **log_fields) + return + + # ServerHello + sh = _parse_server_hello(payload) + if sh is not None: + rev_key = (dst_ip, dst_port, src_ip, src_port) + ch_data = self._sessions.pop(rev_key, None) + self._session_ts.pop(rev_key, None) + + ja3s_str, ja3s_hash = _ja3s(sh) + ja4s_hash = _ja4s(sh) + + fields: dict[str, Any] = { + "src_ip": dst_ip, + "src_port": str(dst_port), + "dst_ip": src_ip, + "dst_port": str(src_port), + "ja3s": ja3s_hash, + "ja4s": ja4s_hash, + "tls_version": _tls_version_str(sh["tls_version"]), + } + + if ch_data: + fields["ja3"] = ch_data["ja3"] + fields["ja4"] = ch_data.get("ja4", "") + fields["sni"] = ch_data["sni"] or "" + fields["alpn"] = ",".join(ch_data["alpn"]) + fields["raw_ciphers"] = "-".join(str(c) for c in ch_data["cipher_suites"]) + fields["raw_extensions"] = "-".join(str(e) for e in ch_data["extensions"]) + if ch_data.get("resumption", {}).get("resumption_attempted"): + fields["resumption"] = ",".join(ch_data["resumption"]["mechanisms"]) + + rtt_data = self._tcp_rtt.pop(rev_key, None) + if rtt_data: + fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"]) + fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"]) + + # Server response — resolve by src_ip (the decky responding) + target_node = self._ip_to_decky.get(src_ip, node_name) + self._log(target_node, "tls_session", severity=SEVERITY_WARNING, **fields) + return + + # Certificate (TLS 1.2 only) + cert = _parse_certificate(payload) + if cert is not None: + rev_key = (dst_ip, dst_port, src_ip, src_port) + ch_data = self._sessions.get(rev_key) + + cert_fields: dict[str, Any] = { + "src_ip": dst_ip, + "src_port": str(dst_port), + "dst_ip": src_ip, + "dst_port": str(src_port), + "subject_cn": cert["subject_cn"], + "issuer": cert["issuer"], + "self_signed": str(cert["self_signed"]).lower(), + "not_before": cert["not_before"], + "not_after": cert["not_after"], + } + if cert["sans"]: + cert_fields["sans"] = ",".join(cert["sans"]) + if ch_data: + cert_fields["sni"] = ch_data.get("sni", "") + + target_node = self._ip_to_decky.get(src_ip, node_name) + self._log(target_node, "tls_certificate", **cert_fields) diff --git a/decnet/sniffer/syslog.py b/decnet/sniffer/syslog.py new file mode 100644 index 0000000..1fd7587 --- /dev/null +++ b/decnet/sniffer/syslog.py @@ -0,0 +1,69 @@ +""" +RFC 5424 syslog formatting and log-file writing for the fleet sniffer. + +Reuses the same wire format as templates/sniffer/decnet_logging.py so the +existing collector parser and ingester can consume events without changes. +""" + +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from decnet.collector.worker import parse_rfc5424 + +# ─── Constants (must match templates/sniffer/decnet_logging.py) ────────────── + +_FACILITY_LOCAL0 = 16 +_SD_ID = "decnet@55555" +_NILVALUE = "-" + +SEVERITY_INFO = 6 +SEVERITY_WARNING = 4 + +_MAX_HOSTNAME = 255 +_MAX_APPNAME = 48 +_MAX_MSGID = 32 + + +# ─── Formatter ─────────────────────────────────────────────────────────────── + +def _sd_escape(value: str) -> str: + return value.replace("\\", "\\\\").replace('"', '\\"').replace("]", "\\]") + + +def _sd_element(fields: dict[str, Any]) -> str: + if not fields: + return _NILVALUE + params = " ".join(f'{k}="{_sd_escape(str(v))}"' for k, v in fields.items()) + return f"[{_SD_ID} {params}]" + + +def syslog_line( + service: str, + hostname: str, + event_type: str, + severity: int = SEVERITY_INFO, + msg: str | None = None, + **fields: Any, +) -> str: + pri = f"<{_FACILITY_LOCAL0 * 8 + severity}>" + ts = datetime.now(timezone.utc).isoformat() + host = (hostname or _NILVALUE)[:_MAX_HOSTNAME] + appname = (service or _NILVALUE)[:_MAX_APPNAME] + msgid = (event_type or _NILVALUE)[:_MAX_MSGID] + sd = _sd_element(fields) + message = f" {msg}" if msg else "" + return f"{pri}1 {ts} {host} {appname} {_NILVALUE} {msgid} {sd}{message}" + + +def write_event(line: str, log_path: Path, json_path: Path) -> None: + """Append a syslog line to the raw log and its parsed JSON to the json log.""" + with open(log_path, "a", encoding="utf-8") as lf: + lf.write(line + "\n") + lf.flush() + parsed = parse_rfc5424(line) + if parsed: + with open(json_path, "a", encoding="utf-8") as jf: + jf.write(json.dumps(parsed) + "\n") + jf.flush() diff --git a/decnet/sniffer/worker.py b/decnet/sniffer/worker.py new file mode 100644 index 0000000..91fd15d --- /dev/null +++ b/decnet/sniffer/worker.py @@ -0,0 +1,145 @@ +""" +Fleet-wide MACVLAN sniffer worker. + +Runs as a single host-side async background task that sniffs all TLS +traffic on the MACVLAN host interface. Maps packets to deckies by IP +and feeds fingerprint events into the existing log pipeline. + +Modeled on decnet.collector.worker — same lifecycle pattern. +Fault-isolated: any exception is logged and the worker exits cleanly. +The API never depends on this worker being alive. +""" + +import asyncio +import os +import subprocess +import threading +import time +from pathlib import Path +from typing import Any + +from decnet.logging import get_logger +from decnet.network import HOST_MACVLAN_IFACE +from decnet.sniffer.fingerprint import SnifferEngine +from decnet.sniffer.syslog import write_event + +logger = get_logger("sniffer") + +_IP_MAP_REFRESH_INTERVAL: float = 60.0 + + +def _load_ip_to_decky() -> dict[str, str]: + """Build IP → decky-name mapping from decnet-state.json.""" + from decnet.config import load_state + state = load_state() + if state is None: + return {} + config, _ = state + mapping: dict[str, str] = {} + for decky in config.deckies: + mapping[decky.ip] = decky.name + return mapping + + +def _interface_exists(iface: str) -> bool: + """Check if a network interface exists on this host.""" + try: + result = subprocess.run( + ["ip", "link", "show", iface], + capture_output=True, text=True, check=False, + ) + return result.returncode == 0 + except Exception: + return False + + +def _sniff_loop( + interface: str, + log_path: Path, + json_path: Path, + stop_event: threading.Event, +) -> None: + """Blocking sniff loop. Runs in a dedicated thread via asyncio.to_thread.""" + try: + from scapy.sendrecv import sniff + except ImportError: + logger.error("scapy not installed — sniffer cannot start") + return + + ip_map = _load_ip_to_decky() + if not ip_map: + logger.warning("sniffer: no deckies in state — nothing to sniff") + return + + def _write_fn(line: str) -> None: + write_event(line, log_path, json_path) + + engine = SnifferEngine(ip_to_decky=ip_map, write_fn=_write_fn) + + # Periodically refresh IP map in a background daemon thread + def _refresh_loop() -> None: + while not stop_event.is_set(): + stop_event.wait(_IP_MAP_REFRESH_INTERVAL) + if stop_event.is_set(): + break + try: + new_map = _load_ip_to_decky() + if new_map: + engine.update_ip_map(new_map) + except Exception as exc: + logger.debug("sniffer: ip map refresh failed: %s", exc) + + refresh_thread = threading.Thread(target=_refresh_loop, daemon=True) + refresh_thread.start() + + logger.info("sniffer: sniffing on interface=%s deckies=%d", interface, len(ip_map)) + + try: + sniff( + iface=interface, + filter="tcp", + prn=engine.on_packet, + store=False, + stop_filter=lambda pkt: stop_event.is_set(), + ) + except Exception as exc: + logger.error("sniffer: scapy sniff exited: %s", exc) + finally: + stop_event.set() + logger.info("sniffer: sniff loop ended") + + +async def sniffer_worker(log_file: str) -> None: + """ + Async entry point — started as asyncio.create_task in the API lifespan. + + Fully fault-isolated: catches all exceptions, logs them, and returns + cleanly. The API continues running regardless of sniffer state. + """ + try: + interface = os.environ.get("DECNET_SNIFFER_IFACE", HOST_MACVLAN_IFACE) + + if not _interface_exists(interface): + logger.warning( + "sniffer: interface %s not found — sniffer disabled " + "(fleet may not be deployed yet)", interface, + ) + return + + log_path = Path(log_file) + json_path = log_path.with_suffix(".json") + log_path.parent.mkdir(parents=True, exist_ok=True) + + stop_event = threading.Event() + + try: + await asyncio.to_thread(_sniff_loop, interface, log_path, json_path, stop_event) + except asyncio.CancelledError: + logger.info("sniffer: shutdown requested") + stop_event.set() + raise + + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error("sniffer: worker failed — API continues without sniffing: %s", exc) diff --git a/decnet/web/api.py b/decnet/web/api.py index f1cfbb7..1d8f21b 100644 --- a/decnet/web/api.py +++ b/decnet/web/api.py @@ -21,11 +21,12 @@ log = get_logger("api") ingestion_task: Optional[asyncio.Task[Any]] = None collector_task: Optional[asyncio.Task[Any]] = None attacker_task: Optional[asyncio.Task[Any]] = None +sniffer_task: Optional[asyncio.Task[Any]] = None @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - global ingestion_task, collector_task, attacker_task + global ingestion_task, collector_task, attacker_task, sniffer_task log.info("API startup initialising database") for attempt in range(1, 6): @@ -58,13 +59,22 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: if attacker_task is None or attacker_task.done(): attacker_task = asyncio.create_task(attacker_profile_worker(repo)) log.debug("API startup attacker profile worker started") + + # Start fleet-wide MACVLAN sniffer (fault-isolated — never crashes the API) + try: + from decnet.sniffer import sniffer_worker + if sniffer_task is None or sniffer_task.done(): + sniffer_task = asyncio.create_task(sniffer_worker(_log_file)) + log.debug("API startup sniffer worker started") + except Exception as exc: + log.warning("Sniffer worker failed to start — API continues without sniffing: %s", exc) else: log.info("Contract Test Mode: skipping background worker startup") yield log.info("API shutdown cancelling background tasks") - for task in (ingestion_task, collector_task, attacker_task): + for task in (ingestion_task, collector_task, attacker_task, sniffer_task): if task and not task.done(): task.cancel() try: diff --git a/tests/test_cli_service_pool.py b/tests/test_cli_service_pool.py index 6c673a3..266f0a8 100644 --- a/tests/test_cli_service_pool.py +++ b/tests/test_cli_service_pool.py @@ -10,11 +10,12 @@ from decnet.services.registry import all_services ORIGINAL_5 = {"ssh", "smb", "rdp", "http", "ftp"} -def test_all_service_names_covers_full_registry(): - """_all_service_names() must return every service in the registry.""" +def test_all_service_names_covers_per_decky_services(): + """_all_service_names() must return every per-decky service (not fleet singletons).""" pool = set(_all_service_names()) - registry = set(all_services().keys()) - assert pool == registry + registry = all_services() + per_decky = {name for name, svc in registry.items() if not svc.fleet_singleton} + assert pool == per_decky def test_all_service_names_is_sorted(): diff --git a/tests/test_fleet_singleton.py b/tests/test_fleet_singleton.py new file mode 100644 index 0000000..78664e0 --- /dev/null +++ b/tests/test_fleet_singleton.py @@ -0,0 +1,78 @@ +""" +Tests for fleet_singleton service behavior. + +Verifies that: + - The sniffer is registered but marked as fleet_singleton + - fleet_singleton services are excluded from compose generation + - fleet_singleton services are excluded from random service assignment +""" + +from decnet.composer import generate_compose +from decnet.fleet import all_service_names, build_deckies +from decnet.models import DeckyConfig, DecnetConfig +from decnet.services.registry import all_services, get_service + + +def test_sniffer_is_fleet_singleton(): + svc = get_service("sniffer") + assert svc.fleet_singleton is True + + +def test_non_sniffer_services_are_not_fleet_singleton(): + for name, svc in all_services().items(): + if name == "sniffer": + continue + assert svc.fleet_singleton is False, f"{name} should not be fleet_singleton" + + +def test_sniffer_excluded_from_all_service_names(): + names = all_service_names() + assert "sniffer" not in names + + +def test_sniffer_still_in_registry(): + """Sniffer must remain discoverable in the registry even though it's a singleton.""" + registry = all_services() + assert "sniffer" in registry + + +def test_compose_skips_fleet_singleton(): + """When a decky lists 'sniffer' in its services, compose must not generate a container.""" + config = DecnetConfig( + mode="unihost", + interface="eth0", + subnet="192.168.1.0/24", + gateway="192.168.1.1", + host_ip="192.168.1.5", + deckies=[ + DeckyConfig( + name="decky-01", + ip="192.168.1.10", + services=["ssh", "sniffer"], + distro="debian", + base_image="debian:bookworm-slim", + hostname="test-host", + ), + ], + ) + compose = generate_compose(config) + services = compose["services"] + + assert "decky-01" in services # base container exists + assert "decky-01-ssh" in services # ssh service exists + assert "decky-01-sniffer" not in services # sniffer skipped + + +def test_randomize_never_picks_sniffer(): + """Random service assignment must never include fleet_singleton services.""" + all_drawn: set[str] = set() + for _ in range(100): + deckies = build_deckies( + n=1, + ips=["10.0.0.10"], + services_explicit=None, + randomize_services=True, + ) + all_drawn.update(deckies[0].services) + + assert "sniffer" not in all_drawn diff --git a/tests/test_sniffer_worker.py b/tests/test_sniffer_worker.py new file mode 100644 index 0000000..4a815bb --- /dev/null +++ b/tests/test_sniffer_worker.py @@ -0,0 +1,280 @@ +""" +Tests for the fleet-wide sniffer worker and fingerprinting engine. + +Tests the IP-to-decky mapping, packet callback routing, syslog output +format, dedup logic, and worker fault isolation. +""" + +import struct +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from decnet.sniffer.fingerprint import ( + SnifferEngine, + _ja3, + _ja4, + _ja4_alpn_tag, + _ja4_version, + _ja4s, + _ja3s, + _parse_client_hello, + _parse_server_hello, + _session_resumption_info, + _tls_version_str, +) +from decnet.sniffer.syslog import syslog_line, write_event +from decnet.sniffer.worker import _load_ip_to_decky + + +# ─── Helpers ───────────────────────────────────────────────────────────────── + +def _build_tls_client_hello( + tls_version: int = 0x0303, + cipher_suites: list[int] | None = None, + sni: str = "example.com", +) -> bytes: + """Build a minimal TLS ClientHello payload for testing.""" + if cipher_suites is None: + cipher_suites = [0x1301, 0x1302, 0x1303] + + body = b"" + body += struct.pack("!H", tls_version) # ClientHello version + body += b"\x00" * 32 # Random + body += b"\x00" # Session ID length = 0 + + # Cipher suites + cs_data = b"".join(struct.pack("!H", cs) for cs in cipher_suites) + body += struct.pack("!H", len(cs_data)) + cs_data + + # Compression methods + body += b"\x01\x00" # 1 method, null + + # Extensions + ext_data = b"" + if sni: + sni_bytes = sni.encode("ascii") + sni_ext = struct.pack("!HBH", len(sni_bytes) + 3, 0, len(sni_bytes)) + sni_bytes + ext_data += struct.pack("!HH", 0x0000, len(sni_ext)) + sni_ext + + body += struct.pack("!H", len(ext_data)) + ext_data + + # Handshake header + hs = struct.pack("!B", 0x01) + struct.pack("!I", len(body))[1:] # type + 3-byte length + hs_with_body = hs + body + + # TLS record header + record = struct.pack("!BHH", 0x16, 0x0301, len(hs_with_body)) + hs_with_body + return record + + +def _build_tls_server_hello( + tls_version: int = 0x0303, + cipher_suite: int = 0x1301, +) -> bytes: + """Build a minimal TLS ServerHello payload for testing.""" + body = b"" + body += struct.pack("!H", tls_version) + body += b"\x00" * 32 # Random + body += b"\x00" # Session ID length = 0 + body += struct.pack("!H", cipher_suite) + body += b"\x00" # Compression method + + # No extensions + body += struct.pack("!H", 0) + + hs = struct.pack("!B", 0x02) + struct.pack("!I", len(body))[1:] + hs_with_body = hs + body + + record = struct.pack("!BHH", 0x16, 0x0301, len(hs_with_body)) + hs_with_body + return record + + +# ─── TLS parser tests ─────────────────────────────────────────────────────── + +class TestTlsParsers: + def test_parse_client_hello_valid(self): + data = _build_tls_client_hello() + result = _parse_client_hello(data) + assert result is not None + assert result["tls_version"] == 0x0303 + assert result["cipher_suites"] == [0x1301, 0x1302, 0x1303] + assert result["sni"] == "example.com" + + def test_parse_client_hello_no_sni(self): + data = _build_tls_client_hello(sni="") + result = _parse_client_hello(data) + assert result is not None + assert result["sni"] == "" + + def test_parse_client_hello_invalid_data(self): + assert _parse_client_hello(b"\x00\x01") is None + assert _parse_client_hello(b"") is None + assert _parse_client_hello(b"\x16\x03\x01\x00\x00") is None + + def test_parse_server_hello_valid(self): + data = _build_tls_server_hello() + result = _parse_server_hello(data) + assert result is not None + assert result["tls_version"] == 0x0303 + assert result["cipher_suite"] == 0x1301 + + def test_parse_server_hello_invalid(self): + assert _parse_server_hello(b"garbage") is None + + +# ─── Fingerprint computation tests ────────────────────────────────────────── + +class TestFingerprints: + def test_ja3_deterministic(self): + data = _build_tls_client_hello() + ch = _parse_client_hello(data) + ja3_str1, hash1 = _ja3(ch) + ja3_str2, hash2 = _ja3(ch) + assert hash1 == hash2 + assert len(hash1) == 32 # MD5 hex + + def test_ja4_format(self): + data = _build_tls_client_hello() + ch = _parse_client_hello(data) + ja4 = _ja4(ch) + parts = ja4.split("_") + assert len(parts) == 3 + assert parts[0].startswith("t") # TCP + + def test_ja3s_deterministic(self): + data = _build_tls_server_hello() + sh = _parse_server_hello(data) + _, hash1 = _ja3s(sh) + _, hash2 = _ja3s(sh) + assert hash1 == hash2 + + def test_ja4s_format(self): + data = _build_tls_server_hello() + sh = _parse_server_hello(data) + ja4s = _ja4s(sh) + parts = ja4s.split("_") + assert len(parts) == 2 + assert parts[0].startswith("t") + + def test_tls_version_str(self): + assert _tls_version_str(0x0303) == "TLS 1.2" + assert _tls_version_str(0x0304) == "TLS 1.3" + assert "0x" in _tls_version_str(0x9999) + + def test_ja4_version_with_supported_versions(self): + ch = {"tls_version": 0x0303, "supported_versions": [0x0303, 0x0304]} + assert _ja4_version(ch) == "13" + + def test_ja4_alpn_tag(self): + assert _ja4_alpn_tag([]) == "00" + assert _ja4_alpn_tag(["h2"]) == "h2" + assert _ja4_alpn_tag(["http/1.1"]) == "h1" + + def test_session_resumption_info(self): + ch = {"has_session_ticket_data": True, "has_pre_shared_key": False, + "has_early_data": False, "session_id": b""} + info = _session_resumption_info(ch) + assert info["resumption_attempted"] is True + assert "session_ticket" in info["mechanisms"] + + +# ─── Syslog format tests ──────────────────────────────────────────────────── + +class TestSyslog: + def test_syslog_line_format(self): + line = syslog_line("sniffer", "decky-01", "tls_client_hello", src_ip="10.0.0.1") + assert "<134>" in line # PRI for local0 + INFO + assert "decky-01" in line + assert "sniffer" in line + assert "tls_client_hello" in line + assert 'src_ip="10.0.0.1"' in line + + def test_write_event_creates_files(self, tmp_path): + log_path = tmp_path / "test.log" + json_path = tmp_path / "test.json" + line = syslog_line("sniffer", "decky-01", "tls_client_hello", src_ip="10.0.0.1") + write_event(line, log_path, json_path) + assert log_path.exists() + assert json_path.exists() + assert "decky-01" in log_path.read_text() + + +# ─── SnifferEngine tests ──────────────────────────────────────────────────── + +class TestSnifferEngine: + def test_resolve_decky_by_dst(self): + engine = SnifferEngine( + ip_to_decky={"192.168.1.10": "decky-01"}, + write_fn=lambda _: None, + ) + assert engine._resolve_decky("10.0.0.1", "192.168.1.10") == "decky-01" + + def test_resolve_decky_by_src(self): + engine = SnifferEngine( + ip_to_decky={"192.168.1.10": "decky-01"}, + write_fn=lambda _: None, + ) + assert engine._resolve_decky("192.168.1.10", "10.0.0.1") == "decky-01" + + def test_resolve_decky_unknown(self): + engine = SnifferEngine( + ip_to_decky={"192.168.1.10": "decky-01"}, + write_fn=lambda _: None, + ) + assert engine._resolve_decky("10.0.0.1", "10.0.0.2") is None + + def test_update_ip_map(self): + engine = SnifferEngine( + ip_to_decky={"192.168.1.10": "decky-01"}, + write_fn=lambda _: None, + ) + engine.update_ip_map({"192.168.1.20": "decky-02"}) + assert engine._resolve_decky("10.0.0.1", "192.168.1.20") == "decky-02" + assert engine._resolve_decky("10.0.0.1", "192.168.1.10") is None + + def test_dedup_suppresses_identical_events(self): + written: list[str] = [] + engine = SnifferEngine( + ip_to_decky={}, + write_fn=written.append, + dedup_ttl=300.0, + ) + fields = {"src_ip": "10.0.0.1", "ja3": "abc", "ja4": "def"} + engine._log("decky-01", "tls_client_hello", **fields) + engine._log("decky-01", "tls_client_hello", **fields) + assert len(written) == 1 + + def test_dedup_allows_different_fingerprints(self): + written: list[str] = [] + engine = SnifferEngine( + ip_to_decky={}, + write_fn=written.append, + dedup_ttl=300.0, + ) + engine._log("decky-01", "tls_client_hello", src_ip="10.0.0.1", ja3="abc", ja4="def") + engine._log("decky-01", "tls_client_hello", src_ip="10.0.0.1", ja3="xyz", ja4="uvw") + assert len(written) == 2 + + def test_dedup_disabled_when_ttl_zero(self): + written: list[str] = [] + engine = SnifferEngine( + ip_to_decky={}, + write_fn=written.append, + dedup_ttl=0, + ) + fields = {"src_ip": "10.0.0.1", "ja3": "abc", "ja4": "def"} + engine._log("decky-01", "tls_client_hello", **fields) + engine._log("decky-01", "tls_client_hello", **fields) + assert len(written) == 2 + + +# ─── Worker IP map loading ────────────────────────────────────────────────── + +class TestWorkerIpMap: + def test_load_ip_to_decky_no_state(self): + with patch("decnet.config.load_state", return_value=None): + result = _load_ip_to_decky() + assert result == {}