3176 lines
123 KiB
Python
3176 lines
123 KiB
Python
"""
|
||
Shared SQLModel-based repository implementation.
|
||
|
||
Contains all dialect-portable query code used by the SQLite and MySQL
|
||
backends. Dialect-specific behavior lives in subclasses:
|
||
|
||
* engine/session construction (``__init__``)
|
||
* ``_migrate_attackers_table`` (legacy schema check; DDL introspection
|
||
is not portable)
|
||
* ``get_log_histogram`` (date-bucket expression differs per dialect)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
import orjson
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Optional, List
|
||
|
||
from sqlalchemy import func, select, desc, asc, text, or_, update
|
||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||
from sqlmodel.sql.expression import SelectOfScalar
|
||
|
||
from decnet.config import load_state
|
||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||
from decnet.web.auth import get_password_hash
|
||
from decnet.web.db.repository import BaseRepository
|
||
from decnet.web.db.models import (
|
||
User,
|
||
Log,
|
||
Bounty,
|
||
Credential,
|
||
CredentialReuse,
|
||
State,
|
||
Attacker,
|
||
AttackerBehavior,
|
||
AttackerIdentity,
|
||
Campaign,
|
||
SessionProfile,
|
||
SmtpTarget,
|
||
DeckyShard,
|
||
FleetDecky,
|
||
LOCAL_HOST_SENTINEL,
|
||
Topology,
|
||
LAN,
|
||
TopologyDecky,
|
||
TopologyEdge,
|
||
TopologyStatusEvent,
|
||
TopologyMutation,
|
||
OrchestratorEmail,
|
||
OrchestratorEvent,
|
||
RealismConfig,
|
||
SyntheticFile,
|
||
WebhookSubscription,
|
||
CanaryBlob,
|
||
CanaryToken,
|
||
CanaryTrigger,
|
||
)
|
||
|
||
|
||
from decnet.web.db.sqlmodel_repo._helpers import ( # noqa: F401 (re-exported for tests/external)
|
||
_safe_session,
|
||
_detach_close,
|
||
_cleanup_tasks,
|
||
)
|
||
from decnet.web.db.sqlmodel_repo.attacker_intel import AttackerIntelMixin
|
||
from decnet.web.db.sqlmodel_repo.auth import AuthMixin
|
||
from decnet.web.db.sqlmodel_repo.deckies import DeckiesMixin
|
||
from decnet.web.db.sqlmodel_repo.swarm import SwarmMixin
|
||
|
||
|
||
class SQLModelRepository(
|
||
AttackerIntelMixin,
|
||
AuthMixin,
|
||
DeckiesMixin,
|
||
SwarmMixin,
|
||
BaseRepository,
|
||
):
|
||
"""Concrete SQLModel/SQLAlchemy-async repository.
|
||
|
||
Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory``
|
||
in ``__init__``, and override the few dialect-specific helpers.
|
||
"""
|
||
|
||
engine: AsyncEngine
|
||
session_factory: async_sessionmaker[AsyncSession]
|
||
|
||
def _session(self):
|
||
"""Return a cancellation-safe session context manager."""
|
||
return _safe_session(self.session_factory)
|
||
|
||
# ------------------------------------------------------------ lifecycle
|
||
|
||
async def initialize(self) -> None:
|
||
"""Create tables if absent and seed the admin user."""
|
||
from sqlmodel import SQLModel
|
||
await self._migrate_attackers_table()
|
||
await self._migrate_session_profile_table()
|
||
async with self.engine.begin() as conn:
|
||
await conn.run_sync(SQLModel.metadata.create_all)
|
||
await self._ensure_admin_user()
|
||
|
||
async def reinitialize(self) -> None:
|
||
"""Re-create schema (for tests / reset flows). Does NOT drop existing tables."""
|
||
from sqlmodel import SQLModel
|
||
async with self.engine.begin() as conn:
|
||
await conn.run_sync(SQLModel.metadata.create_all)
|
||
await self._ensure_admin_user()
|
||
|
||
async def _ensure_admin_user(self) -> None:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(User).where(User.username == DECNET_ADMIN_USER)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
if existing is None:
|
||
session.add(User(
|
||
uuid=str(uuid.uuid4()),
|
||
username=DECNET_ADMIN_USER,
|
||
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
|
||
role="admin",
|
||
must_change_password=True,
|
||
))
|
||
await session.commit()
|
||
return
|
||
# Self-heal env drift: if admin never finalized their password,
|
||
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
|
||
# the user's chosen password alone.
|
||
if existing.must_change_password:
|
||
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
|
||
session.add(existing)
|
||
await session.commit()
|
||
|
||
async def _migrate_attackers_table(self) -> None:
|
||
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
||
return None
|
||
|
||
async def _migrate_session_profile_table(self) -> None:
|
||
"""Add DEBT-036 keystroke-dynamics columns to existing session_profile
|
||
rows. Override per dialect — DDL introspection is non-portable."""
|
||
return None
|
||
|
||
# ---------------------------------------------------------------- logs
|
||
|
||
@staticmethod
|
||
def _normalize_log_row(log_data: dict[str, Any]) -> dict[str, Any]:
|
||
data = log_data.copy()
|
||
if "fields" in data and isinstance(data["fields"], dict):
|
||
data["fields"] = orjson.dumps(data["fields"]).decode()
|
||
if "timestamp" in data and isinstance(data["timestamp"], str):
|
||
try:
|
||
data["timestamp"] = datetime.fromisoformat(
|
||
data["timestamp"].replace("Z", "+00:00")
|
||
)
|
||
except ValueError:
|
||
pass
|
||
return data
|
||
|
||
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||
data = self._normalize_log_row(log_data)
|
||
async with self._session() as session:
|
||
session.add(Log(**data))
|
||
await session.commit()
|
||
|
||
async def add_logs(self, log_entries: list[dict[str, Any]]) -> None:
|
||
"""Bulk insert — one session, one commit for the whole batch."""
|
||
if not log_entries:
|
||
return
|
||
_rows = [Log(**self._normalize_log_row(e)) for e in log_entries]
|
||
async with self._session() as session:
|
||
session.add_all(_rows)
|
||
await session.commit()
|
||
|
||
def _apply_filters(
|
||
self,
|
||
statement: SelectOfScalar,
|
||
search: Optional[str],
|
||
start_time: Optional[str],
|
||
end_time: Optional[str],
|
||
) -> SelectOfScalar:
|
||
import re
|
||
import shlex
|
||
|
||
if start_time:
|
||
statement = statement.where(Log.timestamp >= start_time)
|
||
if end_time:
|
||
statement = statement.where(Log.timestamp <= end_time)
|
||
|
||
if search:
|
||
try:
|
||
tokens = shlex.split(search)
|
||
except ValueError:
|
||
tokens = search.split()
|
||
|
||
core_fields = {
|
||
"decky": Log.decky,
|
||
"service": Log.service,
|
||
"event": Log.event_type,
|
||
"attacker": Log.attacker_ip,
|
||
"attacker-ip": Log.attacker_ip,
|
||
"attacker_ip": Log.attacker_ip,
|
||
}
|
||
|
||
for token in tokens:
|
||
if ":" in token:
|
||
key, val = token.split(":", 1)
|
||
if key in core_fields:
|
||
statement = statement.where(core_fields[key] == val)
|
||
else:
|
||
key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key)
|
||
if key_safe:
|
||
statement = statement.where(
|
||
self._json_field_equals(key_safe)
|
||
).params(val=val)
|
||
else:
|
||
lk = f"%{token}%"
|
||
statement = statement.where(
|
||
or_(
|
||
Log.raw_line.like(lk),
|
||
Log.decky.like(lk),
|
||
Log.service.like(lk),
|
||
Log.attacker_ip.like(lk),
|
||
)
|
||
)
|
||
return statement
|
||
|
||
def _json_field_equals(self, key: str):
|
||
"""Return a text() predicate that matches rows where fields->key == :val.
|
||
|
||
Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also
|
||
exposes the same function under ``json_extract`` (case-insensitive).
|
||
The ``:val`` parameter is bound separately and must be supplied with
|
||
``.params(val=...)`` by the caller, which keeps us safe from injection.
|
||
"""
|
||
return text(f"JSON_EXTRACT(fields, '$.{key}') = :val")
|
||
|
||
async def get_logs(
|
||
self,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
search: Optional[str] = None,
|
||
start_time: Optional[str] = None,
|
||
end_time: Optional[str] = None,
|
||
) -> List[dict]:
|
||
statement = (
|
||
select(Log)
|
||
.order_by(desc(Log.timestamp))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||
|
||
async with self._session() as session:
|
||
results = await session.execute(statement)
|
||
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||
|
||
async def get_max_log_id(self) -> int:
|
||
async with self._session() as session:
|
||
result = await session.execute(select(func.max(Log.id)))
|
||
val = result.scalar()
|
||
return val if val is not None else 0
|
||
|
||
async def get_logs_after_id(
|
||
self,
|
||
last_id: int,
|
||
limit: int = 50,
|
||
search: Optional[str] = None,
|
||
start_time: Optional[str] = None,
|
||
end_time: Optional[str] = None,
|
||
) -> List[dict]:
|
||
statement = (
|
||
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
||
)
|
||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||
|
||
async with self._session() as session:
|
||
results = await session.execute(statement)
|
||
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||
|
||
async def get_total_logs(
|
||
self,
|
||
search: Optional[str] = None,
|
||
start_time: Optional[str] = None,
|
||
end_time: Optional[str] = None,
|
||
) -> int:
|
||
statement = select(func.count()).select_from(Log)
|
||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
async def get_log_histogram(
|
||
self,
|
||
search: Optional[str] = None,
|
||
start_time: Optional[str] = None,
|
||
end_time: Optional[str] = None,
|
||
interval_minutes: int = 15,
|
||
) -> List[dict]:
|
||
"""Dialect-specific — override per backend."""
|
||
raise NotImplementedError
|
||
|
||
async def get_stats_summary(self) -> dict[str, Any]:
|
||
async with self._session() as session:
|
||
total_logs = (
|
||
await session.execute(select(func.count()).select_from(Log))
|
||
).scalar() or 0
|
||
unique_attackers = (
|
||
await session.execute(
|
||
select(func.count(func.distinct(Log.attacker_ip)))
|
||
)
|
||
).scalar() or 0
|
||
topo_total = (
|
||
await session.execute(select(func.count()).select_from(TopologyDecky))
|
||
).scalar() or 0
|
||
topo_running = (
|
||
await session.execute(
|
||
select(func.count())
|
||
.select_from(TopologyDecky)
|
||
.where(TopologyDecky.state == "running")
|
||
)
|
||
).scalar() or 0
|
||
|
||
_state = await asyncio.to_thread(load_state)
|
||
fleet_deckies = len(_state[0].deckies) if _state else 0
|
||
|
||
return {
|
||
"total_logs": total_logs,
|
||
"unique_attackers": unique_attackers,
|
||
# Fleet state file doesn't track per-decky runtime; treat all
|
||
# fleet rows as active and add MazeNET running rows on top.
|
||
"active_deckies": fleet_deckies + topo_running,
|
||
"deployed_deckies": fleet_deckies + topo_total,
|
||
}
|
||
|
||
async def get_deckies(self) -> List[dict]:
|
||
_state = await asyncio.to_thread(load_state)
|
||
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
||
|
||
# --------------------------------------------------------------- users
|
||
|
||
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
||
async with self._session() as session:
|
||
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
|
||
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
|
||
credentials_deleted = (
|
||
await session.execute(text("DELETE FROM credentials"))
|
||
).rowcount
|
||
# attacker_behavior has FK → attackers.uuid; delete children first.
|
||
await session.execute(text("DELETE FROM attacker_behavior"))
|
||
attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount
|
||
await session.commit()
|
||
return {
|
||
"logs": logs_deleted,
|
||
"bounties": bounties_deleted,
|
||
"credentials": credentials_deleted,
|
||
"attackers": attackers_deleted,
|
||
}
|
||
|
||
# ------------------------------------------------------------ bounties
|
||
|
||
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||
data = bounty_data.copy()
|
||
if "payload" in data and isinstance(data["payload"], dict):
|
||
data["payload"] = orjson.dumps(data["payload"]).decode()
|
||
|
||
async with self._session() as session:
|
||
dup = await session.execute(
|
||
select(Bounty.id).where(
|
||
Bounty.bounty_type == data.get("bounty_type"),
|
||
Bounty.attacker_ip == data.get("attacker_ip"),
|
||
Bounty.payload == data.get("payload"),
|
||
).limit(1)
|
||
)
|
||
if dup.first() is not None:
|
||
return
|
||
session.add(Bounty(**data))
|
||
await session.commit()
|
||
|
||
def _apply_bounty_filters(
|
||
self,
|
||
statement: SelectOfScalar,
|
||
bounty_type: Optional[str],
|
||
search: Optional[str],
|
||
) -> SelectOfScalar:
|
||
if bounty_type:
|
||
statement = statement.where(Bounty.bounty_type == bounty_type)
|
||
if search:
|
||
lk = f"%{search}%"
|
||
statement = statement.where(
|
||
or_(
|
||
Bounty.decky.like(lk),
|
||
Bounty.service.like(lk),
|
||
Bounty.attacker_ip.like(lk),
|
||
Bounty.payload.like(lk),
|
||
)
|
||
)
|
||
return statement
|
||
|
||
async def get_bounties(
|
||
self,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
bounty_type: Optional[str] = None,
|
||
search: Optional[str] = None,
|
||
) -> List[dict]:
|
||
statement = (
|
||
select(Bounty)
|
||
.order_by(desc(Bounty.timestamp))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||
|
||
async with self._session() as session:
|
||
results = await session.execute(statement)
|
||
final = []
|
||
for item in results.scalars().all():
|
||
d = item.model_dump(mode="json")
|
||
try:
|
||
d["payload"] = json.loads(d["payload"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
final.append(d)
|
||
return final
|
||
|
||
async def get_total_bounties(
|
||
self, bounty_type: Optional[str] = None, search: Optional[str] = None
|
||
) -> int:
|
||
statement = select(func.count()).select_from(Bounty)
|
||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
# ─── credentials ──────────────────────────────────────────────────────
|
||
|
||
async def upsert_credential(self, data: dict[str, Any]) -> int:
|
||
"""Upsert a credential attempt; returns the row id.
|
||
|
||
Dedup tuple: (attacker_ip, decky_name, service, secret_sha256,
|
||
principal_or_None). On match, ``attempt_count`` += 1 and
|
||
``last_seen`` advances; ``first_seen`` and ``fields`` are
|
||
preserved from the original sighting.
|
||
"""
|
||
payload = dict(data)
|
||
if "fields" in payload and isinstance(payload["fields"], dict):
|
||
# ensure_ascii=True keeps utf8mb4 columns safe even when
|
||
# service-specific keys carry non-ASCII bytes.
|
||
payload["fields"] = json.dumps(payload["fields"], ensure_ascii=True)
|
||
|
||
principal = payload.get("principal")
|
||
secret_kind = payload.get("secret_kind") or "plaintext"
|
||
async with self._session() as session:
|
||
stmt = select(Credential).where(
|
||
Credential.attacker_ip == payload["attacker_ip"],
|
||
Credential.decky_name == payload["decky_name"],
|
||
Credential.service == payload["service"],
|
||
Credential.secret_kind == secret_kind,
|
||
Credential.secret_sha256 == payload["secret_sha256"],
|
||
# NULL == NULL is False under SQL — branch the predicate.
|
||
(Credential.principal == principal) if principal is not None
|
||
else Credential.principal.is_(None),
|
||
)
|
||
existing = (await session.execute(stmt)).scalar_one_or_none()
|
||
now = datetime.now(timezone.utc)
|
||
if existing is not None:
|
||
existing.attempt_count = (existing.attempt_count or 1) + 1
|
||
existing.last_seen = now
|
||
if payload.get("outcome") is not None:
|
||
existing.outcome = payload["outcome"]
|
||
session.add(existing)
|
||
await session.commit()
|
||
return existing.id # type: ignore[return-value]
|
||
row = Credential(
|
||
attacker_ip=payload["attacker_ip"],
|
||
decky_name=payload["decky_name"],
|
||
service=payload["service"],
|
||
principal=principal,
|
||
secret_kind=secret_kind,
|
||
secret_sha256=payload["secret_sha256"],
|
||
secret_b64=payload.get("secret_b64"),
|
||
secret_printable=payload.get("secret_printable"),
|
||
outcome=payload.get("outcome"),
|
||
fields=payload.get("fields", "{}"),
|
||
first_seen=now,
|
||
last_seen=now,
|
||
attempt_count=1,
|
||
)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.id # type: ignore[return-value]
|
||
|
||
def _apply_credential_filters(
|
||
self,
|
||
statement: SelectOfScalar,
|
||
search: Optional[str],
|
||
service: Optional[str],
|
||
attacker_ip: Optional[str],
|
||
) -> SelectOfScalar:
|
||
if service:
|
||
statement = statement.where(Credential.service == service)
|
||
if attacker_ip:
|
||
statement = statement.where(Credential.attacker_ip == attacker_ip)
|
||
if search:
|
||
lk = f"%{search}%"
|
||
statement = statement.where(
|
||
or_(
|
||
Credential.decky_name.like(lk),
|
||
Credential.service.like(lk),
|
||
Credential.principal.like(lk),
|
||
Credential.secret_printable.like(lk),
|
||
)
|
||
)
|
||
return statement
|
||
|
||
async def get_credentials(
|
||
self,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
search: Optional[str] = None,
|
||
service: Optional[str] = None,
|
||
attacker_ip: Optional[str] = None,
|
||
) -> List[dict[str, Any]]:
|
||
statement = (
|
||
select(Credential)
|
||
.order_by(desc(Credential.last_seen))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
statement = self._apply_credential_filters(
|
||
statement, search, service, attacker_ip
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
out: List[dict[str, Any]] = []
|
||
for item in result.scalars().all():
|
||
d = item.model_dump(mode="json")
|
||
try:
|
||
d["fields"] = json.loads(d["fields"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
out.append(d)
|
||
return out
|
||
|
||
async def get_total_credentials(
|
||
self,
|
||
search: Optional[str] = None,
|
||
service: Optional[str] = None,
|
||
attacker_ip: Optional[str] = None,
|
||
) -> int:
|
||
statement = select(func.count()).select_from(Credential)
|
||
statement = self._apply_credential_filters(
|
||
statement, search, service, attacker_ip
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
async def get_credentials_for_attacker(
|
||
self, attacker_ip: str
|
||
) -> List[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Credential)
|
||
.where(Credential.attacker_ip == attacker_ip)
|
||
.order_by(desc(Credential.last_seen))
|
||
)
|
||
out: List[dict[str, Any]] = []
|
||
for item in result.scalars().all():
|
||
d = item.model_dump(mode="json")
|
||
try:
|
||
d["fields"] = json.loads(d["fields"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
out.append(d)
|
||
return out
|
||
|
||
async def get_credential_attempts_for_secret(
|
||
self, secret_sha256: str
|
||
) -> List[dict[str, Any]]:
|
||
"""Every (attacker_ip, decky, service, principal) row sharing this
|
||
secret hash. Indexed lookup via ix_credentials_secret_service.
|
||
"""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Credential)
|
||
.where(Credential.secret_sha256 == secret_sha256)
|
||
.order_by(desc(Credential.last_seen))
|
||
)
|
||
out: List[dict[str, Any]] = []
|
||
for item in result.scalars().all():
|
||
d = item.model_dump(mode="json")
|
||
try:
|
||
d["fields"] = json.loads(d["fields"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
out.append(d)
|
||
return out
|
||
|
||
# ─── credential reuse (findings) ──────────────────────────────────────
|
||
|
||
async def update_credential_attacker_uuid(
|
||
self, attacker_ip: str, attacker_uuid: str
|
||
) -> int:
|
||
"""Backfill ``attacker_uuid`` on every Credential row matching the
|
||
given IP whose ``attacker_uuid`` is currently null. Run by the
|
||
profiler after it mints/updates an Attacker row.
|
||
"""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
update(Credential)
|
||
.where(
|
||
Credential.attacker_ip == attacker_ip,
|
||
Credential.attacker_uuid.is_(None),
|
||
)
|
||
.values(attacker_uuid=attacker_uuid)
|
||
)
|
||
await session.commit()
|
||
return int(result.rowcount or 0)
|
||
|
||
@staticmethod
|
||
def _merge_unique(existing_json: str, value: Optional[str]) -> tuple[str, bool]:
|
||
"""Append ``value`` to a JSON list[str] column if not present.
|
||
Returns (new_json, changed). None values and duplicates are skipped.
|
||
"""
|
||
if value is None:
|
||
return existing_json, False
|
||
try:
|
||
current = json.loads(existing_json) if existing_json else []
|
||
if not isinstance(current, list):
|
||
current = []
|
||
except (json.JSONDecodeError, TypeError):
|
||
current = []
|
||
if value in current:
|
||
return existing_json, False
|
||
current.append(value)
|
||
return json.dumps(current, ensure_ascii=True), True
|
||
|
||
async def upsert_credential_reuse(
|
||
self,
|
||
*,
|
||
secret_sha256: str,
|
||
secret_kind: str,
|
||
principal: Optional[str],
|
||
attacker_uuid: Optional[str],
|
||
attacker_ip: str,
|
||
decky: str,
|
||
service: str,
|
||
attempt_count: int,
|
||
ts: Optional[datetime] = None,
|
||
) -> Optional[dict[str, Any]]:
|
||
"""Upsert a credential-reuse finding.
|
||
|
||
The row is keyed by ``(secret_sha256, secret_kind, principal_key)``
|
||
— ``principal_key`` is the canonicalised non-null form ("" when
|
||
principal is null) so the unique constraint behaves the same on
|
||
SQLite and MySQL.
|
||
|
||
Returns the row dict augmented with ``inserted: bool`` and
|
||
``changed: bool`` so the correlator can decide whether to publish
|
||
a bus event.
|
||
"""
|
||
principal_key = principal or ""
|
||
now = ts or datetime.now(timezone.utc)
|
||
async with self._session() as session:
|
||
existing = (await session.execute(
|
||
select(CredentialReuse).where(
|
||
CredentialReuse.secret_sha256 == secret_sha256,
|
||
CredentialReuse.secret_kind == secret_kind,
|
||
CredentialReuse.principal_key == principal_key,
|
||
)
|
||
)).scalar_one_or_none()
|
||
|
||
if existing is None:
|
||
row = CredentialReuse(
|
||
id=str(uuid.uuid4()),
|
||
secret_sha256=secret_sha256,
|
||
secret_kind=secret_kind,
|
||
principal=principal,
|
||
principal_key=principal_key,
|
||
attacker_uuids=json.dumps(
|
||
[attacker_uuid] if attacker_uuid else [], ensure_ascii=True
|
||
),
|
||
attacker_ips=json.dumps([attacker_ip], ensure_ascii=True),
|
||
deckies=json.dumps([decky], ensure_ascii=True),
|
||
services=json.dumps([service], ensure_ascii=True),
|
||
target_count=1,
|
||
attempt_count=int(attempt_count),
|
||
confidence=1.0,
|
||
first_seen=now,
|
||
last_seen=now,
|
||
updated_at=now,
|
||
)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
d = row.model_dump(mode="json")
|
||
d["inserted"] = True
|
||
d["changed"] = True
|
||
return d
|
||
|
||
changed = False
|
||
new_uuids, c1 = self._merge_unique(existing.attacker_uuids, attacker_uuid)
|
||
new_ips, c2 = self._merge_unique(existing.attacker_ips, attacker_ip)
|
||
new_deckies, c3 = self._merge_unique(existing.deckies, decky)
|
||
new_services, c4 = self._merge_unique(existing.services, service)
|
||
existing.attacker_uuids = new_uuids
|
||
existing.attacker_ips = new_ips
|
||
if c3 or c4:
|
||
existing.deckies = new_deckies
|
||
existing.services = new_services
|
||
# Recount target tuples from the underlying credentials
|
||
# table — a (decky, service) tuple only counts when both
|
||
# were observed together, which the JSON lists alone
|
||
# can't tell us.
|
||
stmt = (
|
||
select(func.count(func.distinct(
|
||
Credential.decky_name + ":" + Credential.service
|
||
)))
|
||
.where(
|
||
Credential.secret_sha256 == secret_sha256,
|
||
Credential.secret_kind == secret_kind,
|
||
(Credential.principal == principal) if principal is not None
|
||
else Credential.principal.is_(None),
|
||
)
|
||
)
|
||
target_count = (await session.execute(stmt)).scalar() or 0
|
||
existing.target_count = int(target_count)
|
||
existing.attempt_count = (existing.attempt_count or 0) + int(attempt_count)
|
||
existing.last_seen = now
|
||
existing.updated_at = now
|
||
if c1 or c2 or c3 or c4:
|
||
changed = True
|
||
session.add(existing)
|
||
await session.commit()
|
||
await session.refresh(existing)
|
||
d = existing.model_dump(mode="json")
|
||
d["inserted"] = False
|
||
d["changed"] = changed
|
||
return d
|
||
|
||
async def find_credential_reuse_candidates(
|
||
self, min_targets: int = 2
|
||
) -> List[dict[str, Any]]:
|
||
"""Find credential groups crossing the reuse threshold.
|
||
|
||
Returns one dict per qualifying ``(secret_sha256, secret_kind,
|
||
principal)`` group, with the keys plus a ``credentials`` list of
|
||
the underlying rows so the correlator can fold each into
|
||
``CredentialReuse`` via ``upsert_credential_reuse``.
|
||
"""
|
||
target_expr = func.count(
|
||
func.distinct(Credential.decky_name + ":" + Credential.service)
|
||
).label("target_count")
|
||
async with self._session() as session:
|
||
group_stmt = (
|
||
select(
|
||
Credential.secret_sha256,
|
||
Credential.secret_kind,
|
||
Credential.principal,
|
||
target_expr,
|
||
)
|
||
.group_by(
|
||
Credential.secret_sha256,
|
||
Credential.secret_kind,
|
||
Credential.principal,
|
||
)
|
||
.having(target_expr >= int(min_targets))
|
||
)
|
||
groups = (await session.execute(group_stmt)).all()
|
||
out: List[dict[str, Any]] = []
|
||
for sha, kind, principal, target_count in groups:
|
||
cred_stmt = select(Credential).where(
|
||
Credential.secret_sha256 == sha,
|
||
Credential.secret_kind == kind,
|
||
(Credential.principal == principal)
|
||
if principal is not None
|
||
else Credential.principal.is_(None),
|
||
)
|
||
rows = (await session.execute(cred_stmt)).scalars().all()
|
||
out.append({
|
||
"secret_sha256": sha,
|
||
"secret_kind": kind,
|
||
"principal": principal,
|
||
"target_count": int(target_count or 0),
|
||
"credentials": [r.model_dump(mode="json") for r in rows],
|
||
})
|
||
return out
|
||
|
||
async def list_credential_reuses(
|
||
self,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
min_target_count: int = 2,
|
||
secret_kind: Optional[str] = None,
|
||
) -> tuple[int, List[dict[str, Any]]]:
|
||
async with self._session() as session:
|
||
base = select(CredentialReuse).where(
|
||
CredentialReuse.target_count >= min_target_count
|
||
)
|
||
if secret_kind:
|
||
base = base.where(CredentialReuse.secret_kind == secret_kind)
|
||
total_stmt = select(func.count()).select_from(base.subquery())
|
||
total = (await session.execute(total_stmt)).scalar() or 0
|
||
list_stmt = (
|
||
base.order_by(desc(CredentialReuse.target_count),
|
||
desc(CredentialReuse.last_seen))
|
||
.offset(offset).limit(limit)
|
||
)
|
||
rows = (await session.execute(list_stmt)).scalars().all()
|
||
out: List[dict[str, Any]] = []
|
||
for r in rows:
|
||
d = r.model_dump(mode="json")
|
||
for key in ("attacker_uuids", "attacker_ips", "deckies", "services"):
|
||
try:
|
||
d[key] = json.loads(d[key])
|
||
except (json.JSONDecodeError, TypeError):
|
||
d[key] = []
|
||
out.append(d)
|
||
await self._enrich_with_secret(session, out)
|
||
return int(total), out
|
||
|
||
async def get_credential_reuse_by_id(
|
||
self, reuse_id: str
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
row = (await session.execute(
|
||
select(CredentialReuse).where(CredentialReuse.id == reuse_id)
|
||
)).scalar_one_or_none()
|
||
if row is None:
|
||
return None
|
||
d = row.model_dump(mode="json")
|
||
for key in ("attacker_uuids", "attacker_ips", "deckies", "services"):
|
||
try:
|
||
d[key] = json.loads(d[key])
|
||
except (json.JSONDecodeError, TypeError):
|
||
d[key] = []
|
||
await self._enrich_with_secret(session, [d])
|
||
return d
|
||
|
||
@staticmethod
|
||
async def _enrich_with_secret(
|
||
session: Any, rows: List[dict[str, Any]]
|
||
) -> None:
|
||
"""Tack ``secret_printable`` + ``secret_b64`` onto each reuse row.
|
||
|
||
``CredentialReuse`` only stores the sha256+kind hash of the
|
||
secret — the actual printable/b64 representations live on the
|
||
underlying ``Credential`` rows. The dashboard wants to show the
|
||
secret in the drawer, so we lift one matching credential per
|
||
``(sha256, kind, principal)`` finding. One batched query for the
|
||
whole page; rows with no surviving credential (shouldn't happen
|
||
in practice) get nulls.
|
||
"""
|
||
if not rows:
|
||
return
|
||
sha_set = {r["secret_sha256"] for r in rows}
|
||
if not sha_set:
|
||
return
|
||
stmt = select(
|
||
Credential.secret_sha256,
|
||
Credential.secret_kind,
|
||
Credential.principal,
|
||
Credential.secret_printable,
|
||
Credential.secret_b64,
|
||
).where(Credential.secret_sha256.in_(sha_set))
|
||
secret_map: dict[
|
||
tuple[str, str, Optional[str]],
|
||
tuple[Optional[str], Optional[str]],
|
||
] = {}
|
||
for sha, kind, principal, printable, b64 in (
|
||
(await session.execute(stmt)).all()
|
||
):
|
||
secret_map.setdefault((sha, kind, principal), (printable, b64))
|
||
for r in rows:
|
||
key = (r["secret_sha256"], r["secret_kind"], r.get("principal"))
|
||
printable, b64 = secret_map.get(key, (None, None))
|
||
r["secret_printable"] = printable
|
||
r["secret_b64"] = b64
|
||
|
||
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
statement = select(State).where(State.key == key)
|
||
result = await session.execute(statement)
|
||
state = result.scalar_one_or_none()
|
||
if state:
|
||
return json.loads(state.value)
|
||
return None
|
||
|
||
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
||
async with self._session() as session:
|
||
statement = select(State).where(State.key == key)
|
||
result = await session.execute(statement)
|
||
state = result.scalar_one_or_none()
|
||
|
||
value_json = orjson.dumps(value).decode()
|
||
if state:
|
||
state.value = value_json
|
||
session.add(state)
|
||
else:
|
||
session.add(State(key=key, value=value_json))
|
||
|
||
await session.commit()
|
||
|
||
# ----------------------------------------------------------- attackers
|
||
|
||
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
||
from collections import defaultdict
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Bounty).order_by(asc(Bounty.timestamp))
|
||
)
|
||
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
||
for item in result.scalars().all():
|
||
d = item.model_dump(mode="json")
|
||
try:
|
||
d["payload"] = json.loads(d["payload"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
grouped[item.attacker_ip].append(d)
|
||
return dict(grouped)
|
||
|
||
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
||
from collections import defaultdict
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
||
)
|
||
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
||
for item in result.scalars().all():
|
||
d = item.model_dump(mode="json")
|
||
try:
|
||
d["payload"] = json.loads(d["payload"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
grouped[item.attacker_ip].append(d)
|
||
return dict(grouped)
|
||
|
||
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Attacker).where(Attacker.ip == data["ip"])
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
if existing:
|
||
for k, v in data.items():
|
||
setattr(existing, k, v)
|
||
session.add(existing)
|
||
row_uuid = existing.uuid
|
||
else:
|
||
row_uuid = str(uuid.uuid4())
|
||
data = {**data, "uuid": row_uuid}
|
||
session.add(Attacker(**data))
|
||
await session.commit()
|
||
return row_uuid
|
||
|
||
async def upsert_attacker_behavior(
|
||
self,
|
||
attacker_uuid: str,
|
||
data: dict[str, Any],
|
||
) -> None:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(AttackerBehavior).where(
|
||
AttackerBehavior.attacker_uuid == attacker_uuid
|
||
)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
payload = {**data, "updated_at": datetime.now(timezone.utc)}
|
||
if existing:
|
||
for k, v in payload.items():
|
||
setattr(existing, k, v)
|
||
session.add(existing)
|
||
else:
|
||
session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload))
|
||
await session.commit()
|
||
|
||
async def get_attacker_behavior(
|
||
self,
|
||
attacker_uuid: str,
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(AttackerBehavior).where(
|
||
AttackerBehavior.attacker_uuid == attacker_uuid
|
||
)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
if not row:
|
||
return None
|
||
return self._deserialize_behavior(row.model_dump(mode="json"))
|
||
|
||
async def get_behaviors_for_ips(
|
||
self,
|
||
ips: set[str],
|
||
) -> dict[str, dict[str, Any]]:
|
||
if not ips:
|
||
return {}
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Attacker.ip, AttackerBehavior)
|
||
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
||
.where(Attacker.ip.in_(ips))
|
||
)
|
||
out: dict[str, dict[str, Any]] = {}
|
||
for ip, row in result.all():
|
||
out[ip] = self._deserialize_behavior(row.model_dump(mode="json"))
|
||
return out
|
||
|
||
@staticmethod
|
||
def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]:
|
||
for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"):
|
||
if isinstance(d.get(key), str):
|
||
try:
|
||
d[key] = json.loads(d[key])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
# Deserialize tool_guesses JSON array; normalise None → [].
|
||
raw = d.get("tool_guesses")
|
||
if isinstance(raw, str):
|
||
try:
|
||
parsed = json.loads(raw)
|
||
d["tool_guesses"] = parsed if isinstance(parsed, list) else [parsed]
|
||
except (json.JSONDecodeError, TypeError):
|
||
d["tool_guesses"] = []
|
||
elif raw is None:
|
||
d["tool_guesses"] = []
|
||
# Same list-or-None pattern for kex_order_raw.
|
||
raw_kex = d.get("kex_order_raw")
|
||
if isinstance(raw_kex, str):
|
||
try:
|
||
parsed_kex = json.loads(raw_kex)
|
||
d["kex_order_raw"] = parsed_kex if isinstance(parsed_kex, list) else [parsed_kex]
|
||
except (json.JSONDecodeError, TypeError):
|
||
d["kex_order_raw"] = []
|
||
elif raw_kex is None:
|
||
d["kex_order_raw"] = []
|
||
# Same list-or-None pattern for ssh_client_banners.
|
||
raw_banners = d.get("ssh_client_banners")
|
||
if isinstance(raw_banners, str):
|
||
try:
|
||
parsed_banners = json.loads(raw_banners)
|
||
d["ssh_client_banners"] = parsed_banners if isinstance(parsed_banners, list) else [parsed_banners]
|
||
except (json.JSONDecodeError, TypeError):
|
||
d["ssh_client_banners"] = []
|
||
elif raw_banners is None:
|
||
d["ssh_client_banners"] = []
|
||
return d
|
||
|
||
async def upsert_session_profile(
|
||
self,
|
||
sid: str,
|
||
data: dict[str, Any],
|
||
) -> None:
|
||
"""
|
||
Write (or update) the session_profile row for *sid*.
|
||
|
||
Pre-v1, the typical call is the empty-write path at session close:
|
||
`upsert_session_profile(sid, {"log_id": <id>})` — all keystroke
|
||
feature columns stay NULL until the V2 ingestion job populates them.
|
||
"""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(SessionProfile).where(SessionProfile.sid == sid)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
if existing:
|
||
for k, v in data.items():
|
||
setattr(existing, k, v)
|
||
session.add(existing)
|
||
else:
|
||
session.add(SessionProfile(sid=sid, **data))
|
||
await session.commit()
|
||
|
||
async def get_session_profile(
|
||
self,
|
||
sid: str,
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(SessionProfile).where(SessionProfile.sid == sid)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
if not row:
|
||
return None
|
||
return row.model_dump(mode="json")
|
||
|
||
async def increment_smtp_target(self, attacker_uuid: str, domain: str) -> None:
|
||
"""Upsert an (attacker_uuid, domain) pair and bump count + last_seen.
|
||
|
||
Read-then-write under a single session — the UNIQUE constraint on
|
||
(attacker_uuid, domain) guards against duplicate rows if the race
|
||
ever materialises; we accept the ~1ms extra round-trip in exchange
|
||
for a single dialect-portable implementation.
|
||
"""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(SmtpTarget)
|
||
.where(SmtpTarget.attacker_uuid == attacker_uuid)
|
||
.where(SmtpTarget.domain == domain)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
now = datetime.now(timezone.utc)
|
||
if existing:
|
||
existing.count += 1
|
||
existing.last_seen = now
|
||
session.add(existing)
|
||
else:
|
||
session.add(SmtpTarget(
|
||
attacker_uuid=attacker_uuid,
|
||
domain=domain,
|
||
first_seen=now,
|
||
last_seen=now,
|
||
count=1,
|
||
))
|
||
await session.commit()
|
||
|
||
async def list_smtp_targets(self, attacker_uuid: str) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(SmtpTarget)
|
||
.where(SmtpTarget.attacker_uuid == attacker_uuid)
|
||
.order_by(desc(SmtpTarget.last_seen))
|
||
)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def smtp_target_seen(self, domain: str) -> dict[str, Any]:
|
||
"""Aggregate rows for this domain across every attacker in the DB."""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(
|
||
func.coalesce(func.sum(SmtpTarget.count), 0),
|
||
func.min(SmtpTarget.first_seen),
|
||
func.max(SmtpTarget.last_seen),
|
||
).where(SmtpTarget.domain == domain)
|
||
)
|
||
total, first_seen, last_seen = result.one()
|
||
return {
|
||
"seen": int(total) > 0,
|
||
"count": int(total),
|
||
"first_seen": first_seen,
|
||
"last_seen": last_seen,
|
||
}
|
||
|
||
@staticmethod
|
||
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
|
||
for key in ("services", "deckies", "fingerprints", "commands"):
|
||
if isinstance(d.get(key), str):
|
||
try:
|
||
d[key] = json.loads(d[key])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
return d
|
||
|
||
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Attacker).where(Attacker.uuid == uuid)
|
||
)
|
||
attacker = result.scalar_one_or_none()
|
||
if not attacker:
|
||
return None
|
||
return self._deserialize_attacker(attacker.model_dump(mode="json"))
|
||
|
||
async def get_attackers(
|
||
self,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
search: Optional[str] = None,
|
||
sort_by: str = "recent",
|
||
service: Optional[str] = None,
|
||
) -> List[dict[str, Any]]:
|
||
order = {
|
||
"active": desc(Attacker.event_count),
|
||
"traversals": desc(Attacker.is_traversal),
|
||
}.get(sort_by, desc(Attacker.last_seen))
|
||
|
||
statement = select(Attacker).order_by(order).offset(offset).limit(limit)
|
||
if search:
|
||
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
||
if service:
|
||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [
|
||
self._deserialize_attacker(a.model_dump(mode="json"))
|
||
for a in result.scalars().all()
|
||
]
|
||
|
||
async def get_total_attackers(
|
||
self, search: Optional[str] = None, service: Optional[str] = None
|
||
) -> int:
|
||
statement = select(func.count()).select_from(Attacker)
|
||
if search:
|
||
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
||
if service:
|
||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
# ─── Identity resolution reads ────────────────────────────────────────
|
||
|
||
async def get_identity_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||
# Follow merged_into_uuid up to the winner. Loop bounded by
|
||
# _MAX_MERGE_HOPS so a (hypothetically) corrupted ring can't
|
||
# spin the worker. Clusterer is responsible for never producing
|
||
# a cycle; this guard is belt-and-braces.
|
||
_MAX_MERGE_HOPS = 8
|
||
async with self._session() as session:
|
||
current_uuid = uuid
|
||
for _ in range(_MAX_MERGE_HOPS):
|
||
result = await session.execute(
|
||
select(AttackerIdentity).where(AttackerIdentity.uuid == current_uuid)
|
||
)
|
||
identity = result.scalar_one_or_none()
|
||
if identity is None:
|
||
return None
|
||
if identity.merged_into_uuid is None:
|
||
return identity.model_dump(mode="json")
|
||
current_uuid = identity.merged_into_uuid
|
||
# Hit the hop cap — surface what we have rather than recurse.
|
||
return identity.model_dump(mode="json")
|
||
|
||
async def list_identities(
|
||
self, limit: int = 50, offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
# Exclude merged-out rows so the list view is the de-duped truth.
|
||
# The history is still queryable per-uuid via get_identity_by_uuid
|
||
# and a future "merged into" endpoint when we need it.
|
||
statement = (
|
||
select(AttackerIdentity)
|
||
.where(AttackerIdentity.merged_into_uuid.is_(None))
|
||
.order_by(desc(AttackerIdentity.updated_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [i.model_dump(mode="json") for i in result.scalars().all()]
|
||
|
||
async def count_identities(self) -> int:
|
||
statement = (
|
||
select(func.count())
|
||
.select_from(AttackerIdentity)
|
||
.where(AttackerIdentity.merged_into_uuid.is_(None))
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
async def list_observations_for_identity(
|
||
self, identity_uuid: str, limit: int = 50, offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
statement = (
|
||
select(Attacker)
|
||
.where(Attacker.identity_id == identity_uuid)
|
||
.order_by(desc(Attacker.last_seen))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [
|
||
self._deserialize_attacker(a.model_dump(mode="json"))
|
||
for a in result.scalars().all()
|
||
]
|
||
|
||
async def count_observations_for_identity(self, identity_uuid: str) -> int:
|
||
statement = (
|
||
select(func.count())
|
||
.select_from(Attacker)
|
||
.where(Attacker.identity_id == identity_uuid)
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
# ─── Identity resolution writes (clusterer worker) ─────────────────────
|
||
|
||
async def list_attackers_for_clustering(
|
||
self, limit: Optional[int] = None,
|
||
) -> list[dict[str, Any]]:
|
||
# Project the columns the clusterer's similarity graph reads.
|
||
# Keep it narrow so future denormalised projections (payloads
|
||
# joined from logs, c2 endpoints aggregated from sessions) can
|
||
# land here without churning every caller. ``fingerprints`` is
|
||
# the raw JSON list — the clusterer parses for JA3 / HASSH.
|
||
statement = select(
|
||
Attacker.uuid, Attacker.asn, Attacker.identity_id, Attacker.fingerprints,
|
||
).order_by(Attacker.first_seen)
|
||
if limit is not None:
|
||
statement = statement.limit(limit)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [
|
||
{
|
||
"uuid": row.uuid,
|
||
"asn": row.asn,
|
||
"identity_id": row.identity_id,
|
||
"fingerprints": row.fingerprints,
|
||
}
|
||
for row in result.all()
|
||
]
|
||
|
||
async def create_attacker_identity(self, row: dict[str, Any]) -> str:
|
||
identity = AttackerIdentity(**row)
|
||
async with self._session() as session:
|
||
session.add(identity)
|
||
await session.commit()
|
||
return identity.uuid
|
||
|
||
async def set_attacker_identity_id(
|
||
self, attacker_uuid: str, identity_uuid: str,
|
||
) -> None:
|
||
statement = (
|
||
update(Attacker)
|
||
.where(Attacker.uuid == attacker_uuid)
|
||
.values(identity_id=identity_uuid)
|
||
)
|
||
async with self._session() as session:
|
||
await session.execute(statement)
|
||
await session.commit()
|
||
|
||
async def list_all_identities(self) -> list[dict[str, Any]]:
|
||
statement = select(AttackerIdentity).order_by(AttackerIdentity.created_at)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [i.model_dump(mode="json") for i in result.scalars().all()]
|
||
|
||
async def update_identity_merged_into(
|
||
self, identity_uuid: str, winner_uuid: Optional[str],
|
||
) -> None:
|
||
statement = (
|
||
update(AttackerIdentity)
|
||
.where(AttackerIdentity.uuid == identity_uuid)
|
||
.values(
|
||
merged_into_uuid=winner_uuid,
|
||
updated_at=datetime.now(timezone.utc),
|
||
)
|
||
)
|
||
async with self._session() as session:
|
||
await session.execute(statement)
|
||
await session.commit()
|
||
|
||
async def update_identity_fingerprints(
|
||
self,
|
||
identity_uuid: str,
|
||
*,
|
||
ja3_hashes: Optional[str] = None,
|
||
hassh_hashes: Optional[str] = None,
|
||
tls_cert_sha256: Optional[str] = None,
|
||
) -> None:
|
||
statement = (
|
||
update(AttackerIdentity)
|
||
.where(AttackerIdentity.uuid == identity_uuid)
|
||
.values(
|
||
ja3_hashes=ja3_hashes,
|
||
hassh_hashes=hassh_hashes,
|
||
tls_cert_sha256=tls_cert_sha256,
|
||
updated_at=datetime.now(timezone.utc),
|
||
)
|
||
)
|
||
async with self._session() as session:
|
||
await session.execute(statement)
|
||
await session.commit()
|
||
|
||
# ─── Campaign clustering reads ────────────────────────────────────────
|
||
|
||
async def get_campaign_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||
# Same chain-walk as get_identity_by_uuid; bounded against
|
||
# corrupted rings.
|
||
_MAX_MERGE_HOPS = 8
|
||
async with self._session() as session:
|
||
current_uuid = uuid
|
||
for _ in range(_MAX_MERGE_HOPS):
|
||
result = await session.execute(
|
||
select(Campaign).where(Campaign.uuid == current_uuid)
|
||
)
|
||
campaign = result.scalar_one_or_none()
|
||
if campaign is None:
|
||
return None
|
||
if campaign.merged_into_uuid is None:
|
||
return campaign.model_dump(mode="json")
|
||
current_uuid = campaign.merged_into_uuid
|
||
return campaign.model_dump(mode="json")
|
||
|
||
async def list_campaigns(
|
||
self, limit: int = 50, offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
statement = (
|
||
select(Campaign)
|
||
.where(Campaign.merged_into_uuid.is_(None))
|
||
.order_by(desc(Campaign.updated_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [c.model_dump(mode="json") for c in result.scalars().all()]
|
||
|
||
async def count_campaigns(self) -> int:
|
||
statement = (
|
||
select(func.count())
|
||
.select_from(Campaign)
|
||
.where(Campaign.merged_into_uuid.is_(None))
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
async def list_identities_for_campaign(
|
||
self, campaign_uuid: str, limit: int = 50, offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
statement = (
|
||
select(AttackerIdentity)
|
||
.where(AttackerIdentity.campaign_id == campaign_uuid)
|
||
.order_by(desc(AttackerIdentity.updated_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [i.model_dump(mode="json") for i in result.scalars().all()]
|
||
|
||
async def count_identities_for_campaign(self, campaign_uuid: str) -> int:
|
||
statement = (
|
||
select(func.count())
|
||
.select_from(AttackerIdentity)
|
||
.where(AttackerIdentity.campaign_id == campaign_uuid)
|
||
)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return result.scalar() or 0
|
||
|
||
# ─── Campaign clustering writes (campaign-clusterer worker) ───────────
|
||
|
||
async def list_identities_for_clustering(
|
||
self, limit: Optional[int] = None,
|
||
) -> list[dict[str, Any]]:
|
||
# Project the columns the campaign clusterer's similarity
|
||
# graph reads. Narrow on purpose — future denormalised
|
||
# projections (commands_by_phase from log mining, decky-set
|
||
# aggregates) can land here without churning callers.
|
||
statement = select(
|
||
AttackerIdentity.uuid,
|
||
AttackerIdentity.campaign_id,
|
||
AttackerIdentity.merged_into_uuid,
|
||
AttackerIdentity.first_seen_at,
|
||
AttackerIdentity.last_seen_at,
|
||
AttackerIdentity.ja3_hashes,
|
||
AttackerIdentity.hassh_hashes,
|
||
AttackerIdentity.payload_simhashes,
|
||
AttackerIdentity.c2_endpoints,
|
||
).order_by(AttackerIdentity.created_at)
|
||
if limit is not None:
|
||
statement = statement.limit(limit)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [
|
||
{
|
||
"uuid": row.uuid,
|
||
"campaign_id": row.campaign_id,
|
||
"merged_into_uuid": row.merged_into_uuid,
|
||
"first_seen_at": (
|
||
row.first_seen_at.isoformat()
|
||
if row.first_seen_at is not None
|
||
else None
|
||
),
|
||
"last_seen_at": (
|
||
row.last_seen_at.isoformat()
|
||
if row.last_seen_at is not None
|
||
else None
|
||
),
|
||
"ja3_hashes": row.ja3_hashes,
|
||
"hassh_hashes": row.hassh_hashes,
|
||
"payload_simhashes": row.payload_simhashes,
|
||
"c2_endpoints": row.c2_endpoints,
|
||
}
|
||
for row in result.all()
|
||
]
|
||
|
||
async def create_campaign(self, row: dict[str, Any]) -> str:
|
||
campaign = Campaign(**row)
|
||
async with self._session() as session:
|
||
session.add(campaign)
|
||
await session.commit()
|
||
return campaign.uuid
|
||
|
||
async def set_identity_campaign_id(
|
||
self, identity_uuid: str, campaign_uuid: Optional[str],
|
||
) -> None:
|
||
statement = (
|
||
update(AttackerIdentity)
|
||
.where(AttackerIdentity.uuid == identity_uuid)
|
||
.values(
|
||
campaign_id=campaign_uuid,
|
||
updated_at=datetime.now(timezone.utc),
|
||
)
|
||
)
|
||
async with self._session() as session:
|
||
await session.execute(statement)
|
||
await session.commit()
|
||
|
||
async def list_all_campaigns(self) -> list[dict[str, Any]]:
|
||
statement = select(Campaign).order_by(Campaign.created_at)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [c.model_dump(mode="json") for c in result.scalars().all()]
|
||
|
||
async def update_campaign_merged_into(
|
||
self, campaign_uuid: str, winner_uuid: Optional[str],
|
||
) -> None:
|
||
statement = (
|
||
update(Campaign)
|
||
.where(Campaign.uuid == campaign_uuid)
|
||
.values(
|
||
merged_into_uuid=winner_uuid,
|
||
updated_at=datetime.now(timezone.utc),
|
||
)
|
||
)
|
||
async with self._session() as session:
|
||
await session.execute(statement)
|
||
await session.commit()
|
||
|
||
async def get_attacker_commands(
|
||
self,
|
||
uuid: str,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
service: Optional[str] = None,
|
||
) -> dict[str, Any]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Attacker.commands).where(Attacker.uuid == uuid)
|
||
)
|
||
raw = result.scalar_one_or_none()
|
||
if raw is None:
|
||
return {"total": 0, "data": []}
|
||
|
||
commands: list = json.loads(raw) if isinstance(raw, str) else raw
|
||
if service:
|
||
commands = [c for c in commands if c.get("service") == service]
|
||
|
||
total = len(commands)
|
||
page = commands[offset: offset + limit]
|
||
return {"total": total, "data": page}
|
||
|
||
async def get_attacker_service_activity(
|
||
self, attacker_uuid: str
|
||
) -> list[tuple[str, str]]:
|
||
"""Return distinct ``(service, event_type)`` pairs for an attacker.
|
||
|
||
Resolves IP then ``SELECT DISTINCT service, event_type FROM logs
|
||
WHERE attacker_ip = :ip`` — the result set is bounded by the
|
||
cardinality of services × event_types (tens, not thousands), so
|
||
this stays cheap even for attackers with long event streams.
|
||
Caller applies `event_kinds.bucket_services` to split into
|
||
scanned vs. interacted.
|
||
"""
|
||
async with self._session() as session:
|
||
ip_res = await session.execute(
|
||
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
|
||
)
|
||
ip = ip_res.scalar_one_or_none()
|
||
if not ip:
|
||
return []
|
||
rows = await session.execute(
|
||
select(Log.service, Log.event_type)
|
||
.where(Log.attacker_ip == ip)
|
||
.distinct()
|
||
)
|
||
return [(svc, evt) for svc, evt in rows.all()]
|
||
|
||
async def get_attacker_ip_leaks(
|
||
self, attacker_uuid: str, *, limit: int = 10,
|
||
) -> list[dict[str, Any]]:
|
||
"""Return ``bounty_type='ip_leak'`` rows for this attacker, newest
|
||
first, capped at ``limit``. Shape matches the XFF-mismatch
|
||
payload emitted by the ingester: keys include ``real_ip_claim``,
|
||
``source_header``, ``headers_seen``. Use
|
||
:meth:`count_attacker_ip_leaks` to get the unbounded total for
|
||
rotation detection."""
|
||
async with self._session() as session:
|
||
ip_res = await session.execute(
|
||
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
|
||
)
|
||
ip = ip_res.scalar_one_or_none()
|
||
if not ip:
|
||
return []
|
||
rows = await session.execute(
|
||
select(Bounty)
|
||
.where(Bounty.attacker_ip == ip)
|
||
.where(Bounty.bounty_type == "ip_leak")
|
||
.order_by(desc(Bounty.timestamp))
|
||
.limit(limit)
|
||
)
|
||
out: list[dict[str, Any]] = []
|
||
for row in rows.scalars().all():
|
||
rec = row.model_dump(mode="json")
|
||
# Bounty.payload is stored JSON-encoded; pre-decode for UX.
|
||
raw = rec.get("payload")
|
||
if isinstance(raw, str):
|
||
try:
|
||
rec["payload"] = json.loads(raw)
|
||
except (ValueError, TypeError):
|
||
rec["payload"] = {}
|
||
out.append(rec)
|
||
return out
|
||
|
||
async def count_attacker_ip_leaks(self, attacker_uuid: str) -> int:
|
||
"""Cheap COUNT(*) for XFF-rotation detection."""
|
||
async with self._session() as session:
|
||
ip_res = await session.execute(
|
||
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
|
||
)
|
||
ip = ip_res.scalar_one_or_none()
|
||
if not ip:
|
||
return 0
|
||
count_res = await session.execute(
|
||
select(func.count(Bounty.id))
|
||
.where(Bounty.attacker_ip == ip)
|
||
.where(Bounty.bounty_type == "ip_leak")
|
||
)
|
||
return int(count_res.scalar() or 0)
|
||
|
||
async def get_attacker_artifacts(self, uuid: str) -> list[dict[str, Any]]:
|
||
"""Return `file_captured` logs for the attacker identified by UUID.
|
||
|
||
Resolves the attacker's IP first, then queries the logs table on two
|
||
indexed columns (``attacker_ip`` and ``event_type``). No JSON extract
|
||
needed — the decky/stored_as are already decoded into ``fields`` by
|
||
the ingester and returned to the frontend for drawer rendering.
|
||
"""
|
||
async with self._session() as session:
|
||
ip_res = await session.execute(
|
||
select(Attacker.ip).where(Attacker.uuid == uuid)
|
||
)
|
||
ip = ip_res.scalar_one_or_none()
|
||
if not ip:
|
||
return []
|
||
rows = await session.execute(
|
||
select(Log)
|
||
.where(Log.attacker_ip == ip)
|
||
.where(Log.event_type == "file_captured")
|
||
.order_by(desc(Log.timestamp))
|
||
.limit(200)
|
||
)
|
||
return [r.model_dump(mode="json") for r in rows.scalars().all()]
|
||
|
||
async def get_attacker_stored_mail(self, uuid: str) -> list[dict[str, Any]]:
|
||
"""Return `message_stored` logs for an attacker, newest first.
|
||
|
||
Mirrors :meth:`get_attacker_artifacts` — the SMTP template emits one
|
||
`message_stored` row per accepted DATA body, with headers + sha256 +
|
||
attachment manifest already decoded into ``fields`` by the ingester.
|
||
Capped at 200 rows to match the artifact/transcript query shape.
|
||
"""
|
||
async with self._session() as session:
|
||
ip_res = await session.execute(
|
||
select(Attacker.ip).where(Attacker.uuid == uuid)
|
||
)
|
||
ip = ip_res.scalar_one_or_none()
|
||
if not ip:
|
||
return []
|
||
rows = await session.execute(
|
||
select(Log)
|
||
.where(Log.attacker_ip == ip)
|
||
.where(Log.event_type == "message_stored")
|
||
.order_by(desc(Log.timestamp))
|
||
.limit(200)
|
||
)
|
||
return [r.model_dump(mode="json") for r in rows.scalars().all()]
|
||
|
||
async def get_session_log(self, sid: str) -> Optional[dict[str, Any]]:
|
||
"""Look up the `session_recorded` Log row that owns a given sid.
|
||
|
||
sid is a v4 UUID embedded in the row's ``fields`` JSON blob. Matched
|
||
with LIKE on the textual sid substring — cheap given the bounded
|
||
cardinality of session_recorded rows vs. the full logs table.
|
||
"""
|
||
needle = f'"sid":"{sid}"'
|
||
async with self._session() as session:
|
||
rows = await session.execute(
|
||
select(Log)
|
||
.where(Log.event_type == "session_recorded")
|
||
.where(Log.fields.contains(needle))
|
||
.limit(1)
|
||
)
|
||
row = rows.scalars().first()
|
||
return row.model_dump(mode="json") if row else None
|
||
|
||
async def get_attacker_transcripts(self, uuid: str) -> list[dict[str, Any]]:
|
||
"""Return `session_recorded` logs for the attacker identified by UUID.
|
||
|
||
Mirror of :meth:`get_attacker_artifacts` — sessions ride in the same
|
||
Log table with event_type=session_recorded; the ingester decodes the
|
||
RFC 5424 SD fields (sid, service, decky, src_ip, duration_s, bytes,
|
||
truncated, shard_path) into the returned ``fields`` blob.
|
||
"""
|
||
async with self._session() as session:
|
||
ip_res = await session.execute(
|
||
select(Attacker.ip).where(Attacker.uuid == uuid)
|
||
)
|
||
ip = ip_res.scalar_one_or_none()
|
||
if not ip:
|
||
return []
|
||
rows = await session.execute(
|
||
select(Log)
|
||
.where(Log.attacker_ip == ip)
|
||
.where(Log.event_type == "session_recorded")
|
||
.order_by(desc(Log.timestamp))
|
||
.limit(200)
|
||
)
|
||
return [r.model_dump(mode="json") for r in rows.scalars().all()]
|
||
|
||
# ------------------------------------------------------------- swarm
|
||
|
||
# -------------------------------------------------------------- fleet
|
||
|
||
async def upsert_fleet_decky(self, data: dict[str, Any]) -> None:
|
||
payload: dict[str, Any] = {
|
||
**data,
|
||
"updated_at": datetime.now(timezone.utc),
|
||
}
|
||
payload.setdefault("host_uuid", LOCAL_HOST_SENTINEL)
|
||
if payload.get("host_uuid") is None:
|
||
payload["host_uuid"] = LOCAL_HOST_SENTINEL
|
||
if isinstance(payload.get("services"), list):
|
||
payload["services"] = orjson.dumps(payload["services"]).decode()
|
||
if isinstance(payload.get("decky_config"), dict):
|
||
payload["decky_config"] = orjson.dumps(payload["decky_config"]).decode()
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(FleetDecky).where(
|
||
FleetDecky.host_uuid == payload["host_uuid"],
|
||
FleetDecky.name == payload["name"],
|
||
)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
if existing:
|
||
for k, v in payload.items():
|
||
setattr(existing, k, v)
|
||
session.add(existing)
|
||
else:
|
||
session.add(FleetDecky(**payload))
|
||
await session.commit()
|
||
|
||
async def delete_fleet_decky(self, *, host_uuid: str, name: str) -> None:
|
||
async with self._session() as session:
|
||
await session.execute(
|
||
text(
|
||
"DELETE FROM fleet_deckies "
|
||
"WHERE host_uuid = :h AND name = :n"
|
||
),
|
||
{"h": host_uuid, "n": name},
|
||
)
|
||
await session.commit()
|
||
|
||
async def list_fleet_deckies(
|
||
self, *, host_uuid: Optional[str] = None,
|
||
) -> list[dict[str, Any]]:
|
||
stmt = select(FleetDecky).order_by(asc(FleetDecky.name))
|
||
if host_uuid:
|
||
stmt = stmt.where(FleetDecky.host_uuid == host_uuid)
|
||
async with self._session() as session:
|
||
result = await session.execute(stmt)
|
||
return [
|
||
self._deserialize_json_fields(
|
||
r.model_dump(mode="json"), ("services", "decky_config")
|
||
)
|
||
for r in result.scalars().all()
|
||
]
|
||
|
||
async def list_running_fleet_deckies(self) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(FleetDecky).where(FleetDecky.state == "running")
|
||
)
|
||
return [
|
||
self._deserialize_json_fields(
|
||
r.model_dump(mode="json"), ("services", "decky_config")
|
||
)
|
||
for r in result.scalars().all()
|
||
]
|
||
|
||
async def update_fleet_decky_state(
|
||
self, *, host_uuid: str, name: str, state: str,
|
||
last_error: Optional[str] = None,
|
||
) -> None:
|
||
now = datetime.now(timezone.utc)
|
||
values: dict[str, Any] = {
|
||
"state": state,
|
||
"updated_at": now,
|
||
"last_seen": now,
|
||
}
|
||
if last_error is not None:
|
||
values["last_error"] = last_error
|
||
async with self._session() as session:
|
||
await session.execute(
|
||
update(FleetDecky)
|
||
.where(
|
||
FleetDecky.host_uuid == host_uuid,
|
||
FleetDecky.name == name,
|
||
)
|
||
.values(**values)
|
||
)
|
||
await session.commit()
|
||
|
||
async def list_running_deckies(self) -> list[dict[str, Any]]:
|
||
out: list[dict[str, Any]] = []
|
||
# MazeNET — already shaped {uuid, name, ip, services}. We carry
|
||
# topology_id through so consumers (emailgen scheduler) can walk
|
||
# back to the parent topology row without a second round-trip;
|
||
# fleet/shard rows never have one, hence Optional.
|
||
for d in await self.list_running_topology_deckies():
|
||
out.append({
|
||
"uuid": d.get("uuid"),
|
||
"name": d.get("name"),
|
||
"ip": d.get("ip"),
|
||
"services": d.get("services") or [],
|
||
"topology_id": d.get("topology_id"),
|
||
"source": "topology",
|
||
})
|
||
# Fleet — column is `decky_ip`, PK is composite (host_uuid, name)
|
||
for d in await self.list_running_fleet_deckies():
|
||
out.append({
|
||
"uuid": f"{d.get('host_uuid')}:{d.get('name')}",
|
||
"name": d.get("name"),
|
||
"ip": d.get("decky_ip"),
|
||
"services": d.get("services") or [],
|
||
"source": "fleet",
|
||
})
|
||
# SWARM — DeckyShard rows in 'running' state on enrolled workers.
|
||
async with self._session() as session:
|
||
shard_rows = await session.execute(
|
||
select(DeckyShard).where(DeckyShard.state == "running")
|
||
)
|
||
for s in shard_rows.scalars().all():
|
||
d = self._deserialize_json_fields(
|
||
s.model_dump(mode="json"), ("services", "decky_config")
|
||
)
|
||
out.append({
|
||
"uuid": f"{d.get('host_uuid')}:{d.get('decky_name')}",
|
||
"name": d.get("decky_name"),
|
||
"ip": d.get("decky_ip"),
|
||
"services": d.get("services") or [],
|
||
"source": "shard",
|
||
})
|
||
return out
|
||
|
||
# ------------------------------------------------------------ mazenet
|
||
|
||
@staticmethod
|
||
def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
|
||
out = dict(data)
|
||
for k in keys:
|
||
v = out.get(k)
|
||
if v is not None and not isinstance(v, str):
|
||
out[k] = orjson.dumps(v).decode()
|
||
return out
|
||
|
||
@staticmethod
|
||
def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
|
||
for k in keys:
|
||
v = d.get(k)
|
||
if isinstance(v, str):
|
||
try:
|
||
d[k] = json.loads(v)
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
return d
|
||
|
||
async def create_topology(self, data: dict[str, Any]) -> str:
|
||
payload = self._serialize_json_fields(data, ("config_snapshot",))
|
||
async with self._session() as session:
|
||
row = Topology(**payload)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.id
|
||
|
||
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
if not row:
|
||
return None
|
||
d = row.model_dump(mode="json")
|
||
return self._deserialize_json_fields(d, ("config_snapshot",))
|
||
|
||
async def list_topologies(
|
||
self,
|
||
status: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
offset: Optional[int] = None,
|
||
) -> list[dict[str, Any]]:
|
||
statement = select(Topology).order_by(desc(Topology.created_at))
|
||
if status:
|
||
statement = statement.where(Topology.status == status)
|
||
if offset is not None:
|
||
statement = statement.offset(offset)
|
||
if limit is not None:
|
||
statement = statement.limit(limit)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return [
|
||
self._deserialize_json_fields(
|
||
r.model_dump(mode="json"), ("config_snapshot",)
|
||
)
|
||
for r in result.scalars().all()
|
||
]
|
||
|
||
async def count_topologies(self, status: Optional[str] = None) -> int:
|
||
from sqlalchemy import func
|
||
statement = select(func.count(Topology.id))
|
||
if status:
|
||
statement = statement.where(Topology.status == status)
|
||
async with self._session() as session:
|
||
result = await session.execute(statement)
|
||
return int(result.scalar_one() or 0)
|
||
|
||
async def update_topology_status(
|
||
self,
|
||
topology_id: str,
|
||
new_status: str,
|
||
reason: Optional[str] = None,
|
||
) -> None:
|
||
"""Update topology.status and append a TopologyStatusEvent atomically.
|
||
|
||
Transition legality is enforced in ``decnet.topology.status``; this
|
||
method trusts the caller.
|
||
"""
|
||
now = datetime.now(timezone.utc)
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
topo = result.scalar_one_or_none()
|
||
if topo is None:
|
||
return
|
||
from_status = topo.status
|
||
topo.status = new_status
|
||
topo.status_changed_at = now
|
||
session.add(topo)
|
||
session.add(
|
||
TopologyStatusEvent(
|
||
topology_id=topology_id,
|
||
from_status=from_status,
|
||
to_status=new_status,
|
||
at=now,
|
||
reason=reason,
|
||
)
|
||
)
|
||
await session.commit()
|
||
|
||
async def set_topology_resync(self, topology_id: str, value: bool) -> None:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
topo = result.scalar_one_or_none()
|
||
if topo is None:
|
||
return
|
||
topo.needs_resync = bool(value)
|
||
session.add(topo)
|
||
await session.commit()
|
||
|
||
async def set_topology_email_personas(
|
||
self, topology_id: str, personas_json: str,
|
||
) -> bool:
|
||
"""Replace ``Topology.email_personas`` with the supplied JSON.
|
||
|
||
The string is stored as-is; validation/parsing is the caller's
|
||
job (and is repeated by the emailgen scheduler each tick anyway).
|
||
Returns True if a row was updated.
|
||
"""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
topo = result.scalar_one_or_none()
|
||
if topo is None:
|
||
return False
|
||
topo.email_personas = personas_json
|
||
session.add(topo)
|
||
await session.commit()
|
||
return True
|
||
|
||
async def list_topologies_needing_resync(self) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.needs_resync == True) # noqa: E712
|
||
)
|
||
return [
|
||
self._deserialize_json_fields(
|
||
r.model_dump(mode="json"), ("config_snapshot",)
|
||
)
|
||
for r in result.scalars().all()
|
||
]
|
||
|
||
async def delete_topology_cascade(self, topology_id: str) -> bool:
|
||
"""Delete topology and all children. No portable ON DELETE CASCADE."""
|
||
async with self._session() as session:
|
||
params = {"t": topology_id}
|
||
await session.execute(
|
||
text("DELETE FROM topology_status_events WHERE topology_id = :t"),
|
||
params,
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM topology_edges WHERE topology_id = :t"),
|
||
params,
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM topology_deckies WHERE topology_id = :t"),
|
||
params,
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM lans WHERE topology_id = :t"),
|
||
params,
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM topology_mutations WHERE topology_id = :t"),
|
||
params,
|
||
)
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
topo = result.scalar_one_or_none()
|
||
if not topo:
|
||
await session.commit()
|
||
return False
|
||
await session.delete(topo)
|
||
await session.commit()
|
||
return True
|
||
|
||
async def _assert_pending(self, session, topology_id: str) -> None:
|
||
"""Pre-deploy edits are pending-only. Raises TopologyNotEditable."""
|
||
from decnet.topology.status import TopologyNotEditable, TopologyStatus
|
||
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
topo = result.scalar_one_or_none()
|
||
if topo is None:
|
||
raise ValueError(f"topology {topology_id!r} not found")
|
||
if topo.status != TopologyStatus.PENDING:
|
||
raise TopologyNotEditable(
|
||
status=topo.status,
|
||
reason="free-form edits are pending-only; use the "
|
||
"mutator (topology_mutations) after deploy",
|
||
)
|
||
|
||
async def _check_and_bump_version(
|
||
self,
|
||
session,
|
||
topology_id: str,
|
||
expected_version: Optional[int],
|
||
) -> None:
|
||
"""Optimistic-concurrency guard used by child-row mutators.
|
||
|
||
If ``expected_version`` is None, no check happens (backward-compat
|
||
for internal callers that don't need concurrency protection).
|
||
|
||
If supplied, loads the Topology row in the same session,
|
||
compares ``version == expected_version``, raises VersionConflict
|
||
on mismatch, otherwise bumps ``version += 1``. The caller must
|
||
commit the enclosing session.
|
||
"""
|
||
from decnet.topology.status import VersionConflict
|
||
|
||
if expected_version is None:
|
||
return
|
||
result = await session.execute(
|
||
select(Topology).where(Topology.id == topology_id)
|
||
)
|
||
topo = result.scalar_one_or_none()
|
||
if topo is None:
|
||
raise ValueError(f"topology {topology_id!r} not found")
|
||
if topo.version != expected_version:
|
||
raise VersionConflict(
|
||
current=topo.version, expected=expected_version
|
||
)
|
||
topo.version = topo.version + 1
|
||
session.add(topo)
|
||
|
||
async def add_lan(
|
||
self,
|
||
data: dict[str, Any],
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> str:
|
||
async with self._session() as session:
|
||
await self._check_and_bump_version(
|
||
session, data["topology_id"], expected_version
|
||
)
|
||
row = LAN(**data)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.id
|
||
|
||
async def update_lan(
|
||
self,
|
||
lan_id: str,
|
||
fields: dict[str, Any],
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
enforce_pending: bool = False,
|
||
) -> None:
|
||
if not fields:
|
||
return
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(LAN).where(LAN.id == lan_id)
|
||
)
|
||
lan = result.scalar_one_or_none()
|
||
if lan is None:
|
||
raise ValueError(f"lan {lan_id!r} not found")
|
||
if enforce_pending:
|
||
await self._assert_pending(session, lan.topology_id)
|
||
if expected_version is not None:
|
||
await self._check_and_bump_version(
|
||
session, lan.topology_id, expected_version
|
||
)
|
||
await session.execute(
|
||
update(LAN).where(LAN.id == lan_id).values(**fields)
|
||
)
|
||
await session.commit()
|
||
|
||
async def delete_lan(
|
||
self,
|
||
lan_id: str,
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> None:
|
||
"""Cascade-delete a LAN from a pending topology.
|
||
|
||
Rejects if any decky declares this LAN as its home (i.e. has a
|
||
non-bridge edge to it — the only LAN that decky lives in). The
|
||
caller must delete or reassign the home-deckies first.
|
||
"""
|
||
from decnet.topology.status import TopologyNotEditable # noqa: F401
|
||
|
||
async with self._session() as session:
|
||
result = await session.execute(select(LAN).where(LAN.id == lan_id))
|
||
lan = result.scalar_one_or_none()
|
||
if lan is None:
|
||
return
|
||
await self._assert_pending(session, lan.topology_id)
|
||
|
||
# Home-decky check: any decky whose only edge lands here?
|
||
edges_result = await session.execute(
|
||
select(TopologyEdge).where(TopologyEdge.lan_id == lan_id)
|
||
)
|
||
edges_here = edges_result.scalars().all()
|
||
decky_uuids_on_this_lan = {e.decky_uuid for e in edges_here}
|
||
for decky_uuid in decky_uuids_on_this_lan:
|
||
other = await session.execute(
|
||
select(TopologyEdge).where(
|
||
TopologyEdge.decky_uuid == decky_uuid,
|
||
TopologyEdge.lan_id != lan_id,
|
||
)
|
||
)
|
||
if other.scalars().first() is None:
|
||
raise ValueError(
|
||
f"cannot delete LAN {lan.name!r}: decky "
|
||
f"{decky_uuid} has no other LAN (would be orphaned)"
|
||
)
|
||
|
||
if expected_version is not None:
|
||
await self._check_and_bump_version(
|
||
session, lan.topology_id, expected_version
|
||
)
|
||
# Cascade edges → LAN.
|
||
await session.execute(
|
||
text("DELETE FROM topology_edges WHERE lan_id = :l"),
|
||
{"l": lan_id},
|
||
)
|
||
await session.execute(text("DELETE FROM lans WHERE id = :l"), {"l": lan_id})
|
||
await session.commit()
|
||
|
||
async def list_lans_for_topology(
|
||
self, topology_id: str
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name))
|
||
)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def add_topology_decky(
|
||
self,
|
||
data: dict[str, Any],
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> str:
|
||
payload = self._serialize_json_fields(data, ("services", "decky_config"))
|
||
async with self._session() as session:
|
||
await self._check_and_bump_version(
|
||
session, data["topology_id"], expected_version
|
||
)
|
||
row = TopologyDecky(**payload)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.uuid
|
||
|
||
async def update_topology_decky(
|
||
self,
|
||
decky_uuid: str,
|
||
fields: dict[str, Any],
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
enforce_pending: bool = False,
|
||
) -> None:
|
||
if not fields:
|
||
return
|
||
payload = self._serialize_json_fields(fields, ("services", "decky_config"))
|
||
payload.setdefault("updated_at", datetime.now(timezone.utc))
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
||
)
|
||
d = result.scalar_one_or_none()
|
||
if d is None:
|
||
raise ValueError(f"decky {decky_uuid!r} not found")
|
||
if enforce_pending:
|
||
await self._assert_pending(session, d.topology_id)
|
||
if expected_version is not None:
|
||
await self._check_and_bump_version(
|
||
session, d.topology_id, expected_version
|
||
)
|
||
await session.execute(
|
||
update(TopologyDecky)
|
||
.where(TopologyDecky.uuid == decky_uuid)
|
||
.values(**payload)
|
||
)
|
||
await session.commit()
|
||
|
||
async def delete_topology_decky(
|
||
self,
|
||
decky_uuid: str,
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> None:
|
||
"""Cascade-delete a decky + all its edges from a pending topology."""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
||
)
|
||
d = result.scalar_one_or_none()
|
||
if d is None:
|
||
return
|
||
await self._assert_pending(session, d.topology_id)
|
||
if expected_version is not None:
|
||
await self._check_and_bump_version(
|
||
session, d.topology_id, expected_version
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM topology_edges WHERE decky_uuid = :u"),
|
||
{"u": decky_uuid},
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM topology_deckies WHERE uuid = :u"),
|
||
{"u": decky_uuid},
|
||
)
|
||
await session.commit()
|
||
|
||
async def list_topology_deckies(
|
||
self, topology_id: str
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyDecky)
|
||
.where(TopologyDecky.topology_id == topology_id)
|
||
.order_by(asc(TopologyDecky.name))
|
||
)
|
||
return [
|
||
self._deserialize_json_fields(
|
||
r.model_dump(mode="json"), ("services", "decky_config")
|
||
)
|
||
for r in result.scalars().all()
|
||
]
|
||
|
||
async def add_topology_edge(
|
||
self,
|
||
data: dict[str, Any],
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> str:
|
||
async with self._session() as session:
|
||
await self._check_and_bump_version(
|
||
session, data["topology_id"], expected_version
|
||
)
|
||
row = TopologyEdge(**data)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.id
|
||
|
||
async def delete_topology_edge(
|
||
self,
|
||
edge_id: str,
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> None:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyEdge).where(TopologyEdge.id == edge_id)
|
||
)
|
||
edge = result.scalar_one_or_none()
|
||
if edge is None:
|
||
return
|
||
await self._assert_pending(session, edge.topology_id)
|
||
if expected_version is not None:
|
||
await self._check_and_bump_version(
|
||
session, edge.topology_id, expected_version
|
||
)
|
||
await session.execute(
|
||
text("DELETE FROM topology_edges WHERE id = :e"),
|
||
{"e": edge_id},
|
||
)
|
||
await session.commit()
|
||
|
||
async def list_topology_edges(
|
||
self, topology_id: str
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyEdge).where(TopologyEdge.topology_id == topology_id)
|
||
)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def list_topology_status_events(
|
||
self, topology_id: str, limit: int = 100
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyStatusEvent)
|
||
.where(TopologyStatusEvent.topology_id == topology_id)
|
||
.order_by(desc(TopologyStatusEvent.at))
|
||
.limit(limit)
|
||
)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
# ---------------- topology_mutations (live reconciler queue) ----------------
|
||
|
||
async def enqueue_topology_mutation(
|
||
self,
|
||
topology_id: str,
|
||
op: str,
|
||
payload: dict[str, Any],
|
||
*,
|
||
expected_version: Optional[int] = None,
|
||
) -> str:
|
||
"""Append a pending mutation row and bump the topology version.
|
||
|
||
Intended for use while the topology is ``active|degraded``; the
|
||
reconciler picks these rows up on its next tick.
|
||
"""
|
||
async with self._session() as session:
|
||
await self._check_and_bump_version(
|
||
session, topology_id, expected_version
|
||
)
|
||
row = TopologyMutation(
|
||
topology_id=topology_id,
|
||
op=op,
|
||
payload=orjson.dumps(payload).decode(),
|
||
)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.id
|
||
|
||
async def claim_next_mutation(
|
||
self, topology_id: str
|
||
) -> Optional[dict[str, Any]]:
|
||
"""Atomically claim the oldest pending mutation for ``topology_id``.
|
||
|
||
Correctness-critical: this is ONE SQL statement. Splitting it
|
||
into SELECT-then-UPDATE would let two racing watch-loops both
|
||
see the same ``pending`` row and both transition it to
|
||
``applying`` — double-executing the op. With the single
|
||
``UPDATE ... WHERE id = (SELECT ... LIMIT 1) AND state='pending'``
|
||
pattern the loser's UPDATE matches zero rows and returns
|
||
``None`` — that is the expected, non-error outcome under
|
||
contention.
|
||
"""
|
||
async with self._session() as session:
|
||
now = datetime.now(timezone.utc).isoformat()
|
||
# Single-statement atomic claim. The inner SELECT picks the
|
||
# oldest pending row; the outer UPDATE re-checks state so a
|
||
# second racer that also saw that id finds state='applying'
|
||
# and matches zero rows.
|
||
# MySQL forbids referencing the UPDATE target inside a
|
||
# subquery (ERROR 1093). Wrapping the inner SELECT in a
|
||
# derived table forces materialisation and sidesteps the
|
||
# rule. SQLite accepts both forms, so this stays portable.
|
||
sql = text(
|
||
"""
|
||
UPDATE topology_mutations
|
||
SET state = 'applying'
|
||
WHERE id = (
|
||
SELECT id FROM (
|
||
SELECT id FROM topology_mutations
|
||
WHERE topology_id = :t AND state = 'pending'
|
||
ORDER BY requested_at ASC
|
||
LIMIT 1
|
||
) AS _next
|
||
)
|
||
AND state = 'pending'
|
||
"""
|
||
)
|
||
result = await session.execute(sql, {"t": topology_id})
|
||
if result.rowcount == 0:
|
||
await session.commit()
|
||
return None
|
||
# Re-read the row we just claimed. The post-UPDATE SELECT is
|
||
# safe: no racer can now transition an ``applying`` row back
|
||
# to ``pending``.
|
||
sel = await session.execute(
|
||
select(TopologyMutation)
|
||
.where(TopologyMutation.topology_id == topology_id)
|
||
.where(TopologyMutation.state == "applying")
|
||
.order_by(asc(TopologyMutation.requested_at))
|
||
.limit(1)
|
||
)
|
||
row = sel.scalar_one_or_none()
|
||
await session.commit()
|
||
_ = now
|
||
if row is None:
|
||
return None
|
||
return row.model_dump(mode="json")
|
||
|
||
async def mark_mutation_applied(self, mutation_id: str) -> None:
|
||
async with self._session() as session:
|
||
await session.execute(
|
||
text(
|
||
"UPDATE topology_mutations "
|
||
"SET state = 'applied', applied_at = :at "
|
||
"WHERE id = :i"
|
||
),
|
||
{
|
||
"at": datetime.now(timezone.utc).isoformat(),
|
||
"i": mutation_id,
|
||
},
|
||
)
|
||
await session.commit()
|
||
|
||
async def mark_mutation_failed(
|
||
self, mutation_id: str, reason: str
|
||
) -> None:
|
||
async with self._session() as session:
|
||
await session.execute(
|
||
text(
|
||
"UPDATE topology_mutations "
|
||
"SET state = 'failed', applied_at = :at, reason = :r "
|
||
"WHERE id = :i"
|
||
),
|
||
{
|
||
"at": datetime.now(timezone.utc).isoformat(),
|
||
"r": reason,
|
||
"i": mutation_id,
|
||
},
|
||
)
|
||
await session.commit()
|
||
|
||
async def list_topology_mutations(
|
||
self,
|
||
topology_id: str,
|
||
state: Optional[str] = None,
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = (
|
||
select(TopologyMutation)
|
||
.where(TopologyMutation.topology_id == topology_id)
|
||
.order_by(desc(TopologyMutation.requested_at))
|
||
)
|
||
if state is not None:
|
||
stmt = stmt.where(TopologyMutation.state == state)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def has_pending_topology_mutation(self) -> bool:
|
||
"""Cheap watch-loop guard: any pending mutation on a live topology?
|
||
|
||
Uses the ``ix_topology_mutations_state_topology`` composite index
|
||
to keep the join cheap at scale. Returns False as soon as the
|
||
reconciler path should be skipped.
|
||
"""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
text(
|
||
"SELECT 1 FROM topology_mutations "
|
||
"WHERE state = 'pending' "
|
||
"AND topology_id IN ("
|
||
" SELECT id FROM topologies "
|
||
" WHERE status IN ('active', 'degraded')"
|
||
") LIMIT 1"
|
||
)
|
||
)
|
||
return result.first() is not None
|
||
|
||
async def list_live_topology_ids(self) -> list[str]:
|
||
"""Return ids of topologies currently in ``active|degraded``."""
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(Topology.id).where(
|
||
Topology.status.in_(["active", "degraded"])
|
||
)
|
||
)
|
||
return [r for r in result.scalars().all()]
|
||
|
||
# --------------------------------------------------------- webhooks
|
||
|
||
async def create_webhook_subscription(self, data: dict[str, Any]) -> None:
|
||
async with self._session() as session:
|
||
session.add(WebhookSubscription(**data))
|
||
await session.commit()
|
||
|
||
async def get_webhook_subscription(
|
||
self, uuid: str
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(WebhookSubscription).where(WebhookSubscription.uuid == uuid)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
return row.model_dump() if row else None
|
||
|
||
async def get_webhook_subscription_by_name(
|
||
self, name: str
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(WebhookSubscription).where(WebhookSubscription.name == name)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
return row.model_dump() if row else None
|
||
|
||
async def list_webhook_subscriptions(
|
||
self, enabled_only: bool = False
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(WebhookSubscription)
|
||
if enabled_only:
|
||
stmt = stmt.where(WebhookSubscription.enabled.is_(True))
|
||
stmt = stmt.order_by(WebhookSubscription.created_at)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump() for r in result.scalars().all()]
|
||
|
||
async def update_webhook_subscription(
|
||
self, uuid: str, patch: dict[str, Any]
|
||
) -> bool:
|
||
if not patch:
|
||
return True
|
||
patch = {**patch, "updated_at": datetime.now(timezone.utc)}
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
update(WebhookSubscription)
|
||
.where(WebhookSubscription.uuid == uuid)
|
||
.values(**patch)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
|
||
async def delete_webhook_subscription(self, uuid: str) -> bool:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(WebhookSubscription).where(WebhookSubscription.uuid == uuid)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
if not row:
|
||
return False
|
||
await session.delete(row)
|
||
await session.commit()
|
||
return True
|
||
|
||
async def record_webhook_success(
|
||
self, uuid: str, ts: datetime
|
||
) -> None:
|
||
async with self._session() as session:
|
||
await session.execute(
|
||
update(WebhookSubscription)
|
||
.where(WebhookSubscription.uuid == uuid)
|
||
.values(
|
||
consecutive_failures=0,
|
||
last_success_at=ts,
|
||
last_error=None,
|
||
updated_at=ts,
|
||
)
|
||
)
|
||
await session.commit()
|
||
|
||
async def record_webhook_failure(
|
||
self, uuid: str, ts: datetime, error: str
|
||
) -> int:
|
||
async with self._session() as session:
|
||
# Read current failure count, bump, write. Small race window on
|
||
# concurrent deliveries to the same subscription is acceptable —
|
||
# the counter informs the circuit-breaker heuristic, not a
|
||
# correctness invariant.
|
||
result = await session.execute(
|
||
select(WebhookSubscription.consecutive_failures).where(
|
||
WebhookSubscription.uuid == uuid
|
||
)
|
||
)
|
||
current = result.scalar_one_or_none() or 0
|
||
new_count = current + 1
|
||
await session.execute(
|
||
update(WebhookSubscription)
|
||
.where(WebhookSubscription.uuid == uuid)
|
||
.values(
|
||
consecutive_failures=new_count,
|
||
last_failure_at=ts,
|
||
last_error=error[:512] if error else None,
|
||
updated_at=ts,
|
||
)
|
||
)
|
||
await session.commit()
|
||
return new_count
|
||
|
||
async def trip_webhook_circuit(self, uuid: str, ts: datetime) -> None:
|
||
async with self._session() as session:
|
||
await session.execute(
|
||
update(WebhookSubscription)
|
||
.where(WebhookSubscription.uuid == uuid)
|
||
.values(
|
||
enabled=False,
|
||
auto_disabled_at=ts,
|
||
updated_at=ts,
|
||
)
|
||
)
|
||
await session.commit()
|
||
|
||
# ---------------------------------------------------------- canary
|
||
|
||
async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]:
|
||
sha = data.get("sha256")
|
||
if not sha:
|
||
raise ValueError("upsert_canary_blob: sha256 is required")
|
||
async with self._session() as session:
|
||
existing = await session.execute(
|
||
select(CanaryBlob).where(CanaryBlob.sha256 == sha)
|
||
)
|
||
row = existing.scalar_one_or_none()
|
||
if row:
|
||
return row.model_dump(mode="json")
|
||
row = CanaryBlob(**data)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.model_dump(mode="json")
|
||
|
||
async def get_canary_blob(self, uuid: str) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(CanaryBlob).where(CanaryBlob.uuid == uuid)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
return row.model_dump(mode="json") if row else None
|
||
|
||
async def get_canary_blob_by_sha256(
|
||
self, sha256: str
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(CanaryBlob).where(CanaryBlob.sha256 == sha256)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
return row.model_dump(mode="json") if row else None
|
||
|
||
async def list_canary_blobs(self) -> list[dict[str, Any]]:
|
||
# One round-trip: outer-join blobs -> tokens, group by blob, count
|
||
# live (non-revoked) references. Revoked tokens still occupy the
|
||
# blob conceptually until garbage-collected, so we count them too;
|
||
# the operator deletes blobs explicitly via the API.
|
||
async with self._session() as session:
|
||
stmt = (
|
||
select(CanaryBlob, func.count(CanaryToken.uuid))
|
||
.join(
|
||
CanaryToken,
|
||
CanaryToken.blob_uuid == CanaryBlob.uuid,
|
||
isouter=True,
|
||
)
|
||
.group_by(CanaryBlob.uuid)
|
||
.order_by(desc(CanaryBlob.uploaded_at))
|
||
)
|
||
result = await session.execute(stmt)
|
||
out: list[dict[str, Any]] = []
|
||
for blob, count in result.all():
|
||
d = blob.model_dump(mode="json")
|
||
d["token_count"] = int(count or 0)
|
||
out.append(d)
|
||
return out
|
||
|
||
async def delete_canary_blob(self, uuid: str) -> bool:
|
||
async with self._session() as session:
|
||
ref = await session.execute(
|
||
select(func.count(CanaryToken.uuid)).where(
|
||
CanaryToken.blob_uuid == uuid
|
||
)
|
||
)
|
||
if (ref.scalar_one() or 0) > 0:
|
||
return False
|
||
result = await session.execute(
|
||
select(CanaryBlob).where(CanaryBlob.uuid == uuid)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
if not row:
|
||
return False
|
||
await session.delete(row)
|
||
await session.commit()
|
||
return True
|
||
|
||
async def create_canary_token(self, data: dict[str, Any]) -> None:
|
||
async with self._session() as session:
|
||
session.add(CanaryToken(**data))
|
||
await session.commit()
|
||
|
||
async def get_canary_token(self, uuid: str) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(CanaryToken).where(CanaryToken.uuid == uuid)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
return row.model_dump(mode="json") if row else None
|
||
|
||
async def get_canary_token_by_slug(
|
||
self, callback_token: str
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(CanaryToken).where(
|
||
CanaryToken.callback_token == callback_token
|
||
)
|
||
)
|
||
row = result.scalar_one_or_none()
|
||
return row.model_dump(mode="json") if row else None
|
||
|
||
async def list_canary_tokens(
|
||
self,
|
||
*,
|
||
decky_name: Optional[str] = None,
|
||
state: Optional[str] = None,
|
||
kind: Optional[str] = None,
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(CanaryToken)
|
||
if decky_name is not None:
|
||
stmt = stmt.where(CanaryToken.decky_name == decky_name)
|
||
if state is not None:
|
||
stmt = stmt.where(CanaryToken.state == state)
|
||
if kind is not None:
|
||
stmt = stmt.where(CanaryToken.kind == kind)
|
||
stmt = stmt.order_by(desc(CanaryToken.placed_at))
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def update_canary_token_state(
|
||
self,
|
||
uuid: str,
|
||
state: str,
|
||
last_error: Optional[str] = None,
|
||
) -> bool:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
update(CanaryToken)
|
||
.where(CanaryToken.uuid == uuid)
|
||
.values(state=state, last_error=last_error)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
|
||
async def record_canary_trigger(self, data: dict[str, Any]) -> str:
|
||
# Persist the trigger row + bump the token's counters in the same
|
||
# session so a subscriber that reads the token row right after
|
||
# receiving the bus event sees the updated count.
|
||
headers = data.get("raw_headers")
|
||
if isinstance(headers, dict):
|
||
data = {**data, "raw_headers": json.dumps(headers)}
|
||
async with self._session() as session:
|
||
row = CanaryTrigger(**data)
|
||
session.add(row)
|
||
ts = data.get("occurred_at") or datetime.now(timezone.utc)
|
||
await session.execute(
|
||
update(CanaryToken)
|
||
.where(CanaryToken.uuid == row.token_uuid)
|
||
.values(
|
||
last_triggered_at=ts,
|
||
trigger_count=CanaryToken.trigger_count + 1,
|
||
)
|
||
)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.uuid
|
||
|
||
async def list_canary_triggers(
|
||
self, token_uuid: str, *, limit: int = 100, offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = (
|
||
select(CanaryTrigger)
|
||
.where(CanaryTrigger.token_uuid == token_uuid)
|
||
.order_by(desc(CanaryTrigger.occurred_at))
|
||
.limit(limit)
|
||
.offset(offset)
|
||
)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def attribute_canary_trigger(
|
||
self, trigger_uuid: str, attacker_id: str,
|
||
) -> bool:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
update(CanaryTrigger)
|
||
.where(CanaryTrigger.uuid == trigger_uuid)
|
||
.values(attacker_id=attacker_id)
|
||
)
|
||
await session.commit()
|
||
return result.rowcount > 0
|
||
|
||
# ---------------------------------------------------------- orchestrator
|
||
|
||
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
result = await session.execute(
|
||
select(TopologyDecky).where(TopologyDecky.state == "running")
|
||
)
|
||
return [
|
||
self._deserialize_json_fields(
|
||
r.model_dump(mode="json"), ("services", "decky_config")
|
||
)
|
||
for r in result.scalars().all()
|
||
]
|
||
|
||
async def record_orchestrator_event(self, data: dict[str, Any]) -> str:
|
||
payload = data.get("payload")
|
||
if isinstance(payload, (dict, list)):
|
||
data = {**data, "payload": json.dumps(payload)}
|
||
async with self._session() as session:
|
||
row = OrchestratorEvent(**data)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.uuid
|
||
|
||
async def list_orchestrator_events(
|
||
self,
|
||
limit: int = 100,
|
||
offset: int = 0,
|
||
*,
|
||
kind: Optional[str] = None,
|
||
since_ts: Optional[datetime] = None,
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(OrchestratorEvent)
|
||
if kind is not None:
|
||
stmt = stmt.where(OrchestratorEvent.kind == kind)
|
||
if since_ts is not None:
|
||
stmt = stmt.where(OrchestratorEvent.ts >= since_ts)
|
||
stmt = (
|
||
stmt.order_by(desc(OrchestratorEvent.ts))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def count_orchestrator_events(
|
||
self, *, kind: Optional[str] = None,
|
||
) -> int:
|
||
stmt = select(func.count()).select_from(OrchestratorEvent)
|
||
if kind is not None:
|
||
stmt = stmt.where(OrchestratorEvent.kind == kind)
|
||
async with self._session() as session:
|
||
result = await session.execute(stmt)
|
||
return result.scalar() or 0
|
||
|
||
async def prune_orchestrator_events(self, per_dst_cap: int = 10000) -> int:
|
||
"""Trim per-dst rows to *per_dst_cap*, oldest-first. Returns deleted count."""
|
||
deleted = 0
|
||
async with self._session() as session:
|
||
dst_rows = await session.execute(
|
||
select(OrchestratorEvent.dst_decky_uuid).distinct()
|
||
)
|
||
for (dst,) in dst_rows.all():
|
||
keep = await session.execute(
|
||
select(OrchestratorEvent.uuid)
|
||
.where(OrchestratorEvent.dst_decky_uuid == dst)
|
||
.order_by(desc(OrchestratorEvent.ts))
|
||
.limit(per_dst_cap)
|
||
)
|
||
keep_uuids = [u for (u,) in keep.all()]
|
||
if not keep_uuids:
|
||
continue
|
||
from sqlalchemy import delete as _delete
|
||
stmt = _delete(OrchestratorEvent).where(
|
||
OrchestratorEvent.dst_decky_uuid == dst,
|
||
OrchestratorEvent.uuid.notin_(keep_uuids),
|
||
)
|
||
res = await session.execute(stmt)
|
||
deleted += res.rowcount or 0
|
||
await session.commit()
|
||
return deleted
|
||
|
||
# ---------------------------------------------------------- emailgen
|
||
|
||
async def record_orchestrator_email(self, data: dict[str, Any]) -> str:
|
||
payload = data.get("payload")
|
||
if isinstance(payload, (dict, list)):
|
||
data = {**data, "payload": json.dumps(payload)}
|
||
async with self._session() as session:
|
||
row = OrchestratorEmail(**data)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.uuid
|
||
|
||
async def list_orchestrator_emails(
|
||
self,
|
||
limit: int = 100,
|
||
offset: int = 0,
|
||
*,
|
||
mail_decky_uuid: Optional[str] = None,
|
||
thread_id: Optional[str] = None,
|
||
since_ts: Optional[datetime] = None,
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(OrchestratorEmail)
|
||
if mail_decky_uuid is not None:
|
||
stmt = stmt.where(
|
||
OrchestratorEmail.mail_decky_uuid == mail_decky_uuid
|
||
)
|
||
if thread_id is not None:
|
||
stmt = stmt.where(OrchestratorEmail.thread_id == thread_id)
|
||
if since_ts is not None:
|
||
stmt = stmt.where(OrchestratorEmail.ts >= since_ts)
|
||
stmt = (
|
||
stmt.order_by(desc(OrchestratorEmail.ts))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def count_orchestrator_emails(
|
||
self,
|
||
*,
|
||
mail_decky_uuid: Optional[str] = None,
|
||
) -> int:
|
||
stmt = select(func.count()).select_from(OrchestratorEmail)
|
||
if mail_decky_uuid is not None:
|
||
stmt = stmt.where(OrchestratorEmail.mail_decky_uuid == mail_decky_uuid)
|
||
async with self._session() as session:
|
||
result = await session.execute(stmt)
|
||
return result.scalar() or 0
|
||
|
||
async def list_orchestrator_email_threads(
|
||
self,
|
||
mail_decky_uuid: str,
|
||
sender_email: str,
|
||
recipient_email: str,
|
||
*,
|
||
limit: int = 50,
|
||
) -> list[dict[str, Any]]:
|
||
# Most-recent row per (sender, recipient) pair under this mail decky.
|
||
# The scheduler only needs the latest message_id/subject/thread_id to
|
||
# construct a reply; older rows in the same thread aren't relevant
|
||
# for the "do we reply or start fresh" decision.
|
||
async with self._session() as session:
|
||
stmt = (
|
||
select(OrchestratorEmail)
|
||
.where(
|
||
OrchestratorEmail.mail_decky_uuid == mail_decky_uuid,
|
||
or_(
|
||
(OrchestratorEmail.sender_email == sender_email)
|
||
& (OrchestratorEmail.recipient_email == recipient_email),
|
||
(OrchestratorEmail.sender_email == recipient_email)
|
||
& (OrchestratorEmail.recipient_email == sender_email),
|
||
),
|
||
OrchestratorEmail.success.is_(True),
|
||
)
|
||
.order_by(desc(OrchestratorEmail.ts))
|
||
.limit(limit)
|
||
)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def prune_orchestrator_emails(self, per_decky_cap: int = 10000) -> int:
|
||
"""Trim per-mail-decky rows to *per_decky_cap*, oldest-first."""
|
||
deleted = 0
|
||
async with self._session() as session:
|
||
decky_rows = await session.execute(
|
||
select(OrchestratorEmail.mail_decky_uuid).distinct()
|
||
)
|
||
for (mail_uuid,) in decky_rows.all():
|
||
keep = await session.execute(
|
||
select(OrchestratorEmail.uuid)
|
||
.where(OrchestratorEmail.mail_decky_uuid == mail_uuid)
|
||
.order_by(desc(OrchestratorEmail.ts))
|
||
.limit(per_decky_cap)
|
||
)
|
||
keep_uuids = [u for (u,) in keep.all()]
|
||
if not keep_uuids:
|
||
continue
|
||
from sqlalchemy import delete as _delete
|
||
stmt = _delete(OrchestratorEmail).where(
|
||
OrchestratorEmail.mail_decky_uuid == mail_uuid,
|
||
OrchestratorEmail.uuid.notin_(keep_uuids),
|
||
)
|
||
res = await session.execute(stmt)
|
||
deleted += res.rowcount or 0
|
||
await session.commit()
|
||
return deleted
|
||
|
||
# ------------------------------------------------------------ realism
|
||
|
||
async def record_synthetic_file(self, data: dict[str, Any]) -> str:
|
||
from decnet.web.db.models.realism import SYNTHETIC_FILE_BODY_LIMIT
|
||
if "last_body" in data and data["last_body"] is not None:
|
||
data = {**data, "last_body": data["last_body"][:SYNTHETIC_FILE_BODY_LIMIT]}
|
||
async with self._session() as session:
|
||
row = SyntheticFile(**data)
|
||
session.add(row)
|
||
await session.commit()
|
||
await session.refresh(row)
|
||
return row.uuid
|
||
|
||
async def update_synthetic_file(
|
||
self, row_uuid: str, data: dict[str, Any],
|
||
) -> None:
|
||
from decnet.web.db.models.realism import SYNTHETIC_FILE_BODY_LIMIT
|
||
if "last_body" in data and data["last_body"] is not None:
|
||
data = {**data, "last_body": data["last_body"][:SYNTHETIC_FILE_BODY_LIMIT]}
|
||
async with self._session() as session:
|
||
stmt = (
|
||
update(SyntheticFile)
|
||
.where(SyntheticFile.uuid == row_uuid)
|
||
.values(**data)
|
||
)
|
||
await session.execute(stmt)
|
||
await session.commit()
|
||
|
||
async def list_synthetic_files(
|
||
self,
|
||
*,
|
||
decky_uuid: Optional[str] = None,
|
||
persona: Optional[str] = None,
|
||
content_class: Optional[str] = None,
|
||
limit: int = 100,
|
||
offset: int = 0,
|
||
) -> list[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(SyntheticFile)
|
||
if decky_uuid is not None:
|
||
stmt = stmt.where(SyntheticFile.decky_uuid == decky_uuid)
|
||
if persona is not None:
|
||
stmt = stmt.where(SyntheticFile.persona == persona)
|
||
if content_class is not None:
|
||
stmt = stmt.where(SyntheticFile.content_class == content_class)
|
||
stmt = (
|
||
stmt.order_by(desc(SyntheticFile.last_modified))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
result = await session.execute(stmt)
|
||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||
|
||
async def count_synthetic_files(
|
||
self,
|
||
*,
|
||
decky_uuid: Optional[str] = None,
|
||
persona: Optional[str] = None,
|
||
content_class: Optional[str] = None,
|
||
) -> int:
|
||
from sqlalchemy import func as _f
|
||
async with self._session() as session:
|
||
stmt = select(_f.count(SyntheticFile.uuid))
|
||
if decky_uuid is not None:
|
||
stmt = stmt.where(SyntheticFile.decky_uuid == decky_uuid)
|
||
if persona is not None:
|
||
stmt = stmt.where(SyntheticFile.persona == persona)
|
||
if content_class is not None:
|
||
stmt = stmt.where(SyntheticFile.content_class == content_class)
|
||
result = await session.execute(stmt)
|
||
return int(result.scalar() or 0)
|
||
|
||
async def get_synthetic_file(
|
||
self, uuid: str,
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(SyntheticFile).where(SyntheticFile.uuid == uuid)
|
||
result = await session.execute(stmt)
|
||
row = result.scalars().first()
|
||
if row is None:
|
||
return None
|
||
return row.model_dump(mode="json")
|
||
|
||
async def get_realism_config(
|
||
self, key: str,
|
||
) -> Optional[dict[str, Any]]:
|
||
async with self._session() as session:
|
||
stmt = select(RealismConfig).where(RealismConfig.key == key)
|
||
result = await session.execute(stmt)
|
||
row = result.scalars().first()
|
||
if row is None:
|
||
return None
|
||
return row.model_dump(mode="json")
|
||
|
||
async def set_realism_config(
|
||
self, key: str, value: str,
|
||
) -> None:
|
||
"""Upsert one realism_config row. Last-write-wins."""
|
||
async with self._session() as session:
|
||
stmt = select(RealismConfig).where(RealismConfig.key == key)
|
||
result = await session.execute(stmt)
|
||
row = result.scalars().first()
|
||
if row is None:
|
||
session.add(RealismConfig(
|
||
key=key, value=value,
|
||
updated_at=datetime.now(timezone.utc),
|
||
))
|
||
else:
|
||
upd = (
|
||
update(RealismConfig)
|
||
.where(RealismConfig.uuid == row.uuid)
|
||
.values(value=value, updated_at=datetime.now(timezone.utc))
|
||
)
|
||
await session.execute(upd)
|
||
await session.commit()
|
||
|
||
async def pick_random_synthetic_file_for_edit(
|
||
self,
|
||
decky_uuid: str,
|
||
*,
|
||
max_age_days: int = 30,
|
||
) -> Optional[dict[str, Any]]:
|
||
# Editable classes: anything whose body is plain text we can
|
||
# mutate idempotently. Binary canary artifacts are out — they
|
||
# rotate via a fresh plant, not an edit.
|
||
editable = (
|
||
"note", "todo", "draft", "script", "log_cron", "log_daemon",
|
||
)
|
||
from datetime import timedelta
|
||
cutoff = datetime.now(timezone.utc) - timedelta(days=max_age_days)
|
||
async with self._session() as session:
|
||
stmt = (
|
||
select(SyntheticFile)
|
||
.where(
|
||
SyntheticFile.decky_uuid == decky_uuid,
|
||
SyntheticFile.content_class.in_(editable), # type: ignore[attr-defined]
|
||
SyntheticFile.last_modified >= cutoff,
|
||
)
|
||
# SQLite + MySQL both support func.random() / RAND() —
|
||
# SQLAlchemy's func.random() compiles per-dialect.
|
||
.order_by(func.random())
|
||
.limit(1)
|
||
)
|
||
result = await session.execute(stmt)
|
||
row = result.scalars().first()
|
||
if row is None:
|
||
return None
|
||
return row.model_dump(mode="json")
|