diff --git a/decnet/profiler/behave_shell/_handler.py b/decnet/profiler/behave_shell/_handler.py new file mode 100644 index 00000000..00bf2ed4 --- /dev/null +++ b/decnet/profiler/behave_shell/_handler.py @@ -0,0 +1,219 @@ +"""``attacker.session.ended`` handler — Phase 4 wiring. + +Pure handler module: takes a payload (from bus or poll fallback), +disk-reaches the asciinema shard, runs ``extract_session()``, +upserts observations, and publishes them on the bus best-effort. +Lives outside ``worker.py`` so unit tests can exercise it without +spinning up the asyncio worker loop. + +Trigger isolation: every public entry point is wrapped in a single +try/except in the worker; this module is allowed to raise. The worker +logs and continues with the next event. +""" +from __future__ import annotations + +import collections +import json +from pathlib import Path +from typing import Any, Callable, Iterable, Optional + +from decnet_behave_core.spec.envelope import Observation +from decnet_behave_shell.spec.event_adapter import event_topic_for, to_event_payload + +from decnet.logging import get_logger +from decnet.profiler.behave_shell import extract_session +from decnet.profiler.behave_shell._parse import AsciinemaEvent, parse_shard_line +from decnet.web.db.repository import BaseRepository + +log = get_logger("profiler.behave_handler") + +PublishFn = Callable[[str, dict[str, Any], str], None] +"""Bus-publish callable (sync). The thread-safe publisher returned by +``decnet.bus.publish.make_thread_safe_publisher`` matches this shape; +``None`` is also accepted (no-op publish path).""" + +_REQUIRED_FIELDS: tuple[str, ...] = ( + "session_id", "decky_id", "service", "attacker_ip", +) + + +def _build_evidence_ref(decky: str, service: str, shard_path: str, sid: str) -> str: + """Canonical ``shard:{decky}/{service}/{shard_basename}#{sid}`` pointer. + + Stays a *pointer*, never the evidence itself. Worker uses it as + the idempotency key against the ``observations`` table. + """ + basename = Path(shard_path).name + return f"shard:{decky}/{service}/{basename}#{sid}" + + +def _events_for_sid(shard: Path, sid: str) -> list[AsciinemaEvent]: + """Read ``shard``, return parsed events whose ``sid`` matches. + + Mirrors the loader pattern in + ``tests/profiler/behave_shell/test_calibration_grid.py``: skip + headers / non-matching sids / unparseable lines silently. + """ + events: list[AsciinemaEvent] = [] + with shard.open() as f: + for line in f: + try: + rec = json.loads(line) + except (ValueError, json.JSONDecodeError): + continue + if not isinstance(rec, dict): + continue + if rec.get("sid") != sid or "hdr" in rec: + continue + ev = parse_shard_line(line) + if ev is not None: + events.append(ev) + return events + + +def _flatten_observation(obs: Observation, attacker_uuid: str) -> dict[str, Any]: + """Project a BEHAVE Observation onto the ObservationRow column shape. + + Mirrors the storage schema in + ``decnet/web/db/models/observations.py`` — flattens + ``window.{start,end}_ts`` and stamps the DECNET-side + ``attacker_uuid`` denorm. ``id`` / ``ts`` / ``v`` / ``identity_ref`` + / ``evidence_ref`` ride through unchanged. + """ + return { + "id": obs.id, + "identity_ref": obs.identity_ref, + "primitive": obs.primitive, + "value": obs.value, + "confidence": obs.confidence, + "window_start_ts": obs.window.start_ts, + "window_end_ts": obs.window.end_ts, + "source": obs.source, + "evidence_ref": obs.evidence_ref, + "envelope_v": obs.v, + "ts": obs.ts, + "attacker_uuid": attacker_uuid, + } + + +def _publish_observation(publish: Optional[PublishFn], obs: Observation) -> None: + """Best-effort publish; never raise. Re-merges id/ts/v into payload + per BEHAVE-INTEGRATION.md §339-366 deviation note.""" + if publish is None: + return + payload = to_event_payload(obs) | { + "id": obs.id, + "ts": obs.ts, + "v": obs.v, + } + try: + publish(event_topic_for(obs.primitive), payload, obs.primitive) + except Exception as exc: # noqa: BLE001 + log.debug( + "behave_handler: publish failed for primitive=%s: %s", + obs.primitive, exc, + ) + + +async def handle_session_ended( + repo: BaseRepository, + payload: dict[str, Any], + publish: Optional[PublishFn] = None, +) -> int: + """Process one ``attacker.session.ended`` event end-to-end. + + Returns the number of observations persisted (zero on any skip + path: missing fields, missing shard, idempotency hit, attacker + not yet resolved, sid not in shard, extractor produced nothing). + + Order: persist first, publish best-effort. DB is the source of + truth (see BEHAVE-INTEGRATION.md §"Persistence"). + """ + # 1. Required-field guard. + missing = [k for k in _REQUIRED_FIELDS if not payload.get(k)] + if missing: + log.debug( + "behave_handler: skipping session.ended (missing fields=%s)", + missing, + ) + return 0 + shard_path = payload.get("shard_path") + if not shard_path: + log.debug("behave_handler: skipping session.ended (no shard_path)") + return 0 + + sid = str(payload["session_id"]) + decky = str(payload["decky_id"]) + service = str(payload["service"]) + attacker_ip = str(payload["attacker_ip"]) + + # 2. evidence_ref + idempotency. + evidence_ref = _build_evidence_ref(decky, service, str(shard_path), sid) + if await repo.has_observations_for_evidence(evidence_ref): + log.debug( + "behave_handler: already profiled evidence_ref=%s", evidence_ref, + ) + return 0 + + # 3. Resolve attacker_uuid. Skip until profiler tick has materialised + # the Attacker row — same posture as TTP's _resolve_attacker_uuid. + attacker_uuid = await repo.get_attacker_uuid_by_ip(attacker_ip) + if not attacker_uuid: + log.info( + "behave_handler: no Attacker row for ip=%s yet; deferring", + attacker_ip, + ) + return 0 + + # 4. Load shard, slice events. + shard = Path(shard_path) + if not shard.is_file(): + log.info( + "behave_handler: shard not on disk yet path=%s sid=%s; deferring", + shard_path, sid, + ) + return 0 + events = _events_for_sid(shard, sid) + if not events: + log.info( + "behave_handler: sid=%s not present in shard=%s; skipping", + sid, shard_path, + ) + return 0 + + # 5. Extract. + observations: list[Observation] = [] + for obs in extract_session(events, sid=sid, evidence_ref=evidence_ref): + observations.append(obs) + if not observations: + log.info( + "behave_handler: extractor produced zero observations sid=%s", + sid, + ) + return 0 + + # 6. Persist. Per-row upsert via the existing repo method; the + # idempotency unique index makes accidental duplicates impossible. + # Any per-row failure aborts publishing — DB is source of truth. + persisted = 0 + for obs in observations: + await repo.upsert_observation(_flatten_observation(obs, attacker_uuid)) + persisted += 1 + + # 7. Publish — fire-and-forget, never raises out. + for obs in observations: + _publish_observation(publish, obs) + + log.info( + "behave_handler: persisted=%d primitives sid=%s attacker_ip=%s", + persisted, sid, attacker_ip, + ) + return persisted + + +def primitive_counts(observations: Iterable[Observation]) -> dict[str, int]: + """Small debug helper: count emissions by primitive name.""" + counter: collections.Counter[str] = collections.Counter() + for obs in observations: + counter[obs.primitive] += 1 + return dict(counter) diff --git a/decnet/profiler/worker.py b/decnet/profiler/worker.py index 9ac19098..501fd2d9 100644 --- a/decnet/profiler/worker.py +++ b/decnet/profiler/worker.py @@ -20,7 +20,9 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Callable +from decnet.artifacts.shards import find_shard_with_sid from decnet.bus import topics as _topics +from decnet.bus.base import BaseBus, Event from decnet.bus.factory import get_bus from decnet.bus.publish import ( make_thread_safe_publisher, @@ -33,6 +35,7 @@ from decnet.asn import enrich_ip as enrich_ip_asn from decnet.geoip import enrich_ip from decnet.geoip.ptr import resolve_ptr_record from decnet.logging import get_logger +from decnet.profiler.behave_shell._handler import handle_session_ended from decnet.profiler.behavioral import build_behavior_record from decnet.telemetry import traced as _traced, get_tracer as _get_tracer from decnet.web.db.repository import BaseRepository @@ -41,6 +44,14 @@ logger = get_logger("attacker_worker") _BATCH_SIZE = 500 _STATE_KEY = "attacker_worker_cursor" +# Separate cursor for the BEHAVE-SHELL poll fallback so it doesn't +# conflate with the correlation tick's log-id cursor (memory rule: +# "Poll fallback's Log cursor — use a separate state key"). +_BEHAVE_POLL_STATE_KEY = "attacker_worker_session_cursor" +# Pattern the bus subscription matches. Single-topic for BEHAVE-SHELL +# wiring; matches what the collector publishes from +# ``_SessionAggregator._emit_session``. +_BEHAVE_TOPIC = _topics.attacker(_topics.ATTACKER_SESSION_ENDED) # Event types that indicate active command/query execution — the # shell-family subset of INTERACTION_EVENT_TYPES in @@ -126,6 +137,17 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) - control_task = asyncio.create_task( run_control_listener(bus, "profiler", shutdown), ) + + # BEHAVE-SHELL session-ended handler — bus subscription pump (when + # bus is available) feeds an asyncio.Queue; the tick body drains + # the queue per iteration. Same shape as decnet/ttp/worker.py. + behave_queue: "asyncio.Queue[tuple[str, Event] | None]" = asyncio.Queue() + behave_pump_task: asyncio.Task[None] | None = None + if bus is not None: + behave_pump_task = asyncio.create_task( + _behave_pump(bus, behave_queue), + ) + try: while not shutdown.is_set(): try: @@ -138,11 +160,26 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) - await _incremental_update(repo, state) except Exception as exc: logger.error("attacker worker: update failed: %s", exc) + # BEHAVE-SHELL drain (bus path). + await _drain_behave_queue(repo, behave_queue, raw_publish) + # BEHAVE-SHELL poll fallback. Always runs — when bus is up + # this catches anything the subscription missed during a + # transient reconnect; when bus is down it's the only path. + try: + await _behave_poll_tick(repo, raw_publish) + except Exception as exc: # noqa: BLE001 + logger.error( + "attacker worker: behave poll tick failed: %s", exc, + ) finally: for t in (heartbeat_task, control_task): t.cancel() with contextlib.suppress(Exception, asyncio.CancelledError): await t + if behave_pump_task is not None: + behave_pump_task.cancel() + with contextlib.suppress(Exception, asyncio.CancelledError): + await behave_pump_task if bus is not None: with contextlib.suppress(Exception): await bus.close() @@ -440,3 +477,132 @@ def _extract_smtp_domains(events: list[LogEvent]) -> set[str]: if domain: domains.add(domain) return domains + + +# ── BEHAVE-SHELL session-ended wiring (Phase 4) ───────────────────────────── + + +async def _behave_pump( + bus: BaseBus, + queue: "asyncio.Queue[tuple[str, Event] | None]", +) -> None: + """Forward every ``attacker.session.ended`` event into ``queue``. + + Tolerance contract mirrors :func:`decnet.ttp.worker._pump`: the + subscriber dies → log-and-fall-back-to-poll, never crash the worker + loop. The poll path (always-on per tick) catches anything missed + while the subscription is down. + """ + try: + sub = bus.subscribe(_BEHAVE_TOPIC) + async with sub: + async for event in sub: + await queue.put((event.topic, event)) + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 + logger.warning( + "attacker worker: behave subscriber for %s died (%s); " + "falling back to poll", _BEHAVE_TOPIC, exc, + ) + + +async def _drain_behave_queue( + repo: BaseRepository, + queue: "asyncio.Queue[tuple[str, Event] | None]", + publish: Callable[[str, dict[str, Any], str], None] | None, +) -> None: + """Drain queued ``attacker.session.ended`` events through the + handler. Each handler invocation is isolated — exceptions log and + do not block the next event.""" + while not queue.empty(): + item = queue.get_nowait() + if item is None: + continue + _topic, event = item + try: + await handle_session_ended(repo, event.payload, publish) + except Exception as exc: # noqa: BLE001 + logger.error( + "attacker worker: behave handler raised on bus path: %s", exc, + ) + + +async def _behave_poll_tick( + repo: BaseRepository, + publish: Callable[[str, dict[str, Any], str], None] | None, +) -> None: + """Poll fallback: scan ``Log`` rows after the saved cursor for + ``event_type='session_recorded'`` and call the handler for any + not yet profiled. + + Cursor is stored under :data:`_BEHAVE_POLL_STATE_KEY`, separate from + the correlation tick's cursor so the two never conflate. + """ + cursor_state = await repo.get_state(_BEHAVE_POLL_STATE_KEY) or {} + last_id = int(cursor_state.get("last_log_id", 0)) + rows = await repo.get_logs_after_id(last_id, limit=_BATCH_SIZE) + if not rows: + return + new_cursor = last_id + for row in rows: + new_cursor = max(new_cursor, int(row.get("id", 0))) + if row.get("event_type") != "session_recorded": + continue + payload = _payload_from_log_row(row) + if payload is None: + continue + try: + await handle_session_ended(repo, payload, publish) + except Exception as exc: # noqa: BLE001 + logger.error( + "attacker worker: behave handler raised on poll path: %s", exc, + ) + if new_cursor > last_id: + await repo.set_state( + _BEHAVE_POLL_STATE_KEY, {"last_log_id": new_cursor}, + ) + + +def _payload_from_log_row(row: dict[str, Any]) -> dict[str, Any] | None: + """Project a ``session_recorded`` Log row into the same shape the + collector publishes on the bus. + + Returns ``None`` when required fields are missing — the handler + has its own guard, but pre-filtering here avoids the round-trip to + the handler's logger for malformed rows. + """ + fields_raw = row.get("fields") or "{}" + if isinstance(fields_raw, dict): + fields = fields_raw + else: + try: + fields = json.loads(fields_raw) + except (ValueError, TypeError): + return None + sid = fields.get("sid") + decky = row.get("decky") + service = fields.get("service") or row.get("service") + attacker_ip = row.get("attacker_ip") + if not (sid and decky and service and attacker_ip): + return None + # Resolve shard_path locally — the Log row may not carry one + # (sessrec.c does not yet emit fields.shard_path). + shard_path: str | None = None + try: + resolved = find_shard_with_sid(str(decky), str(service), str(sid)) + except (ValueError, OSError, PermissionError): + resolved = None + if resolved is not None: + shard_path = str(resolved) + return { + "session_id": str(sid), + "attacker_uuid": None, + "attacker_ip": str(attacker_ip), + "decky_id": str(decky), + "service": str(service), + "ended_at": row.get("timestamp"), + "duration_s": fields.get("duration_s"), + "commands": [], + "shard_path": shard_path, + } diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index 399c8105..8ffc86d0 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -341,6 +341,16 @@ class BaseRepository(ABC): ordered by ``ts`` ASC. Empty list when none.""" raise NotImplementedError + @abstractmethod + async def has_observations_for_evidence(self, evidence_ref: str) -> bool: + """True iff any observation row carries this ``evidence_ref``. + + Worker uses this as the "have we already profiled this session?" + check before kicking the BEHAVE-SHELL extractor — equivalent + to "is this ``(decky, service, sid)`` already in the table?". + """ + raise NotImplementedError + async def upsert_observed_attachment( self, *, diff --git a/tests/profiler/behave_shell/test_handler_session_ended.py b/tests/profiler/behave_shell/test_handler_session_ended.py new file mode 100644 index 00000000..9a058177 --- /dev/null +++ b/tests/profiler/behave_shell/test_handler_session_ended.py @@ -0,0 +1,177 @@ +"""Unit tests for ``decnet.profiler.behave_shell._handler``. + +Direct exercise of ``handle_session_ended()`` without the worker loop +or a real bus. The handler is the load-bearing piece — bus / poll +fallback paths in the worker just feed it. Pin the contract here. +""" +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from decnet.profiler.behave_shell._handler import ( + _build_evidence_ref, + handle_session_ended, +) + + +_SID = "11111111-2222-3333-4444-555555555555" +_DECKY = "test-decky" +_SERVICE = "ssh" +_IP = "10.0.0.5" +_ATTACKER_UUID = "att-uuid-abc" + + +def _write_shard(tmp_path, sid: str, lines: list[dict]) -> str: + """Write a synthetic asciinema shard JSONL and return its path.""" + shard_dir = tmp_path / _DECKY / _SERVICE / "transcripts" + shard_dir.mkdir(parents=True, exist_ok=True) + shard = shard_dir / "sessions-2026-05-08.jsonl" + with shard.open("w") as f: + for line in lines: + f.write(json.dumps(line) + "\n") + return str(shard) + + +def _shard_with_typing_session(tmp_path, sid: str = _SID) -> str: + """A minimal session with enough events to fire the calibration floor.""" + lines = [{"sid": sid, "hdr": {"version": 2, "width": 80, "height": 24, + "timestamp": 1714521600}}] + text = "ls\rps\rid\rwhoami\rpwd\runame\r" + for i, c in enumerate(text): + lines.append({"sid": sid, "t": i * 0.05, "ch": "i", "d": c}) + lines.append({"sid": sid, "t": 5.0, "ch": "o", "d": "anti@host:~$ "}) + return _write_shard(tmp_path, sid, lines) + + +def _payload(shard_path: str | None) -> dict[str, Any]: + return { + "session_id": _SID, + "attacker_uuid": None, + "attacker_ip": _IP, + "decky_id": _DECKY, + "service": _SERVICE, + "ended_at": "2026-05-08T10:00:00", + "duration_s": 5.0, + "commands": [], + "shard_path": shard_path, + } + + +def _make_repo(*, has_evidence: bool = False, attacker_uuid: str | None = _ATTACKER_UUID): + repo = AsyncMock() + repo.has_observations_for_evidence = AsyncMock(return_value=has_evidence) + repo.get_attacker_uuid_by_ip = AsyncMock(return_value=attacker_uuid) + repo.upsert_observation = AsyncMock(return_value="row-uuid") + return repo + + +def test_evidence_ref_shape() -> None: + ref = _build_evidence_ref( + "deck", "ssh", "/var/lib/decnet/artifacts/deck/ssh/transcripts/sessions-2026-05-08.jsonl", + "abc", + ) + assert ref == "shard:deck/ssh/sessions-2026-05-08.jsonl#abc" + + +async def test_happy_path_persists_and_publishes(tmp_path) -> None: + shard_path = _shard_with_typing_session(tmp_path) + repo = _make_repo() + published: list[tuple[str, dict[str, Any], str]] = [] + publish = lambda topic, payload, etype: published.append((topic, payload, etype)) + + n = await handle_session_ended(repo, _payload(shard_path), publish) + + assert n > 0 + assert repo.upsert_observation.await_count == n + # Every persistence row must include the attacker_uuid denorm. + for call in repo.upsert_observation.await_args_list: + row = call.args[0] + assert row["attacker_uuid"] == _ATTACKER_UUID + assert row["evidence_ref"].startswith("shard:") + # Bus published once per observation. + assert len(published) == n + for topic, payload, etype in published: + assert topic.startswith("attacker.observation.") + # Adapter excludes id/ts/v from payload body; handler re-merges. + assert "id" in payload and "ts" in payload and "v" in payload + + +async def test_missing_session_id_skipped(tmp_path) -> None: + shard_path = _shard_with_typing_session(tmp_path) + p = _payload(shard_path) + p["session_id"] = None + repo = _make_repo() + n = await handle_session_ended(repo, p, None) + assert n == 0 + repo.upsert_observation.assert_not_awaited() + + +async def test_missing_shard_path_skipped(tmp_path) -> None: + repo = _make_repo() + n = await handle_session_ended(repo, _payload(None), None) + assert n == 0 + repo.has_observations_for_evidence.assert_not_awaited() + + +async def test_already_profiled_skipped(tmp_path) -> None: + """Idempotency: handler returns 0 if has_observations_for_evidence True.""" + shard_path = _shard_with_typing_session(tmp_path) + repo = _make_repo(has_evidence=True) + n = await handle_session_ended(repo, _payload(shard_path), None) + assert n == 0 + repo.get_attacker_uuid_by_ip.assert_not_awaited() + repo.upsert_observation.assert_not_awaited() + + +async def test_attacker_uuid_unresolved_defers(tmp_path) -> None: + """Cold IP — no Attacker row yet. Skip and let the next tick retry.""" + shard_path = _shard_with_typing_session(tmp_path) + repo = _make_repo(attacker_uuid=None) + n = await handle_session_ended(repo, _payload(shard_path), None) + assert n == 0 + repo.upsert_observation.assert_not_awaited() + + +async def test_shard_missing_on_disk_defers(tmp_path) -> None: + """shard_path points at a file that hasn't been flushed yet.""" + fake_path = str(tmp_path / "nope" / "sessions-2026-05-08.jsonl") + repo = _make_repo() + n = await handle_session_ended(repo, _payload(fake_path), None) + assert n == 0 + repo.upsert_observation.assert_not_awaited() + + +async def test_sid_not_in_shard_skipped(tmp_path) -> None: + """Shard exists but doesn't contain our sid.""" + other_sid = "ffffffff-eeee-dddd-cccc-bbbbbbbbbbbb" + shard_path = _shard_with_typing_session(tmp_path, sid=other_sid) + repo = _make_repo() + n = await handle_session_ended(repo, _payload(shard_path), None) + assert n == 0 + repo.upsert_observation.assert_not_awaited() + + +async def test_publish_failure_does_not_raise(tmp_path) -> None: + """Bus publish failures are best-effort; persistence already + succeeded so we don't roll back.""" + shard_path = _shard_with_typing_session(tmp_path) + repo = _make_repo() + + def _bad(*_a: Any, **_k: Any) -> None: + raise RuntimeError("bus exploded") + + n = await handle_session_ended(repo, _payload(shard_path), _bad) + assert n > 0 + assert repo.upsert_observation.await_count == n + + +async def test_publish_none_is_silent(tmp_path) -> None: + """publish=None is the no-op path used in poll-fallback mode.""" + shard_path = _shard_with_typing_session(tmp_path) + repo = _make_repo() + n = await handle_session_ended(repo, _payload(shard_path), None) + assert n > 0 diff --git a/tests/profiler/behave_shell/test_worker_behave_drain.py b/tests/profiler/behave_shell/test_worker_behave_drain.py new file mode 100644 index 00000000..0cd209d1 --- /dev/null +++ b/tests/profiler/behave_shell/test_worker_behave_drain.py @@ -0,0 +1,84 @@ +"""W.3 bus-path drain tests. + +Exercises ``_drain_behave_queue`` directly without the asyncio worker +loop. The handler is unit-tested in +``test_handler_session_ended.py``; this file pins the queue-drain +plumbing (Event unwrapping, isolation against handler exceptions, +empty-queue no-op). +""" +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +from decnet.profiler.worker import _drain_behave_queue + + +async def _make_event(payload: dict[str, Any]): + """Build a minimal Event-like object the drain expects.""" + ev = MagicMock() + ev.topic = "attacker.session.ended" + ev.payload = payload + return ev + + +async def test_drain_empty_queue_is_noop() -> None: + repo = AsyncMock() + queue: asyncio.Queue = asyncio.Queue() + await _drain_behave_queue(repo, queue, None) + repo.has_observations_for_evidence.assert_not_awaited() + + +async def test_drain_skips_none_sentinel() -> None: + repo = AsyncMock() + queue: asyncio.Queue = asyncio.Queue() + await queue.put(None) + await _drain_behave_queue(repo, queue, None) + repo.has_observations_for_evidence.assert_not_awaited() + + +async def test_drain_passes_event_payload_to_handler(monkeypatch) -> None: + """The drain unwraps Event.payload and feeds it to the handler.""" + captured: list[dict[str, Any]] = [] + + async def _fake_handler(repo, payload, publish): + captured.append(payload) + return 0 + + monkeypatch.setattr( + "decnet.profiler.worker.handle_session_ended", _fake_handler, + ) + repo = AsyncMock() + queue: asyncio.Queue = asyncio.Queue() + ev = await _make_event({"session_id": "abc", "decky_id": "d"}) + await queue.put((ev.topic, ev)) + await _drain_behave_queue(repo, queue, None) + assert captured == [{"session_id": "abc", "decky_id": "d"}] + + +async def test_drain_isolates_handler_exception(monkeypatch) -> None: + """A handler that raises must not crash subsequent events.""" + call_count = 0 + + async def _maybe_blowing_handler(repo, payload, publish): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("handler exploded") + return 0 + + monkeypatch.setattr( + "decnet.profiler.worker.handle_session_ended", + _maybe_blowing_handler, + ) + repo = AsyncMock() + queue: asyncio.Queue = asyncio.Queue() + ev1 = await _make_event({"session_id": "a"}) + ev2 = await _make_event({"session_id": "b"}) + await queue.put((ev1.topic, ev1)) + await queue.put((ev2.topic, ev2)) + + # Should not raise; both events should be drained. + await _drain_behave_queue(repo, queue, None) + assert call_count == 2 diff --git a/tests/profiler/behave_shell/test_worker_behave_poll.py b/tests/profiler/behave_shell/test_worker_behave_poll.py new file mode 100644 index 00000000..a7db4546 --- /dev/null +++ b/tests/profiler/behave_shell/test_worker_behave_poll.py @@ -0,0 +1,172 @@ +"""W.3 poll-fallback tests. + +Exercises ``_behave_poll_tick`` and ``_payload_from_log_row`` — +the path used when the bus is unavailable +(``DECNET_BUS_ENABLED=false`` or transient subscriber failure). +""" +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock + +from decnet.profiler.worker import ( + _behave_poll_tick, + _BEHAVE_POLL_STATE_KEY, + _payload_from_log_row, +) + + +def _log_row( + log_id: int = 42, + event_type: str = "session_recorded", + fields: dict | None = None, +) -> dict[str, Any]: + base_fields = {"sid": "11111111-2222-3333-4444-555555555555", + "service": "ssh", "duration_s": "5.0", + "src_ip": "10.0.0.5"} + if fields is not None: + base_fields.update(fields) + return { + "id": log_id, + "event_type": event_type, + "decky": "test-decky", + "service": "ssh", + "attacker_ip": "10.0.0.5", + "timestamp": "2026-05-08T10:00:00", + "fields": json.dumps(base_fields), + } + + +def test_payload_from_log_row_happy() -> None: + payload = _payload_from_log_row(_log_row()) + assert payload is not None + assert payload["session_id"] == "11111111-2222-3333-4444-555555555555" + assert payload["decky_id"] == "test-decky" + assert payload["service"] == "ssh" + assert payload["attacker_ip"] == "10.0.0.5" + # shard_path may be None (no fixture file) — that's the honest + # "skip until next tick" path. + assert "shard_path" in payload + + +def test_payload_from_log_row_returns_none_on_missing_fields() -> None: + """Empty fields blob → required-field guard short-circuits.""" + row = _log_row(fields={"sid": ""}) + row["fields"] = "{}" + assert _payload_from_log_row(row) is None + + +def test_payload_from_log_row_returns_none_on_unparseable_fields() -> None: + row = _log_row() + row["fields"] = "not json" + assert _payload_from_log_row(row) is None + + +async def test_poll_tick_no_logs_does_nothing() -> None: + repo = AsyncMock() + repo.get_state = AsyncMock(return_value=None) + repo.get_logs_after_id = AsyncMock(return_value=[]) + + await _behave_poll_tick(repo, None) + + repo.get_logs_after_id.assert_awaited_once() + repo.set_state.assert_not_awaited() + + +async def test_poll_tick_skips_non_session_recorded_event_types() -> None: + repo = AsyncMock() + repo.get_state = AsyncMock(return_value=None) + repo.get_logs_after_id = AsyncMock(return_value=[ + _log_row(log_id=1, event_type="command"), + _log_row(log_id=2, event_type="connection.opened"), + ]) + + await _behave_poll_tick(repo, None) + + # Cursor still advances even when nothing is processed. + repo.set_state.assert_awaited_once_with( + _BEHAVE_POLL_STATE_KEY, {"last_log_id": 2}, + ) + repo.has_observations_for_evidence.assert_not_awaited() + + +async def test_poll_tick_drives_handler_for_session_recorded(monkeypatch) -> None: + captured: list[dict[str, Any]] = [] + + async def _fake_handler(repo, payload, publish): + captured.append(payload) + return 0 + + monkeypatch.setattr( + "decnet.profiler.worker.handle_session_ended", _fake_handler, + ) + + repo = AsyncMock() + repo.get_state = AsyncMock(return_value={"last_log_id": 0}) + repo.get_logs_after_id = AsyncMock(return_value=[_log_row(log_id=99)]) + + await _behave_poll_tick(repo, None) + + assert len(captured) == 1 + assert captured[0]["session_id"] == "11111111-2222-3333-4444-555555555555" + repo.set_state.assert_awaited_once_with( + _BEHAVE_POLL_STATE_KEY, {"last_log_id": 99}, + ) + + +async def test_poll_tick_uses_separate_cursor_state_key(monkeypatch) -> None: + """Cursor key must be _BEHAVE_POLL_STATE_KEY, NOT + attacker_worker_cursor (which the correlation tick owns).""" + repo = AsyncMock() + repo.get_state = AsyncMock(return_value=None) + repo.get_logs_after_id = AsyncMock(return_value=[_log_row(log_id=5)]) + + async def _noop(*_a, **_k): + return 0 + + monkeypatch.setattr( + "decnet.profiler.worker.handle_session_ended", _noop, + ) + + await _behave_poll_tick(repo, None) + + # Read uses the separate key. + repo.get_state.assert_awaited_with(_BEHAVE_POLL_STATE_KEY) + # Write also uses it. + repo.set_state.assert_awaited_with( + _BEHAVE_POLL_STATE_KEY, {"last_log_id": 5}, + ) + + +async def test_poll_tick_isolates_handler_exception(monkeypatch) -> None: + """A blowing-up handler must not stop cursor advancement on + subsequent rows.""" + call_count = 0 + + async def _maybe_blowing_handler(repo, payload, publish): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("handler exploded") + return 0 + + monkeypatch.setattr( + "decnet.profiler.worker.handle_session_ended", + _maybe_blowing_handler, + ) + + repo = AsyncMock() + repo.get_state = AsyncMock(return_value=None) + repo.get_logs_after_id = AsyncMock(return_value=[ + _log_row(log_id=1), + _log_row(log_id=2), + ]) + + # Should not raise. + await _behave_poll_tick(repo, None) + assert call_count == 2 + # Cursor advanced past both rows. + repo.set_state.assert_awaited_once_with( + _BEHAVE_POLL_STATE_KEY, {"last_log_id": 2}, + )