feat(mazenet): repo methods for topology/LAN/decky/edge/status events

Adds topology CRUD to BaseRepository (NotImplementedError defaults) and
implements them in SQLModelRepository: create/get/list/delete topologies,
add/update/list LANs and TopologyDeckies, add/list edges, plus an atomic
update_topology_status that appends a TopologyStatusEvent in the same
transaction.  Cascade delete sweeps children before the topology row.

Covered by tests/topology/test_repo.py (roundtrip, per-topology name
uniqueness, status event log, cascade delete, status filter) and an
extension to tests/test_base_repo.py for the NotImplementedError surface.
This commit is contained in:
2026-04-20 16:43:49 -04:00
parent 096a35b24a
commit 47cd200e1d
5 changed files with 466 additions and 0 deletions

View File

@@ -234,3 +234,67 @@ class BaseRepository(ABC):
async def delete_decky_shard(self, decky_name: str) -> bool:
raise NotImplementedError
# ----------------------------------------------------------- mazenet
# MazeNET topology persistence. Default no-op / NotImplementedError so
# non-default backends stay functional; SQLModelRepository provides the
# real implementation used by SQLite and MySQL.
async def create_topology(self, data: dict[str, Any]) -> str:
raise NotImplementedError
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
raise NotImplementedError
async def list_topologies(
self, status: Optional[str] = None
) -> list[dict[str, Any]]:
raise NotImplementedError
async def update_topology_status(
self,
topology_id: str,
new_status: str,
reason: Optional[str] = None,
) -> None:
raise NotImplementedError
async def delete_topology_cascade(self, topology_id: str) -> bool:
raise NotImplementedError
async def add_lan(self, data: dict[str, Any]) -> str:
raise NotImplementedError
async def update_lan(self, lan_id: str, fields: dict[str, Any]) -> None:
raise NotImplementedError
async def list_lans_for_topology(
self, topology_id: str
) -> list[dict[str, Any]]:
raise NotImplementedError
async def add_topology_decky(self, data: dict[str, Any]) -> str:
raise NotImplementedError
async def update_topology_decky(
self, decky_uuid: str, fields: dict[str, Any]
) -> None:
raise NotImplementedError
async def list_topology_deckies(
self, topology_id: str
) -> list[dict[str, Any]]:
raise NotImplementedError
async def add_topology_edge(self, data: dict[str, Any]) -> str:
raise NotImplementedError
async def list_topology_edges(
self, topology_id: str
) -> list[dict[str, Any]]:
raise NotImplementedError
async def list_topology_status_events(
self, topology_id: str, limit: int = 100
) -> list[dict[str, Any]]:
raise NotImplementedError

View File

@@ -36,6 +36,11 @@ from decnet.web.db.models import (
AttackerBehavior,
SwarmHost,
DeckyShard,
Topology,
LAN,
TopologyDecky,
TopologyEdge,
TopologyStatusEvent,
)
@@ -899,3 +904,220 @@ class SQLModelRepository(BaseRepository):
)
await session.commit()
return bool(result.rowcount)
# ------------------------------------------------------------ mazenet
@staticmethod
def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
out = dict(data)
for k in keys:
v = out.get(k)
if v is not None and not isinstance(v, str):
out[k] = orjson.dumps(v).decode()
return out
@staticmethod
def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
for k in keys:
v = d.get(k)
if isinstance(v, str):
try:
d[k] = json.loads(v)
except (json.JSONDecodeError, TypeError):
pass
return d
async def create_topology(self, data: dict[str, Any]) -> str:
payload = self._serialize_json_fields(data, ("config_snapshot",))
async with self._session() as session:
row = Topology(**payload)
session.add(row)
await session.commit()
await session.refresh(row)
return row.id
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(Topology).where(Topology.id == topology_id)
)
row = result.scalar_one_or_none()
if not row:
return None
d = row.model_dump(mode="json")
return self._deserialize_json_fields(d, ("config_snapshot",))
async def list_topologies(
self, status: Optional[str] = None
) -> list[dict[str, Any]]:
statement = select(Topology).order_by(desc(Topology.created_at))
if status:
statement = statement.where(Topology.status == status)
async with self._session() as session:
result = await session.execute(statement)
return [
self._deserialize_json_fields(
r.model_dump(mode="json"), ("config_snapshot",)
)
for r in result.scalars().all()
]
async def update_topology_status(
self,
topology_id: str,
new_status: str,
reason: Optional[str] = None,
) -> None:
"""Update topology.status and append a TopologyStatusEvent atomically.
Transition legality is enforced in ``decnet.topology.status``; this
method trusts the caller.
"""
now = datetime.now(timezone.utc)
async with self._session() as session:
result = await session.execute(
select(Topology).where(Topology.id == topology_id)
)
topo = result.scalar_one_or_none()
if topo is None:
return
from_status = topo.status
topo.status = new_status
topo.status_changed_at = now
session.add(topo)
session.add(
TopologyStatusEvent(
topology_id=topology_id,
from_status=from_status,
to_status=new_status,
at=now,
reason=reason,
)
)
await session.commit()
async def delete_topology_cascade(self, topology_id: str) -> bool:
"""Delete topology and all children. No portable ON DELETE CASCADE."""
async with self._session() as session:
params = {"t": topology_id}
await session.execute(
text("DELETE FROM topology_status_events WHERE topology_id = :t"),
params,
)
await session.execute(
text("DELETE FROM topology_edges WHERE topology_id = :t"),
params,
)
await session.execute(
text("DELETE FROM topology_deckies WHERE topology_id = :t"),
params,
)
await session.execute(
text("DELETE FROM lans WHERE topology_id = :t"),
params,
)
result = await session.execute(
select(Topology).where(Topology.id == topology_id)
)
topo = result.scalar_one_or_none()
if not topo:
await session.commit()
return False
await session.delete(topo)
await session.commit()
return True
async def add_lan(self, data: dict[str, Any]) -> str:
async with self._session() as session:
row = LAN(**data)
session.add(row)
await session.commit()
await session.refresh(row)
return row.id
async def update_lan(self, lan_id: str, fields: dict[str, Any]) -> None:
if not fields:
return
async with self._session() as session:
await session.execute(
update(LAN).where(LAN.id == lan_id).values(**fields)
)
await session.commit()
async def list_lans_for_topology(
self, topology_id: str
) -> list[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name))
)
return [r.model_dump(mode="json") for r in result.scalars().all()]
async def add_topology_decky(self, data: dict[str, Any]) -> str:
payload = self._serialize_json_fields(data, ("services", "decky_config"))
async with self._session() as session:
row = TopologyDecky(**payload)
session.add(row)
await session.commit()
await session.refresh(row)
return row.uuid
async def update_topology_decky(
self, decky_uuid: str, fields: dict[str, Any]
) -> None:
if not fields:
return
payload = self._serialize_json_fields(fields, ("services", "decky_config"))
payload.setdefault("updated_at", datetime.now(timezone.utc))
async with self._session() as session:
await session.execute(
update(TopologyDecky)
.where(TopologyDecky.uuid == decky_uuid)
.values(**payload)
)
await session.commit()
async def list_topology_deckies(
self, topology_id: str
) -> list[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(TopologyDecky)
.where(TopologyDecky.topology_id == topology_id)
.order_by(asc(TopologyDecky.name))
)
return [
self._deserialize_json_fields(
r.model_dump(mode="json"), ("services", "decky_config")
)
for r in result.scalars().all()
]
async def add_topology_edge(self, data: dict[str, Any]) -> str:
async with self._session() as session:
row = TopologyEdge(**data)
session.add(row)
await session.commit()
await session.refresh(row)
return row.id
async def list_topology_edges(
self, topology_id: str
) -> list[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(TopologyEdge).where(TopologyEdge.topology_id == topology_id)
)
return [r.model_dump(mode="json") for r in result.scalars().all()]
async def list_topology_status_events(
self, topology_id: str, limit: int = 100
) -> list[dict[str, Any]]:
async with self._session() as session:
result = await session.execute(
select(TopologyStatusEvent)
.where(TopologyStatusEvent.topology_id == topology_id)
.order_by(desc(TopologyStatusEvent.at))
.limit(limit)
)
return [r.model_dump(mode="json") for r in result.scalars().all()]