Files
DECNET/decnet/ttp/worker.py
anti 44ade3eb63 fix(ttp): E.3.18a worker hydrates per-lifter rule indexes via watch_store
Each per-source lifter holds its own RuleIndex and exposes an
`async watch_store()` that loads the corpus and drains store change
events forever. Until this commit nothing called `watch_store()` in
production — every dispatch index stayed empty and no rule fired.

- Add `WatchableTagger` runtime-checkable Protocol in `decnet.ttp.base`.
- `CompositeTagger.iter_watchables()` yields lifters that satisfy it.
- `run_ttp_worker_loop` fans out one task per watchable, cancelled
  and awaited alongside pump/heartbeat/control in the existing finally.
- Watch failures log and exit the watch task without taking the
  worker down — mirrors the pump-task tolerance contract.
2026-05-02 01:25:15 -04:00

411 lines
15 KiB
Python

"""Long-running TTP-tagging worker.
E.3.14 of ``development/TTP_TAGGING.md``. Drains the bus topics
declared in :data:`_TOPICS`, dispatches each event through the
:class:`~decnet.ttp.factory.CompositeTagger`, persists the produced
:class:`~decnet.web.db.models.ttp.TTPTag` rows via
:meth:`BaseRepository.insert_tags`, and publishes the documented
``ttp.tagged`` + ``ttp.rule.fired.<technique_id>`` events — but
*only* when ``insert_tags`` reported a non-zero rowcount, per the
"loop-prevention invariant" in TTP_TAGGING.md §"Bus topics".
Bus subscriptions are enumerated as the module-level constant
:data:`_TOPICS` so E.2.12 can assert subscription wiring without
invoking the loop. The constant is the *single source of truth* —
the loop iterates over it; tests introspect it.
The inner loop drains a shared ``asyncio.Queue`` populated by one
task per topic. Each queued item is a ``(topic, Event)`` pair —
the topic decides the lifter family (and therefore the
``source_kind``), the payload carries the per-event identifiers.
Bus loss is tolerated: on transport error the per-topic pump task
exits and the loop falls back to the poll interval, which still
heartbeats and accepts a clean shutdown.
"""
from __future__ import annotations
import asyncio
import contextlib
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, Optional
from decnet import telemetry as _telemetry
from decnet.bus import topics as _topics
from decnet.bus.base import BaseBus, Event
from decnet.bus.factory import get_bus
from decnet.bus.publish import (
run_control_listener_signal as _run_control_listener_signal,
run_health_heartbeat as _run_health_heartbeat,
)
from decnet.logging import get_logger
from decnet.ttp.base import Tagger, TaggerEvent
from decnet.ttp.factory import CompositeTagger, get_tagger
from decnet.web.db.models.ttp import TTPTag
from decnet.web.db.repository import BaseRepository
log = get_logger("ttp.worker")
_DEFAULT_POLL_SECS = 60.0
# Bus topics the worker subscribes to. Kept as a module-level constant
# so E.2.12 can assert subscription wiring without invoking the loop —
# the test introspects this tuple, the loop iterates it. The set
# matches the design doc "Worker shape" section: session-ended primary
# trigger, observed for low-latency rules, intel-enriched + identity
# events for opportunistic re-tag, credential-reuse + email for the
# dedicated lifters, and ``canary.>`` for fleet-wide canary triggers.
_TOPICS: tuple[str, ...] = (
_topics.attacker(_topics.ATTACKER_SESSION_ENDED),
_topics.attacker(_topics.ATTACKER_OBSERVED),
_topics.attacker(_topics.ATTACKER_INTEL_ENRICHED),
_topics.identity(_topics.IDENTITY_FORMED),
_topics.identity(_topics.IDENTITY_MERGED),
_topics.credential(_topics.CREDENTIAL_REUSE_DETECTED),
_topics.email_topic(_topics.EMAIL_RECEIVED),
# Canary triggers carry a per-token segment, so subscribe with the
# multi-token wildcard rather than enumerating per-token. Pattern
# validated against ``decnet.bus.topics.canary()``'s shape.
f"{_topics.CANARY}.>",
)
# Topic-segment → ``source_kind`` for the resulting TaggerEvent. We
# match on a short token contained in the topic so wildcard topics
# (``canary.{id}.triggered``) and per-event topics work uniformly.
_TOPIC_SOURCE_KIND: tuple[tuple[str, str], ...] = (
("session.ended", "session"),
("observed", "session"),
("intel.enriched", "intel"),
("identity.formed", "identity"),
("identity.merged", "identity"),
("reuse.detected", "credential"),
("email.received", "email"),
("canary.", "canary_fingerprint"),
)
def _source_kind_for(topic: str) -> str | None:
for fragment, kind in _TOPIC_SOURCE_KIND:
if fragment in topic:
return kind
return None
@contextmanager
def _span(name: str, **attrs: Any) -> Iterator[Any]:
"""Tracing helper short-circuiting on ``DECNET_DEVELOPER_TRACING``.
Same shape as the engine / store helpers — single attribute lookup
when off, late-bound tracer when on so test monkeypatches reach us.
"""
if not _telemetry._ENABLED:
yield None
return
tracer = _telemetry.get_tracer("ttp.worker")
with tracer.start_as_current_span(name) as span:
for key, value in attrs.items():
try:
span.set_attribute(key, value)
except (TypeError, ValueError):
continue
yield span
def _build_event(topic: str, payload: dict[str, Any]) -> TaggerEvent | None:
"""Translate one bus payload into a :class:`TaggerEvent`.
Returns ``None`` if the topic isn't one we know how to dispatch
(defensive — :data:`_TOPICS` and :data:`_TOPIC_SOURCE_KIND` are
kept in sync, but a wildcard subscription could in theory deliver
a topic outside the table).
``source_id`` is the stable per-event identifier the repository
uses for idempotency. We prefer the most-specific ID present in
the payload so a replay of the same upstream event produces the
same :func:`compute_tag_uuid` and the ``INSERT OR IGNORE`` write
becomes a no-op the second time around. The order below is the
same priority list the lifters use internally.
"""
source_kind = _source_kind_for(topic)
if source_kind is None:
return None
source_id = (
payload.get("source_id")
or payload.get("session_id")
or payload.get("token_id")
or payload.get("identity_uuid")
or payload.get("credential_id")
or payload.get("attacker_uuid")
or payload.get("uuid")
or topic
)
return TaggerEvent(
source_kind=source_kind,
source_id=str(source_id),
attacker_uuid=_str_or_none(payload.get("attacker_uuid")),
identity_uuid=_str_or_none(payload.get("identity_uuid")),
session_id=_str_or_none(payload.get("session_id")),
decky_id=_str_or_none(payload.get("decky_id")),
payload=dict(payload),
)
def _str_or_none(value: Any) -> str | None:
if value is None:
return None
return str(value)
async def run_ttp_worker_loop(
repo: BaseRepository,
*,
poll_interval_secs: float = _DEFAULT_POLL_SECS,
tagger: Optional[Tagger] = None,
shutdown: Optional[asyncio.Event] = None,
bus: Optional[BaseBus] = None,
) -> None:
"""Run the TTP-tagging loop until cancelled.
*tagger* defaults to :func:`decnet.ttp.factory.get_tagger`; tests
pass a fake. *shutdown* is an optional external stop signal; the
loop also exits cleanly on :class:`asyncio.CancelledError` and
:class:`KeyboardInterrupt`. *bus* is an optional pre-wired bus;
when omitted the worker calls :func:`get_bus` itself, falling back
to poll-only when the bus is unavailable (typical dev box without
a NATS daemon).
"""
if tagger is None:
tagger = get_tagger()
log.info(
"ttp worker started tagger=%s poll_interval_secs=%s topics=%d",
tagger.name, poll_interval_secs, len(_TOPICS),
)
owned_bus = False
queue: asyncio.Queue[tuple[str, Event] | None] = asyncio.Queue()
pump_tasks: list[asyncio.Task[None]] = []
watch_tasks: list[asyncio.Task[None]] = []
heartbeat_task: Optional[asyncio.Task[None]] = None
control_task: Optional[asyncio.Task[None]] = None
# Hydrate per-lifter rule indexes. Each WatchableTagger
# (CompositeTagger children + the RuleEngineTagger) owns its own
# RuleIndex and drains store change events forever via
# `watch_store`. Without these tasks every dispatch index stays
# empty and no rule fires — the bus subscriptions work, the
# pump tasks run, and tagger.tag() returns [] every call. Tasks
# are independent of the bus, so this fan-out runs even in
# poll-only mode.
if isinstance(tagger, CompositeTagger):
for watchable in tagger.iter_watchables():
watch_tasks.append(asyncio.create_task(
_run_watch(watchable),
))
try:
if bus is None:
try:
candidate = get_bus(client_name="ttp")
await candidate.connect()
bus = candidate
owned_bus = True
except Exception as exc: # noqa: BLE001
log.warning(
"ttp worker: bus unavailable, running in poll-only mode: %s",
exc,
)
bus = None
if bus is not None:
for pattern in _TOPICS:
pump_tasks.append(asyncio.create_task(
_pump(bus, queue, pattern),
))
heartbeat_task = asyncio.create_task(
_run_health_heartbeat(bus, "ttp"),
)
control_task = asyncio.create_task(
_run_control_listener_signal(bus, "ttp"),
)
except Exception as exc: # noqa: BLE001
log.warning(
"ttp worker: bus setup failed, running in poll-only mode: %s", exc,
)
if shutdown is None:
shutdown = asyncio.Event()
try:
while not shutdown.is_set():
try:
item = await asyncio.wait_for(
queue.get(), timeout=float(poll_interval_secs),
)
except asyncio.TimeoutError:
continue
if item is None:
continue
topic, event = item
await _process_event(topic, event, tagger, repo, bus)
except (asyncio.CancelledError, KeyboardInterrupt):
log.info("ttp worker stopped")
finally:
for task in pump_tasks:
task.cancel()
for task in watch_tasks:
task.cancel()
if heartbeat_task is not None:
heartbeat_task.cancel()
if control_task is not None:
control_task.cancel()
for task in pump_tasks:
with contextlib.suppress(asyncio.CancelledError, Exception):
await task
for task in watch_tasks:
with contextlib.suppress(asyncio.CancelledError, Exception):
await task
for opt in (heartbeat_task, control_task):
if opt is None:
continue
with contextlib.suppress(asyncio.CancelledError, Exception):
await opt
if owned_bus and bus is not None:
with contextlib.suppress(Exception):
await bus.close()
async def _process_event(
topic: str,
event: Event,
tagger: Tagger,
repo: BaseRepository,
bus: BaseBus | None,
) -> None:
"""Dispatch one event through the tagger, persist, publish if new.
Loop-prevention invariant: ``ttp.tagged`` is published ONLY when
:meth:`BaseRepository.insert_tags` returned a non-zero count. A
replay of the same upstream event hits the idempotent
``INSERT OR IGNORE`` and writes zero rows → publishes zero events.
"""
tagger_event = _build_event(topic, event.payload)
if tagger_event is None:
return
with _span(
"ttp.worker.tick",
topic=topic,
source_kind=tagger_event.source_kind,
):
try:
tags = await tagger.tag(tagger_event)
except Exception: # noqa: BLE001
# Composite + TolerantTagger normally swallow per-lifter
# blow-ups already; this is the worst-case backstop so a
# single bad event can't take down the whole loop.
log.exception(
"ttp worker: tagger raised on topic=%r", topic,
)
return
if not tags:
return
try:
inserted = await repo.insert_tags(tags)
except Exception: # noqa: BLE001
log.exception(
"ttp worker: insert_tags failed on topic=%r", topic,
)
return
if inserted <= 0:
# Idempotent re-eval — the loop-prevention invariant
# forbids publishing here.
return
if bus is not None:
await _publish_tagged(bus, tags)
async def _publish_tagged(bus: BaseBus, tags: list[TTPTag]) -> None:
"""Publish ``ttp.tagged`` + per-technique ``ttp.rule.fired.*``.
``ttp.tagged`` carries the deduped technique list so a SIEM
subscriber can correlate without a DB read; per-technique fires
are 1:1 with the technique IDs touched by this batch (deduped so
a single batch produces one ``ttp.rule.fired.T1110`` even if
three rules emitted T1110).
"""
if not tags:
return
techniques = sorted({t.technique_id for t in tags})
aggregate_payload: dict[str, Any] = {
"attacker_uuid": tags[0].attacker_uuid,
"identity_uuid": tags[0].identity_uuid,
"session_id": tags[0].session_id,
"tag_uuids": [t.uuid for t in tags],
"techniques_added": techniques,
}
await bus.publish(
_topics.ttp(_topics.TTP_TAGGED),
aggregate_payload,
event_type=_topics.TTP_TAGGED,
)
for technique_id in techniques:
per_tech_payload: dict[str, Any] = {
"technique_id": technique_id,
"tag_uuids": [t.uuid for t in tags if t.technique_id == technique_id],
"attacker_uuid": tags[0].attacker_uuid,
"identity_uuid": tags[0].identity_uuid,
"session_id": tags[0].session_id,
}
await bus.publish(
_topics.ttp_rule_fired(technique_id),
per_tech_payload,
event_type=_topics.TTP_RULE_FIRED,
)
async def _run_watch(watchable: Any) -> None:
"""Drive one lifter's ``watch_store()`` coroutine forever.
Mirrors :func:`_pump`'s tolerance contract: a transient store error
logs and exits the watch task without taking the worker down. The
main loop's poll-interval fallback continues to heartbeat; a
subsequent worker restart re-runs the watch fan-out and rehydrates.
"""
name = getattr(watchable, "name", watchable.__class__.__name__)
try:
await watchable.watch_store()
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
log.warning(
"ttp worker: watch_store for %s died (%s); index will not "
"hot-reload until next worker restart", name, exc,
)
async def _pump(
bus: BaseBus,
queue: "asyncio.Queue[tuple[str, Event] | None]",
pattern: str,
) -> None:
"""Forward every event matching *pattern* into *queue*.
Survives transient subscriber errors by logging and exiting; the
poll-interval fallback in the main loop keeps the worker alive
until the next reconnect attempt.
"""
try:
sub = bus.subscribe(pattern)
async with sub:
async for event in sub:
await queue.put((event.topic, event))
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
log.warning(
"ttp worker: subscriber for %s died (%s); falling back to poll",
pattern, exc,
)
__all__ = ["run_ttp_worker_loop", "_TOPICS"]