Files
DECNET/decnet/sniffer/fingerprint.py
anti d7da3a7fc7 feat: add advanced OS fingerprinting via p0f integration
- decnet/sniffer/fingerprint.py: enhance TCP/IP fingerprinting pipeline
- decnet/sniffer/p0f.py: integrate p0f for passive OS classification
- Improves attacker profiling accuracy in honeypot interaction analysis
2026-04-15 12:51:17 -04:00

1135 lines
42 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
TLS fingerprinting engine for the fleet-wide MACVLAN sniffer.
Extracted from templates/sniffer/server.py. All pure-Python TLS parsing,
JA3/JA3S/JA4/JA4S/JA4L computation, session tracking, and dedup logic
lives here. The packet callback is parameterized to accept an IP-to-decky
mapping and a write function, so it works for fleet-wide sniffing.
"""
from __future__ import annotations
import hashlib
import struct
import time
from typing import Any, Callable
from decnet.prober.tcpfp import _extract_options_order
from decnet.sniffer.p0f import guess_os, hop_distance, initial_ttl
from decnet.sniffer.syslog import SEVERITY_INFO, SEVERITY_WARNING, syslog_line
# ─── Constants ───────────────────────────────────────────────────────────────
SERVICE_NAME: str = "sniffer"
_SESSION_TTL: float = 60.0
_DEDUP_TTL: float = 300.0
# Inactivity after which a TCP flow is considered closed and its timing
# summary is flushed as an event.
_FLOW_IDLE_TIMEOUT: float = 120.0
_GREASE: frozenset[int] = frozenset(0x0A0A + i * 0x1010 for i in range(16))
_TLS_RECORD_HANDSHAKE: int = 0x16
_TLS_HT_CLIENT_HELLO: int = 0x01
_TLS_HT_SERVER_HELLO: int = 0x02
_TLS_HT_CERTIFICATE: int = 0x0B
_EXT_SNI: int = 0x0000
_EXT_SUPPORTED_GROUPS: int = 0x000A
_EXT_EC_POINT_FORMATS: int = 0x000B
_EXT_SIGNATURE_ALGORITHMS: int = 0x000D
_EXT_ALPN: int = 0x0010
_EXT_SESSION_TICKET: int = 0x0023
_EXT_SUPPORTED_VERSIONS: int = 0x002B
_EXT_PRE_SHARED_KEY: int = 0x0029
_EXT_EARLY_DATA: int = 0x002A
_TCP_SYN: int = 0x02
_TCP_ACK: int = 0x10
_TCP_FIN: int = 0x01
_TCP_RST: int = 0x04
# ─── TCP option extraction for passive fingerprinting ───────────────────────
def _extract_tcp_fingerprint(tcp_options: list) -> dict[str, Any]:
"""
Extract MSS, window-scale, SACK, timestamp flags, and the options order
signature from a scapy TCP options list.
"""
mss = 0
wscale: int | None = None
sack_ok = False
has_ts = False
for opt_name, opt_value in tcp_options or []:
if opt_name == "MSS":
mss = opt_value
elif opt_name == "WScale":
wscale = opt_value
elif opt_name in ("SAckOK", "SAck"):
sack_ok = True
elif opt_name == "Timestamp":
has_ts = True
options_sig = _extract_options_order(tcp_options or [])
return {
"mss": mss,
"wscale": wscale,
"sack_ok": sack_ok,
"has_timestamps": has_ts,
"options_sig": options_sig,
}
# ─── GREASE helpers ──────────────────────────────────────────────────────────
def _is_grease(value: int) -> bool:
return value in _GREASE
def _filter_grease(values: list[int]) -> list[int]:
return [v for v in values if not _is_grease(v)]
# ─── TLS parsers ─────────────────────────────────────────────────────────────
def _parse_client_hello(data: bytes) -> dict[str, Any] | None:
try:
if len(data) < 6:
return None
if data[0] != _TLS_RECORD_HANDSHAKE:
return None
record_len = struct.unpack_from("!H", data, 3)[0]
if len(data) < 5 + record_len:
return None
hs = data[5:]
if hs[0] != _TLS_HT_CLIENT_HELLO:
return None
hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0]
body = hs[4: 4 + hs_len]
if len(body) < 34:
return None
pos = 0
tls_version = struct.unpack_from("!H", body, pos)[0]
pos += 2
pos += 32 # Random
session_id_len = body[pos]
session_id = body[pos + 1: pos + 1 + session_id_len]
pos += 1 + session_id_len
cs_len = struct.unpack_from("!H", body, pos)[0]
pos += 2
cipher_suites = [
struct.unpack_from("!H", body, pos + i * 2)[0]
for i in range(cs_len // 2)
]
pos += cs_len
comp_len = body[pos]
pos += 1 + comp_len
extensions: list[int] = []
supported_groups: list[int] = []
ec_point_formats: list[int] = []
signature_algorithms: list[int] = []
supported_versions: list[int] = []
sni: str = ""
alpn: list[str] = []
has_session_ticket_data: bool = False
has_pre_shared_key: bool = False
has_early_data: bool = False
if pos + 2 <= len(body):
ext_total = struct.unpack_from("!H", body, pos)[0]
pos += 2
ext_end = pos + ext_total
while pos + 4 <= ext_end:
ext_type = struct.unpack_from("!H", body, pos)[0]
ext_len = struct.unpack_from("!H", body, pos + 2)[0]
ext_data = body[pos + 4: pos + 4 + ext_len]
pos += 4 + ext_len
if not _is_grease(ext_type):
extensions.append(ext_type)
if ext_type == _EXT_SNI and len(ext_data) > 5:
sni = ext_data[5:].decode("ascii", errors="replace")
elif ext_type == _EXT_SUPPORTED_GROUPS and len(ext_data) >= 2:
grp_len = struct.unpack_from("!H", ext_data, 0)[0]
supported_groups = [
struct.unpack_from("!H", ext_data, 2 + i * 2)[0]
for i in range(grp_len // 2)
]
elif ext_type == _EXT_EC_POINT_FORMATS and len(ext_data) >= 1:
pf_len = ext_data[0]
ec_point_formats = list(ext_data[1: 1 + pf_len])
elif ext_type == _EXT_ALPN and len(ext_data) >= 2:
proto_list_len = struct.unpack_from("!H", ext_data, 0)[0]
ap = 2
while ap < 2 + proto_list_len:
plen = ext_data[ap]
alpn.append(ext_data[ap + 1: ap + 1 + plen].decode("ascii", errors="replace"))
ap += 1 + plen
elif ext_type == _EXT_SIGNATURE_ALGORITHMS and len(ext_data) >= 2:
sa_len = struct.unpack_from("!H", ext_data, 0)[0]
signature_algorithms = [
struct.unpack_from("!H", ext_data, 2 + i * 2)[0]
for i in range(sa_len // 2)
]
elif ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 1:
sv_len = ext_data[0]
supported_versions = [
struct.unpack_from("!H", ext_data, 1 + i * 2)[0]
for i in range(sv_len // 2)
]
elif ext_type == _EXT_SESSION_TICKET:
has_session_ticket_data = len(ext_data) > 0
elif ext_type == _EXT_PRE_SHARED_KEY:
has_pre_shared_key = True
elif ext_type == _EXT_EARLY_DATA:
has_early_data = True
filtered_ciphers = _filter_grease(cipher_suites)
filtered_groups = _filter_grease(supported_groups)
filtered_sig_algs = _filter_grease(signature_algorithms)
filtered_versions = _filter_grease(supported_versions)
return {
"tls_version": tls_version,
"cipher_suites": filtered_ciphers,
"extensions": extensions,
"supported_groups": filtered_groups,
"ec_point_formats": ec_point_formats,
"signature_algorithms": filtered_sig_algs,
"supported_versions": filtered_versions,
"sni": sni,
"alpn": alpn,
"session_id": session_id,
"has_session_ticket_data": has_session_ticket_data,
"has_pre_shared_key": has_pre_shared_key,
"has_early_data": has_early_data,
}
except Exception:
return None
def _parse_server_hello(data: bytes) -> dict[str, Any] | None:
try:
if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE:
return None
hs = data[5:]
if hs[0] != _TLS_HT_SERVER_HELLO:
return None
hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0]
body = hs[4: 4 + hs_len]
if len(body) < 35:
return None
pos = 0
tls_version = struct.unpack_from("!H", body, pos)[0]
pos += 2
pos += 32 # Random
session_id_len = body[pos]
pos += 1 + session_id_len
if pos + 2 > len(body):
return None
cipher_suite = struct.unpack_from("!H", body, pos)[0]
pos += 2
pos += 1 # Compression method
extensions: list[int] = []
selected_version: int | None = None
alpn: str = ""
if pos + 2 <= len(body):
ext_total = struct.unpack_from("!H", body, pos)[0]
pos += 2
ext_end = pos + ext_total
while pos + 4 <= ext_end:
ext_type = struct.unpack_from("!H", body, pos)[0]
ext_len = struct.unpack_from("!H", body, pos + 2)[0]
ext_data = body[pos + 4: pos + 4 + ext_len]
pos += 4 + ext_len
if not _is_grease(ext_type):
extensions.append(ext_type)
if ext_type == _EXT_SUPPORTED_VERSIONS and len(ext_data) >= 2:
selected_version = struct.unpack_from("!H", ext_data, 0)[0]
elif ext_type == _EXT_ALPN and len(ext_data) >= 2:
proto_list_len = struct.unpack_from("!H", ext_data, 0)[0]
if proto_list_len > 0 and len(ext_data) >= 4:
plen = ext_data[2]
alpn = ext_data[3: 3 + plen].decode("ascii", errors="replace")
return {
"tls_version": tls_version,
"cipher_suite": cipher_suite,
"extensions": extensions,
"selected_version": selected_version,
"alpn": alpn,
}
except Exception:
return None
def _parse_certificate(data: bytes) -> dict[str, Any] | None:
try:
if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE:
return None
hs = data[5:]
if hs[0] != _TLS_HT_CERTIFICATE:
return None
hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0]
body = hs[4: 4 + hs_len]
if len(body) < 3:
return None
certs_len = struct.unpack_from("!I", b"\x00" + body[0:3])[0]
if certs_len == 0:
return None
pos = 3
if pos + 3 > len(body):
return None
cert_len = struct.unpack_from("!I", b"\x00" + body[pos:pos + 3])[0]
pos += 3
if pos + cert_len > len(body):
return None
cert_der = body[pos: pos + cert_len]
return _parse_x509_der(cert_der)
except Exception:
return None
# ─── Minimal DER/ASN.1 X.509 parser ─────────────────────────────────────────
def _der_read_tag_len(data: bytes, pos: int) -> tuple[int, int, int]:
tag = data[pos]
pos += 1
length_byte = data[pos]
pos += 1
if length_byte & 0x80:
num_bytes = length_byte & 0x7F
length = int.from_bytes(data[pos: pos + num_bytes], "big")
pos += num_bytes
else:
length = length_byte
return tag, pos, length
def _der_read_sequence(data: bytes, pos: int) -> tuple[int, int]:
tag, content_start, length = _der_read_tag_len(data, pos)
return content_start, length
def _der_read_oid(data: bytes, pos: int, length: int) -> str:
if length < 1:
return ""
first = data[pos]
oid_parts = [str(first // 40), str(first % 40)]
val = 0
for i in range(1, length):
b = data[pos + i]
val = (val << 7) | (b & 0x7F)
if not (b & 0x80):
oid_parts.append(str(val))
val = 0
return ".".join(oid_parts)
def _der_extract_cn(data: bytes, start: int, length: int) -> str:
pos = start
end = start + length
while pos < end:
set_tag, set_start, set_len = _der_read_tag_len(data, pos)
if set_tag != 0x31:
break
set_end = set_start + set_len
attr_pos = set_start
while attr_pos < set_end:
seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos)
if seq_tag != 0x30:
break
oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start)
if oid_tag == 0x06:
oid = _der_read_oid(data, oid_start, oid_len)
if oid == "2.5.4.3":
val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len)
return data[val_start: val_start + val_len].decode("utf-8", errors="replace")
attr_pos = seq_start + seq_len
pos = set_end
return ""
def _der_extract_name_str(data: bytes, start: int, length: int) -> str:
parts: list[str] = []
pos = start
end = start + length
oid_names = {
"2.5.4.3": "CN",
"2.5.4.6": "C",
"2.5.4.7": "L",
"2.5.4.8": "ST",
"2.5.4.10": "O",
"2.5.4.11": "OU",
}
while pos < end:
set_tag, set_start, set_len = _der_read_tag_len(data, pos)
if set_tag != 0x31:
break
set_end = set_start + set_len
attr_pos = set_start
while attr_pos < set_end:
seq_tag, seq_start, seq_len = _der_read_tag_len(data, attr_pos)
if seq_tag != 0x30:
break
oid_tag, oid_start, oid_len = _der_read_tag_len(data, seq_start)
if oid_tag == 0x06:
oid = _der_read_oid(data, oid_start, oid_len)
val_tag, val_start, val_len = _der_read_tag_len(data, oid_start + oid_len)
val = data[val_start: val_start + val_len].decode("utf-8", errors="replace")
name = oid_names.get(oid, oid)
parts.append(f"{name}={val}")
attr_pos = seq_start + seq_len
pos = set_end
return ", ".join(parts)
def _parse_x509_der(cert_der: bytes) -> dict[str, Any] | None:
try:
outer_start, outer_len = _der_read_sequence(cert_der, 0)
tbs_tag, tbs_start, tbs_len = _der_read_tag_len(cert_der, outer_start)
tbs_end = tbs_start + tbs_len
pos = tbs_start
if cert_der[pos] == 0xA0:
_, v_start, v_len = _der_read_tag_len(cert_der, pos)
pos = v_start + v_len
_, sn_start, sn_len = _der_read_tag_len(cert_der, pos)
pos = sn_start + sn_len
_, sa_start, sa_len = _der_read_tag_len(cert_der, pos)
pos = sa_start + sa_len
issuer_tag, issuer_start, issuer_len = _der_read_tag_len(cert_der, pos)
issuer_str = _der_extract_name_str(cert_der, issuer_start, issuer_len)
issuer_cn = _der_extract_cn(cert_der, issuer_start, issuer_len)
pos = issuer_start + issuer_len
val_tag, val_start, val_len = _der_read_tag_len(cert_der, pos)
nb_tag, nb_start, nb_len = _der_read_tag_len(cert_der, val_start)
not_before = cert_der[nb_start: nb_start + nb_len].decode("ascii", errors="replace")
na_tag, na_start, na_len = _der_read_tag_len(cert_der, nb_start + nb_len)
not_after = cert_der[na_start: na_start + na_len].decode("ascii", errors="replace")
pos = val_start + val_len
subj_tag, subj_start, subj_len = _der_read_tag_len(cert_der, pos)
subject_cn = _der_extract_cn(cert_der, subj_start, subj_len)
subject_str = _der_extract_name_str(cert_der, subj_start, subj_len)
self_signed = (issuer_cn == subject_cn) and subject_cn != ""
pos = subj_start + subj_len
sans: list[str] = _extract_sans(cert_der, pos, tbs_end)
return {
"subject_cn": subject_cn,
"subject": subject_str,
"issuer": issuer_str,
"issuer_cn": issuer_cn,
"not_before": not_before,
"not_after": not_after,
"self_signed": self_signed,
"sans": sans,
}
except Exception:
return None
def _extract_sans(cert_der: bytes, pos: int, end: int) -> list[str]:
sans: list[str] = []
try:
if pos >= end:
return sans
spki_tag, spki_start, spki_len = _der_read_tag_len(cert_der, pos)
pos = spki_start + spki_len
while pos < end:
tag = cert_der[pos]
if tag == 0xA3:
_, ext_wrap_start, ext_wrap_len = _der_read_tag_len(cert_der, pos)
_, exts_start, exts_len = _der_read_tag_len(cert_der, ext_wrap_start)
epos = exts_start
eend = exts_start + exts_len
while epos < eend:
ext_tag, ext_start, ext_len = _der_read_tag_len(cert_der, epos)
ext_end = ext_start + ext_len
oid_tag, oid_start, oid_len = _der_read_tag_len(cert_der, ext_start)
if oid_tag == 0x06:
oid = _der_read_oid(cert_der, oid_start, oid_len)
if oid == "2.5.29.17":
vpos = oid_start + oid_len
if vpos < ext_end and cert_der[vpos] == 0x01:
_, bs, bl = _der_read_tag_len(cert_der, vpos)
vpos = bs + bl
if vpos < ext_end:
os_tag, os_start, os_len = _der_read_tag_len(cert_der, vpos)
if os_tag == 0x04:
sans = _parse_san_sequence(cert_der, os_start, os_len)
epos = ext_end
break
else:
_, skip_start, skip_len = _der_read_tag_len(cert_der, pos)
pos = skip_start + skip_len
except Exception:
pass
return sans
def _parse_san_sequence(data: bytes, start: int, length: int) -> list[str]:
names: list[str] = []
try:
seq_tag, seq_start, seq_len = _der_read_tag_len(data, start)
pos = seq_start
end = seq_start + seq_len
while pos < end:
tag = data[pos]
_, val_start, val_len = _der_read_tag_len(data, pos)
context_tag = tag & 0x1F
if context_tag == 2:
names.append(data[val_start: val_start + val_len].decode("ascii", errors="replace"))
elif context_tag == 7 and val_len == 4:
names.append(".".join(str(b) for b in data[val_start: val_start + val_len]))
pos = val_start + val_len
except Exception:
pass
return names
# ─── JA3 / JA3S ─────────────────────────────────────────────────────────────
def _tls_version_str(version: int) -> str:
return {
0x0301: "TLS 1.0",
0x0302: "TLS 1.1",
0x0303: "TLS 1.2",
0x0304: "TLS 1.3",
0x0200: "SSL 2.0",
0x0300: "SSL 3.0",
}.get(version, f"0x{version:04x}")
def _ja3(ch: dict[str, Any]) -> tuple[str, str]:
parts = [
str(ch["tls_version"]),
"-".join(str(c) for c in ch["cipher_suites"]),
"-".join(str(e) for e in ch["extensions"]),
"-".join(str(g) for g in ch["supported_groups"]),
"-".join(str(p) for p in ch["ec_point_formats"]),
]
ja3_str = ",".join(parts)
return ja3_str, hashlib.md5(ja3_str.encode()).hexdigest() # nosec B324
def _ja3s(sh: dict[str, Any]) -> tuple[str, str]:
parts = [
str(sh["tls_version"]),
str(sh["cipher_suite"]),
"-".join(str(e) for e in sh["extensions"]),
]
ja3s_str = ",".join(parts)
return ja3s_str, hashlib.md5(ja3s_str.encode()).hexdigest() # nosec B324
# ─── JA4 / JA4S ─────────────────────────────────────────────────────────────
def _ja4_version(ch: dict[str, Any]) -> str:
versions = ch.get("supported_versions", [])
if versions:
best = max(versions)
else:
best = ch["tls_version"]
return {
0x0304: "13",
0x0303: "12",
0x0302: "11",
0x0301: "10",
0x0300: "s3",
0x0200: "s2",
}.get(best, "00")
def _ja4_alpn_tag(alpn_list: list[str] | str) -> str:
if isinstance(alpn_list, str):
proto = alpn_list
elif alpn_list:
proto = alpn_list[0]
else:
return "00"
if not proto:
return "00"
if len(proto) == 1:
return proto[0] + proto[0]
return proto[0] + proto[-1]
def _sha256_12(text: str) -> str:
return hashlib.sha256(text.encode()).hexdigest()[:12]
def _ja4(ch: dict[str, Any]) -> str:
proto = "t"
ver = _ja4_version(ch)
sni_flag = "d" if ch.get("sni") else "i"
cs_count = min(len(ch["cipher_suites"]), 99)
ext_count = min(len(ch["extensions"]), 99)
alpn_tag = _ja4_alpn_tag(ch.get("alpn", []))
section_a = f"{proto}{ver}{sni_flag}{cs_count:02d}{ext_count:02d}{alpn_tag}"
sorted_cs = sorted(ch["cipher_suites"])
section_b = _sha256_12(",".join(str(c) for c in sorted_cs))
sorted_ext = sorted(ch["extensions"])
sorted_sa = sorted(ch.get("signature_algorithms", []))
ext_str = ",".join(str(e) for e in sorted_ext)
sa_str = ",".join(str(s) for s in sorted_sa)
combined = f"{ext_str}_{sa_str}" if sa_str else ext_str
section_c = _sha256_12(combined)
return f"{section_a}_{section_b}_{section_c}"
def _ja4s(sh: dict[str, Any]) -> str:
proto = "t"
selected = sh.get("selected_version")
if selected:
ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10",
0x0300: "s3", 0x0200: "s2"}.get(selected, "00")
else:
ver = {0x0304: "13", 0x0303: "12", 0x0302: "11", 0x0301: "10",
0x0300: "s3", 0x0200: "s2"}.get(sh["tls_version"], "00")
ext_count = min(len(sh["extensions"]), 99)
alpn_tag = _ja4_alpn_tag(sh.get("alpn", ""))
section_a = f"{proto}{ver}{ext_count:02d}{alpn_tag}"
sorted_ext = sorted(sh["extensions"])
inner = f"{sh['cipher_suite']},{','.join(str(e) for e in sorted_ext)}"
section_b = _sha256_12(inner)
return f"{section_a}_{section_b}"
# ─── JA4L (latency) ─────────────────────────────────────────────────────────
def _ja4l(
key: tuple[str, int, str, int],
tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]],
) -> dict[str, Any] | None:
return tcp_rtt.get(key)
# ─── Session resumption ─────────────────────────────────────────────────────
def _session_resumption_info(ch: dict[str, Any]) -> dict[str, Any]:
mechanisms: list[str] = []
if ch.get("has_session_ticket_data"):
mechanisms.append("session_ticket")
if ch.get("has_pre_shared_key"):
mechanisms.append("psk")
if ch.get("has_early_data"):
mechanisms.append("early_data_0rtt")
if ch.get("session_id") and len(ch["session_id"]) > 0:
mechanisms.append("session_id")
return {
"resumption_attempted": len(mechanisms) > 0,
"mechanisms": mechanisms,
}
# ─── Sniffer engine (stateful, one instance per worker) ─────────────────────
class SnifferEngine:
"""
Stateful TLS fingerprinting engine. Tracks sessions, TCP RTTs,
and dedup state. Thread-safe only when called from a single thread
(the scapy sniff thread).
"""
def __init__(
self,
ip_to_decky: dict[str, str],
write_fn: Callable[[str], None],
dedup_ttl: float = 300.0,
):
self._ip_to_decky = ip_to_decky
self._write_fn = write_fn
self._dedup_ttl = dedup_ttl
self._sessions: dict[tuple[str, int, str, int], dict[str, Any]] = {}
self._session_ts: dict[tuple[str, int, str, int], float] = {}
self._tcp_syn: dict[tuple[str, int, str, int], dict[str, Any]] = {}
self._tcp_rtt: dict[tuple[str, int, str, int], dict[str, Any]] = {}
# Per-flow timing aggregator. Key: (src_ip, src_port, dst_ip, dst_port).
# Flow direction is client→decky; reverse packets are associated back
# to the forward flow so we can track retransmits and inter-arrival.
self._flows: dict[tuple[str, int, str, int], dict[str, Any]] = {}
self._flow_last_cleanup: float = 0.0
self._FLOW_CLEANUP_INTERVAL: float = 30.0
self._dedup_cache: dict[tuple[str, str, str], float] = {}
self._dedup_last_cleanup: float = 0.0
self._DEDUP_CLEANUP_INTERVAL: float = 60.0
def update_ip_map(self, ip_to_decky: dict[str, str]) -> None:
self._ip_to_decky = ip_to_decky
def _resolve_decky(self, src_ip: str, dst_ip: str) -> str | None:
"""Map a packet to a decky name. Returns None if neither IP is a known decky."""
if dst_ip in self._ip_to_decky:
return self._ip_to_decky[dst_ip]
if src_ip in self._ip_to_decky:
return self._ip_to_decky[src_ip]
return None
def _cleanup_sessions(self) -> None:
now = time.monotonic()
stale = [k for k, ts in self._session_ts.items() if now - ts > _SESSION_TTL]
for k in stale:
self._sessions.pop(k, None)
self._session_ts.pop(k, None)
stale_syn = [k for k, v in self._tcp_syn.items()
if now - v.get("time", 0) > _SESSION_TTL]
for k in stale_syn:
self._tcp_syn.pop(k, None)
stale_rtt = [k for k, _ in self._tcp_rtt.items()
if k not in self._sessions and k not in self._session_ts]
for k in stale_rtt:
self._tcp_rtt.pop(k, None)
def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str:
if event_type == "tls_client_hello":
return fields.get("ja3", "") + "|" + fields.get("ja4", "")
if event_type == "tls_session":
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
if event_type == "tls_certificate":
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "")
if event_type == "tcp_syn_fingerprint":
# Dedupe per (OS signature, options layout). One event per unique
# stack profile from this attacker IP per dedup window.
return fields.get("os_guess", "") + "|" + fields.get("options_sig", "")
if event_type == "tcp_flow_timing":
# Dedup per (attacker_ip, decky_port) — src_port is deliberately
# excluded so a port scanner rotating source ports only produces
# one timing event per dedup window. Behavior cadence doesn't
# need per-ephemeral-port fidelity.
return fields.get("dst_ip", "") + "|" + fields.get("dst_port", "")
return fields.get("mechanisms", fields.get("resumption", ""))
def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool:
if self._dedup_ttl <= 0:
return False
now = time.monotonic()
if now - self._dedup_last_cleanup > self._DEDUP_CLEANUP_INTERVAL:
stale = [k for k, ts in self._dedup_cache.items() if now - ts > self._dedup_ttl]
for k in stale:
del self._dedup_cache[k]
self._dedup_last_cleanup = now
src_ip = fields.get("src_ip", "")
fp = self._dedup_key_for(event_type, fields)
cache_key = (src_ip, event_type, fp)
last_seen = self._dedup_cache.get(cache_key)
if last_seen is not None and now - last_seen < self._dedup_ttl:
return True
self._dedup_cache[cache_key] = now
return False
def _log(self, node_name: str, event_type: str, severity: int = SEVERITY_INFO, **fields: Any) -> None:
if self._is_duplicate(event_type, fields):
return
line = syslog_line(SERVICE_NAME, node_name, event_type, severity=severity, **fields)
self._write_fn(line)
# ── Flow tracking (per-TCP-4-tuple timing + retransmits) ────────────────
def _flow_key(
self,
src_ip: str,
src_port: int,
dst_ip: str,
dst_port: int,
) -> tuple[str, int, str, int]:
"""
Canonicalize a packet to the *client→decky* direction so forward and
reverse packets share one flow record.
"""
if dst_ip in self._ip_to_decky:
return (src_ip, src_port, dst_ip, dst_port)
# Otherwise src is the decky, flip.
return (dst_ip, dst_port, src_ip, src_port)
def _update_flow(
self,
flow_key: tuple[str, int, str, int],
now: float,
seq: int,
payload_len: int,
direction_forward: bool,
) -> None:
"""Record one packet into the flow aggregator."""
flow = self._flows.get(flow_key)
if flow is None:
flow = {
"start": now,
"last": now,
"packets": 0,
"bytes": 0,
"iat_sum": 0.0,
"iat_min": float("inf"),
"iat_max": 0.0,
"iat_count": 0,
"forward_seqs": set(),
"retransmits": 0,
"emitted": False,
}
self._flows[flow_key] = flow
if flow["packets"] > 0:
iat = now - flow["last"]
if iat >= 0:
flow["iat_sum"] += iat
flow["iat_count"] += 1
if iat < flow["iat_min"]:
flow["iat_min"] = iat
if iat > flow["iat_max"]:
flow["iat_max"] = iat
flow["last"] = now
flow["packets"] += 1
flow["bytes"] += payload_len
# Retransmit detection: a forward-direction packet with payload whose
# sequence number we've already seen is a retransmit. Empty SYN/ACKs
# are excluded because they share seq legitimately.
if direction_forward and payload_len > 0:
if seq in flow["forward_seqs"]:
flow["retransmits"] += 1
else:
flow["forward_seqs"].add(seq)
def _flush_flow(
self,
flow_key: tuple[str, int, str, int],
node_name: str,
) -> None:
"""Emit one `tcp_flow_timing` event for *flow_key* and drop its state.
Trivial flows (scan probes: 12 packets, sub-second duration) are
dropped silently — they add noise to the log pipeline without carrying
usable behavioral signal (beacon cadence, exfil timing, retransmits
are all meaningful only on longer-lived flows).
"""
flow = self._flows.pop(flow_key, None)
if flow is None or flow.get("emitted"):
return
flow["emitted"] = True
# Skip uninteresting flows — keep the log pipeline from being flooded
# by short-lived scan probes.
duration = flow["last"] - flow["start"]
if flow["packets"] < 4 and flow["retransmits"] == 0 and duration < 1.0:
return
src_ip, src_port, dst_ip, dst_port = flow_key
iat_count = flow["iat_count"]
mean_iat_ms = round((flow["iat_sum"] / iat_count) * 1000, 2) if iat_count else 0.0
min_iat_ms = round(flow["iat_min"] * 1000, 2) if iat_count else 0.0
max_iat_ms = round(flow["iat_max"] * 1000, 2) if iat_count else 0.0
duration_s = round(duration, 3)
self._log(
node_name,
"tcp_flow_timing",
src_ip=src_ip,
src_port=str(src_port),
dst_ip=dst_ip,
dst_port=str(dst_port),
packets=str(flow["packets"]),
bytes=str(flow["bytes"]),
duration_s=str(duration_s),
mean_iat_ms=str(mean_iat_ms),
min_iat_ms=str(min_iat_ms),
max_iat_ms=str(max_iat_ms),
retransmits=str(flow["retransmits"]),
)
def flush_all_flows(self) -> None:
"""
Flush every tracked flow (emit `tcp_flow_timing` events) and drop
state. Safe to call from outside the sniff thread; used during
shutdown and in tests.
"""
for key in list(self._flows.keys()):
decky = self._ip_to_decky.get(key[2])
if decky:
self._flush_flow(key, decky)
else:
self._flows.pop(key, None)
def _flush_idle_flows(self) -> None:
"""Flush any flow whose last packet was more than _FLOW_IDLE_TIMEOUT ago."""
now = time.monotonic()
if now - self._flow_last_cleanup < self._FLOW_CLEANUP_INTERVAL:
return
self._flow_last_cleanup = now
stale: list[tuple[str, int, str, int]] = [
k for k, f in self._flows.items()
if now - f["last"] > _FLOW_IDLE_TIMEOUT
]
for key in stale:
decky = self._ip_to_decky.get(key[2])
if decky:
self._flush_flow(key, decky)
else:
self._flows.pop(key, None)
def on_packet(self, pkt: Any) -> None:
"""Process a single scapy packet. Called from the sniff thread."""
try:
from scapy.layers.inet import IP, TCP
except ImportError:
return
if not (pkt.haslayer(IP) and pkt.haslayer(TCP)):
return
ip = pkt[IP]
tcp = pkt[TCP]
src_ip: str = ip.src
dst_ip: str = ip.dst
src_port: int = tcp.sport
dst_port: int = tcp.dport
flags: int = tcp.flags.value if hasattr(tcp.flags, 'value') else int(tcp.flags)
# Skip traffic not involving any decky
node_name = self._resolve_decky(src_ip, dst_ip)
if node_name is None:
return
now = time.monotonic()
# Per-flow timing aggregation (covers all TCP traffic, not just TLS)
flow_key = self._flow_key(src_ip, src_port, dst_ip, dst_port)
direction_forward = (flow_key[0] == src_ip and flow_key[1] == src_port)
tcp_payload_len = len(bytes(tcp.payload))
self._update_flow(
flow_key,
now=now,
seq=int(tcp.seq),
payload_len=tcp_payload_len,
direction_forward=direction_forward,
)
self._flush_idle_flows()
# TCP SYN tracking for JA4L + passive SYN fingerprint
if flags & _TCP_SYN and not (flags & _TCP_ACK):
key = (src_ip, src_port, dst_ip, dst_port)
self._tcp_syn[key] = {"time": now, "ttl": ip.ttl}
# Emit passive OS fingerprint on the *client* SYN. Only do this
# when the destination is a known decky, i.e. we're seeing an
# attacker's initial packet.
if dst_ip in self._ip_to_decky:
tcp_fp = _extract_tcp_fingerprint(list(tcp.options or []))
os_label = guess_os(
ttl=ip.ttl,
window=int(tcp.window),
mss=tcp_fp["mss"],
wscale=tcp_fp["wscale"],
options_sig=tcp_fp["options_sig"],
)
target_node = self._ip_to_decky[dst_ip]
self._log(
target_node,
"tcp_syn_fingerprint",
src_ip=src_ip,
src_port=str(src_port),
dst_ip=dst_ip,
dst_port=str(dst_port),
ttl=str(ip.ttl),
initial_ttl=str(initial_ttl(ip.ttl)),
hop_distance=str(hop_distance(ip.ttl)),
window=str(int(tcp.window)),
mss=str(tcp_fp["mss"]),
wscale=("" if tcp_fp["wscale"] is None else str(tcp_fp["wscale"])),
options_sig=tcp_fp["options_sig"],
has_sack=str(tcp_fp["sack_ok"]).lower(),
has_timestamps=str(tcp_fp["has_timestamps"]).lower(),
os_guess=os_label,
)
elif flags & _TCP_SYN and flags & _TCP_ACK:
rev_key = (dst_ip, dst_port, src_ip, src_port)
syn_data = self._tcp_syn.pop(rev_key, None)
if syn_data:
rtt_ms = round((now - syn_data["time"]) * 1000, 2)
self._tcp_rtt[rev_key] = {
"rtt_ms": rtt_ms,
"client_ttl": syn_data["ttl"],
}
# Flush flow on FIN/RST (terminal packets).
if flags & (_TCP_FIN | _TCP_RST):
decky = self._ip_to_decky.get(flow_key[2])
if decky:
self._flush_flow(flow_key, decky)
payload = bytes(tcp.payload)
if not payload:
return
if payload[0] != _TLS_RECORD_HANDSHAKE:
return
# ClientHello
ch = _parse_client_hello(payload)
if ch is not None:
self._cleanup_sessions()
key = (src_ip, src_port, dst_ip, dst_port)
ja3_str, ja3_hash = _ja3(ch)
ja4_hash = _ja4(ch)
resumption = _session_resumption_info(ch)
rtt_data = _ja4l(key, self._tcp_rtt)
self._sessions[key] = {
"ja3": ja3_hash,
"ja3_str": ja3_str,
"ja4": ja4_hash,
"tls_version": ch["tls_version"],
"cipher_suites": ch["cipher_suites"],
"extensions": ch["extensions"],
"signature_algorithms": ch.get("signature_algorithms", []),
"supported_versions": ch.get("supported_versions", []),
"sni": ch["sni"],
"alpn": ch["alpn"],
"resumption": resumption,
}
self._session_ts[key] = time.monotonic()
log_fields: dict[str, Any] = {
"src_ip": src_ip,
"src_port": str(src_port),
"dst_ip": dst_ip,
"dst_port": str(dst_port),
"ja3": ja3_hash,
"ja4": ja4_hash,
"tls_version": _tls_version_str(ch["tls_version"]),
"sni": ch["sni"] or "",
"alpn": ",".join(ch["alpn"]),
"raw_ciphers": "-".join(str(c) for c in ch["cipher_suites"]),
"raw_extensions": "-".join(str(e) for e in ch["extensions"]),
}
if resumption["resumption_attempted"]:
log_fields["resumption"] = ",".join(resumption["mechanisms"])
if rtt_data:
log_fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"])
log_fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"])
# Resolve node for the *destination* (the decky being attacked)
target_node = self._ip_to_decky.get(dst_ip, node_name)
self._log(target_node, "tls_client_hello", **log_fields)
return
# ServerHello
sh = _parse_server_hello(payload)
if sh is not None:
rev_key = (dst_ip, dst_port, src_ip, src_port)
ch_data = self._sessions.pop(rev_key, None)
self._session_ts.pop(rev_key, None)
ja3s_str, ja3s_hash = _ja3s(sh)
ja4s_hash = _ja4s(sh)
fields: dict[str, Any] = {
"src_ip": dst_ip,
"src_port": str(dst_port),
"dst_ip": src_ip,
"dst_port": str(src_port),
"ja3s": ja3s_hash,
"ja4s": ja4s_hash,
"tls_version": _tls_version_str(sh["tls_version"]),
}
if ch_data:
fields["ja3"] = ch_data["ja3"]
fields["ja4"] = ch_data.get("ja4", "")
fields["sni"] = ch_data["sni"] or ""
fields["alpn"] = ",".join(ch_data["alpn"])
fields["raw_ciphers"] = "-".join(str(c) for c in ch_data["cipher_suites"])
fields["raw_extensions"] = "-".join(str(e) for e in ch_data["extensions"])
if ch_data.get("resumption", {}).get("resumption_attempted"):
fields["resumption"] = ",".join(ch_data["resumption"]["mechanisms"])
rtt_data = self._tcp_rtt.pop(rev_key, None)
if rtt_data:
fields["ja4l_rtt_ms"] = str(rtt_data["rtt_ms"])
fields["ja4l_client_ttl"] = str(rtt_data["client_ttl"])
# Server response — resolve by src_ip (the decky responding)
target_node = self._ip_to_decky.get(src_ip, node_name)
self._log(target_node, "tls_session", severity=SEVERITY_WARNING, **fields)
return
# Certificate (TLS 1.2 only)
cert = _parse_certificate(payload)
if cert is not None:
rev_key = (dst_ip, dst_port, src_ip, src_port)
ch_data = self._sessions.get(rev_key)
cert_fields: dict[str, Any] = {
"src_ip": dst_ip,
"src_port": str(dst_port),
"dst_ip": src_ip,
"dst_port": str(src_port),
"subject_cn": cert["subject_cn"],
"issuer": cert["issuer"],
"self_signed": str(cert["self_signed"]).lower(),
"not_before": cert["not_before"],
"not_after": cert["not_after"],
}
if cert["sans"]:
cert_fields["sans"] = ",".join(cert["sans"])
if ch_data:
cert_fields["sni"] = ch_data.get("sni", "")
target_node = self._ip_to_decky.get(src_ip, node_name)
self._log(target_node, "tls_certificate", **cert_fields)