diff --git a/decnet/web/db/sqlmodel_repo/__init__.py b/decnet/web/db/sqlmodel_repo/__init__.py index 162fb6f1..beac5e96 100644 --- a/decnet/web/db/sqlmodel_repo/__init__.py +++ b/decnet/web/db/sqlmodel_repo/__init__.py @@ -40,9 +40,6 @@ from decnet.web.db.models import ( Campaign, SessionProfile, SmtpTarget, - DeckyShard, - FleetDecky, - LOCAL_HOST_SENTINEL, Topology, LAN, TopologyDecky, @@ -63,11 +60,14 @@ from decnet.web.db.sqlmodel_repo._helpers import ( # noqa: F401 (re-exported f _safe_session, _detach_close, _cleanup_tasks, + _serialize_json_fields, + _deserialize_json_fields, ) from decnet.web.db.sqlmodel_repo.attacker_intel import AttackerIntelMixin from decnet.web.db.sqlmodel_repo.auth import AuthMixin from decnet.web.db.sqlmodel_repo.bounties import BountiesMixin from decnet.web.db.sqlmodel_repo.deckies import DeckiesMixin +from decnet.web.db.sqlmodel_repo.fleet import FleetMixin from decnet.web.db.sqlmodel_repo.swarm import SwarmMixin from decnet.web.db.sqlmodel_repo.webhooks import WebhooksMixin @@ -77,6 +77,7 @@ class SQLModelRepository( AuthMixin, BountiesMixin, DeckiesMixin, + FleetMixin, SwarmMixin, WebhooksMixin, BaseRepository, @@ -1597,165 +1598,10 @@ class SQLModelRepository( ) return [r.model_dump(mode="json") for r in rows.scalars().all()] - # ------------------------------------------------------------- swarm - - # -------------------------------------------------------------- fleet - - async def upsert_fleet_decky(self, data: dict[str, Any]) -> None: - payload: dict[str, Any] = { - **data, - "updated_at": datetime.now(timezone.utc), - } - payload.setdefault("host_uuid", LOCAL_HOST_SENTINEL) - if payload.get("host_uuid") is None: - payload["host_uuid"] = LOCAL_HOST_SENTINEL - if isinstance(payload.get("services"), list): - payload["services"] = orjson.dumps(payload["services"]).decode() - if isinstance(payload.get("decky_config"), dict): - payload["decky_config"] = orjson.dumps(payload["decky_config"]).decode() - async with self._session() as session: - result = await session.execute( - select(FleetDecky).where( - FleetDecky.host_uuid == payload["host_uuid"], - FleetDecky.name == payload["name"], - ) - ) - existing = result.scalar_one_or_none() - if existing: - for k, v in payload.items(): - setattr(existing, k, v) - session.add(existing) - else: - session.add(FleetDecky(**payload)) - await session.commit() - - async def delete_fleet_decky(self, *, host_uuid: str, name: str) -> None: - async with self._session() as session: - await session.execute( - text( - "DELETE FROM fleet_deckies " - "WHERE host_uuid = :h AND name = :n" - ), - {"h": host_uuid, "n": name}, - ) - await session.commit() - - async def list_fleet_deckies( - self, *, host_uuid: Optional[str] = None, - ) -> list[dict[str, Any]]: - stmt = select(FleetDecky).order_by(asc(FleetDecky.name)) - if host_uuid: - stmt = stmt.where(FleetDecky.host_uuid == host_uuid) - async with self._session() as session: - result = await session.execute(stmt) - return [ - self._deserialize_json_fields( - r.model_dump(mode="json"), ("services", "decky_config") - ) - for r in result.scalars().all() - ] - - async def list_running_fleet_deckies(self) -> list[dict[str, Any]]: - async with self._session() as session: - result = await session.execute( - select(FleetDecky).where(FleetDecky.state == "running") - ) - return [ - self._deserialize_json_fields( - r.model_dump(mode="json"), ("services", "decky_config") - ) - for r in result.scalars().all() - ] - - async def update_fleet_decky_state( - self, *, host_uuid: str, name: str, state: str, - last_error: Optional[str] = None, - ) -> None: - now = datetime.now(timezone.utc) - values: dict[str, Any] = { - "state": state, - "updated_at": now, - "last_seen": now, - } - if last_error is not None: - values["last_error"] = last_error - async with self._session() as session: - await session.execute( - update(FleetDecky) - .where( - FleetDecky.host_uuid == host_uuid, - FleetDecky.name == name, - ) - .values(**values) - ) - await session.commit() - - async def list_running_deckies(self) -> list[dict[str, Any]]: - out: list[dict[str, Any]] = [] - # MazeNET — already shaped {uuid, name, ip, services}. We carry - # topology_id through so consumers (emailgen scheduler) can walk - # back to the parent topology row without a second round-trip; - # fleet/shard rows never have one, hence Optional. - for d in await self.list_running_topology_deckies(): - out.append({ - "uuid": d.get("uuid"), - "name": d.get("name"), - "ip": d.get("ip"), - "services": d.get("services") or [], - "topology_id": d.get("topology_id"), - "source": "topology", - }) - # Fleet — column is `decky_ip`, PK is composite (host_uuid, name) - for d in await self.list_running_fleet_deckies(): - out.append({ - "uuid": f"{d.get('host_uuid')}:{d.get('name')}", - "name": d.get("name"), - "ip": d.get("decky_ip"), - "services": d.get("services") or [], - "source": "fleet", - }) - # SWARM — DeckyShard rows in 'running' state on enrolled workers. - async with self._session() as session: - shard_rows = await session.execute( - select(DeckyShard).where(DeckyShard.state == "running") - ) - for s in shard_rows.scalars().all(): - d = self._deserialize_json_fields( - s.model_dump(mode="json"), ("services", "decky_config") - ) - out.append({ - "uuid": f"{d.get('host_uuid')}:{d.get('decky_name')}", - "name": d.get("decky_name"), - "ip": d.get("decky_ip"), - "services": d.get("services") or [], - "source": "shard", - }) - return out - # ------------------------------------------------------------ mazenet - @staticmethod - def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]: - out = dict(data) - for k in keys: - v = out.get(k) - if v is not None and not isinstance(v, str): - out[k] = orjson.dumps(v).decode() - return out - - @staticmethod - def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]: - for k in keys: - v = d.get(k) - if isinstance(v, str): - try: - d[k] = json.loads(v) - except (json.JSONDecodeError, TypeError): - pass - return d - async def create_topology(self, data: dict[str, Any]) -> str: - payload = self._serialize_json_fields(data, ("config_snapshot",)) + payload = _serialize_json_fields(data, ("config_snapshot",)) async with self._session() as session: row = Topology(**payload) session.add(row) @@ -1772,7 +1618,7 @@ class SQLModelRepository( if not row: return None d = row.model_dump(mode="json") - return self._deserialize_json_fields(d, ("config_snapshot",)) + return _deserialize_json_fields(d, ("config_snapshot",)) async def list_topologies( self, @@ -1790,7 +1636,7 @@ class SQLModelRepository( async with self._session() as session: result = await session.execute(statement) return [ - self._deserialize_json_fields( + _deserialize_json_fields( r.model_dump(mode="json"), ("config_snapshot",) ) for r in result.scalars().all() @@ -1878,7 +1724,7 @@ class SQLModelRepository( select(Topology).where(Topology.needs_resync == True) # noqa: E712 ) return [ - self._deserialize_json_fields( + _deserialize_json_fields( r.model_dump(mode="json"), ("config_snapshot",) ) for r in result.scalars().all() @@ -2080,7 +1926,7 @@ class SQLModelRepository( *, expected_version: Optional[int] = None, ) -> str: - payload = self._serialize_json_fields(data, ("services", "decky_config")) + payload = _serialize_json_fields(data, ("services", "decky_config")) async with self._session() as session: await self._check_and_bump_version( session, data["topology_id"], expected_version @@ -2101,7 +1947,7 @@ class SQLModelRepository( ) -> None: if not fields: return - payload = self._serialize_json_fields(fields, ("services", "decky_config")) + payload = _serialize_json_fields(fields, ("services", "decky_config")) payload.setdefault("updated_at", datetime.now(timezone.utc)) async with self._session() as session: result = await session.execute( @@ -2162,7 +2008,7 @@ class SQLModelRepository( .order_by(asc(TopologyDecky.name)) ) return [ - self._deserialize_json_fields( + _deserialize_json_fields( r.model_dump(mode="json"), ("services", "decky_config") ) for r in result.scalars().all() @@ -2593,7 +2439,7 @@ class SQLModelRepository( select(TopologyDecky).where(TopologyDecky.state == "running") ) return [ - self._deserialize_json_fields( + _deserialize_json_fields( r.model_dump(mode="json"), ("services", "decky_config") ) for r in result.scalars().all() diff --git a/decnet/web/db/sqlmodel_repo/_helpers.py b/decnet/web/db/sqlmodel_repo/_helpers.py index eae889b0..8791dbd0 100644 --- a/decnet/web/db/sqlmodel_repo/_helpers.py +++ b/decnet/web/db/sqlmodel_repo/_helpers.py @@ -2,12 +2,20 @@ ``_safe_session`` and ``_detach_close`` make session cleanup robust under client-cancellation. See ``_detach_close`` for the full rationale. + +``_serialize_json_fields`` / ``_deserialize_json_fields`` live here +because they're used across multiple domain mixins (fleet, topology, +…); putting them in a single mixin would force the others to inherit +that mixin or import a free function — both worse than a shared helper. """ from __future__ import annotations import asyncio +import json from contextlib import asynccontextmanager +from typing import Any +import orjson from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from decnet.logging import get_logger @@ -81,3 +89,25 @@ async def _safe_session(factory: async_sessionmaker[AsyncSession]): raise else: await session.close() + + +def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]: + """Encode the named keys as JSON strings if they're not already.""" + out = dict(data) + for k in keys: + v = out.get(k) + if v is not None and not isinstance(v, str): + out[k] = orjson.dumps(v).decode() + return out + + +def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]: + """Decode the named JSON-string keys in place.""" + for k in keys: + v = d.get(k) + if isinstance(v, str): + try: + d[k] = json.loads(v) + except (json.JSONDecodeError, TypeError): + pass + return d diff --git a/decnet/web/db/sqlmodel_repo/fleet.py b/decnet/web/db/sqlmodel_repo/fleet.py new file mode 100644 index 00000000..60eef39f --- /dev/null +++ b/decnet/web/db/sqlmodel_repo/fleet.py @@ -0,0 +1,152 @@ +"""Fleet decky CRUD + cross-source running-decky aggregator.""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any, Optional + +import orjson +from sqlalchemy import asc, select, text, update + +from decnet.web.db.models import DeckyShard, FleetDecky, LOCAL_HOST_SENTINEL +from decnet.web.db.sqlmodel_repo._helpers import _deserialize_json_fields + + +class FleetMixin: + """Mixin: composed onto ``SQLModelRepository``. + + ``list_running_deckies`` aggregates topology + fleet + swarm-shard + sources and stays here because the fleet entry is the canonical + shape; ``list_running_topology_deckies`` / ``list_running_fleet_deckies`` + on ``self`` resolve through the composed class. + """ + + async def upsert_fleet_decky(self, data: dict[str, Any]) -> None: + payload: dict[str, Any] = { + **data, + "updated_at": datetime.now(timezone.utc), + } + payload.setdefault("host_uuid", LOCAL_HOST_SENTINEL) + if payload.get("host_uuid") is None: + payload["host_uuid"] = LOCAL_HOST_SENTINEL + if isinstance(payload.get("services"), list): + payload["services"] = orjson.dumps(payload["services"]).decode() + if isinstance(payload.get("decky_config"), dict): + payload["decky_config"] = orjson.dumps(payload["decky_config"]).decode() + async with self._session() as session: + result = await session.execute( + select(FleetDecky).where( + FleetDecky.host_uuid == payload["host_uuid"], + FleetDecky.name == payload["name"], + ) + ) + existing = result.scalar_one_or_none() + if existing: + for k, v in payload.items(): + setattr(existing, k, v) + session.add(existing) + else: + session.add(FleetDecky(**payload)) + await session.commit() + + async def delete_fleet_decky(self, *, host_uuid: str, name: str) -> None: + async with self._session() as session: + await session.execute( + text( + "DELETE FROM fleet_deckies " + "WHERE host_uuid = :h AND name = :n" + ), + {"h": host_uuid, "n": name}, + ) + await session.commit() + + async def list_fleet_deckies( + self, *, host_uuid: Optional[str] = None, + ) -> list[dict[str, Any]]: + stmt = select(FleetDecky).order_by(asc(FleetDecky.name)) + if host_uuid: + stmt = stmt.where(FleetDecky.host_uuid == host_uuid) + async with self._session() as session: + result = await session.execute(stmt) + return [ + _deserialize_json_fields( + r.model_dump(mode="json"), ("services", "decky_config") + ) + for r in result.scalars().all() + ] + + async def list_running_fleet_deckies(self) -> list[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(FleetDecky).where(FleetDecky.state == "running") + ) + return [ + _deserialize_json_fields( + r.model_dump(mode="json"), ("services", "decky_config") + ) + for r in result.scalars().all() + ] + + async def update_fleet_decky_state( + self, *, host_uuid: str, name: str, state: str, + last_error: Optional[str] = None, + ) -> None: + now = datetime.now(timezone.utc) + values: dict[str, Any] = { + "state": state, + "updated_at": now, + "last_seen": now, + } + if last_error is not None: + values["last_error"] = last_error + async with self._session() as session: + await session.execute( + update(FleetDecky) + .where( + FleetDecky.host_uuid == host_uuid, + FleetDecky.name == name, + ) + .values(**values) + ) + await session.commit() + + async def list_running_deckies(self) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + # MazeNET — already shaped {uuid, name, ip, services}. We carry + # topology_id through so consumers (emailgen scheduler) can walk + # back to the parent topology row without a second round-trip; + # fleet/shard rows never have one, hence Optional. + for d in await self.list_running_topology_deckies(): + out.append({ + "uuid": d.get("uuid"), + "name": d.get("name"), + "ip": d.get("ip"), + "services": d.get("services") or [], + "topology_id": d.get("topology_id"), + "source": "topology", + }) + # Fleet — column is `decky_ip`, PK is composite (host_uuid, name) + for d in await self.list_running_fleet_deckies(): + out.append({ + "uuid": f"{d.get('host_uuid')}:{d.get('name')}", + "name": d.get("name"), + "ip": d.get("decky_ip"), + "services": d.get("services") or [], + "source": "fleet", + }) + # SWARM — DeckyShard rows in 'running' state on enrolled workers. + async with self._session() as session: + shard_rows = await session.execute( + select(DeckyShard).where(DeckyShard.state == "running") + ) + for s in shard_rows.scalars().all(): + d = _deserialize_json_fields( + s.model_dump(mode="json"), ("services", "decky_config") + ) + out.append({ + "uuid": f"{d.get('host_uuid')}:{d.get('decky_name')}", + "name": d.get("decky_name"), + "ip": d.get("decky_ip"), + "services": d.get("services") or [], + "source": "shard", + }) + return out