refactor(db): extract TopologyMixin
Moves the 31 MazeNET topology methods (topologies CRUD, LANs, deckies, edges, status events, mutation queue) into sqlmodel_repo/topology.py. Includes _assert_pending and _check_and_bump_version concurrency guards. This is the last domain extraction; sqlmodel_repo/__init__.py is now ~165 lines: lifecycle (initialize/reinitialize/migrations), the admin self-heal seed, get_state/set_state, and the mixin composition.
This commit is contained in:
@@ -16,34 +16,22 @@ import json
|
||||
|
||||
import orjson
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, List
|
||||
|
||||
from sqlalchemy import func, select, desc, asc, text, update
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
from decnet.config import load_state
|
||||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||
from decnet.web.auth import get_password_hash
|
||||
from decnet.web.db.repository import BaseRepository
|
||||
from decnet.web.db.models import (
|
||||
User,
|
||||
State,
|
||||
Topology,
|
||||
LAN,
|
||||
TopologyDecky,
|
||||
TopologyEdge,
|
||||
TopologyStatusEvent,
|
||||
TopologyMutation,
|
||||
)
|
||||
from decnet.web.db.models import State, User
|
||||
|
||||
|
||||
from decnet.web.db.sqlmodel_repo._helpers import ( # noqa: F401 (re-exported for tests/external)
|
||||
_safe_session,
|
||||
_detach_close,
|
||||
_cleanup_tasks,
|
||||
_serialize_json_fields,
|
||||
_deserialize_json_fields,
|
||||
)
|
||||
from decnet.web.db.sqlmodel_repo.attacker_intel import AttackerIntelMixin
|
||||
from decnet.web.db.sqlmodel_repo.attackers import AttackersMixin
|
||||
@@ -59,6 +47,7 @@ from decnet.web.db.sqlmodel_repo.logs import LogsMixin
|
||||
from decnet.web.db.sqlmodel_repo.orchestrator import OrchestratorMixin
|
||||
from decnet.web.db.sqlmodel_repo.realism import RealismMixin
|
||||
from decnet.web.db.sqlmodel_repo.swarm import SwarmMixin
|
||||
from decnet.web.db.sqlmodel_repo.topology import TopologyMixin
|
||||
from decnet.web.db.sqlmodel_repo.webhooks import WebhooksMixin
|
||||
|
||||
|
||||
@@ -77,6 +66,7 @@ class SQLModelRepository(
|
||||
OrchestratorMixin,
|
||||
RealismMixin,
|
||||
SwarmMixin,
|
||||
TopologyMixin,
|
||||
WebhooksMixin,
|
||||
BaseRepository,
|
||||
):
|
||||
@@ -174,660 +164,3 @@ class SQLModelRepository(
|
||||
|
||||
await session.commit()
|
||||
|
||||
# ------------------------------------------------------------ mazenet
|
||||
|
||||
async def create_topology(self, data: dict[str, Any]) -> str:
|
||||
payload = _serialize_json_fields(data, ("config_snapshot",))
|
||||
async with self._session() as session:
|
||||
row = Topology(**payload)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
d = row.model_dump(mode="json")
|
||||
return _deserialize_json_fields(d, ("config_snapshot",))
|
||||
|
||||
async def list_topologies(
|
||||
self,
|
||||
status: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
statement = select(Topology).order_by(desc(Topology.created_at))
|
||||
if status:
|
||||
statement = statement.where(Topology.status == status)
|
||||
if offset is not None:
|
||||
statement = statement.offset(offset)
|
||||
if limit is not None:
|
||||
statement = statement.limit(limit)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("config_snapshot",)
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
async def count_topologies(self, status: Optional[str] = None) -> int:
|
||||
from sqlalchemy import func
|
||||
statement = select(func.count(Topology.id))
|
||||
if status:
|
||||
statement = statement.where(Topology.status == status)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return int(result.scalar_one() or 0)
|
||||
|
||||
async def update_topology_status(
|
||||
self,
|
||||
topology_id: str,
|
||||
new_status: str,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update topology.status and append a TopologyStatusEvent atomically.
|
||||
|
||||
Transition legality is enforced in ``decnet.topology.status``; this
|
||||
method trusts the caller.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
return
|
||||
from_status = topo.status
|
||||
topo.status = new_status
|
||||
topo.status_changed_at = now
|
||||
session.add(topo)
|
||||
session.add(
|
||||
TopologyStatusEvent(
|
||||
topology_id=topology_id,
|
||||
from_status=from_status,
|
||||
to_status=new_status,
|
||||
at=now,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def set_topology_resync(self, topology_id: str, value: bool) -> None:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
return
|
||||
topo.needs_resync = bool(value)
|
||||
session.add(topo)
|
||||
await session.commit()
|
||||
|
||||
async def set_topology_email_personas(
|
||||
self, topology_id: str, personas_json: str,
|
||||
) -> bool:
|
||||
"""Replace ``Topology.email_personas`` with the supplied JSON.
|
||||
|
||||
The string is stored as-is; validation/parsing is the caller's
|
||||
job (and is repeated by the emailgen scheduler each tick anyway).
|
||||
Returns True if a row was updated.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
return False
|
||||
topo.email_personas = personas_json
|
||||
session.add(topo)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_topologies_needing_resync(self) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.needs_resync == True) # noqa: E712
|
||||
)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("config_snapshot",)
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
async def delete_topology_cascade(self, topology_id: str) -> bool:
|
||||
"""Delete topology and all children. No portable ON DELETE CASCADE."""
|
||||
async with self._session() as session:
|
||||
params = {"t": topology_id}
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_status_events WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_deckies WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM lans WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_mutations WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if not topo:
|
||||
await session.commit()
|
||||
return False
|
||||
await session.delete(topo)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def _assert_pending(self, session, topology_id: str) -> None:
|
||||
"""Pre-deploy edits are pending-only. Raises TopologyNotEditable."""
|
||||
from decnet.topology.status import TopologyNotEditable, TopologyStatus
|
||||
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
raise ValueError(f"topology {topology_id!r} not found")
|
||||
if topo.status != TopologyStatus.PENDING:
|
||||
raise TopologyNotEditable(
|
||||
status=topo.status,
|
||||
reason="free-form edits are pending-only; use the "
|
||||
"mutator (topology_mutations) after deploy",
|
||||
)
|
||||
|
||||
async def _check_and_bump_version(
|
||||
self,
|
||||
session,
|
||||
topology_id: str,
|
||||
expected_version: Optional[int],
|
||||
) -> None:
|
||||
"""Optimistic-concurrency guard used by child-row mutators.
|
||||
|
||||
If ``expected_version`` is None, no check happens (backward-compat
|
||||
for internal callers that don't need concurrency protection).
|
||||
|
||||
If supplied, loads the Topology row in the same session,
|
||||
compares ``version == expected_version``, raises VersionConflict
|
||||
on mismatch, otherwise bumps ``version += 1``. The caller must
|
||||
commit the enclosing session.
|
||||
"""
|
||||
from decnet.topology.status import VersionConflict
|
||||
|
||||
if expected_version is None:
|
||||
return
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
raise ValueError(f"topology {topology_id!r} not found")
|
||||
if topo.version != expected_version:
|
||||
raise VersionConflict(
|
||||
current=topo.version, expected=expected_version
|
||||
)
|
||||
topo.version = topo.version + 1
|
||||
session.add(topo)
|
||||
|
||||
async def add_lan(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, data["topology_id"], expected_version
|
||||
)
|
||||
row = LAN(**data)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def update_lan(
|
||||
self,
|
||||
lan_id: str,
|
||||
fields: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
enforce_pending: bool = False,
|
||||
) -> None:
|
||||
if not fields:
|
||||
return
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(LAN).where(LAN.id == lan_id)
|
||||
)
|
||||
lan = result.scalar_one_or_none()
|
||||
if lan is None:
|
||||
raise ValueError(f"lan {lan_id!r} not found")
|
||||
if enforce_pending:
|
||||
await self._assert_pending(session, lan.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, lan.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
update(LAN).where(LAN.id == lan_id).values(**fields)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def delete_lan(
|
||||
self,
|
||||
lan_id: str,
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Cascade-delete a LAN from a pending topology.
|
||||
|
||||
Rejects if any decky declares this LAN as its home (i.e. has a
|
||||
non-bridge edge to it — the only LAN that decky lives in). The
|
||||
caller must delete or reassign the home-deckies first.
|
||||
"""
|
||||
from decnet.topology.status import TopologyNotEditable # noqa: F401
|
||||
|
||||
async with self._session() as session:
|
||||
result = await session.execute(select(LAN).where(LAN.id == lan_id))
|
||||
lan = result.scalar_one_or_none()
|
||||
if lan is None:
|
||||
return
|
||||
await self._assert_pending(session, lan.topology_id)
|
||||
|
||||
# Home-decky check: any decky whose only edge lands here?
|
||||
edges_result = await session.execute(
|
||||
select(TopologyEdge).where(TopologyEdge.lan_id == lan_id)
|
||||
)
|
||||
edges_here = edges_result.scalars().all()
|
||||
decky_uuids_on_this_lan = {e.decky_uuid for e in edges_here}
|
||||
for decky_uuid in decky_uuids_on_this_lan:
|
||||
other = await session.execute(
|
||||
select(TopologyEdge).where(
|
||||
TopologyEdge.decky_uuid == decky_uuid,
|
||||
TopologyEdge.lan_id != lan_id,
|
||||
)
|
||||
)
|
||||
if other.scalars().first() is None:
|
||||
raise ValueError(
|
||||
f"cannot delete LAN {lan.name!r}: decky "
|
||||
f"{decky_uuid} has no other LAN (would be orphaned)"
|
||||
)
|
||||
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, lan.topology_id, expected_version
|
||||
)
|
||||
# Cascade edges → LAN.
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE lan_id = :l"),
|
||||
{"l": lan_id},
|
||||
)
|
||||
await session.execute(text("DELETE FROM lans WHERE id = :l"), {"l": lan_id})
|
||||
await session.commit()
|
||||
|
||||
async def list_lans_for_topology(
|
||||
self, topology_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name))
|
||||
)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
async def add_topology_decky(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
payload = _serialize_json_fields(data, ("services", "decky_config"))
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, data["topology_id"], expected_version
|
||||
)
|
||||
row = TopologyDecky(**payload)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.uuid
|
||||
|
||||
async def update_topology_decky(
|
||||
self,
|
||||
decky_uuid: str,
|
||||
fields: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
enforce_pending: bool = False,
|
||||
) -> None:
|
||||
if not fields:
|
||||
return
|
||||
payload = _serialize_json_fields(fields, ("services", "decky_config"))
|
||||
payload.setdefault("updated_at", datetime.now(timezone.utc))
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
||||
)
|
||||
d = result.scalar_one_or_none()
|
||||
if d is None:
|
||||
raise ValueError(f"decky {decky_uuid!r} not found")
|
||||
if enforce_pending:
|
||||
await self._assert_pending(session, d.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, d.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
update(TopologyDecky)
|
||||
.where(TopologyDecky.uuid == decky_uuid)
|
||||
.values(**payload)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def delete_topology_decky(
|
||||
self,
|
||||
decky_uuid: str,
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Cascade-delete a decky + all its edges from a pending topology."""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
||||
)
|
||||
d = result.scalar_one_or_none()
|
||||
if d is None:
|
||||
return
|
||||
await self._assert_pending(session, d.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, d.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE decky_uuid = :u"),
|
||||
{"u": decky_uuid},
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_deckies WHERE uuid = :u"),
|
||||
{"u": decky_uuid},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def list_topology_deckies(
|
||||
self, topology_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky)
|
||||
.where(TopologyDecky.topology_id == topology_id)
|
||||
.order_by(asc(TopologyDecky.name))
|
||||
)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("services", "decky_config")
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
async def add_topology_edge(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, data["topology_id"], expected_version
|
||||
)
|
||||
row = TopologyEdge(**data)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def delete_topology_edge(
|
||||
self,
|
||||
edge_id: str,
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyEdge).where(TopologyEdge.id == edge_id)
|
||||
)
|
||||
edge = result.scalar_one_or_none()
|
||||
if edge is None:
|
||||
return
|
||||
await self._assert_pending(session, edge.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, edge.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE id = :e"),
|
||||
{"e": edge_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def list_topology_edges(
|
||||
self, topology_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyEdge).where(TopologyEdge.topology_id == topology_id)
|
||||
)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
async def list_topology_status_events(
|
||||
self, topology_id: str, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyStatusEvent)
|
||||
.where(TopologyStatusEvent.topology_id == topology_id)
|
||||
.order_by(desc(TopologyStatusEvent.at))
|
||||
.limit(limit)
|
||||
)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
# ---------------- topology_mutations (live reconciler queue) ----------------
|
||||
|
||||
async def enqueue_topology_mutation(
|
||||
self,
|
||||
topology_id: str,
|
||||
op: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Append a pending mutation row and bump the topology version.
|
||||
|
||||
Intended for use while the topology is ``active|degraded``; the
|
||||
reconciler picks these rows up on its next tick.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, topology_id, expected_version
|
||||
)
|
||||
row = TopologyMutation(
|
||||
topology_id=topology_id,
|
||||
op=op,
|
||||
payload=orjson.dumps(payload).decode(),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def claim_next_mutation(
|
||||
self, topology_id: str
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Atomically claim the oldest pending mutation for ``topology_id``.
|
||||
|
||||
Correctness-critical: this is ONE SQL statement. Splitting it
|
||||
into SELECT-then-UPDATE would let two racing watch-loops both
|
||||
see the same ``pending`` row and both transition it to
|
||||
``applying`` — double-executing the op. With the single
|
||||
``UPDATE ... WHERE id = (SELECT ... LIMIT 1) AND state='pending'``
|
||||
pattern the loser's UPDATE matches zero rows and returns
|
||||
``None`` — that is the expected, non-error outcome under
|
||||
contention.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
# Single-statement atomic claim. The inner SELECT picks the
|
||||
# oldest pending row; the outer UPDATE re-checks state so a
|
||||
# second racer that also saw that id finds state='applying'
|
||||
# and matches zero rows.
|
||||
# MySQL forbids referencing the UPDATE target inside a
|
||||
# subquery (ERROR 1093). Wrapping the inner SELECT in a
|
||||
# derived table forces materialisation and sidesteps the
|
||||
# rule. SQLite accepts both forms, so this stays portable.
|
||||
sql = text(
|
||||
"""
|
||||
UPDATE topology_mutations
|
||||
SET state = 'applying'
|
||||
WHERE id = (
|
||||
SELECT id FROM (
|
||||
SELECT id FROM topology_mutations
|
||||
WHERE topology_id = :t AND state = 'pending'
|
||||
ORDER BY requested_at ASC
|
||||
LIMIT 1
|
||||
) AS _next
|
||||
)
|
||||
AND state = 'pending'
|
||||
"""
|
||||
)
|
||||
result = await session.execute(sql, {"t": topology_id})
|
||||
if result.rowcount == 0:
|
||||
await session.commit()
|
||||
return None
|
||||
# Re-read the row we just claimed. The post-UPDATE SELECT is
|
||||
# safe: no racer can now transition an ``applying`` row back
|
||||
# to ``pending``.
|
||||
sel = await session.execute(
|
||||
select(TopologyMutation)
|
||||
.where(TopologyMutation.topology_id == topology_id)
|
||||
.where(TopologyMutation.state == "applying")
|
||||
.order_by(asc(TopologyMutation.requested_at))
|
||||
.limit(1)
|
||||
)
|
||||
row = sel.scalar_one_or_none()
|
||||
await session.commit()
|
||||
_ = now
|
||||
if row is None:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
|
||||
async def mark_mutation_applied(self, mutation_id: str) -> None:
|
||||
async with self._session() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE topology_mutations "
|
||||
"SET state = 'applied', applied_at = :at "
|
||||
"WHERE id = :i"
|
||||
),
|
||||
{
|
||||
"at": datetime.now(timezone.utc).isoformat(),
|
||||
"i": mutation_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def mark_mutation_failed(
|
||||
self, mutation_id: str, reason: str
|
||||
) -> None:
|
||||
async with self._session() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE topology_mutations "
|
||||
"SET state = 'failed', applied_at = :at, reason = :r "
|
||||
"WHERE id = :i"
|
||||
),
|
||||
{
|
||||
"at": datetime.now(timezone.utc).isoformat(),
|
||||
"r": reason,
|
||||
"i": mutation_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def list_topology_mutations(
|
||||
self,
|
||||
topology_id: str,
|
||||
state: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
stmt = (
|
||||
select(TopologyMutation)
|
||||
.where(TopologyMutation.topology_id == topology_id)
|
||||
.order_by(desc(TopologyMutation.requested_at))
|
||||
)
|
||||
if state is not None:
|
||||
stmt = stmt.where(TopologyMutation.state == state)
|
||||
result = await session.execute(stmt)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
async def has_pending_topology_mutation(self) -> bool:
|
||||
"""Cheap watch-loop guard: any pending mutation on a live topology?
|
||||
|
||||
Uses the ``ix_topology_mutations_state_topology`` composite index
|
||||
to keep the join cheap at scale. Returns False as soon as the
|
||||
reconciler path should be skipped.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT 1 FROM topology_mutations "
|
||||
"WHERE state = 'pending' "
|
||||
"AND topology_id IN ("
|
||||
" SELECT id FROM topologies "
|
||||
" WHERE status IN ('active', 'degraded')"
|
||||
") LIMIT 1"
|
||||
)
|
||||
)
|
||||
return result.first() is not None
|
||||
|
||||
async def list_live_topology_ids(self) -> list[str]:
|
||||
"""Return ids of topologies currently in ``active|degraded``."""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology.id).where(
|
||||
Topology.status.in_(["active", "degraded"])
|
||||
)
|
||||
)
|
||||
return [r for r in result.scalars().all()]
|
||||
|
||||
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky).where(TopologyDecky.state == "running")
|
||||
)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("services", "decky_config")
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
|
||||
694
decnet/web/db/sqlmodel_repo/topology.py
Normal file
694
decnet/web/db/sqlmodel_repo/topology.py
Normal file
@@ -0,0 +1,694 @@
|
||||
"""MazeNET topology: topologies, LANs, deckies, edges, status events,
|
||||
and the live mutation queue.
|
||||
|
||||
This is the largest domain in the repo (~600 lines of methods).
|
||||
Sections below correspond to the MazeNET dashboard panels.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import asc, desc, func, select, text, update
|
||||
|
||||
from decnet.web.db.models import (
|
||||
LAN,
|
||||
Topology,
|
||||
TopologyDecky,
|
||||
TopologyEdge,
|
||||
TopologyMutation,
|
||||
TopologyStatusEvent,
|
||||
)
|
||||
from decnet.web.db.sqlmodel_repo._helpers import (
|
||||
_deserialize_json_fields,
|
||||
_serialize_json_fields,
|
||||
)
|
||||
|
||||
|
||||
class TopologyMixin:
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
# ─── topologies ────────────────────────────────────────────────────────
|
||||
|
||||
async def create_topology(self, data: dict[str, Any]) -> str:
|
||||
payload = _serialize_json_fields(data, ("config_snapshot",))
|
||||
async with self._session() as session:
|
||||
row = Topology(**payload)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
d = row.model_dump(mode="json")
|
||||
return _deserialize_json_fields(d, ("config_snapshot",))
|
||||
|
||||
async def list_topologies(
|
||||
self,
|
||||
status: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
statement = select(Topology).order_by(desc(Topology.created_at))
|
||||
if status:
|
||||
statement = statement.where(Topology.status == status)
|
||||
if offset is not None:
|
||||
statement = statement.offset(offset)
|
||||
if limit is not None:
|
||||
statement = statement.limit(limit)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("config_snapshot",)
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
async def count_topologies(self, status: Optional[str] = None) -> int:
|
||||
statement = select(func.count(Topology.id))
|
||||
if status:
|
||||
statement = statement.where(Topology.status == status)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return int(result.scalar_one() or 0)
|
||||
|
||||
async def update_topology_status(
|
||||
self,
|
||||
topology_id: str,
|
||||
new_status: str,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update topology.status and append a TopologyStatusEvent atomically.
|
||||
|
||||
Transition legality is enforced in ``decnet.topology.status``; this
|
||||
method trusts the caller.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
return
|
||||
from_status = topo.status
|
||||
topo.status = new_status
|
||||
topo.status_changed_at = now
|
||||
session.add(topo)
|
||||
session.add(
|
||||
TopologyStatusEvent(
|
||||
topology_id=topology_id,
|
||||
from_status=from_status,
|
||||
to_status=new_status,
|
||||
at=now,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def set_topology_resync(self, topology_id: str, value: bool) -> None:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
return
|
||||
topo.needs_resync = bool(value)
|
||||
session.add(topo)
|
||||
await session.commit()
|
||||
|
||||
async def set_topology_email_personas(
|
||||
self, topology_id: str, personas_json: str,
|
||||
) -> bool:
|
||||
"""Replace ``Topology.email_personas`` with the supplied JSON.
|
||||
|
||||
The string is stored as-is; validation/parsing is the caller's
|
||||
job (and is repeated by the emailgen scheduler each tick anyway).
|
||||
Returns True if a row was updated.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
return False
|
||||
topo.email_personas = personas_json
|
||||
session.add(topo)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_topologies_needing_resync(self) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.needs_resync == True) # noqa: E712
|
||||
)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("config_snapshot",)
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
async def delete_topology_cascade(self, topology_id: str) -> bool:
|
||||
"""Delete topology and all children. No portable ON DELETE CASCADE."""
|
||||
async with self._session() as session:
|
||||
params = {"t": topology_id}
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_status_events WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_deckies WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM lans WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_mutations WHERE topology_id = :t"),
|
||||
params,
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if not topo:
|
||||
await session.commit()
|
||||
return False
|
||||
await session.delete(topo)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
# ─── concurrency / pending-state guards ───────────────────────────────
|
||||
|
||||
async def _assert_pending(self, session, topology_id: str) -> None:
|
||||
"""Pre-deploy edits are pending-only. Raises TopologyNotEditable."""
|
||||
from decnet.topology.status import TopologyNotEditable, TopologyStatus
|
||||
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
raise ValueError(f"topology {topology_id!r} not found")
|
||||
if topo.status != TopologyStatus.PENDING:
|
||||
raise TopologyNotEditable(
|
||||
status=topo.status,
|
||||
reason="free-form edits are pending-only; use the "
|
||||
"mutator (topology_mutations) after deploy",
|
||||
)
|
||||
|
||||
async def _check_and_bump_version(
|
||||
self,
|
||||
session,
|
||||
topology_id: str,
|
||||
expected_version: Optional[int],
|
||||
) -> None:
|
||||
"""Optimistic-concurrency guard used by child-row mutators.
|
||||
|
||||
If ``expected_version`` is None, no check happens (backward-compat
|
||||
for internal callers that don't need concurrency protection).
|
||||
|
||||
If supplied, loads the Topology row in the same session,
|
||||
compares ``version == expected_version``, raises VersionConflict
|
||||
on mismatch, otherwise bumps ``version += 1``. The caller must
|
||||
commit the enclosing session.
|
||||
"""
|
||||
from decnet.topology.status import VersionConflict
|
||||
|
||||
if expected_version is None:
|
||||
return
|
||||
result = await session.execute(
|
||||
select(Topology).where(Topology.id == topology_id)
|
||||
)
|
||||
topo = result.scalar_one_or_none()
|
||||
if topo is None:
|
||||
raise ValueError(f"topology {topology_id!r} not found")
|
||||
if topo.version != expected_version:
|
||||
raise VersionConflict(
|
||||
current=topo.version, expected=expected_version
|
||||
)
|
||||
topo.version = topo.version + 1
|
||||
session.add(topo)
|
||||
|
||||
# ─── LANs ──────────────────────────────────────────────────────────────
|
||||
|
||||
async def add_lan(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, data["topology_id"], expected_version
|
||||
)
|
||||
row = LAN(**data)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def update_lan(
|
||||
self,
|
||||
lan_id: str,
|
||||
fields: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
enforce_pending: bool = False,
|
||||
) -> None:
|
||||
if not fields:
|
||||
return
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(LAN).where(LAN.id == lan_id)
|
||||
)
|
||||
lan = result.scalar_one_or_none()
|
||||
if lan is None:
|
||||
raise ValueError(f"lan {lan_id!r} not found")
|
||||
if enforce_pending:
|
||||
await self._assert_pending(session, lan.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, lan.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
update(LAN).where(LAN.id == lan_id).values(**fields)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def delete_lan(
|
||||
self,
|
||||
lan_id: str,
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Cascade-delete a LAN from a pending topology.
|
||||
|
||||
Rejects if any decky declares this LAN as its home (i.e. has a
|
||||
non-bridge edge to it — the only LAN that decky lives in). The
|
||||
caller must delete or reassign the home-deckies first.
|
||||
"""
|
||||
from decnet.topology.status import TopologyNotEditable # noqa: F401
|
||||
|
||||
async with self._session() as session:
|
||||
result = await session.execute(select(LAN).where(LAN.id == lan_id))
|
||||
lan = result.scalar_one_or_none()
|
||||
if lan is None:
|
||||
return
|
||||
await self._assert_pending(session, lan.topology_id)
|
||||
|
||||
# Home-decky check: any decky whose only edge lands here?
|
||||
edges_result = await session.execute(
|
||||
select(TopologyEdge).where(TopologyEdge.lan_id == lan_id)
|
||||
)
|
||||
edges_here = edges_result.scalars().all()
|
||||
decky_uuids_on_this_lan = {e.decky_uuid for e in edges_here}
|
||||
for decky_uuid in decky_uuids_on_this_lan:
|
||||
other = await session.execute(
|
||||
select(TopologyEdge).where(
|
||||
TopologyEdge.decky_uuid == decky_uuid,
|
||||
TopologyEdge.lan_id != lan_id,
|
||||
)
|
||||
)
|
||||
if other.scalars().first() is None:
|
||||
raise ValueError(
|
||||
f"cannot delete LAN {lan.name!r}: decky "
|
||||
f"{decky_uuid} has no other LAN (would be orphaned)"
|
||||
)
|
||||
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, lan.topology_id, expected_version
|
||||
)
|
||||
# Cascade edges → LAN.
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE lan_id = :l"),
|
||||
{"l": lan_id},
|
||||
)
|
||||
await session.execute(text("DELETE FROM lans WHERE id = :l"), {"l": lan_id})
|
||||
await session.commit()
|
||||
|
||||
async def list_lans_for_topology(
|
||||
self, topology_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name))
|
||||
)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
# ─── topology deckies ─────────────────────────────────────────────────
|
||||
|
||||
async def add_topology_decky(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
payload = _serialize_json_fields(data, ("services", "decky_config"))
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, data["topology_id"], expected_version
|
||||
)
|
||||
row = TopologyDecky(**payload)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.uuid
|
||||
|
||||
async def update_topology_decky(
|
||||
self,
|
||||
decky_uuid: str,
|
||||
fields: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
enforce_pending: bool = False,
|
||||
) -> None:
|
||||
if not fields:
|
||||
return
|
||||
payload = _serialize_json_fields(fields, ("services", "decky_config"))
|
||||
payload.setdefault("updated_at", datetime.now(timezone.utc))
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
||||
)
|
||||
d = result.scalar_one_or_none()
|
||||
if d is None:
|
||||
raise ValueError(f"decky {decky_uuid!r} not found")
|
||||
if enforce_pending:
|
||||
await self._assert_pending(session, d.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, d.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
update(TopologyDecky)
|
||||
.where(TopologyDecky.uuid == decky_uuid)
|
||||
.values(**payload)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def delete_topology_decky(
|
||||
self,
|
||||
decky_uuid: str,
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Cascade-delete a decky + all its edges from a pending topology."""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
||||
)
|
||||
d = result.scalar_one_or_none()
|
||||
if d is None:
|
||||
return
|
||||
await self._assert_pending(session, d.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, d.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE decky_uuid = :u"),
|
||||
{"u": decky_uuid},
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_deckies WHERE uuid = :u"),
|
||||
{"u": decky_uuid},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def list_topology_deckies(
|
||||
self, topology_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky)
|
||||
.where(TopologyDecky.topology_id == topology_id)
|
||||
.order_by(asc(TopologyDecky.name))
|
||||
)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("services", "decky_config")
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
|
||||
# ─── topology edges ───────────────────────────────────────────────────
|
||||
|
||||
async def add_topology_edge(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, data["topology_id"], expected_version
|
||||
)
|
||||
row = TopologyEdge(**data)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def delete_topology_edge(
|
||||
self,
|
||||
edge_id: str,
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> None:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyEdge).where(TopologyEdge.id == edge_id)
|
||||
)
|
||||
edge = result.scalar_one_or_none()
|
||||
if edge is None:
|
||||
return
|
||||
await self._assert_pending(session, edge.topology_id)
|
||||
if expected_version is not None:
|
||||
await self._check_and_bump_version(
|
||||
session, edge.topology_id, expected_version
|
||||
)
|
||||
await session.execute(
|
||||
text("DELETE FROM topology_edges WHERE id = :e"),
|
||||
{"e": edge_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def list_topology_edges(
|
||||
self, topology_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyEdge).where(TopologyEdge.topology_id == topology_id)
|
||||
)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
async def list_topology_status_events(
|
||||
self, topology_id: str, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyStatusEvent)
|
||||
.where(TopologyStatusEvent.topology_id == topology_id)
|
||||
.order_by(desc(TopologyStatusEvent.at))
|
||||
.limit(limit)
|
||||
)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
# ─── topology mutations (live reconciler queue) ──────────────────────
|
||||
|
||||
async def enqueue_topology_mutation(
|
||||
self,
|
||||
topology_id: str,
|
||||
op: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
expected_version: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Append a pending mutation row and bump the topology version.
|
||||
|
||||
Intended for use while the topology is ``active|degraded``; the
|
||||
reconciler picks these rows up on its next tick.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
await self._check_and_bump_version(
|
||||
session, topology_id, expected_version
|
||||
)
|
||||
row = TopologyMutation(
|
||||
topology_id=topology_id,
|
||||
op=op,
|
||||
payload=orjson.dumps(payload).decode(),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return row.id
|
||||
|
||||
async def claim_next_mutation(
|
||||
self, topology_id: str
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Atomically claim the oldest pending mutation for ``topology_id``.
|
||||
|
||||
Correctness-critical: this is ONE SQL statement. Splitting it
|
||||
into SELECT-then-UPDATE would let two racing watch-loops both
|
||||
see the same ``pending`` row and both transition it to
|
||||
``applying`` — double-executing the op. With the single
|
||||
``UPDATE ... WHERE id = (SELECT ... LIMIT 1) AND state='pending'``
|
||||
pattern the loser's UPDATE matches zero rows and returns
|
||||
``None`` — that is the expected, non-error outcome under
|
||||
contention.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
# Single-statement atomic claim. The inner SELECT picks the
|
||||
# oldest pending row; the outer UPDATE re-checks state so a
|
||||
# second racer that also saw that id finds state='applying'
|
||||
# and matches zero rows.
|
||||
# MySQL forbids referencing the UPDATE target inside a
|
||||
# subquery (ERROR 1093). Wrapping the inner SELECT in a
|
||||
# derived table forces materialisation and sidesteps the
|
||||
# rule. SQLite accepts both forms, so this stays portable.
|
||||
sql = text(
|
||||
"""
|
||||
UPDATE topology_mutations
|
||||
SET state = 'applying'
|
||||
WHERE id = (
|
||||
SELECT id FROM (
|
||||
SELECT id FROM topology_mutations
|
||||
WHERE topology_id = :t AND state = 'pending'
|
||||
ORDER BY requested_at ASC
|
||||
LIMIT 1
|
||||
) AS _next
|
||||
)
|
||||
AND state = 'pending'
|
||||
"""
|
||||
)
|
||||
result = await session.execute(sql, {"t": topology_id})
|
||||
if result.rowcount == 0:
|
||||
await session.commit()
|
||||
return None
|
||||
# Re-read the row we just claimed. The post-UPDATE SELECT is
|
||||
# safe: no racer can now transition an ``applying`` row back
|
||||
# to ``pending``.
|
||||
sel = await session.execute(
|
||||
select(TopologyMutation)
|
||||
.where(TopologyMutation.topology_id == topology_id)
|
||||
.where(TopologyMutation.state == "applying")
|
||||
.order_by(asc(TopologyMutation.requested_at))
|
||||
.limit(1)
|
||||
)
|
||||
row = sel.scalar_one_or_none()
|
||||
await session.commit()
|
||||
_ = now
|
||||
if row is None:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
|
||||
async def mark_mutation_applied(self, mutation_id: str) -> None:
|
||||
async with self._session() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE topology_mutations "
|
||||
"SET state = 'applied', applied_at = :at "
|
||||
"WHERE id = :i"
|
||||
),
|
||||
{
|
||||
"at": datetime.now(timezone.utc).isoformat(),
|
||||
"i": mutation_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def mark_mutation_failed(
|
||||
self, mutation_id: str, reason: str
|
||||
) -> None:
|
||||
async with self._session() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE topology_mutations "
|
||||
"SET state = 'failed', applied_at = :at, reason = :r "
|
||||
"WHERE id = :i"
|
||||
),
|
||||
{
|
||||
"at": datetime.now(timezone.utc).isoformat(),
|
||||
"r": reason,
|
||||
"i": mutation_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def list_topology_mutations(
|
||||
self,
|
||||
topology_id: str,
|
||||
state: Optional[str] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
stmt = (
|
||||
select(TopologyMutation)
|
||||
.where(TopologyMutation.topology_id == topology_id)
|
||||
.order_by(desc(TopologyMutation.requested_at))
|
||||
)
|
||||
if state is not None:
|
||||
stmt = stmt.where(TopologyMutation.state == state)
|
||||
result = await session.execute(stmt)
|
||||
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
||||
|
||||
async def has_pending_topology_mutation(self) -> bool:
|
||||
"""Cheap watch-loop guard: any pending mutation on a live topology?
|
||||
|
||||
Uses the ``ix_topology_mutations_state_topology`` composite index
|
||||
to keep the join cheap at scale. Returns False as soon as the
|
||||
reconciler path should be skipped.
|
||||
"""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT 1 FROM topology_mutations "
|
||||
"WHERE state = 'pending' "
|
||||
"AND topology_id IN ("
|
||||
" SELECT id FROM topologies "
|
||||
" WHERE status IN ('active', 'degraded')"
|
||||
") LIMIT 1"
|
||||
)
|
||||
)
|
||||
return result.first() is not None
|
||||
|
||||
async def list_live_topology_ids(self) -> list[str]:
|
||||
"""Return ids of topologies currently in ``active|degraded``."""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology.id).where(
|
||||
Topology.status.in_(["active", "degraded"])
|
||||
)
|
||||
)
|
||||
return [r for r in result.scalars().all()]
|
||||
|
||||
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(TopologyDecky).where(TopologyDecky.state == "running")
|
||||
)
|
||||
return [
|
||||
_deserialize_json_fields(
|
||||
r.model_dump(mode="json"), ("services", "decky_config")
|
||||
)
|
||||
for r in result.scalars().all()
|
||||
]
|
||||
Reference in New Issue
Block a user