diff --git a/decnet/bus/publish.py b/decnet/bus/publish.py index 727fce2e..4849d978 100644 --- a/decnet/bus/publish.py +++ b/decnet/bus/publish.py @@ -1,4 +1,4 @@ -"""Fire-and-forget publish helper shared across every worker. +"""Fire-and-forget publish helpers shared across every worker. Lifted out of ``decnet/mutator/engine.py`` once a second caller showed up (DEBT-031). Keeping one implementation means the "never break the worker @@ -6,7 +6,8 @@ loop" guarantee is audited in exactly one place. """ from __future__ import annotations -from typing import Any +import asyncio +from typing import Any, Callable from decnet.bus.base import BaseBus from decnet.logging import get_logger @@ -34,3 +35,33 @@ async def publish_safely( await bus.publish(topic, payload, event_type=event_type) except Exception as exc: # noqa: BLE001 log.warning("bus publish failed topic=%s: %s", topic, exc) + + +def make_thread_safe_publisher( + bus: BaseBus | None, + loop: asyncio.AbstractEventLoop, +) -> Callable[[str, dict[str, Any], str], None]: + """Build a sync callable that marshals publishes back to *loop*. + + Workers that run their hot paths in a worker thread (scapy sniff loop, + ``asyncio.to_thread`` probes, blocking socket reads) cannot ``await`` + the bus directly. This helper returns a plain function that schedules + the publish on *loop* via ``run_coroutine_threadsafe`` and returns + immediately — the calling thread is never blocked on the publish. + + A ``None`` bus yields a no-op callable, matching the degraded-mode + contract the rest of this module already upholds. + """ + if bus is None: + return lambda _topic, _payload, _event_type="": None + + def _publish(topic: str, payload: dict[str, Any], event_type: str = "") -> None: + try: + asyncio.run_coroutine_threadsafe( + publish_safely(bus, topic, payload, event_type=event_type), + loop, + ) + except Exception as exc: # noqa: BLE001 + log.debug("cross-thread bus publish failed topic=%s: %s", topic, exc) + + return _publish diff --git a/decnet/bus/topics.py b/decnet/bus/topics.py index 3b02418d..602a15fd 100644 --- a/decnet/bus/topics.py +++ b/decnet/bus/topics.py @@ -57,6 +57,10 @@ DECKY_TRAFFIC = "traffic" # the wildcard ``attacker.>``. ATTACKER_OBSERVED = "observed" ATTACKER_SCORED = "scored" +# Published once per successful active probe result (JARM/HASSH/TCPfp). +# Distinct from ``observed`` which is the correlator's first-sight signal — +# a fingerprint is additional evidence about an already-observed attacker. +ATTACKER_FINGERPRINTED = "fingerprinted" ATTACKER_SESSION_STARTED = "session.started" ATTACKER_SESSION_ENDED = "session.ended" diff --git a/decnet/prober/worker.py b/decnet/prober/worker.py index 07e0aa09..5be89954 100644 --- a/decnet/prober/worker.py +++ b/decnet/prober/worker.py @@ -20,12 +20,17 @@ a shared log-sink abstraction. from __future__ import annotations import asyncio +import contextlib import json import re from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import Any, Callable +from decnet.bus import topics as _topics +from decnet.bus.base import BaseBus +from decnet.bus.factory import get_bus +from decnet.bus.publish import make_thread_safe_publisher from decnet.logging import get_logger from decnet.prober.hassh import hassh_server from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash @@ -221,6 +226,9 @@ def _discover_attackers(json_path: Path, position: int) -> tuple[set[str], int]: # ─── Probe cycle ───────────────────────────────────────────────────────────── +ProbePublishFn = Callable[[str, dict[str, Any]], None] + + @_traced("prober.probe_cycle") def _probe_cycle( targets: set[str], @@ -231,6 +239,7 @@ def _probe_cycle( log_path: Path, json_path: Path, timeout: float = 5.0, + publish_fn: ProbePublishFn | None = None, ) -> None: """ Probe all known attacker IPs with JARM, HASSH, and TCP/IP fingerprinting. @@ -249,13 +258,13 @@ def _probe_cycle( ip_probed = probed.setdefault(ip, {}) # Phase 1: JARM (TLS fingerprinting) - _jarm_phase(ip, ip_probed, jarm_ports, log_path, json_path, timeout) + _jarm_phase(ip, ip_probed, jarm_ports, log_path, json_path, timeout, publish_fn) # Phase 2: HASSHServer (SSH fingerprinting) - _hassh_phase(ip, ip_probed, ssh_ports, log_path, json_path, timeout) + _hassh_phase(ip, ip_probed, ssh_ports, log_path, json_path, timeout, publish_fn) # Phase 3: TCP/IP stack fingerprinting - _tcpfp_phase(ip, ip_probed, tcpfp_ports, log_path, json_path, timeout) + _tcpfp_phase(ip, ip_probed, tcpfp_ports, log_path, json_path, timeout, publish_fn) @_traced("prober.jarm_phase") @@ -266,6 +275,7 @@ def _jarm_phase( log_path: Path, json_path: Path, timeout: float, + publish_fn: ProbePublishFn | None = None, ) -> None: """JARM-fingerprint an IP on the given TLS ports.""" done = ip_probed.setdefault("jarm", set()) @@ -286,6 +296,11 @@ def _jarm_phase( msg=f"JARM {ip}:{port} = {h}", ) logger.info("prober: JARM %s:%d = %s", ip, port, h) + if publish_fn is not None: + publish_fn( + "jarm", + {"attacker_ip": ip, "port": port, "jarm_hash": h}, + ) except Exception as exc: done.add(port) _write_event( @@ -308,6 +323,7 @@ def _hassh_phase( log_path: Path, json_path: Path, timeout: float, + publish_fn: ProbePublishFn | None = None, ) -> None: """HASSHServer-fingerprint an IP on the given SSH ports.""" done = ip_probed.setdefault("hassh", set()) @@ -333,6 +349,16 @@ def _hassh_phase( msg=f"HASSH {ip}:{port} = {result['hassh_server']}", ) logger.info("prober: HASSH %s:%d = %s", ip, port, result["hassh_server"]) + if publish_fn is not None: + publish_fn( + "hassh", + { + "attacker_ip": ip, + "port": port, + "hassh_server": result["hassh_server"], + "ssh_banner": result["banner"], + }, + ) except Exception as exc: done.add(port) _write_event( @@ -355,6 +381,7 @@ def _tcpfp_phase( log_path: Path, json_path: Path, timeout: float, + publish_fn: ProbePublishFn | None = None, ) -> None: """TCP/IP stack fingerprint an IP on the given ports.""" done = ip_probed.setdefault("tcpfp", set()) @@ -384,6 +411,17 @@ def _tcpfp_phase( msg=f"TCPFP {ip}:{port} = {result['tcpfp_hash']}", ) logger.info("prober: TCPFP %s:%d = %s", ip, port, result["tcpfp_hash"]) + if publish_fn is not None: + publish_fn( + "tcpfp", + { + "attacker_ip": ip, + "port": port, + "tcpfp_hash": result["tcpfp_hash"], + "ttl": result["ttl"], + "mss": result["mss"], + }, + ) except Exception as exc: done.add(port) _write_event( @@ -454,25 +492,58 @@ async def prober_worker( probed: dict[str, dict[str, set[int]]] = {} # IP -> {type -> ports} log_position: int = 0 - while True: - # Discover new attacker IPs from the log stream - new_ips, log_position = await asyncio.to_thread( - _discover_attackers, json_path, log_position, + loop = asyncio.get_running_loop() + + # Connect to the bus for attacker.fingerprinted fan-out. Failure is + # non-fatal: probes still run, results still land in the log file, + # they just don't push notifications to downstream consumers. + bus: BaseBus | None = None + try: + candidate = get_bus(client_name="prober") + await candidate.connect() + bus = candidate + except Exception as exc: # noqa: BLE001 + logger.warning( + "prober: bus unavailable, running in publish-off mode: %s", exc, ) - if new_ips - known_attackers: - fresh = new_ips - known_attackers - known_attackers.update(fresh) - logger.info( - "prober: discovered %d new attacker(s), total=%d", - len(fresh), len(known_attackers), + raw_publish = make_thread_safe_publisher(bus, loop) + + def _publish_attacker(event_type: str, payload: dict[str, Any]) -> None: + # Every successful probe fans out under the same topic; the probe + # family (jarm/hassh/tcpfp) goes in event_type so consumers can + # filter in-memory without needing a dedicated subscription each. + raw_publish( + _topics.attacker(_topics.ATTACKER_FINGERPRINTED), + payload, + event_type, + ) + + try: + while True: + # Discover new attacker IPs from the log stream + new_ips, log_position = await asyncio.to_thread( + _discover_attackers, json_path, log_position, ) - if known_attackers: - await asyncio.to_thread( - _probe_cycle, known_attackers, probed, - jarm_ports, hassh_ports, tcp_ports, - log_path, json_path, timeout, - ) + if new_ips - known_attackers: + fresh = new_ips - known_attackers + known_attackers.update(fresh) + logger.info( + "prober: discovered %d new attacker(s), total=%d", + len(fresh), len(known_attackers), + ) - await asyncio.sleep(interval) + if known_attackers: + await asyncio.to_thread( + _probe_cycle, known_attackers, probed, + jarm_ports, hassh_ports, tcp_ports, + log_path, json_path, timeout, + _publish_attacker, + ) + + await asyncio.sleep(interval) + finally: + if bus is not None: + with contextlib.suppress(Exception): + await bus.close() diff --git a/decnet/sniffer/worker.py b/decnet/sniffer/worker.py index cf715eb4..82565104 100644 --- a/decnet/sniffer/worker.py +++ b/decnet/sniffer/worker.py @@ -22,7 +22,7 @@ from typing import Any, Callable from decnet.bus import topics as _topics from decnet.bus.base import BaseBus from decnet.bus.factory import get_bus -from decnet.bus.publish import publish_safely +from decnet.bus.publish import make_thread_safe_publisher from decnet.logging import get_logger from decnet.network import HOST_IPVLAN_IFACE, HOST_MACVLAN_IFACE from decnet.sniffer.fingerprint import SnifferEngine @@ -47,26 +47,22 @@ def _load_ip_to_decky() -> dict[str, str]: return mapping -def _make_thread_safe_publisher( +def _make_decky_traffic_publisher( bus: BaseBus, loop: asyncio.AbstractEventLoop, ) -> Callable[[str, str, dict[str, Any]], None]: - """Build a sync callable that marshals bus publishes back to *loop*. + """Wrap :func:`make_thread_safe_publisher` with the decky-traffic topic. - The scapy sniff loop runs in a dedicated worker thread and cannot - ``await`` anything. Every call here schedules the async publish on - the event loop and returns immediately; the sniff thread is never - blocked waiting for the publish to actually land on the wire. + The scapy sniff loop runs in a dedicated worker thread — this adapter + turns ``(decky_name, event_type, payload)`` calls from the engine into + a bus publish on ``decky.{name}.traffic`` without blocking the sniff + thread on the network round-trip. """ + raw = make_thread_safe_publisher(bus, loop) + def _publish(decky_name: str, event_type: str, payload: dict[str, Any]) -> None: topic = _topics.decky(decky_name, _topics.DECKY_TRAFFIC) - try: - asyncio.run_coroutine_threadsafe( - publish_safely(bus, topic, payload, event_type=event_type), - loop, - ) - except Exception as exc: # noqa: BLE001 - logger.debug("sniffer: cross-thread bus publish failed: %s", exc) + raw(topic, payload, event_type) return _publish @@ -200,7 +196,7 @@ async def sniffer_worker(log_file: str) -> None: publish_fn: Callable[[str, str, dict[str, Any]], None] | None = None if bus is not None: - publish_fn = _make_thread_safe_publisher(bus, loop) + publish_fn = _make_decky_traffic_publisher(bus, loop) # Dedicated thread pool so the long-running sniff loop doesn't # occupy a slot in the default asyncio executor. diff --git a/tests/bus/test_topics.py b/tests/bus/test_topics.py index 6d18153b..b06494de 100644 --- a/tests/bus/test_topics.py +++ b/tests/bus/test_topics.py @@ -47,6 +47,7 @@ def test_segment_validation(bad: str) -> None: def test_attacker_builder() -> None: assert topics.attacker(topics.ATTACKER_OBSERVED) == "attacker.observed" assert topics.attacker(topics.ATTACKER_SCORED) == "attacker.scored" + assert topics.attacker(topics.ATTACKER_FINGERPRINTED) == "attacker.fingerprinted" # Dotted leaf is intentional — same as system.bus.health. assert topics.attacker(topics.ATTACKER_SESSION_STARTED) == "attacker.session.started" assert topics.attacker(topics.ATTACKER_SESSION_ENDED) == "attacker.session.ended" diff --git a/tests/prober/__init__.py b/tests/prober/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/prober/test_prober_bus.py b/tests/prober/test_prober_bus.py new file mode 100644 index 00000000..f640d70a --- /dev/null +++ b/tests/prober/test_prober_bus.py @@ -0,0 +1,179 @@ +"""Bus wiring for the attacker prober (DEBT-031, worker 2). + +The prober fingerprints observed attackers (JARM / HASSH / TCPfp) in a +``to_thread`` worker. On each successful probe it publishes an +``attacker.fingerprinted`` event under the shared attacker root; the +probe family (jarm/hassh/tcpfp) goes in ``event.type`` so a single +subscription to ``attacker.fingerprinted`` covers all three. +""" +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest +import pytest_asyncio + +from decnet.bus import topics as _topics +from decnet.bus.fake import FakeBus +from decnet.bus.publish import make_thread_safe_publisher +from decnet.prober.worker import _jarm_phase, _hassh_phase, _tcpfp_phase + + +@pytest_asyncio.fixture +async def bus() -> FakeBus: + b = FakeBus() + await b.connect() + yield b + await b.close() + + +# ─── Phase-level publish hooks ─────────────────────────────────────────────── + +def test_jarm_phase_invokes_publish_fn_on_success(monkeypatch, tmp_path: Path) -> None: + captured: list[tuple[str, dict]] = [] + # Stub jarm_hash so the test doesn't touch the network. + from decnet.prober import worker as worker_mod + monkeypatch.setattr(worker_mod, "jarm_hash", lambda ip, port, timeout: "aabbcc") + + _jarm_phase( + ip="203.0.113.9", + ip_probed={}, + ports=[443], + log_path=tmp_path / "p.log", + json_path=tmp_path / "p.json", + timeout=1.0, + publish_fn=lambda event_type, payload: captured.append((event_type, payload)), + ) + + assert captured == [ + ("jarm", {"attacker_ip": "203.0.113.9", "port": 443, "jarm_hash": "aabbcc"}), + ] + + +def test_jarm_phase_skips_empty_hash(monkeypatch, tmp_path: Path) -> None: + # JARM's empty-hash sentinel means "target didn't negotiate TLS" — not + # an observation worth publishing. + captured: list[tuple[str, dict]] = [] + from decnet.prober import worker as worker_mod + from decnet.prober.jarm import JARM_EMPTY_HASH + monkeypatch.setattr(worker_mod, "jarm_hash", lambda ip, port, timeout: JARM_EMPTY_HASH) + + _jarm_phase( + ip="1.2.3.4", ip_probed={}, ports=[443], + log_path=tmp_path / "p.log", json_path=tmp_path / "p.json", timeout=1.0, + publish_fn=lambda event_type, payload: captured.append((event_type, payload)), + ) + assert captured == [] + + +def test_hassh_phase_invokes_publish_fn_on_success(monkeypatch, tmp_path: Path) -> None: + captured: list[tuple[str, dict]] = [] + from decnet.prober import worker as worker_mod + monkeypatch.setattr( + worker_mod, "hassh_server", + lambda ip, port, timeout: { + "hassh_server": "deadbeef", + "banner": "SSH-2.0-OpenSSH_9.0", + "kex_algorithms": "x", + "encryption_s2c": "y", + "mac_s2c": "z", + "compression_s2c": "none", + }, + ) + + _hassh_phase( + ip="1.2.3.4", ip_probed={}, ports=[22], + log_path=tmp_path / "p.log", json_path=tmp_path / "p.json", timeout=1.0, + publish_fn=lambda event_type, payload: captured.append((event_type, payload)), + ) + + assert captured == [ + ("hassh", { + "attacker_ip": "1.2.3.4", + "port": 22, + "hassh_server": "deadbeef", + "ssh_banner": "SSH-2.0-OpenSSH_9.0", + }), + ] + + +def test_tcpfp_phase_invokes_publish_fn_on_success(monkeypatch, tmp_path: Path) -> None: + captured: list[tuple[str, dict]] = [] + from decnet.prober import worker as worker_mod + monkeypatch.setattr( + worker_mod, "tcp_fingerprint", + lambda ip, port, timeout: { + "tcpfp_hash": "cafef00d", + "tcpfp_raw": "raw", + "ttl": 64, + "window_size": 29200, + "df_bit": True, + "mss": 1460, + "window_scale": 7, + "sack_ok": True, + "timestamp": True, + "options_order": "mss,sack,ts,nop,wscale", + }, + ) + + _tcpfp_phase( + ip="1.2.3.4", ip_probed={}, ports=[80], + log_path=tmp_path / "p.log", json_path=tmp_path / "p.json", timeout=1.0, + publish_fn=lambda event_type, payload: captured.append((event_type, payload)), + ) + assert captured == [ + ("tcpfp", { + "attacker_ip": "1.2.3.4", "port": 80, + "tcpfp_hash": "cafef00d", "ttl": 64, "mss": 1460, + }), + ] + + +def test_phases_run_unchanged_without_publish_fn(monkeypatch, tmp_path: Path) -> None: + # Pre-bus behavior must stay intact when publish_fn is None. The + # phase still writes its log file and marks the port done — it just + # doesn't publish. + from decnet.prober import worker as worker_mod + monkeypatch.setattr(worker_mod, "jarm_hash", lambda ip, port, timeout: "aabbcc") + + ip_probed: dict[str, set[int]] = {} + _jarm_phase( + ip="1.2.3.4", ip_probed=ip_probed, ports=[443], + log_path=tmp_path / "p.log", json_path=tmp_path / "p.json", timeout=1.0, + publish_fn=None, + ) + assert 443 in ip_probed["jarm"] + + +# ─── End-to-end through the bus ────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_prober_publishes_on_attacker_fingerprinted_topic(bus: FakeBus) -> None: + loop = asyncio.get_running_loop() + raw = make_thread_safe_publisher(bus, loop) + + def publish(event_type: str, payload: dict) -> None: + raw(_topics.attacker(_topics.ATTACKER_FINGERPRINTED), payload, event_type) + + sub = bus.subscribe("attacker.fingerprinted") + async with sub: + publish("jarm", {"attacker_ip": "1.2.3.4", "port": 443, "jarm_hash": "h"}) + event = await asyncio.wait_for(sub.__anext__(), timeout=2.0) + + assert event.topic == "attacker.fingerprinted" + assert event.type == "jarm" + assert event.payload == {"attacker_ip": "1.2.3.4", "port": 443, "jarm_hash": "h"} + + +@pytest.mark.asyncio +async def test_prober_degrades_cleanly_when_bus_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + # DECNET_BUS_ENABLED=false returns NullBus; connect() + publish() must + # be no-op and never raise. + from decnet.bus.factory import get_bus + + monkeypatch.setenv("DECNET_BUS_ENABLED", "false") + b = get_bus(client_name="prober") + await b.connect() + await b.publish("attacker.fingerprinted", {"x": 1}, event_type="jarm") + await b.close() diff --git a/tests/sniffer/test_sniffer_bus.py b/tests/sniffer/test_sniffer_bus.py index 84581a57..c5756caf 100644 --- a/tests/sniffer/test_sniffer_bus.py +++ b/tests/sniffer/test_sniffer_bus.py @@ -22,7 +22,7 @@ import pytest_asyncio from decnet.bus import topics as _topics from decnet.bus.fake import FakeBus from decnet.sniffer.fingerprint import SnifferEngine -from decnet.sniffer.worker import _make_thread_safe_publisher +from decnet.sniffer.worker import _make_decky_traffic_publisher @pytest_asyncio.fixture @@ -145,7 +145,7 @@ async def test_sniffer_worker_degrades_cleanly_when_bus_disabled( @pytest.mark.asyncio async def test_thread_safe_publisher_routes_to_decky_traffic_topic(bus: FakeBus) -> None: loop = asyncio.get_running_loop() - publish = _make_thread_safe_publisher(bus, loop) + publish = _make_decky_traffic_publisher(bus, loop) sub = bus.subscribe(f"{_topics.DECKY}.*.{_topics.DECKY_TRAFFIC}") async with sub: