""" TLS fingerprinting engine for the fleet-wide MACVLAN sniffer. Extracted from templates/sniffer/server.py. All pure-Python TLS parsing, JA3/JA3S/JA4/JA4S/JA4L computation, session tracking, and dedup logic lives here. The packet callback is parameterized to accept an IP-to-decky mapping and a write function, so it works for fleet-wide sniffing. """ from __future__ import annotations import hashlib import struct import time from typing import Any, Callable from decnet.sniffer.syslog import SEVERITY_INFO, SEVERITY_WARNING, syslog_line # ─── Constants ─────────────────────────────────────────────────────────────── SERVICE_NAME: str = "sniffer" _SESSION_TTL: float = 60.0 _DEDUP_TTL: float = 300.0 _GREASE: frozenset[int] = frozenset(0x0A0A + i * 0x1010 for i in range(16)) _TLS_RECORD_HANDSHAKE: int = 0x16 _TLS_HT_CLIENT_HELLO: int = 0x01 _TLS_HT_SERVER_HELLO: int = 0x02 _TLS_HT_CERTIFICATE: int = 0x0B _EXT_SNI: int = 0x0000 _EXT_SUPPORTED_GROUPS: int = 0x000A _EXT_EC_POINT_FORMATS: int = 0x000B _EXT_SIGNATURE_ALGORITHMS: int = 0x000D _EXT_ALPN: int = 0x0010 _EXT_SESSION_TICKET: int = 0x0023 _EXT_SUPPORTED_VERSIONS: int = 0x002B _EXT_PRE_SHARED_KEY: int = 0x0029 _EXT_EARLY_DATA: int = 0x002A _TCP_SYN: int = 0x02 _TCP_ACK: int = 0x10 # ─── GREASE helpers ────────────────────────────────────────────────────────── def _is_grease(value: int) -> bool: return value in _GREASE def _filter_grease(values: list[int]) -> list[int]: return [v for v in values if not _is_grease(v)] # ─── TLS parsers ───────────────────────────────────────────────────────────── def _parse_client_hello(data: bytes) -> dict[str, Any] | None: try: if len(data) < 6: return None if data[0] != _TLS_RECORD_HANDSHAKE: return None record_len = struct.unpack_from("!H", data, 3)[0] if len(data) < 5 + record_len: return None hs = data[5:] if hs[0] != _TLS_HT_CLIENT_HELLO: return None hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] body = hs[4: 4 + hs_len] if len(body) < 34: return None pos = 0 tls_version = struct.unpack_from("!H", body, pos)[0] pos += 2 pos += 32 # Random session_id_len = body[pos] session_id = body[pos + 1: pos + 1 + session_id_len] pos += 1 + session_id_len cs_len = struct.unpack_from("!H", body, pos)[0] pos += 2 cipher_suites = [ struct.unpack_from("!H", body, pos + i * 2)[0] for i in range(cs_len // 2) ] pos += cs_len comp_len = body[pos] pos += 1 + comp_len extensions: list[int] = [] supported_groups: list[int] = [] ec_point_formats: list[int] = [] signature_algorithms: list[int] = [] supported_versions: list[int] = [] sni: str = "" alpn: list[str] = [] has_session_ticket_data: bool = False has_pre_shared_key: bool = False has_early_data: bool = False if pos + 2 <= len(body): ext_total = struct.unpack_from("!H", body, pos)[0] pos += 2 ext_end = pos + ext_total while pos + 4 <= ext_end: ext_type = struct.unpack_from("!H", body, pos)[0] ext_len = struct.unpack_from("!H", body, pos + 2)[0] ext_data = body[pos + 4: pos + 4 + ext_len] pos += 4 + ext_len if not _is_grease(ext_type): extensions.append(ext_type) if ext_type == _EXT_SNI and len(ext_data) > 5: sni = ext_data[5:].decode("ascii", errors="replace") elif ext_type == _EXT_SUPPORTED_GROUPS and len(ext_data) >= 2: grp_len = struct.unpack_from("!H", ext_data, 0)[0] supported_groups = [ struct.unpack_from("!H", ext_data, 2 + i * 2)[0] for i in range(grp_len // 2) ] elif ext_type == _EXT_EC_POINT_FORMATS and len(ext_data) >= 1: pf_len = ext_data[0] ec_point_formats = list(ext_data[1: 1 + pf_len]) elif ext_type == _EXT_ALPN and len(ext_data) >= 2: proto_list_len = struct.unpack_from("!H", ext_data, 0)[0] ap = 2 while ap < 2 + proto_list_len: plen = ext_data[ap] alpn.append(ext_data[ap + 1: ap + 1 + plen].decode("ascii", errors="replace")) ap += 1 + plen elif ext_type == _EXT_SIGNATURE_ALGORITHMS and len(ext_data) >= 2: sa_len = struct.unpack_from("!H", ext_data, 0)[0] signature_algorithms = [ struct.unpack_from("!H", ext_data, 2 + i * 2)[0] for i in range(sa_len // 2) ] elif ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 1: sv_len = ext_data[0] supported_versions = [ struct.unpack_from("!H", ext_data, 1 + i * 2)[0] for i in range(sv_len // 2) ] elif ext_type == _EXT_SESSION_TICKET: has_session_ticket_data = len(ext_data) > 0 elif ext_type == _EXT_PRE_SHARED_KEY: has_pre_shared_key = True elif ext_type == _EXT_EARLY_DATA: has_early_data = True filtered_ciphers = _filter_grease(cipher_suites) filtered_groups = _filter_grease(supported_groups) filtered_sig_algs = _filter_grease(signature_algorithms) filtered_versions = _filter_grease(supported_versions) return { "tls_version": tls_version, "cipher_suites": filtered_ciphers, "extensions": extensions, "supported_groups": filtered_groups, "ec_point_formats": ec_point_formats, "signature_algorithms": filtered_sig_algs, "supported_versions": filtered_versions, "sni": sni, "alpn": alpn, "session_id": session_id, "has_session_ticket_data": has_session_ticket_data, "has_pre_shared_key": has_pre_shared_key, "has_early_data": has_early_data, } except Exception: return None def _parse_server_hello(data: bytes) -> dict[str, Any] | None: try: if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE: return None hs = data[5:] if hs[0] != _TLS_HT_SERVER_HELLO: return None hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] body = hs[4: 4 + hs_len] if len(body) < 35: return None pos = 0 tls_version = struct.unpack_from("!H", body, pos)[0] pos += 2 pos += 32 # Random session_id_len = body[pos] pos += 1 + session_id_len if pos + 2 > len(body): return None cipher_suite = struct.unpack_from("!H", body, pos)[0] pos += 2 pos += 1 # Compression method extensions: list[int] = [] selected_version: int | None = None alpn: str = "" if pos + 2 <= len(body): ext_total = struct.unpack_from("!H", body, pos)[0] pos += 2 ext_end = pos + ext_total while pos + 4 <= ext_end: ext_type = struct.unpack_from("!H", body, pos)[0] ext_len = struct.unpack_from("!H", body, pos + 2)[0] ext_data = body[pos + 4: pos + 4 + ext_len] pos += 4 + ext_len if not _is_grease(ext_type): extensions.append(ext_type) if ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2: selected_version = struct.unpack_from("!H", ext_data, 0)[0] elif ext_type == _EXT_ALPN and len(ext_data) >= 2: proto_list_len = struct.unpack_from("!H", ext_data, 0)[0] if proto_list_len > 0 and len(ext_data) >= 4: plen = ext_data[2] alpn = ext_data[3: 3 + plen].decode("ascii", errors="replace") return { "tls_version": tls_version, "cipher_suite": cipher_suite, "extensions": extensions, "selected_version": selected_version, "alpn": alpn, } except Exception: return None def _parse_certificate(data: bytes) -> dict[str, Any] | None: try: if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE: return None hs = data[5:] if hs[0] != _TLS_HT_CERTIFICATE: return None hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] body = hs[4: 4 + hs_len] if len(body) < 3: return None certs_len = struct.unpack_from("!I", b"\x00" + body[0:3])[0] if certs_len == 0: return None pos = 3 if pos + 3 > len(body): return None cert_len = struct.unpack_from("!I", b"\x00" + body[pos:pos + 3])[0] pos += 3 if pos + cert_len > len(body): return None cert_der = body[pos: pos + cert_len] return _parse_x509_der(cert_der) except Exception: return None # ─── Minimal DER/ASN.1 X.509 parser ───────────────────────────────────────── def _der_read_tag_len(data: bytes, pos: int) -> tuple[int, int, int]: tag = data[pos] pos += 1 length_byte = data[pos] pos += 1 if length_byte & 0x80: num_bytes = length_byte & 0x7F length = int.from_bytes(data[pos: pos + num_bytes], "big") pos += num_bytes else: length = length_byte return tag, pos, length def _der_read_sequence(data: bytes, pos: int) -> tuple[int, int]: tag, content_start, length = _der_read_tag_len(data, pos) return content_start, length def _der_read_oid(data: bytes, pos: int, length: int) -> str: if length < 1: return "" first = data[pos] oid_parts = [str(first // 40), str(first % 40)] val = 0 for i in range(1, length): b = data[pos + i] val = (val << 7) | (b & 0x7F) if not (b & 0x80): oid_parts.append(str(val)) val = 0 return ".".join(oid_parts) def _der_extract_cn(data: bytes, start: int, length: int) -> str: pos = start end = start + length while pos < end: set_tag, set_start, set_len = _der_read_tag_len(data, pos) if set_tag != 0x31: break set_end = set_start + set_len attr_pos = set_start while attr_pos < set_end: seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos) if seq_tag != 0x30: break oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start) if oid_tag == 0x06: oid = _der_read_oid(data, oid_start, oid_len) if oid == "2.5.4.3": val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len) return data[val_start: val_start + val_len].decode("utf-8", errors="replace") attr_pos = seq_start + seq_len pos = set_end return "" def _der_extract_name_str(data: bytes, start: int, length: int) -> str: parts: list[str] = [] pos = start end = start + length oid_names = { "2.5.4.3": "CN", "2.5.4.6": "C", "2.5.4.7": "L", "2.5.4.8": "ST", "2.5.4.10": "O", "2.5.4.11": "OU", } while pos < end: set_tag, set_start, set_len = _der_read_tag_len(data, pos) if set_tag != 0x31: break set_end = set_start + set_len attr_pos = set_start while attr_pos < set_end: seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos) if seq_tag != 0x30: break oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start) if oid_tag == 0x06: oid = _der_read_oid(data, oid_start, oid_len) val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len) val = data[val_start: val_start + val_len].decode("utf-8", errors="replace") name = oid_names.get(oid, oid) parts.append(f"{name}={val}") attr_pos = seq_start + seq_len pos = set_end return ", ".join(parts) def _parse_x509_der(cert_der: bytes) -> dict[str, Any] | None: try: outer_start, outer_len = _der_read_sequence(cert_der, 0) tbs_tag, tbs_start, tbs_len = _der_read_tag_len(cert_der, outer_start) tbs_end = tbs_start + tbs_len pos = tbs_start if cert_der[pos] == 0xA0: _, v_start, v_len = _der_read_tag_len(cert_der, pos) pos = v_start + v_len _, sn_start, sn_len = _der_read_tag_len(cert_der, pos) pos = sn_start + sn_len _, sa_start, sa_len = _der_read_tag_len(cert_der, pos) pos = sa_start + sa_len issuer_tag, issuer_start, issuer_len = _der_read_tag_len(cert_der, pos) issuer_str = _der_extract_name_str(cert_der, issuer_start, issuer_len) issuer_cn = _der_extract_cn(cert_der, issuer_start, issuer_len) pos = issuer_start + issuer_len val_tag, val_start, val_len = _der_read_tag_len(cert_der, pos) nb_tag, nb_start, nb_len = _der_read_tag_len(cert_der, val_start) not_before = cert_der[nb_start: nb_start + nb_len].decode("ascii", errors="replace") na_tag, na_start, na_len = _der_read_tag_len(cert_der, nb_start + nb_len) not_after = cert_der[na_start: na_start + na_len].decode("ascii", errors="replace") pos = val_start + val_len subj_tag, subj_start, subj_len = _der_read_tag_len(cert_der, pos) subject_cn = _der_extract_cn(cert_der, subj_start, subj_len) subject_str = _der_extract_name_str(cert_der, subj_start, subj_len) self_signed = (issuer_cn == subject_cn) and subject_cn != "" pos = subj_start + subj_len sans: list[str] = _extract_sans(cert_der, pos, tbs_end) return { "subject_cn": subject_cn, "subject": subject_str, "issuer": issuer_str, "issuer_cn": issuer_cn, "not_before": not_before, "not_after": not_after, "self_signed": self_signed, "sans": sans, } except Exception: return None def _extract_sans(cert_der: bytes, pos: int, end: int) -> list[str]: sans: list[str] = [] try: if pos >= end: return sans spki_tag, spki_start, spki_len = _der_read_tag_len(cert_der, pos) pos = spki_start + spki_len while pos < end: tag = cert_der[pos] if tag == 0xA3: _, ext_wrap_start, ext_wrap_len = _der_read_tag_len(cert_der, pos) _, exts_start, exts_len = _der_read_tag_len(cert_der, ext_wrap_start) epos = exts_start eend = exts_start + exts_len while epos < eend: ext_tag, ext_start, ext_len = _der_read_tag_len(cert_der, epos) ext_end = ext_start + ext_len oid_tag, oid_start, oid_len = _der_read_tag_len(cert_der, ext_start) if oid_tag == 0x06: oid = _der_read_oid(cert_der, oid_start, oid_len) if oid == "2.5.29.17": vpos = oid_start + oid_len if vpos < ext_end and cert_der[vpos] == 0x01: _, bs, bl = _der_read_tag_len(cert_der, vpos) vpos = bs + bl if vpos < ext_end: os_tag, os_start, os_len = _der_read_tag_len(cert_der, vpos) if os_tag == 0x04: sans = _parse_san_sequence(cert_der, os_start, os_len) epos = ext_end break else: _, skip_start, skip_len = _der_read_tag_len(cert_der, pos) pos = skip_start + skip_len except Exception: pass return sans def _parse_san_sequence(data: bytes, start: int, length: int) -> list[str]: names: list[str] = [] try: seq_tag, seq_start, seq_len = _der_read_tag_len(data, start) pos = seq_start end = seq_start + seq_len while pos < end: tag = data[pos] _, val_start, val_len = _der_read_tag_len(data, pos) context_tag = tag & 0x1F if context_tag == 2: names.append(data[val_start: val_start + val_len].decode("ascii", errors="replace")) elif context_tag == 7 and val_len == 4: names.append(".".join(str(b) for b in data[val_start: val_start + val_len])) pos = val_start + val_len except Exception: pass return names # ─── JA3 / JA3S ───────────────────────────────────────────────────────────── def _tls_version_str(version: int) -> str: return { 0x0301: "TLS 1.0", 0x0302: "TLS 1.1", 0x0303: "TLS 1.2", 0x0304: "TLS 1.3", 0x0200: "SSL 2.0", 0x0300: "SSL 3.0", }.get(version, f"0x{version:04x}") def _ja3(ch: dict[str, Any]) -> tuple[str, str]: parts = [ str(ch["tls_version"]), "-".join(str(c) for c in ch["cipher_suites"]), "-".join(str(e) for e in ch["extensions"]), "-".join(str(g) for g in ch["supported_groups"]), "-".join(str(p) for p in ch["ec_point_formats"]), ] ja3_str = ",".join(parts) return ja3_str, hashlib.md5(ja3_str.encode()).hexdigest() # nosec B324 def _ja3s(sh: dict[str, Any]) -> tuple[str, str]: parts = [ str(sh["tls_version"]), str(sh["cipher_suite"]), "-".join(str(e) for e in sh["extensions"]), ] ja3s_str = ",".join(parts) return ja3s_str, hashlib.md5(ja3s_str.encode()).hexdigest() # nosec B324 # ─── JA4 / JA4S ───────────────────────────────────────────────────────────── def _ja4_version(ch: dict[str, Any]) -> str: versions = ch.get("supported_versions", []) if versions: best = max(versions) else: best = ch["tls_version"] return { 0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10", 0x0300: "s3", 0x0200: "s2", }.get(best, "00") def _ja4_alpn_tag(alpn_list: list[str] | str) -> str: if isinstance(alpn_list, str): proto = alpn_list elif alpn_list: proto = alpn_list[0] else: return "00" if not proto: return "00" if len(proto) == 1: return proto[0] + proto[0] return proto[0] + proto[-1] def _sha256_12(text: str) -> str: return hashlib.sha256(text.encode()).hexdigest()[:12] def _ja4(ch: dict[str, Any]) -> str: proto = "t" ver = _ja4_version(ch) sni_flag = "d" if ch.get("sni") else "i" cs_count = min(len(ch["cipher_suites"]), 99) ext_count = min(len(ch["extensions"]), 99) alpn_tag = _ja4_alpn_tag(ch.get("alpn", [])) section_a = f"{proto}{ver}{sni_flag}{cs_count:02d}{ext_count:02d}{alpn_tag}" sorted_cs = sorted(ch["cipher_suites"]) section_b = _sha256_12(",".join(str(c) for c in sorted_cs)) sorted_ext = sorted(ch["extensions"]) sorted_sa = sorted(ch.get("signature_algorithms", [])) ext_str = ",".join(str(e) for e in sorted_ext) sa_str = ",".join(str(s) for s in sorted_sa) combined = f"{ext_str}_{sa_str}" if sa_str else ext_str section_c = _sha256_12(combined) return f"{section_a}_{section_b}_{section_c}" def _ja4s(sh: dict[str, Any]) -> str: proto = "t" selected = sh.get("selected_version") if selected: ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10", 0x0300: "s3", 0x0200: "s2"}.get(selected, "00") else: ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10", 0x0300: "s3", 0x0200: "s2"}.get(sh["tls_version"], "00") ext_count = min(len(sh["extensions"]), 99) alpn_tag = _ja4_alpn_tag(sh.get("alpn", "")) section_a = f"{proto}{ver}{ext_count:02d}{alpn_tag}" sorted_ext = sorted(sh["extensions"]) inner = f"{sh['cipher_suite']},{','.join(str(e) for e in sorted_ext)}" section_b = _sha256_12(inner) return f"{section_a}_{section_b}" # ─── JA4L (latency) ───────────────────────────────────────────────────────── def _ja4l( key: tuple[str, int, str, int], tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]], ) -> dict[str, Any] | None: return tcp_rtt.get(key) # ─── Session resumption ───────────────────────────────────────────────────── def _session_resumption_info(ch: dict[str, Any]) -> dict[str, Any]: mechanisms: list[str] = [] if ch.get("has_session_ticket_data"): mechanisms.append("session_ticket") if ch.get("has_pre_shared_key"): mechanisms.append("psk") if ch.get("has_early_data"): mechanisms.append("early_data_0rtt") if ch.get("session_id") and len(ch["session_id"]) > 0: mechanisms.append("session_id") return { "resumption_attempted": len(mechanisms) > 0, "mechanisms": mechanisms, } # ─── Sniffer engine (stateful, one instance per worker) ───────────────────── class SnifferEngine: """ Stateful TLS fingerprinting engine. Tracks sessions, TCP RTTs, and dedup state. Thread-safe only when called from a single thread (the scapy sniff thread). """ def __init__( self, ip_to_decky: dict[str, str], write_fn: Callable[[str], None], dedup_ttl: float = 300.0, ): self._ip_to_decky = ip_to_decky self._write_fn = write_fn self._dedup_ttl = dedup_ttl self._sessions: dict[tuple[str, int, str, int], dict[str, Any]] = {} self._session_ts: dict[tuple[str, int, str, int], float] = {} self._tcp_syn: dict[tuple[str, int, str, int], dict[str, Any]] = {} self._tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]] = {} self._dedup_cache: dict[tuple[str, str, str], float] = {} self._dedup_last_cleanup: float = 0.0 self._DEDUP_CLEANUP_INTERVAL: float = 60.0 def update_ip_map(self, ip_to_decky: dict[str, str]) -> None: self._ip_to_decky = ip_to_decky def _resolve_decky(self, src_ip: str, dst_ip: str) -> str | None: """Map a packet to a decky name. Returns None if neither IP is a known decky.""" if dst_ip in self._ip_to_decky: return self._ip_to_decky[dst_ip] if src_ip in self._ip_to_decky: return self._ip_to_decky[src_ip] return None def _cleanup_sessions(self) -> None: now = time.monotonic() stale = [k for k, ts in self._session_ts.items() if now - ts > _SESSION_TTL] for k in stale: self._sessions.pop(k, None) self._session_ts.pop(k, None) stale_syn = [k for k, v in self._tcp_syn.items() if now - v.get("time", 0) > _SESSION_TTL] for k in stale_syn: self._tcp_syn.pop(k, None) stale_rtt = [k for k, _ in self._tcp_rtt.items() if k not in self._sessions and k not in self._session_ts] for k in stale_rtt: self._tcp_rtt.pop(k, None) def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str: if event_type == "tls_client_hello": return fields.get("ja3", "") + "|" + fields.get("ja4", "") if event_type == "tls_session": return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") + "|" + fields.get("ja4", "") + "|" + fields.get("ja4s", "")) if event_type == "tls_certificate": return fields.get("subject_cn", "") + "|" + fields.get("issuer", "") return fields.get("mechanisms", fields.get("resumption", "")) def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool: if self._dedup_ttl <= 0: return False now = time.monotonic() if now - self._dedup_last_cleanup > self._DEDUP_CLEANUP_INTERVAL: stale = [k for k, ts in self._dedup_cache.items() if now - ts > self._dedup_ttl] for k in stale: del self._dedup_cache[k] self._dedup_last_cleanup = now src_ip = fields.get("src_ip", "") fp = self._dedup_key_for(event_type, fields) cache_key = (src_ip, event_type, fp) last_seen = self._dedup_cache.get(cache_key) if last_seen is not None and now - last_seen < self._dedup_ttl: return True self._dedup_cache[cache_key] = now return False def _log(self, node_name: str, event_type: str, severity: int = SEVERITY_INFO, **fields: Any) -> None: if self._is_duplicate(event_type, fields): return line = syslog_line(SERVICE_NAME, node_name, event_type, severity=severity, **fields) self._write_fn(line) def on_packet(self, pkt: Any) -> None: """Process a single scapy packet. Called from the sniff thread.""" try: from scapy.layers.inet import IP, TCP except ImportError: return if not (pkt.haslayer(IP) and pkt.haslayer(TCP)): return ip = pkt[IP] tcp = pkt[TCP] src_ip: str = ip.src dst_ip: str = ip.dst src_port: int = tcp.sport dst_port: int = tcp.dport flags: int = tcp.flags.value if hasattr(tcp.flags, 'value') else int(tcp.flags) # Skip traffic not involving any decky node_name = self._resolve_decky(src_ip, dst_ip) if node_name is None: return # TCP SYN tracking for JA4L if flags & _TCP_SYN and not (flags & _TCP_ACK): key = (src_ip, src_port, dst_ip, dst_port) self._tcp_syn[key] = {"time": time.monotonic(), "ttl": ip.ttl} elif flags & _TCP_SYN and flags & _TCP_ACK: rev_key = (dst_ip, dst_port, src_ip, src_port) syn_data = self._tcp_syn.pop(rev_key, None) if syn_data: rtt_ms = round((time.monotonic() - syn_data["time"]) * 1000, 2) self._tcp_rtt[rev_key] = { "rtt_ms": rtt_ms, "client_ttl": syn_data["ttl"], } payload = bytes(tcp.payload) if not payload: return if payload[0] != _TLS_RECORD_HANDSHAKE: return # ClientHello ch = _parse_client_hello(payload) if ch is not None: self._cleanup_sessions() key = (src_ip, src_port, dst_ip, dst_port) ja3_str, ja3_hash = _ja3(ch) ja4_hash = _ja4(ch) resumption = _session_resumption_info(ch) rtt_data = _ja4l(key, self._tcp_rtt) self._sessions[key] = { "ja3": ja3_hash, "ja3_str": ja3_str, "ja4": ja4_hash, "tls_version": ch["tls_version"], "cipher_suites": ch["cipher_suites"], "extensions": ch["extensions"], "signature_algorithms": ch.get("signature_algorithms", []), "supported_versions": ch.get("supported_versions", []), "sni": ch["sni"], "alpn": ch["alpn"], "resumption": resumption, } self._session_ts[key] = time.monotonic() log_fields: dict[str, Any] = { "src_ip": src_ip, "src_port": str(src_port), "dst_ip": dst_ip, "dst_port": str(dst_port), "ja3": ja3_hash, "ja4": ja4_hash, "tls_version": _tls_version_str(ch["tls_version"]), "sni": ch["sni"] or "", "alpn": ",".join(ch["alpn"]), "raw_ciphers": "-".join(str(c) for c in ch["cipher_suites"]), "raw_extensions": "-".join(str(e) for e in ch["extensions"]), } if resumption["resumption_attempted"]: log_fields["resumption"] = ",".join(resumption["mechanisms"]) if rtt_data: log_fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"]) log_fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"]) # Resolve node for the *destination* (the decky being attacked) target_node = self._ip_to_decky.get(dst_ip, node_name) self._log(target_node, "tls_client_hello", **log_fields) return # ServerHello sh = _parse_server_hello(payload) if sh is not None: rev_key = (dst_ip, dst_port, src_ip, src_port) ch_data = self._sessions.pop(rev_key, None) self._session_ts.pop(rev_key, None) ja3s_str, ja3s_hash = _ja3s(sh) ja4s_hash = _ja4s(sh) fields: dict[str, Any] = { "src_ip": dst_ip, "src_port": str(dst_port), "dst_ip": src_ip, "dst_port": str(src_port), "ja3s": ja3s_hash, "ja4s": ja4s_hash, "tls_version": _tls_version_str(sh["tls_version"]), } if ch_data: fields["ja3"] = ch_data["ja3"] fields["ja4"] = ch_data.get("ja4", "") fields["sni"] = ch_data["sni"] or "" fields["alpn"] = ",".join(ch_data["alpn"]) fields["raw_ciphers"] = "-".join(str(c) for c in ch_data["cipher_suites"]) fields["raw_extensions"] = "-".join(str(e) for e in ch_data["extensions"]) if ch_data.get("resumption", {}).get("resumption_attempted"): fields["resumption"] = ",".join(ch_data["resumption"]["mechanisms"]) rtt_data = self._tcp_rtt.pop(rev_key, None) if rtt_data: fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"]) fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"]) # Server response — resolve by src_ip (the decky responding) target_node = self._ip_to_decky.get(src_ip, node_name) self._log(target_node, "tls_session", severity=SEVERITY_WARNING, **fields) return # Certificate (TLS 1.2 only) cert = _parse_certificate(payload) if cert is not None: rev_key = (dst_ip, dst_port, src_ip, src_port) ch_data = self._sessions.get(rev_key) cert_fields: dict[str, Any] = { "src_ip": dst_ip, "src_port": str(dst_port), "dst_ip": src_ip, "dst_port": str(src_port), "subject_cn": cert["subject_cn"], "issuer": cert["issuer"], "self_signed": str(cert["self_signed"]).lower(), "not_before": cert["not_before"], "not_after": cert["not_after"], } if cert["sans"]: cert_fields["sans"] = ",".join(cert["sans"]) if ch_data: cert_fields["sni"] = ch_data.get("sni", "") target_node = self._ip_to_decky.get(src_ip, node_name) self._log(target_node, "tls_certificate", **cert_fields)