diff --git a/decnet/correlation/attribution_worker.py b/decnet/correlation/attribution_worker.py index 653fbccd..38487c88 100644 --- a/decnet/correlation/attribution_worker.py +++ b/decnet/correlation/attribution_worker.py @@ -26,12 +26,25 @@ from decnet.bus import topics as _topics from decnet.bus.base import BaseBus from decnet.bus.factory import get_bus from decnet.bus.publish import ( + publish_safely, run_control_listener_signal as _run_control_listener_signal, run_health_heartbeat as _run_health_heartbeat, ) +from decnet.correlation.attribution.aggregate import aggregate_observations from decnet.logging import get_logger from decnet.web.db.repository import BaseRepository +try: + from decnet_behave_shell.spec import ( + PRIMITIVE_REGISTRY, + ValueKind, + ) + _BEHAVE_REGISTRY_AVAILABLE = True +except ImportError: # pragma: no cover + PRIMITIVE_REGISTRY = {} + ValueKind = None + _BEHAVE_REGISTRY_AVAILABLE = False + log = get_logger("correlation.attribution_worker") _WORKER_NAME = "attribution" @@ -156,13 +169,103 @@ async def handle_observation_event( attacker_uuid, ) return - # Phase 4 will run the merger here and emit - # ``attribution.profile.state_changed`` on transition. Phase 1 - # ends with stub materialisation only. - log.debug( - "attribution worker: stub identity=%s for attacker=%s primitive=%s", - identity_uuid, attacker_uuid, primitive, + primitive_str = str(primitive) + + # Load the full per-(identity, primitive) observation series. + # v0 with 1:1 stub identities, this is the single attacker's + # series; v1's clusterer makes it a cross-attacker union. + observations = await repo.observations_for_identity_primitive( + identity_uuid, primitive_str, ) + if not observations: + log.debug( + "attribution worker: no observations yet for identity=%s " + "primitive=%s (race with upsert)", + identity_uuid, primitive_str, + ) + return + + # Run merger. + value_kind = _value_kind_for(primitive_str) + new_state = aggregate_observations(observations, value_kind=value_kind) + + # Load prior state to detect transitions. + prior = await repo.get_attribution_state(identity_uuid, primitive_str) + state_changed = prior is None or prior.get("state") != new_state.state + + # Persist. last_change_ts is locked to the prior row when state is + # unchanged so the dashboard's "stable since" timestamp doesn't + # reset on every observation. + if prior is not None and not state_changed: + last_change_ts = float(prior.get("last_change_ts", new_state.last_observation_ts)) + else: + last_change_ts = new_state.last_observation_ts + await repo.upsert_attribution_state({ + "identity_uuid": identity_uuid, + "primitive": primitive_str, + "current_value": new_state.current_value, + "state": new_state.state, + "confidence": new_state.confidence, + "observation_count": new_state.observation_count, + "last_change_ts": last_change_ts, + "last_observation_ts": new_state.last_observation_ts, + }) + + # Emit state_changed only on transition. Idempotent re-runs (same + # observations, same merger output) produce no event — matches + # the loop-prevention invariant that ttp.tagged uses. + if state_changed and bus is not None: + await publish_safely( + bus, + _topics.attribution(_topics.ATTRIBUTION_PROFILE_STATE_CHANGED), + { + "identity_uuid": identity_uuid, + "primitive": primitive_str, + "old_state": prior.get("state") if prior else None, + "new_state": new_state.state, + "current_value": new_state.current_value, + "confidence": new_state.confidence, + "observation_count": new_state.observation_count, + "ts": new_state.last_observation_ts, + }, + event_type=_topics.ATTRIBUTION_PROFILE_STATE_CHANGED, + ) + log.info( + "attribution worker: identity=%s primitive=%s %s -> %s confidence=%.2f", + identity_uuid, primitive_str, + (prior or {}).get("state") or "", new_state.state, + new_state.confidence, + ) + + +def _value_kind_for(primitive: str) -> str: + """Resolve a BEHAVE primitive name to the merger's ValueKind tag. + + Maps the BEHAVE registry's ``ValueKind`` enum onto the three + mergers the engine ships: + + * ``CATEGORICAL`` / ``BOOL`` / ``FREE_STRING`` / ``ARRAY`` → + ``"categorical"`` (BOOL is a 2-cardinality categorical; + FREE_STRING and ARRAY collapse to opaque-token categorical + until a v1 specialised merger lands) + * ``NUMERIC`` → ``"numeric"`` + * ``HASH`` → ``"hash"`` + + Unknown primitives (registry miss) default to categorical — the + safest fallback because the categorical merger is one-outlier- + tolerant and won't lie about confidence on noisy categorical + data the way a numeric merger would on non-numeric values. + """ + if not _BEHAVE_REGISTRY_AVAILABLE: + return "categorical" + spec = PRIMITIVE_REGISTRY.get(primitive) + if spec is None or ValueKind is None: + return "categorical" + if spec.kind is ValueKind.NUMERIC: + return "numeric" + if spec.kind is ValueKind.HASH: + return "hash" + return "categorical" def _payload_of(event: Any) -> dict[str, Any]: diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index 29b6e4a9..b051c3fd 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -341,6 +341,20 @@ class BaseRepository(ABC): ordered by ``ts`` ASC. Empty list when none.""" raise NotImplementedError + @abstractmethod + async def observations_for_identity_primitive( + self, identity_uuid: str, primitive: str, + ) -> list[dict[str, Any]]: + """Every observation of ``primitive`` across all attackers + rolling up to ``identity_uuid``, ordered by ``ts`` ASC. + + Empty list when the identity has no observations of this + primitive. v0 with 1:1 stub identities returns the same set + as ``observations_time_series(attacker_uuid, primitive)``; + v1's clusterer makes the union meaningful. + """ + raise NotImplementedError + @abstractmethod async def has_observations_for_evidence(self, evidence_ref: str) -> bool: """True iff any observation row carries this ``evidence_ref``. diff --git a/decnet/web/db/sqlmodel_repo/observations.py b/decnet/web/db/sqlmodel_repo/observations.py index dc103e1f..2a90c9f7 100644 --- a/decnet/web/db/sqlmodel_repo/observations.py +++ b/decnet/web/db/sqlmodel_repo/observations.py @@ -25,7 +25,7 @@ from typing import Any, Optional from sqlalchemy import desc, func, select from sqlmodel import col -from decnet.web.db.models import ObservationRow +from decnet.web.db.models import Attacker, ObservationRow from decnet.web.db.sqlmodel_repo._helpers import _MixinBase @@ -164,6 +164,34 @@ class ObservationsMixin(_MixinBase): return None return row.model_dump(mode="json") + async def observations_for_identity_primitive( + self, identity_uuid: str, primitive: str, + ) -> list[dict[str, Any]]: + """Union of every observation of *primitive* across the + attackers rolling up to *identity_uuid*, ordered ``ts`` ASC. + + v0 with 1:1 stub identities returns the same set as + ``observations_time_series(attacker_uuid, primitive)``. + v1's clusterer makes the union load-bearing — multiple + attackers point at the same identity_id and this query is + what gives the merger a cross-attacker view. + """ + async with self._session() as session: + stmt = ( + select(ObservationRow) + .join(Attacker, ObservationRow.attacker_uuid == Attacker.uuid) + .where( + Attacker.identity_id == identity_uuid, + ObservationRow.primitive == primitive, + ) + .order_by(ObservationRow.ts) + ) + rows = (await session.execute(stmt)).scalars().all() + return [ + {"ts": row.ts, "value": row.value, "confidence": row.confidence} + for row in rows + ] + async def has_observations_for_evidence( self, evidence_ref: str, ) -> bool: diff --git a/tests/correlation/attribution/test_aggregate_categorical.py b/tests/correlation/attribution/test_aggregate_categorical.py index a00d994e..1262dbb6 100644 --- a/tests/correlation/attribution/test_aggregate_categorical.py +++ b/tests/correlation/attribution/test_aggregate_categorical.py @@ -165,11 +165,12 @@ def test_dispatcher_routes_categorical() -> None: assert a == b == c -def test_dispatcher_rejects_unimplemented_kinds() -> None: - """numeric / hash kinds land in Phase 3; surface the gap loudly - so a misuse doesn't silently fall through to categorical.""" +def test_dispatcher_rejects_unknown_value_kind() -> None: + """Unknown ValueKind tags surface as ValueError so misuse doesn't + silently fall through to categorical. Phase 3 wired numeric + + hash; the rejection is for typos and v1 kinds that haven't + landed yet.""" import pytest - obs = _pad(value=5000.0, count=5) - for kind in ("numeric", "hash"): - with pytest.raises(NotImplementedError): - aggregate_observations(obs, value_kind=kind) + obs = _pad(value="typed", count=5) + with pytest.raises(ValueError): + aggregate_observations(obs, value_kind="bogus_kind") diff --git a/tests/correlation/attribution/test_attribution_worker_phase4.py b/tests/correlation/attribution/test_attribution_worker_phase4.py new file mode 100644 index 00000000..3edbb239 --- /dev/null +++ b/tests/correlation/attribution/test_attribution_worker_phase4.py @@ -0,0 +1,275 @@ +"""Phase 4 — end-to-end worker wiring. + +Observation event → stub identity → load series → merger → upsert +state → emit ``attribution.profile.state_changed`` on transition. + +Phase 1 covered stub-only wiring; this file pins the merger / +persist / publish path against an in-memory SQLite + FakeBus. +""" +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import pytest + +from decnet.bus import topics as _topics +from decnet.bus.fake import FakeBus +from decnet.correlation import attribution_worker as _aw +from decnet.web.db.factory import get_repository + + +@pytest.fixture +async def repo(tmp_path: Path): + r = get_repository(db_path=str(tmp_path / "phase4.db")) + await r.initialize() + return r + + +@pytest.fixture +async def attacker_uuid(repo) -> str: + now = datetime.now(timezone.utc) + return await repo.upsert_attacker({ + "ip": "10.0.0.5", + "first_seen": now, + "last_seen": now, + }) + + +def _envelope( + *, + primitive: str, + value: Any, + attacker_uuid: str, + evidence_ref: str, + ts: float, + confidence: float = 0.9, +) -> dict[str, Any]: + return { + "id": f"obs-{evidence_ref}-{primitive}", + "primitive": primitive, + "value": value, + "confidence": confidence, + "window_start_ts": ts, + "window_end_ts": ts, + "source": "test", + "evidence_ref": evidence_ref, + "envelope_v": 1, + "ts": ts, + "attacker_uuid": attacker_uuid, + } + + +def _bus_event(payload: dict[str, Any]) -> dict[str, Any]: + """Worker reads payload via getattr(.payload, fallback to dict).""" + return payload + + +async def _seed_observations( + repo, attacker_uuid: str, primitive: str, values: list[Any], + *, start_ts: float = 1714000000.0, +) -> None: + for i, v in enumerate(values): + ts = start_ts + i * 60.0 + # ts in evidence_ref so repeated calls with overlapping i but + # distinct start_ts produce distinct rows. + await repo.upsert_observation(_envelope( + primitive=primitive, + value=v, + attacker_uuid=attacker_uuid, + evidence_ref=f"shard:test#{primitive}-{ts}", + ts=ts, + )) + + +@pytest.mark.anyio +async def test_handler_writes_unknown_below_threshold( + repo, attacker_uuid: str, +) -> None: + """Two observations for one primitive → state row written with + state='unknown' (< MIN_OBSERVATIONS_FOR_STATE).""" + bus = FakeBus() + await bus.connect() + await _seed_observations( + repo, attacker_uuid, "motor.input_modality", ["typed", "typed"], + ) + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": "motor.input_modality", + })) + + attacker = await repo.get_attacker_by_uuid(attacker_uuid) + assert attacker is not None + identity_uuid = attacker["identity_id"] + state = await repo.get_attribution_state( + identity_uuid, "motor.input_modality", + ) + assert state is not None + assert state["state"] == "unknown" + await bus.close() + + +@pytest.mark.anyio +async def test_handler_emits_state_changed_on_transition( + repo, attacker_uuid: str, monkeypatch: pytest.MonkeyPatch, +) -> None: + """As observations cross MIN_OBSERVATIONS_FOR_STATE, the worker + fires →unknown then unknown→stable; idempotent re-runs in + between fire nothing.""" + bus = FakeBus() + await bus.connect() + + captured: list[dict[str, Any]] = [] + + async def _capture(_bus, topic, payload, *, event_type=""): + captured.append({"topic": topic, "payload": payload}) + + monkeypatch.setattr(_aw, "publish_safely", _capture) + + for i in range(5): + await _seed_observations( + repo, attacker_uuid, "motor.input_modality", + ["typed"], start_ts=1714000000.0 + i * 60.0, + ) + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": "motor.input_modality", + })) + + states_seen = [c["payload"]["new_state"] for c in captured] + assert states_seen == ["unknown", "stable"], states_seen + # The transition payload carries old + new + the observation that + # caused the flip. + assert captured[0]["payload"]["old_state"] is None + assert captured[1]["payload"]["old_state"] == "unknown" + await bus.close() + + +@pytest.mark.anyio +async def test_handler_no_event_when_state_unchanged( + repo, attacker_uuid: str, +) -> None: + """Re-running the merger over an unchanged observation set must + not emit a duplicate state_changed event (loop-prevention).""" + bus = FakeBus() + await bus.connect() + + captured: list[Any] = [] + sub = bus.subscribe( + _topics.attribution(_topics.ATTRIBUTION_PROFILE_STATE_CHANGED), + ) + + import asyncio + + async def drain() -> None: + try: + async with sub: + async for ev in sub: + captured.append(ev) + except Exception: + pass + + drain_task = asyncio.create_task(drain()) + await asyncio.sleep(0) + + await _seed_observations( + repo, attacker_uuid, "motor.input_modality", + ["typed"] * 5, + ) + # First run: → stable, fires event. + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": "motor.input_modality", + })) + await asyncio.sleep(0.05) + first_count = len(captured) + + # Re-run with no new observations: state stays "stable", no event. + for _ in range(3): + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": "motor.input_modality", + })) + await asyncio.sleep(0.05) + + drain_task.cancel() + assert len(captured) == first_count, ( + "state didn't change; no additional events should fire" + ) + await bus.close() + + +@pytest.mark.anyio +async def test_handler_locks_last_change_ts_when_unchanged( + repo, attacker_uuid: str, +) -> None: + """When the state doesn't change, last_change_ts must NOT advance — + that's what tells the dashboard 'stable since X', not 'stable + since most-recent-observation'.""" + bus = FakeBus() + await bus.connect() + await _seed_observations( + repo, attacker_uuid, "motor.input_modality", + ["typed"] * 5, + ) + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": "motor.input_modality", + })) + attacker = await repo.get_attacker_by_uuid(attacker_uuid) + assert attacker is not None + identity_uuid = attacker["identity_id"] + first = await repo.get_attribution_state( + identity_uuid, "motor.input_modality", + ) + assert first is not None + locked_ts = first["last_change_ts"] + + # Add another stable observation, re-run. + await _seed_observations( + repo, attacker_uuid, "motor.input_modality", + ["typed"], start_ts=1714010000.0, + ) + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": "motor.input_modality", + })) + second = await repo.get_attribution_state( + identity_uuid, "motor.input_modality", + ) + assert second is not None + assert second["last_change_ts"] == locked_ts + # last_observation_ts DID advance. + assert second["last_observation_ts"] > locked_ts + await bus.close() + + +@pytest.mark.anyio +async def test_handler_routes_numeric_primitive( + repo, attacker_uuid: str, +) -> None: + """Worker dispatches to the numeric merger when the primitive + registry kind is NUMERIC.""" + bus = FakeBus() + await bus.connect() + # toolchain.c2.beacon_interval_ms is registered NUMERIC in BEHAVE. + primitive = "toolchain.c2.beacon_interval_ms" + await _seed_observations( + repo, attacker_uuid, primitive, + [5000.0, 5050.0, 4980.0, 5020.0, 5010.0], + ) + await _aw.handle_observation_event(bus, repo, _bus_event({ + "attacker_uuid": attacker_uuid, + "primitive": primitive, + })) + attacker = await repo.get_attacker_by_uuid(attacker_uuid) + assert attacker is not None + state = await repo.get_attribution_state( + attacker["identity_id"], primitive, + ) + assert state is not None + # Numeric merger returns a smoothed mean, not a string. + assert isinstance(state["current_value"], float) + assert state["state"] == "stable" + await bus.close() diff --git a/tests/db/test_base_repo.py b/tests/db/test_base_repo.py index 1a0245ca..72156e5f 100644 --- a/tests/db/test_base_repo.py +++ b/tests/db/test_base_repo.py @@ -42,6 +42,24 @@ class DummyRepo(BaseRepository): async def upsert_observation(self, data): await super().upsert_observation(data); return "" async def latest_observation_per_primitive(self, attacker_uuid): await super().latest_observation_per_primitive(attacker_uuid); return {} async def observations_time_series(self, attacker_uuid, primitive): await super().observations_time_series(attacker_uuid, primitive); return [] + async def observations_for_identity_primitive(self, identity_uuid, primitive): + await super().observations_for_identity_primitive(identity_uuid, primitive) + return [] + # Attribution engine v0 (ATTRIBUTION-ENGINE.md Phase 1) + async def ensure_stub_identity_for_attacker(self, attacker_uuid): + await super().ensure_stub_identity_for_attacker(attacker_uuid) + return None + async def upsert_attribution_state(self, data): + await super().upsert_attribution_state(data) + async def get_attribution_state(self, identity_uuid, primitive): + await super().get_attribution_state(identity_uuid, primitive) + return None + async def get_attribution_state_for_identity(self, identity_uuid): + await super().get_attribution_state_for_identity(identity_uuid) + return [] + async def list_multi_actor_identities(self): + await super().list_multi_actor_identities() + return [] async def increment_smtp_target(self, u, d): await super().increment_smtp_target(u, d) async def list_smtp_targets(self, u): await super().list_smtp_targets(u) async def get_attacker_stored_mail(self, u): await super().get_attacker_stored_mail(u) @@ -86,6 +104,38 @@ class DummyRepo(BaseRepository): async def set_identity_campaign_id(self, i, c): await super().set_identity_campaign_id(i, c) async def list_all_campaigns(self): await super().list_all_campaigns(); return [] async def update_campaign_merged_into(self, u, w): await super().update_campaign_merged_into(u, w) + # Pre-existing abstract surface that DummyRepo never stubbed — + # added here so the coverage test exercises the full BaseRepository + # contract. + async def get_log_histogram(self, *a, **kw): + await super().get_log_histogram(*a, **kw); return [] + async def has_observations_for_evidence(self, evidence_ref): + await super().has_observations_for_evidence(evidence_ref); return False + async def get_attacker_uuid_by_ip(self, ip): + await super().get_attacker_uuid_by_ip(ip); return None + # TTP rollup surface (TTP_TAGGING.md) + async def insert_tags(self, rows): await super().insert_tags(rows); return 0 + async def list_techniques_by_identity(self, uuid): + await super().list_techniques_by_identity(uuid); return [] + async def list_techniques_by_attacker(self, uuid): + await super().list_techniques_by_attacker(uuid); return [] + async def list_techniques_by_campaign(self, uuid): + await super().list_techniques_by_campaign(uuid); return [] + async def list_techniques_by_session(self, sid): + await super().list_techniques_by_session(sid); return [] + async def list_tags_by_scope_and_technique(self, **kw): + await super().list_tags_by_scope_and_technique(**kw); return [] + async def list_distinct_techniques(self): + await super().list_distinct_techniques(); return [] + # Iter helpers — async generators, can't `await super()` on them + # because the base raises in the body before any yield. Just yield + # nothing so the consumer's ``async for`` exits cleanly. + async def iter_attacker_commands_since(self, since): + return + yield # unreachable, marks the function as a generator + async def iter_canary_triggers_since(self, since): + return + yield @pytest.mark.asyncio async def test_base_repo_coverage(): @@ -127,9 +177,26 @@ async def test_base_repo_coverage(): await dr.upsert_attacker_behavior("a", {}) await dr.get_attacker_behavior("a") await dr.get_behaviors_for_ips({"1.1.1.1"}) - await dr.upsert_observation({}) - await dr.latest_observation_per_primitive("a") - await dr.observations_time_series("a", "motor.input_modality") + # Observation surface — bases raise NotImplementedError. + with pytest.raises(NotImplementedError): + await dr.upsert_observation({}) + with pytest.raises(NotImplementedError): + await dr.latest_observation_per_primitive("a") + with pytest.raises(NotImplementedError): + await dr.observations_time_series("a", "motor.input_modality") + # observations_for_identity_primitive + attribution engine v0 + with pytest.raises(NotImplementedError): + await dr.observations_for_identity_primitive("i", "motor.input_modality") + with pytest.raises(NotImplementedError): + await dr.ensure_stub_identity_for_attacker("a") + with pytest.raises(NotImplementedError): + await dr.upsert_attribution_state({}) + with pytest.raises(NotImplementedError): + await dr.get_attribution_state("i", "motor.input_modality") + with pytest.raises(NotImplementedError): + await dr.get_attribution_state_for_identity("i") + with pytest.raises(NotImplementedError): + await dr.list_multi_actor_identities() await dr.increment_smtp_target("uuid", "corp.com") await dr.list_smtp_targets("uuid") await dr.get_attacker_stored_mail("uuid") @@ -174,6 +241,37 @@ async def test_base_repo_coverage(): await dr.update_campaign_merged_into("c", "d") await dr.update_campaign_merged_into("c", None) + # Pre-existing abstract surface. get_log_histogram's base body + # is ``pass`` (returns None), the rest raise NotImplementedError. + from datetime import datetime, timezone + await dr.get_log_histogram() + with pytest.raises(NotImplementedError): + await dr.has_observations_for_evidence("shard:x#1") + with pytest.raises(NotImplementedError): + await dr.get_attacker_uuid_by_ip("1.1.1.1") + with pytest.raises(NotImplementedError): + await dr.insert_tags([]) + with pytest.raises(NotImplementedError): + await dr.list_techniques_by_identity("i") + with pytest.raises(NotImplementedError): + await dr.list_techniques_by_attacker("a") + with pytest.raises(NotImplementedError): + await dr.list_techniques_by_campaign("c") + with pytest.raises(NotImplementedError): + await dr.list_techniques_by_session("s") + with pytest.raises(NotImplementedError): + await dr.list_tags_by_scope_and_technique( + scope="identity", uuid="i", technique_id="T1059", + ) + with pytest.raises(NotImplementedError): + await dr.list_distinct_techniques() + # Iter helpers: just consume the empty generator. + now = datetime.now(timezone.utc) + async for _ in dr.iter_attacker_commands_since(now): + pass + async for _ in dr.iter_canary_triggers_since(now): + pass + # Swarm methods: default NotImplementedError on BaseRepository. Covering # them here keeps the coverage contract honest for the swarm CRUD surface. for coro, args in [