diff --git a/decnet/web/db/sqlmodel_repo/__init__.py b/decnet/web/db/sqlmodel_repo/__init__.py index 46fa6111..c336bb55 100644 --- a/decnet/web/db/sqlmodel_repo/__init__.py +++ b/decnet/web/db/sqlmodel_repo/__init__.py @@ -46,9 +46,6 @@ from decnet.web.db.models import ( TopologyEdge, TopologyStatusEvent, TopologyMutation, - CanaryBlob, - CanaryToken, - CanaryTrigger, ) @@ -62,6 +59,7 @@ from decnet.web.db.sqlmodel_repo._helpers import ( # noqa: F401 (re-exported f from decnet.web.db.sqlmodel_repo.attacker_intel import AttackerIntelMixin from decnet.web.db.sqlmodel_repo.auth import AuthMixin from decnet.web.db.sqlmodel_repo.bounties import BountiesMixin +from decnet.web.db.sqlmodel_repo.canary import CanaryMixin from decnet.web.db.sqlmodel_repo.deckies import DeckiesMixin from decnet.web.db.sqlmodel_repo.fleet import FleetMixin from decnet.web.db.sqlmodel_repo.logs import LogsMixin @@ -75,6 +73,7 @@ class SQLModelRepository( AttackerIntelMixin, AuthMixin, BountiesMixin, + CanaryMixin, DeckiesMixin, FleetMixin, LogsMixin, @@ -2052,196 +2051,6 @@ class SQLModelRepository( ) return [r for r in result.scalars().all()] - # ---------------------------------------------------------- canary - - async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]: - sha = data.get("sha256") - if not sha: - raise ValueError("upsert_canary_blob: sha256 is required") - async with self._session() as session: - existing = await session.execute( - select(CanaryBlob).where(CanaryBlob.sha256 == sha) - ) - row = existing.scalar_one_or_none() - if row: - return row.model_dump(mode="json") - row = CanaryBlob(**data) - session.add(row) - await session.commit() - await session.refresh(row) - return row.model_dump(mode="json") - - async def get_canary_blob(self, uuid: str) -> Optional[dict[str, Any]]: - async with self._session() as session: - result = await session.execute( - select(CanaryBlob).where(CanaryBlob.uuid == uuid) - ) - row = result.scalar_one_or_none() - return row.model_dump(mode="json") if row else None - - async def get_canary_blob_by_sha256( - self, sha256: str - ) -> Optional[dict[str, Any]]: - async with self._session() as session: - result = await session.execute( - select(CanaryBlob).where(CanaryBlob.sha256 == sha256) - ) - row = result.scalar_one_or_none() - return row.model_dump(mode="json") if row else None - - async def list_canary_blobs(self) -> list[dict[str, Any]]: - # One round-trip: outer-join blobs -> tokens, group by blob, count - # live (non-revoked) references. Revoked tokens still occupy the - # blob conceptually until garbage-collected, so we count them too; - # the operator deletes blobs explicitly via the API. - async with self._session() as session: - stmt = ( - select(CanaryBlob, func.count(CanaryToken.uuid)) - .join( - CanaryToken, - CanaryToken.blob_uuid == CanaryBlob.uuid, - isouter=True, - ) - .group_by(CanaryBlob.uuid) - .order_by(desc(CanaryBlob.uploaded_at)) - ) - result = await session.execute(stmt) - out: list[dict[str, Any]] = [] - for blob, count in result.all(): - d = blob.model_dump(mode="json") - d["token_count"] = int(count or 0) - out.append(d) - return out - - async def delete_canary_blob(self, uuid: str) -> bool: - async with self._session() as session: - ref = await session.execute( - select(func.count(CanaryToken.uuid)).where( - CanaryToken.blob_uuid == uuid - ) - ) - if (ref.scalar_one() or 0) > 0: - return False - result = await session.execute( - select(CanaryBlob).where(CanaryBlob.uuid == uuid) - ) - row = result.scalar_one_or_none() - if not row: - return False - await session.delete(row) - await session.commit() - return True - - async def create_canary_token(self, data: dict[str, Any]) -> None: - async with self._session() as session: - session.add(CanaryToken(**data)) - await session.commit() - - async def get_canary_token(self, uuid: str) -> Optional[dict[str, Any]]: - async with self._session() as session: - result = await session.execute( - select(CanaryToken).where(CanaryToken.uuid == uuid) - ) - row = result.scalar_one_or_none() - return row.model_dump(mode="json") if row else None - - async def get_canary_token_by_slug( - self, callback_token: str - ) -> Optional[dict[str, Any]]: - async with self._session() as session: - result = await session.execute( - select(CanaryToken).where( - CanaryToken.callback_token == callback_token - ) - ) - row = result.scalar_one_or_none() - return row.model_dump(mode="json") if row else None - - async def list_canary_tokens( - self, - *, - decky_name: Optional[str] = None, - state: Optional[str] = None, - kind: Optional[str] = None, - ) -> list[dict[str, Any]]: - async with self._session() as session: - stmt = select(CanaryToken) - if decky_name is not None: - stmt = stmt.where(CanaryToken.decky_name == decky_name) - if state is not None: - stmt = stmt.where(CanaryToken.state == state) - if kind is not None: - stmt = stmt.where(CanaryToken.kind == kind) - stmt = stmt.order_by(desc(CanaryToken.placed_at)) - result = await session.execute(stmt) - return [r.model_dump(mode="json") for r in result.scalars().all()] - - async def update_canary_token_state( - self, - uuid: str, - state: str, - last_error: Optional[str] = None, - ) -> bool: - async with self._session() as session: - result = await session.execute( - update(CanaryToken) - .where(CanaryToken.uuid == uuid) - .values(state=state, last_error=last_error) - ) - await session.commit() - return result.rowcount > 0 - - async def record_canary_trigger(self, data: dict[str, Any]) -> str: - # Persist the trigger row + bump the token's counters in the same - # session so a subscriber that reads the token row right after - # receiving the bus event sees the updated count. - headers = data.get("raw_headers") - if isinstance(headers, dict): - data = {**data, "raw_headers": json.dumps(headers)} - async with self._session() as session: - row = CanaryTrigger(**data) - session.add(row) - ts = data.get("occurred_at") or datetime.now(timezone.utc) - await session.execute( - update(CanaryToken) - .where(CanaryToken.uuid == row.token_uuid) - .values( - last_triggered_at=ts, - trigger_count=CanaryToken.trigger_count + 1, - ) - ) - await session.commit() - await session.refresh(row) - return row.uuid - - async def list_canary_triggers( - self, token_uuid: str, *, limit: int = 100, offset: int = 0, - ) -> list[dict[str, Any]]: - async with self._session() as session: - stmt = ( - select(CanaryTrigger) - .where(CanaryTrigger.token_uuid == token_uuid) - .order_by(desc(CanaryTrigger.occurred_at)) - .limit(limit) - .offset(offset) - ) - result = await session.execute(stmt) - return [r.model_dump(mode="json") for r in result.scalars().all()] - - async def attribute_canary_trigger( - self, trigger_uuid: str, attacker_id: str, - ) -> bool: - async with self._session() as session: - result = await session.execute( - update(CanaryTrigger) - .where(CanaryTrigger.uuid == trigger_uuid) - .values(attacker_id=attacker_id) - ) - await session.commit() - return result.rowcount > 0 - - # ---------------------------------------------------------- orchestrator - async def list_running_topology_deckies(self) -> list[dict[str, Any]]: async with self._session() as session: result = await session.execute( diff --git a/decnet/web/db/sqlmodel_repo/canary.py b/decnet/web/db/sqlmodel_repo/canary.py new file mode 100644 index 00000000..c42d4d4f --- /dev/null +++ b/decnet/web/db/sqlmodel_repo/canary.py @@ -0,0 +1,200 @@ +"""Canary blob/token CRUD + trigger ingestion.""" +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any, Optional + +from sqlalchemy import desc, func, select, update + +from decnet.web.db.models import CanaryBlob, CanaryToken, CanaryTrigger + + +class CanaryMixin: + """Mixin: composed onto ``SQLModelRepository``.""" + + async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]: + sha = data.get("sha256") + if not sha: + raise ValueError("upsert_canary_blob: sha256 is required") + async with self._session() as session: + existing = await session.execute( + select(CanaryBlob).where(CanaryBlob.sha256 == sha) + ) + row = existing.scalar_one_or_none() + if row: + return row.model_dump(mode="json") + row = CanaryBlob(**data) + session.add(row) + await session.commit() + await session.refresh(row) + return row.model_dump(mode="json") + + async def get_canary_blob(self, uuid: str) -> Optional[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(CanaryBlob).where(CanaryBlob.uuid == uuid) + ) + row = result.scalar_one_or_none() + return row.model_dump(mode="json") if row else None + + async def get_canary_blob_by_sha256( + self, sha256: str + ) -> Optional[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(CanaryBlob).where(CanaryBlob.sha256 == sha256) + ) + row = result.scalar_one_or_none() + return row.model_dump(mode="json") if row else None + + async def list_canary_blobs(self) -> list[dict[str, Any]]: + # One round-trip: outer-join blobs -> tokens, group by blob, count + # live (non-revoked) references. Revoked tokens still occupy the + # blob conceptually until garbage-collected, so we count them too; + # the operator deletes blobs explicitly via the API. + async with self._session() as session: + stmt = ( + select(CanaryBlob, func.count(CanaryToken.uuid)) + .join( + CanaryToken, + CanaryToken.blob_uuid == CanaryBlob.uuid, + isouter=True, + ) + .group_by(CanaryBlob.uuid) + .order_by(desc(CanaryBlob.uploaded_at)) + ) + result = await session.execute(stmt) + out: list[dict[str, Any]] = [] + for blob, count in result.all(): + d = blob.model_dump(mode="json") + d["token_count"] = int(count or 0) + out.append(d) + return out + + async def delete_canary_blob(self, uuid: str) -> bool: + async with self._session() as session: + ref = await session.execute( + select(func.count(CanaryToken.uuid)).where( + CanaryToken.blob_uuid == uuid + ) + ) + if (ref.scalar_one() or 0) > 0: + return False + result = await session.execute( + select(CanaryBlob).where(CanaryBlob.uuid == uuid) + ) + row = result.scalar_one_or_none() + if not row: + return False + await session.delete(row) + await session.commit() + return True + + async def create_canary_token(self, data: dict[str, Any]) -> None: + async with self._session() as session: + session.add(CanaryToken(**data)) + await session.commit() + + async def get_canary_token(self, uuid: str) -> Optional[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(CanaryToken).where(CanaryToken.uuid == uuid) + ) + row = result.scalar_one_or_none() + return row.model_dump(mode="json") if row else None + + async def get_canary_token_by_slug( + self, callback_token: str + ) -> Optional[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(CanaryToken).where( + CanaryToken.callback_token == callback_token + ) + ) + row = result.scalar_one_or_none() + return row.model_dump(mode="json") if row else None + + async def list_canary_tokens( + self, + *, + decky_name: Optional[str] = None, + state: Optional[str] = None, + kind: Optional[str] = None, + ) -> list[dict[str, Any]]: + async with self._session() as session: + stmt = select(CanaryToken) + if decky_name is not None: + stmt = stmt.where(CanaryToken.decky_name == decky_name) + if state is not None: + stmt = stmt.where(CanaryToken.state == state) + if kind is not None: + stmt = stmt.where(CanaryToken.kind == kind) + stmt = stmt.order_by(desc(CanaryToken.placed_at)) + result = await session.execute(stmt) + return [r.model_dump(mode="json") for r in result.scalars().all()] + + async def update_canary_token_state( + self, + uuid: str, + state: str, + last_error: Optional[str] = None, + ) -> bool: + async with self._session() as session: + result = await session.execute( + update(CanaryToken) + .where(CanaryToken.uuid == uuid) + .values(state=state, last_error=last_error) + ) + await session.commit() + return result.rowcount > 0 + + async def record_canary_trigger(self, data: dict[str, Any]) -> str: + # Persist the trigger row + bump the token's counters in the same + # session so a subscriber that reads the token row right after + # receiving the bus event sees the updated count. + headers = data.get("raw_headers") + if isinstance(headers, dict): + data = {**data, "raw_headers": json.dumps(headers)} + async with self._session() as session: + row = CanaryTrigger(**data) + session.add(row) + ts = data.get("occurred_at") or datetime.now(timezone.utc) + await session.execute( + update(CanaryToken) + .where(CanaryToken.uuid == row.token_uuid) + .values( + last_triggered_at=ts, + trigger_count=CanaryToken.trigger_count + 1, + ) + ) + await session.commit() + await session.refresh(row) + return row.uuid + + async def list_canary_triggers( + self, token_uuid: str, *, limit: int = 100, offset: int = 0, + ) -> list[dict[str, Any]]: + async with self._session() as session: + stmt = ( + select(CanaryTrigger) + .where(CanaryTrigger.token_uuid == token_uuid) + .order_by(desc(CanaryTrigger.occurred_at)) + .limit(limit) + .offset(offset) + ) + result = await session.execute(stmt) + return [r.model_dump(mode="json") for r in result.scalars().all()] + + async def attribute_canary_trigger( + self, trigger_uuid: str, attacker_id: str, + ) -> bool: + async with self._session() as session: + result = await session.execute( + update(CanaryTrigger) + .where(CanaryTrigger.uuid == trigger_uuid) + .values(attacker_id=attacker_id) + ) + await session.commit() + return result.rowcount > 0