refactor(db): extract FleetMixin and promote JSON helpers

Moves the 6 fleet-decky methods (incl. cross-source list_running_deckies
aggregator) into sqlmodel_repo/fleet.py. _serialize_json_fields and
_deserialize_json_fields move to _helpers.py since they're shared
across fleet, topology, and canary.
This commit is contained in:
2026-04-28 14:50:01 -04:00
parent d989cd0461
commit a0aeba5abc
3 changed files with 194 additions and 166 deletions

View File

@@ -40,9 +40,6 @@ from decnet.web.db.models import (
Campaign,
SessionProfile,
SmtpTarget,
DeckyShard,
FleetDecky,
LOCAL_HOST_SENTINEL,
Topology,
LAN,
TopologyDecky,
@@ -63,11 +60,14 @@ from decnet.web.db.sqlmodel_repo._helpers import ( # noqa: F401 (re-exported f
_safe_session,
_detach_close,
_cleanup_tasks,
_serialize_json_fields,
_deserialize_json_fields,
)
from decnet.web.db.sqlmodel_repo.attacker_intel import AttackerIntelMixin
from decnet.web.db.sqlmodel_repo.auth import AuthMixin
from decnet.web.db.sqlmodel_repo.bounties import BountiesMixin
from decnet.web.db.sqlmodel_repo.deckies import DeckiesMixin
from decnet.web.db.sqlmodel_repo.fleet import FleetMixin
from decnet.web.db.sqlmodel_repo.swarm import SwarmMixin
from decnet.web.db.sqlmodel_repo.webhooks import WebhooksMixin
@@ -77,6 +77,7 @@ class SQLModelRepository(
AuthMixin,
BountiesMixin,
DeckiesMixin,
FleetMixin,
SwarmMixin,
WebhooksMixin,
BaseRepository,
@@ -1597,165 +1598,10 @@ class SQLModelRepository(
)
return [r.model_dump(mode="json") for r in rows.scalars().all()]
# ------------------------------------------------------------- swarm
# -------------------------------------------------------------- fleet
async def upsert_fleet_decky(self, data: dict[str, Any]) -> None:
payload: dict[str, Any] = {
**data,
"updated_at": datetime.now(timezone.utc),
}
payload.setdefault("host_uuid", LOCAL_HOST_SENTINEL)
if payload.get("host_uuid") is None:
payload["host_uuid"] = LOCAL_HOST_SENTINEL
if isinstance(payload.get("services"), list):
payload["services"] = orjson.dumps(payload["services"]).decode()
if isinstance(payload.get("decky_config"), dict):
payload["decky_config"] = orjson.dumps(payload["decky_config"]).decode()
async with self._session() as session:
result = await session.execute(
select(FleetDecky).where(
FleetDecky.host_uuid == payload["host_uuid"],
FleetDecky.name == payload["name"],
)
)
existing = result.scalar_one_or_none()
if existing:
for k, v in payload.items():
setattr(existing, k, v)
session.add(existing)
else:
session.add(FleetDecky(**payload))
await session.commit()
async def delete_fleet_decky(self, *, host_uuid: str, name: str) -> None:
async with self._session() as session:
await session.execute(
text(
"DELETE FROM fleet_deckies "
"WHERE host_uuid = :h AND name = :n"
),
{"h": host_uuid, "n": name},
)
await session.commit()
async def list_fleet_deckies(
self, *, host_uuid: Optional[str] = None,
) -> list[dict[str, Any]]:
stmt = select(FleetDecky).order_by(asc(FleetDecky.name))
if host_uuid:
stmt = stmt.where(FleetDecky.host_uuid == host_uuid)
async with self._session() as session:
result = await session.execute(stmt)
return [
self._deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()
]
async def list_running_fleet_deckies(self) -> list[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(FleetDecky).where(FleetDecky.state == "running")
)
return [
self._deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()
]
async def update_fleet_decky_state(
self, *, host_uuid: str, name: str, state: str,
last_error: Optional[str] = None,
) -> None:
now = datetime.now(timezone.utc)
values: dict[str, Any] = {
"state": state,
"updated_at": now,
"last_seen": now,
}
if last_error is not None:
values["last_error"] = last_error
async with self._session() as session:
await session.execute(
update(FleetDecky)
.where(
FleetDecky.host_uuid == host_uuid,
FleetDecky.name == name,
)
.values(**values)
)
await session.commit()
async def list_running_deckies(self) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
# MazeNET — already shaped {uuid, name, ip, services}. We carry
# topology_id through so consumers (emailgen scheduler) can walk
# back to the parent topology row without a second round-trip;
# fleet/shard rows never have one, hence Optional.
for d in await self.list_running_topology_deckies():
out.append({
"uuid": d.get("uuid"),
"name": d.get("name"),
"ip": d.get("ip"),
"services": d.get("services") or [],
"topology_id": d.get("topology_id"),
"source": "topology",
})
# Fleet — column is `decky_ip`, PK is composite (host_uuid, name)
for d in await self.list_running_fleet_deckies():
out.append({
"uuid": f"{d.get('host_uuid')}:{d.get('name')}",
"name": d.get("name"),
"ip": d.get("decky_ip"),
"services": d.get("services") or [],
"source": "fleet",
})
# SWARM — DeckyShard rows in 'running' state on enrolled workers.
async with self._session() as session:
shard_rows = await session.execute(
select(DeckyShard).where(DeckyShard.state == "running")
)
for s in shard_rows.scalars().all():
d = self._deserialize_json_fields(
s.model_dump(mode="json"), ("services", "decky_config")
)
out.append({
"uuid": f"{d.get('host_uuid')}:{d.get('decky_name')}",
"name": d.get("decky_name"),
"ip": d.get("decky_ip"),
"services": d.get("services") or [],
"source": "shard",
})
return out
# ------------------------------------------------------------ mazenet
@staticmethod
def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
out = dict(data)
for k in keys:
v = out.get(k)
if v is not None and not isinstance(v, str):
out[k] = orjson.dumps(v).decode()
return out
@staticmethod
def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
for k in keys:
v = d.get(k)
if isinstance(v, str):
try:
d[k] = json.loads(v)
except (json.JSONDecodeError, TypeError):
pass
return d
async def create_topology(self, data: dict[str, Any]) -> str:
payload = self._serialize_json_fields(data, ("config_snapshot",))
payload = _serialize_json_fields(data, ("config_snapshot",))
async with self._session() as session:
row = Topology(**payload)
session.add(row)
@@ -1772,7 +1618,7 @@ class SQLModelRepository(
if not row:
return None
d = row.model_dump(mode="json")
return self._deserialize_json_fields(d, ("config_snapshot",))
return _deserialize_json_fields(d, ("config_snapshot",))
async def list_topologies(
self,
@@ -1790,7 +1636,7 @@ class SQLModelRepository(
async with self._session() as session:
result = await session.execute(statement)
return [
self._deserialize_json_fields(
_deserialize_json_fields(
r.model_dump(mode="json"), ("config_snapshot",)
)
for r in result.scalars().all()
@@ -1878,7 +1724,7 @@ class SQLModelRepository(
select(Topology).where(Topology.needs_resync == True) # noqa: E712
)
return [
self._deserialize_json_fields(
_deserialize_json_fields(
r.model_dump(mode="json"), ("config_snapshot",)
)
for r in result.scalars().all()
@@ -2080,7 +1926,7 @@ class SQLModelRepository(
*,
expected_version: Optional[int] = None,
) -> str:
payload = self._serialize_json_fields(data, ("services", "decky_config"))
payload = _serialize_json_fields(data, ("services", "decky_config"))
async with self._session() as session:
await self._check_and_bump_version(
session, data["topology_id"], expected_version
@@ -2101,7 +1947,7 @@ class SQLModelRepository(
) -> None:
if not fields:
return
payload = self._serialize_json_fields(fields, ("services", "decky_config"))
payload = _serialize_json_fields(fields, ("services", "decky_config"))
payload.setdefault("updated_at", datetime.now(timezone.utc))
async with self._session() as session:
result = await session.execute(
@@ -2162,7 +2008,7 @@ class SQLModelRepository(
.order_by(asc(TopologyDecky.name))
)
return [
self._deserialize_json_fields(
_deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()
@@ -2593,7 +2439,7 @@ class SQLModelRepository(
select(TopologyDecky).where(TopologyDecky.state == "running")
)
return [
self._deserialize_json_fields(
_deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()

View File

@@ -2,12 +2,20 @@
``_safe_session`` and ``_detach_close`` make session cleanup robust under
client-cancellation. See ``_detach_close`` for the full rationale.
``_serialize_json_fields`` / ``_deserialize_json_fields`` live here
because they're used across multiple domain mixins (fleet, topology,
…); putting them in a single mixin would force the others to inherit
that mixin or import a free function — both worse than a shared helper.
"""
from __future__ import annotations
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Any
import orjson
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from decnet.logging import get_logger
@@ -81,3 +89,25 @@ async def _safe_session(factory: async_sessionmaker[AsyncSession]):
raise
else:
await session.close()
def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
"""Encode the named keys as JSON strings if they're not already."""
out = dict(data)
for k in keys:
v = out.get(k)
if v is not None and not isinstance(v, str):
out[k] = orjson.dumps(v).decode()
return out
def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
"""Decode the named JSON-string keys in place."""
for k in keys:
v = d.get(k)
if isinstance(v, str):
try:
d[k] = json.loads(v)
except (json.JSONDecodeError, TypeError):
pass
return d

View File

@@ -0,0 +1,152 @@
"""Fleet decky CRUD + cross-source running-decky aggregator."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Optional
import orjson
from sqlalchemy import asc, select, text, update
from decnet.web.db.models import DeckyShard, FleetDecky, LOCAL_HOST_SENTINEL
from decnet.web.db.sqlmodel_repo._helpers import _deserialize_json_fields
class FleetMixin:
"""Mixin: composed onto ``SQLModelRepository``.
``list_running_deckies`` aggregates topology + fleet + swarm-shard
sources and stays here because the fleet entry is the canonical
shape; ``list_running_topology_deckies`` / ``list_running_fleet_deckies``
on ``self`` resolve through the composed class.
"""
async def upsert_fleet_decky(self, data: dict[str, Any]) -> None:
payload: dict[str, Any] = {
**data,
"updated_at": datetime.now(timezone.utc),
}
payload.setdefault("host_uuid", LOCAL_HOST_SENTINEL)
if payload.get("host_uuid") is None:
payload["host_uuid"] = LOCAL_HOST_SENTINEL
if isinstance(payload.get("services"), list):
payload["services"] = orjson.dumps(payload["services"]).decode()
if isinstance(payload.get("decky_config"), dict):
payload["decky_config"] = orjson.dumps(payload["decky_config"]).decode()
async with self._session() as session:
result = await session.execute(
select(FleetDecky).where(
FleetDecky.host_uuid == payload["host_uuid"],
FleetDecky.name == payload["name"],
)
)
existing = result.scalar_one_or_none()
if existing:
for k, v in payload.items():
setattr(existing, k, v)
session.add(existing)
else:
session.add(FleetDecky(**payload))
await session.commit()
async def delete_fleet_decky(self, *, host_uuid: str, name: str) -> None:
async with self._session() as session:
await session.execute(
text(
"DELETE FROM fleet_deckies "
"WHERE host_uuid = :h AND name = :n"
),
{"h": host_uuid, "n": name},
)
await session.commit()
async def list_fleet_deckies(
self, *, host_uuid: Optional[str] = None,
) -> list[dict[str, Any]]:
stmt = select(FleetDecky).order_by(asc(FleetDecky.name))
if host_uuid:
stmt = stmt.where(FleetDecky.host_uuid == host_uuid)
async with self._session() as session:
result = await session.execute(stmt)
return [
_deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()
]
async def list_running_fleet_deckies(self) -> list[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(FleetDecky).where(FleetDecky.state == "running")
)
return [
_deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()
]
async def update_fleet_decky_state(
self, *, host_uuid: str, name: str, state: str,
last_error: Optional[str] = None,
) -> None:
now = datetime.now(timezone.utc)
values: dict[str, Any] = {
"state": state,
"updated_at": now,
"last_seen": now,
}
if last_error is not None:
values["last_error"] = last_error
async with self._session() as session:
await session.execute(
update(FleetDecky)
.where(
FleetDecky.host_uuid == host_uuid,
FleetDecky.name == name,
)
.values(**values)
)
await session.commit()
async def list_running_deckies(self) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
# MazeNET — already shaped {uuid, name, ip, services}. We carry
# topology_id through so consumers (emailgen scheduler) can walk
# back to the parent topology row without a second round-trip;
# fleet/shard rows never have one, hence Optional.
for d in await self.list_running_topology_deckies():
out.append({
"uuid": d.get("uuid"),
"name": d.get("name"),
"ip": d.get("ip"),
"services": d.get("services") or [],
"topology_id": d.get("topology_id"),
"source": "topology",
})
# Fleet — column is `decky_ip`, PK is composite (host_uuid, name)
for d in await self.list_running_fleet_deckies():
out.append({
"uuid": f"{d.get('host_uuid')}:{d.get('name')}",
"name": d.get("name"),
"ip": d.get("decky_ip"),
"services": d.get("services") or [],
"source": "fleet",
})
# SWARM — DeckyShard rows in 'running' state on enrolled workers.
async with self._session() as session:
shard_rows = await session.execute(
select(DeckyShard).where(DeckyShard.state == "running")
)
for s in shard_rows.scalars().all():
d = _deserialize_json_fields(
s.model_dump(mode="json"), ("services", "decky_config")
)
out.append({
"uuid": f"{d.get('host_uuid')}:{d.get('decky_name')}",
"name": d.get("decky_name"),
"ip": d.get("decky_ip"),
"services": d.get("services") or [],
"source": "shard",
})
return out