"""Canary blob/token CRUD + trigger ingestion.""" from __future__ import annotations import json from datetime import datetime, timezone from typing import Any, Optional from sqlalchemy import desc, func, select, update from decnet.web.db.models import CanaryBlob, CanaryToken, CanaryTrigger from decnet.web.db.sqlmodel_repo._helpers import _MixinBase class CanaryMixin(_MixinBase): """Mixin: composed onto ``SQLModelRepository``.""" async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]: sha = data.get("sha256") if not sha: raise ValueError("upsert_canary_blob: sha256 is required") async with self._session() as session: existing = await session.execute( select(CanaryBlob).where(CanaryBlob.sha256 == sha) ) row = existing.scalar_one_or_none() if row: return row.model_dump(mode="json") row = CanaryBlob(**data) session.add(row) await session.commit() await session.refresh(row) return row.model_dump(mode="json") async def get_canary_blob(self, uuid: str) -> Optional[dict[str, Any]]: async with self._session() as session: result = await session.execute( select(CanaryBlob).where(CanaryBlob.uuid == uuid) ) row = result.scalar_one_or_none() return row.model_dump(mode="json") if row else None async def get_canary_blob_by_sha256( self, sha256: str ) -> Optional[dict[str, Any]]: async with self._session() as session: result = await session.execute( select(CanaryBlob).where(CanaryBlob.sha256 == sha256) ) row = result.scalar_one_or_none() return row.model_dump(mode="json") if row else None async def list_canary_blobs(self) -> list[dict[str, Any]]: # One round-trip: outer-join blobs -> tokens, group by blob, count # live (non-revoked) references. Revoked tokens still occupy the # blob conceptually until garbage-collected, so we count them too; # the operator deletes blobs explicitly via the API. async with self._session() as session: stmt = ( select(CanaryBlob, func.count(CanaryToken.uuid)) .join( CanaryToken, CanaryToken.blob_uuid == CanaryBlob.uuid, isouter=True, ) .group_by(CanaryBlob.uuid) .order_by(desc(CanaryBlob.uploaded_at)) ) result = await session.execute(stmt) out: list[dict[str, Any]] = [] for blob, count in result.all(): d = blob.model_dump(mode="json") d["token_count"] = int(count or 0) out.append(d) return out async def delete_canary_blob(self, uuid: str) -> bool: async with self._session() as session: ref = await session.execute( select(func.count(CanaryToken.uuid)).where( CanaryToken.blob_uuid == uuid ) ) if (ref.scalar_one() or 0) > 0: return False result = await session.execute( select(CanaryBlob).where(CanaryBlob.uuid == uuid) ) row = result.scalar_one_or_none() if not row: return False await session.delete(row) await session.commit() return True async def create_canary_token(self, data: dict[str, Any]) -> None: async with self._session() as session: session.add(CanaryToken(**data)) await session.commit() async def get_canary_token(self, uuid: str) -> Optional[dict[str, Any]]: async with self._session() as session: result = await session.execute( select(CanaryToken).where(CanaryToken.uuid == uuid) ) row = result.scalar_one_or_none() return row.model_dump(mode="json") if row else None async def get_canary_token_by_slug( self, callback_token: str ) -> Optional[dict[str, Any]]: async with self._session() as session: result = await session.execute( select(CanaryToken).where( CanaryToken.callback_token == callback_token, CanaryToken.state == "planted", ) ) row = result.scalar_one_or_none() return row.model_dump(mode="json") if row else None async def list_canary_tokens( self, *, decky_name: Optional[str] = None, state: Optional[str] = None, kind: Optional[str] = None, topology_id: Optional[str] = None, ) -> list[dict[str, Any]]: async with self._session() as session: stmt = select(CanaryToken) if decky_name is not None: stmt = stmt.where(CanaryToken.decky_name == decky_name) if state is not None: stmt = stmt.where(CanaryToken.state == state) if kind is not None: stmt = stmt.where(CanaryToken.kind == kind) if topology_id is not None: stmt = stmt.where(CanaryToken.topology_id == topology_id) stmt = stmt.order_by(desc(CanaryToken.placed_at)) result = await session.execute(stmt) return [r.model_dump(mode="json") for r in result.scalars().all()] async def update_canary_token_state( self, uuid: str, state: str, last_error: Optional[str] = None, ) -> bool: async with self._session() as session: result = await session.execute( update(CanaryToken) .where(CanaryToken.uuid == uuid) .values(state=state, last_error=last_error) ) await session.commit() return result.rowcount > 0 async def record_canary_trigger(self, data: dict[str, Any]) -> str: # Persist the trigger row + bump the token's counters in the same # session so a subscriber that reads the token row right after # receiving the bus event sees the updated count. headers = data.get("raw_headers") if isinstance(headers, dict): data = {**data, "raw_headers": json.dumps(headers)} async with self._session() as session: row = CanaryTrigger(**data) session.add(row) ts = data.get("occurred_at") or datetime.now(timezone.utc) await session.execute( update(CanaryToken) .where(CanaryToken.uuid == row.token_uuid) .values( last_triggered_at=ts, trigger_count=CanaryToken.trigger_count + 1, ) ) await session.commit() await session.refresh(row) return row.uuid async def list_canary_triggers( self, token_uuid: str, *, limit: int = 100, offset: int = 0, ) -> list[dict[str, Any]]: async with self._session() as session: stmt = ( select(CanaryTrigger) .where(CanaryTrigger.token_uuid == token_uuid) .order_by(desc(CanaryTrigger.occurred_at)) .limit(limit) .offset(offset) ) result = await session.execute(stmt) return [r.model_dump(mode="json") for r in result.scalars().all()] async def attribute_canary_trigger( self, trigger_uuid: str, attacker_id: str, ) -> bool: async with self._session() as session: result = await session.execute( update(CanaryTrigger) .where(CanaryTrigger.uuid == trigger_uuid) .values(attacker_id=attacker_id) ) await session.commit() return result.rowcount > 0