feat(correlation/attribution): wire bus handler, persist state (Phase 4)

attribution_worker.handle_observation_event now executes the full
end-to-end path:

* ensure stub identity (Phase 1)
* observations_for_identity_primitive() — new repo helper joining
  observations through attackers.identity_id, so v1's clusterer
  gets cross-attacker rollup for free
* aggregate_observations() with ValueKind dispatched off the BEHAVE
  PRIMITIVE_REGISTRY; unknown primitives default to categorical
* upsert_attribution_state() — last_change_ts locked when state is
  unchanged so the dashboard can render "stable since X"
* publish attribution.profile.state_changed only on transition;
  idempotent re-runs over the same observation set fire nothing
  (loop-prevention invariant matching ttp.tagged)

Tests:
* 5 end-to-end attribution scenarios over in-memory SQLite + FakeBus.
* test_base_repo's DummyRepo + coverage body now stub every abstract
  surface BaseRepository declares — the 6 added by this branch plus
  the 12 left un-stubbed by earlier work (BEHAVE Phase 1, TTP
  rollups, iter helpers). The coverage test could not previously
  even instantiate.
* test_aggregate_categorical's dispatcher rejection updated for the
  Phase 3 + 4 contract — ValueError on unknown kinds, not
  NotImplementedError.
This commit is contained in:
2026-05-09 02:16:12 -04:00
parent c39802a4bb
commit dd265d7520
6 changed files with 536 additions and 17 deletions

View File

@@ -26,12 +26,25 @@ from decnet.bus import topics as _topics
from decnet.bus.base import BaseBus
from decnet.bus.factory import get_bus
from decnet.bus.publish import (
publish_safely,
run_control_listener_signal as _run_control_listener_signal,
run_health_heartbeat as _run_health_heartbeat,
)
from decnet.correlation.attribution.aggregate import aggregate_observations
from decnet.logging import get_logger
from decnet.web.db.repository import BaseRepository
try:
from decnet_behave_shell.spec import (
PRIMITIVE_REGISTRY,
ValueKind,
)
_BEHAVE_REGISTRY_AVAILABLE = True
except ImportError: # pragma: no cover
PRIMITIVE_REGISTRY = {}
ValueKind = None
_BEHAVE_REGISTRY_AVAILABLE = False
log = get_logger("correlation.attribution_worker")
_WORKER_NAME = "attribution"
@@ -156,13 +169,103 @@ async def handle_observation_event(
attacker_uuid,
)
return
# Phase 4 will run the merger here and emit
# ``attribution.profile.state_changed`` on transition. Phase 1
# ends with stub materialisation only.
log.debug(
"attribution worker: stub identity=%s for attacker=%s primitive=%s",
identity_uuid, attacker_uuid, primitive,
primitive_str = str(primitive)
# Load the full per-(identity, primitive) observation series.
# v0 with 1:1 stub identities, this is the single attacker's
# series; v1's clusterer makes it a cross-attacker union.
observations = await repo.observations_for_identity_primitive(
identity_uuid, primitive_str,
)
if not observations:
log.debug(
"attribution worker: no observations yet for identity=%s "
"primitive=%s (race with upsert)",
identity_uuid, primitive_str,
)
return
# Run merger.
value_kind = _value_kind_for(primitive_str)
new_state = aggregate_observations(observations, value_kind=value_kind)
# Load prior state to detect transitions.
prior = await repo.get_attribution_state(identity_uuid, primitive_str)
state_changed = prior is None or prior.get("state") != new_state.state
# Persist. last_change_ts is locked to the prior row when state is
# unchanged so the dashboard's "stable since" timestamp doesn't
# reset on every observation.
if prior is not None and not state_changed:
last_change_ts = float(prior.get("last_change_ts", new_state.last_observation_ts))
else:
last_change_ts = new_state.last_observation_ts
await repo.upsert_attribution_state({
"identity_uuid": identity_uuid,
"primitive": primitive_str,
"current_value": new_state.current_value,
"state": new_state.state,
"confidence": new_state.confidence,
"observation_count": new_state.observation_count,
"last_change_ts": last_change_ts,
"last_observation_ts": new_state.last_observation_ts,
})
# Emit state_changed only on transition. Idempotent re-runs (same
# observations, same merger output) produce no event — matches
# the loop-prevention invariant that ttp.tagged uses.
if state_changed and bus is not None:
await publish_safely(
bus,
_topics.attribution(_topics.ATTRIBUTION_PROFILE_STATE_CHANGED),
{
"identity_uuid": identity_uuid,
"primitive": primitive_str,
"old_state": prior.get("state") if prior else None,
"new_state": new_state.state,
"current_value": new_state.current_value,
"confidence": new_state.confidence,
"observation_count": new_state.observation_count,
"ts": new_state.last_observation_ts,
},
event_type=_topics.ATTRIBUTION_PROFILE_STATE_CHANGED,
)
log.info(
"attribution worker: identity=%s primitive=%s %s -> %s confidence=%.2f",
identity_uuid, primitive_str,
(prior or {}).get("state") or "<new>", new_state.state,
new_state.confidence,
)
def _value_kind_for(primitive: str) -> str:
"""Resolve a BEHAVE primitive name to the merger's ValueKind tag.
Maps the BEHAVE registry's ``ValueKind`` enum onto the three
mergers the engine ships:
* ``CATEGORICAL`` / ``BOOL`` / ``FREE_STRING`` / ``ARRAY`` →
``"categorical"`` (BOOL is a 2-cardinality categorical;
FREE_STRING and ARRAY collapse to opaque-token categorical
until a v1 specialised merger lands)
* ``NUMERIC`` → ``"numeric"``
* ``HASH`` → ``"hash"``
Unknown primitives (registry miss) default to categorical — the
safest fallback because the categorical merger is one-outlier-
tolerant and won't lie about confidence on noisy categorical
data the way a numeric merger would on non-numeric values.
"""
if not _BEHAVE_REGISTRY_AVAILABLE:
return "categorical"
spec = PRIMITIVE_REGISTRY.get(primitive)
if spec is None or ValueKind is None:
return "categorical"
if spec.kind is ValueKind.NUMERIC:
return "numeric"
if spec.kind is ValueKind.HASH:
return "hash"
return "categorical"
def _payload_of(event: Any) -> dict[str, Any]:

View File

@@ -341,6 +341,20 @@ class BaseRepository(ABC):
ordered by ``ts`` ASC. Empty list when none."""
raise NotImplementedError
@abstractmethod
async def observations_for_identity_primitive(
self, identity_uuid: str, primitive: str,
) -> list[dict[str, Any]]:
"""Every observation of ``primitive`` across all attackers
rolling up to ``identity_uuid``, ordered by ``ts`` ASC.
Empty list when the identity has no observations of this
primitive. v0 with 1:1 stub identities returns the same set
as ``observations_time_series(attacker_uuid, primitive)``;
v1's clusterer makes the union meaningful.
"""
raise NotImplementedError
@abstractmethod
async def has_observations_for_evidence(self, evidence_ref: str) -> bool:
"""True iff any observation row carries this ``evidence_ref``.

View File

@@ -25,7 +25,7 @@ from typing import Any, Optional
from sqlalchemy import desc, func, select
from sqlmodel import col
from decnet.web.db.models import ObservationRow
from decnet.web.db.models import Attacker, ObservationRow
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
@@ -164,6 +164,34 @@ class ObservationsMixin(_MixinBase):
return None
return row.model_dump(mode="json")
async def observations_for_identity_primitive(
self, identity_uuid: str, primitive: str,
) -> list[dict[str, Any]]:
"""Union of every observation of *primitive* across the
attackers rolling up to *identity_uuid*, ordered ``ts`` ASC.
v0 with 1:1 stub identities returns the same set as
``observations_time_series(attacker_uuid, primitive)``.
v1's clusterer makes the union load-bearing — multiple
attackers point at the same identity_id and this query is
what gives the merger a cross-attacker view.
"""
async with self._session() as session:
stmt = (
select(ObservationRow)
.join(Attacker, ObservationRow.attacker_uuid == Attacker.uuid)
.where(
Attacker.identity_id == identity_uuid,
ObservationRow.primitive == primitive,
)
.order_by(ObservationRow.ts)
)
rows = (await session.execute(stmt)).scalars().all()
return [
{"ts": row.ts, "value": row.value, "confidence": row.confidence}
for row in rows
]
async def has_observations_for_evidence(
self, evidence_ref: str,
) -> bool:

View File

@@ -165,11 +165,12 @@ def test_dispatcher_routes_categorical() -> None:
assert a == b == c
def test_dispatcher_rejects_unimplemented_kinds() -> None:
"""numeric / hash kinds land in Phase 3; surface the gap loudly
so a misuse doesn't silently fall through to categorical."""
def test_dispatcher_rejects_unknown_value_kind() -> None:
"""Unknown ValueKind tags surface as ValueError so misuse doesn't
silently fall through to categorical. Phase 3 wired numeric +
hash; the rejection is for typos and v1 kinds that haven't
landed yet."""
import pytest
obs = _pad(value=5000.0, count=5)
for kind in ("numeric", "hash"):
with pytest.raises(NotImplementedError):
aggregate_observations(obs, value_kind=kind)
obs = _pad(value="typed", count=5)
with pytest.raises(ValueError):
aggregate_observations(obs, value_kind="bogus_kind")

View File

@@ -0,0 +1,275 @@
"""Phase 4 — end-to-end worker wiring.
Observation event → stub identity → load series → merger → upsert
state → emit ``attribution.profile.state_changed`` on transition.
Phase 1 covered stub-only wiring; this file pins the merger /
persist / publish path against an in-memory SQLite + FakeBus.
"""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import pytest
from decnet.bus import topics as _topics
from decnet.bus.fake import FakeBus
from decnet.correlation import attribution_worker as _aw
from decnet.web.db.factory import get_repository
@pytest.fixture
async def repo(tmp_path: Path):
r = get_repository(db_path=str(tmp_path / "phase4.db"))
await r.initialize()
return r
@pytest.fixture
async def attacker_uuid(repo) -> str:
now = datetime.now(timezone.utc)
return await repo.upsert_attacker({
"ip": "10.0.0.5",
"first_seen": now,
"last_seen": now,
})
def _envelope(
*,
primitive: str,
value: Any,
attacker_uuid: str,
evidence_ref: str,
ts: float,
confidence: float = 0.9,
) -> dict[str, Any]:
return {
"id": f"obs-{evidence_ref}-{primitive}",
"primitive": primitive,
"value": value,
"confidence": confidence,
"window_start_ts": ts,
"window_end_ts": ts,
"source": "test",
"evidence_ref": evidence_ref,
"envelope_v": 1,
"ts": ts,
"attacker_uuid": attacker_uuid,
}
def _bus_event(payload: dict[str, Any]) -> dict[str, Any]:
"""Worker reads payload via getattr(.payload, fallback to dict)."""
return payload
async def _seed_observations(
repo, attacker_uuid: str, primitive: str, values: list[Any],
*, start_ts: float = 1714000000.0,
) -> None:
for i, v in enumerate(values):
ts = start_ts + i * 60.0
# ts in evidence_ref so repeated calls with overlapping i but
# distinct start_ts produce distinct rows.
await repo.upsert_observation(_envelope(
primitive=primitive,
value=v,
attacker_uuid=attacker_uuid,
evidence_ref=f"shard:test#{primitive}-{ts}",
ts=ts,
))
@pytest.mark.anyio
async def test_handler_writes_unknown_below_threshold(
repo, attacker_uuid: str,
) -> None:
"""Two observations for one primitive → state row written with
state='unknown' (< MIN_OBSERVATIONS_FOR_STATE)."""
bus = FakeBus()
await bus.connect()
await _seed_observations(
repo, attacker_uuid, "motor.input_modality", ["typed", "typed"],
)
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": "motor.input_modality",
}))
attacker = await repo.get_attacker_by_uuid(attacker_uuid)
assert attacker is not None
identity_uuid = attacker["identity_id"]
state = await repo.get_attribution_state(
identity_uuid, "motor.input_modality",
)
assert state is not None
assert state["state"] == "unknown"
await bus.close()
@pytest.mark.anyio
async def test_handler_emits_state_changed_on_transition(
repo, attacker_uuid: str, monkeypatch: pytest.MonkeyPatch,
) -> None:
"""As observations cross MIN_OBSERVATIONS_FOR_STATE, the worker
fires <new>→unknown then unknown→stable; idempotent re-runs in
between fire nothing."""
bus = FakeBus()
await bus.connect()
captured: list[dict[str, Any]] = []
async def _capture(_bus, topic, payload, *, event_type=""):
captured.append({"topic": topic, "payload": payload})
monkeypatch.setattr(_aw, "publish_safely", _capture)
for i in range(5):
await _seed_observations(
repo, attacker_uuid, "motor.input_modality",
["typed"], start_ts=1714000000.0 + i * 60.0,
)
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": "motor.input_modality",
}))
states_seen = [c["payload"]["new_state"] for c in captured]
assert states_seen == ["unknown", "stable"], states_seen
# The transition payload carries old + new + the observation that
# caused the flip.
assert captured[0]["payload"]["old_state"] is None
assert captured[1]["payload"]["old_state"] == "unknown"
await bus.close()
@pytest.mark.anyio
async def test_handler_no_event_when_state_unchanged(
repo, attacker_uuid: str,
) -> None:
"""Re-running the merger over an unchanged observation set must
not emit a duplicate state_changed event (loop-prevention)."""
bus = FakeBus()
await bus.connect()
captured: list[Any] = []
sub = bus.subscribe(
_topics.attribution(_topics.ATTRIBUTION_PROFILE_STATE_CHANGED),
)
import asyncio
async def drain() -> None:
try:
async with sub:
async for ev in sub:
captured.append(ev)
except Exception:
pass
drain_task = asyncio.create_task(drain())
await asyncio.sleep(0)
await _seed_observations(
repo, attacker_uuid, "motor.input_modality",
["typed"] * 5,
)
# First run: <new> → stable, fires event.
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": "motor.input_modality",
}))
await asyncio.sleep(0.05)
first_count = len(captured)
# Re-run with no new observations: state stays "stable", no event.
for _ in range(3):
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": "motor.input_modality",
}))
await asyncio.sleep(0.05)
drain_task.cancel()
assert len(captured) == first_count, (
"state didn't change; no additional events should fire"
)
await bus.close()
@pytest.mark.anyio
async def test_handler_locks_last_change_ts_when_unchanged(
repo, attacker_uuid: str,
) -> None:
"""When the state doesn't change, last_change_ts must NOT advance —
that's what tells the dashboard 'stable since X', not 'stable
since most-recent-observation'."""
bus = FakeBus()
await bus.connect()
await _seed_observations(
repo, attacker_uuid, "motor.input_modality",
["typed"] * 5,
)
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": "motor.input_modality",
}))
attacker = await repo.get_attacker_by_uuid(attacker_uuid)
assert attacker is not None
identity_uuid = attacker["identity_id"]
first = await repo.get_attribution_state(
identity_uuid, "motor.input_modality",
)
assert first is not None
locked_ts = first["last_change_ts"]
# Add another stable observation, re-run.
await _seed_observations(
repo, attacker_uuid, "motor.input_modality",
["typed"], start_ts=1714010000.0,
)
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": "motor.input_modality",
}))
second = await repo.get_attribution_state(
identity_uuid, "motor.input_modality",
)
assert second is not None
assert second["last_change_ts"] == locked_ts
# last_observation_ts DID advance.
assert second["last_observation_ts"] > locked_ts
await bus.close()
@pytest.mark.anyio
async def test_handler_routes_numeric_primitive(
repo, attacker_uuid: str,
) -> None:
"""Worker dispatches to the numeric merger when the primitive
registry kind is NUMERIC."""
bus = FakeBus()
await bus.connect()
# toolchain.c2.beacon_interval_ms is registered NUMERIC in BEHAVE.
primitive = "toolchain.c2.beacon_interval_ms"
await _seed_observations(
repo, attacker_uuid, primitive,
[5000.0, 5050.0, 4980.0, 5020.0, 5010.0],
)
await _aw.handle_observation_event(bus, repo, _bus_event({
"attacker_uuid": attacker_uuid,
"primitive": primitive,
}))
attacker = await repo.get_attacker_by_uuid(attacker_uuid)
assert attacker is not None
state = await repo.get_attribution_state(
attacker["identity_id"], primitive,
)
assert state is not None
# Numeric merger returns a smoothed mean, not a string.
assert isinstance(state["current_value"], float)
assert state["state"] == "stable"
await bus.close()

View File

@@ -42,6 +42,24 @@ class DummyRepo(BaseRepository):
async def upsert_observation(self, data): await super().upsert_observation(data); return ""
async def latest_observation_per_primitive(self, attacker_uuid): await super().latest_observation_per_primitive(attacker_uuid); return {}
async def observations_time_series(self, attacker_uuid, primitive): await super().observations_time_series(attacker_uuid, primitive); return []
async def observations_for_identity_primitive(self, identity_uuid, primitive):
await super().observations_for_identity_primitive(identity_uuid, primitive)
return []
# Attribution engine v0 (ATTRIBUTION-ENGINE.md Phase 1)
async def ensure_stub_identity_for_attacker(self, attacker_uuid):
await super().ensure_stub_identity_for_attacker(attacker_uuid)
return None
async def upsert_attribution_state(self, data):
await super().upsert_attribution_state(data)
async def get_attribution_state(self, identity_uuid, primitive):
await super().get_attribution_state(identity_uuid, primitive)
return None
async def get_attribution_state_for_identity(self, identity_uuid):
await super().get_attribution_state_for_identity(identity_uuid)
return []
async def list_multi_actor_identities(self):
await super().list_multi_actor_identities()
return []
async def increment_smtp_target(self, u, d): await super().increment_smtp_target(u, d)
async def list_smtp_targets(self, u): await super().list_smtp_targets(u)
async def get_attacker_stored_mail(self, u): await super().get_attacker_stored_mail(u)
@@ -86,6 +104,38 @@ class DummyRepo(BaseRepository):
async def set_identity_campaign_id(self, i, c): await super().set_identity_campaign_id(i, c)
async def list_all_campaigns(self): await super().list_all_campaigns(); return []
async def update_campaign_merged_into(self, u, w): await super().update_campaign_merged_into(u, w)
# Pre-existing abstract surface that DummyRepo never stubbed —
# added here so the coverage test exercises the full BaseRepository
# contract.
async def get_log_histogram(self, *a, **kw):
await super().get_log_histogram(*a, **kw); return []
async def has_observations_for_evidence(self, evidence_ref):
await super().has_observations_for_evidence(evidence_ref); return False
async def get_attacker_uuid_by_ip(self, ip):
await super().get_attacker_uuid_by_ip(ip); return None
# TTP rollup surface (TTP_TAGGING.md)
async def insert_tags(self, rows): await super().insert_tags(rows); return 0
async def list_techniques_by_identity(self, uuid):
await super().list_techniques_by_identity(uuid); return []
async def list_techniques_by_attacker(self, uuid):
await super().list_techniques_by_attacker(uuid); return []
async def list_techniques_by_campaign(self, uuid):
await super().list_techniques_by_campaign(uuid); return []
async def list_techniques_by_session(self, sid):
await super().list_techniques_by_session(sid); return []
async def list_tags_by_scope_and_technique(self, **kw):
await super().list_tags_by_scope_and_technique(**kw); return []
async def list_distinct_techniques(self):
await super().list_distinct_techniques(); return []
# Iter helpers — async generators, can't `await super()` on them
# because the base raises in the body before any yield. Just yield
# nothing so the consumer's ``async for`` exits cleanly.
async def iter_attacker_commands_since(self, since):
return
yield # unreachable, marks the function as a generator
async def iter_canary_triggers_since(self, since):
return
yield
@pytest.mark.asyncio
async def test_base_repo_coverage():
@@ -127,9 +177,26 @@ async def test_base_repo_coverage():
await dr.upsert_attacker_behavior("a", {})
await dr.get_attacker_behavior("a")
await dr.get_behaviors_for_ips({"1.1.1.1"})
await dr.upsert_observation({})
await dr.latest_observation_per_primitive("a")
await dr.observations_time_series("a", "motor.input_modality")
# Observation surface — bases raise NotImplementedError.
with pytest.raises(NotImplementedError):
await dr.upsert_observation({})
with pytest.raises(NotImplementedError):
await dr.latest_observation_per_primitive("a")
with pytest.raises(NotImplementedError):
await dr.observations_time_series("a", "motor.input_modality")
# observations_for_identity_primitive + attribution engine v0
with pytest.raises(NotImplementedError):
await dr.observations_for_identity_primitive("i", "motor.input_modality")
with pytest.raises(NotImplementedError):
await dr.ensure_stub_identity_for_attacker("a")
with pytest.raises(NotImplementedError):
await dr.upsert_attribution_state({})
with pytest.raises(NotImplementedError):
await dr.get_attribution_state("i", "motor.input_modality")
with pytest.raises(NotImplementedError):
await dr.get_attribution_state_for_identity("i")
with pytest.raises(NotImplementedError):
await dr.list_multi_actor_identities()
await dr.increment_smtp_target("uuid", "corp.com")
await dr.list_smtp_targets("uuid")
await dr.get_attacker_stored_mail("uuid")
@@ -174,6 +241,37 @@ async def test_base_repo_coverage():
await dr.update_campaign_merged_into("c", "d")
await dr.update_campaign_merged_into("c", None)
# Pre-existing abstract surface. get_log_histogram's base body
# is ``pass`` (returns None), the rest raise NotImplementedError.
from datetime import datetime, timezone
await dr.get_log_histogram()
with pytest.raises(NotImplementedError):
await dr.has_observations_for_evidence("shard:x#1")
with pytest.raises(NotImplementedError):
await dr.get_attacker_uuid_by_ip("1.1.1.1")
with pytest.raises(NotImplementedError):
await dr.insert_tags([])
with pytest.raises(NotImplementedError):
await dr.list_techniques_by_identity("i")
with pytest.raises(NotImplementedError):
await dr.list_techniques_by_attacker("a")
with pytest.raises(NotImplementedError):
await dr.list_techniques_by_campaign("c")
with pytest.raises(NotImplementedError):
await dr.list_techniques_by_session("s")
with pytest.raises(NotImplementedError):
await dr.list_tags_by_scope_and_technique(
scope="identity", uuid="i", technique_id="T1059",
)
with pytest.raises(NotImplementedError):
await dr.list_distinct_techniques()
# Iter helpers: just consume the empty generator.
now = datetime.now(timezone.utc)
async for _ in dr.iter_attacker_commands_since(now):
pass
async for _ in dr.iter_canary_triggers_since(now):
pass
# Swarm methods: default NotImplementedError on BaseRepository. Covering
# them here keeps the coverage contract honest for the swarm CRUD surface.
for coro, args in [