diff --git a/decnet/topology/status.py b/decnet/topology/status.py index 2e1b8c76..cc5c818c 100644 --- a/decnet/topology/status.py +++ b/decnet/topology/status.py @@ -53,6 +53,22 @@ class TopologyStatusError(ValueError): """Raised when an illegal topology status transition is attempted.""" +class VersionConflict(RuntimeError): + """Raised when a topology write is supplied a stale ``expected_version``. + + Optimistic concurrency guard: the caller passed the version it last + observed, and the topology has since been mutated by someone else. + The caller should re-read and retry. + """ + + def __init__(self, *, current: int, expected: int) -> None: + self.current = current + self.expected = expected + super().__init__( + f"topology version conflict: expected {expected}, current is {current}" + ) + + def assert_transition(current: str, new: str) -> None: """Validate ``current → new`` or raise :class:`TopologyStatusError`.""" if current not in TopologyStatus.ALL: diff --git a/decnet/web/db/models.py b/decnet/web/db/models.py index c94f4392..3a44ea08 100644 --- a/decnet/web/db/models.py +++ b/decnet/web/db/models.py @@ -216,6 +216,10 @@ class Topology(SQLModel, table=True): created_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), index=True ) + # Optimistic-concurrency token. Bumped by repo methods that mutate + # the topology or any child row when an expected_version is supplied. + # Callers pass their last-seen version; mismatch raises VersionConflict. + version: int = Field(default=1, nullable=False) class LAN(SQLModel, table=True): diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py index 910a4d8a..306d3d19 100644 --- a/decnet/web/db/sqlmodel_repo.py +++ b/decnet/web/db/sqlmodel_repo.py @@ -1027,18 +1027,76 @@ class SQLModelRepository(BaseRepository): await session.commit() return True - async def add_lan(self, data: dict[str, Any]) -> str: + async def _check_and_bump_version( + self, + session, + topology_id: str, + expected_version: Optional[int], + ) -> None: + """Optimistic-concurrency guard used by child-row mutators. + + If ``expected_version`` is None, no check happens (backward-compat + for internal callers that don't need concurrency protection). + + If supplied, loads the Topology row in the same session, + compares ``version == expected_version``, raises VersionConflict + on mismatch, otherwise bumps ``version += 1``. The caller must + commit the enclosing session. + """ + from decnet.topology.status import VersionConflict + + if expected_version is None: + return + result = await session.execute( + select(Topology).where(Topology.id == topology_id) + ) + topo = result.scalar_one_or_none() + if topo is None: + raise ValueError(f"topology {topology_id!r} not found") + if topo.version != expected_version: + raise VersionConflict( + current=topo.version, expected=expected_version + ) + topo.version = topo.version + 1 + session.add(topo) + + async def add_lan( + self, + data: dict[str, Any], + *, + expected_version: Optional[int] = None, + ) -> str: async with self._session() as session: + await self._check_and_bump_version( + session, data["topology_id"], expected_version + ) row = LAN(**data) session.add(row) await session.commit() await session.refresh(row) return row.id - async def update_lan(self, lan_id: str, fields: dict[str, Any]) -> None: + async def update_lan( + self, + lan_id: str, + fields: dict[str, Any], + *, + expected_version: Optional[int] = None, + ) -> None: if not fields: return async with self._session() as session: + if expected_version is not None: + # Need the LAN's topology_id to check version. + result = await session.execute( + select(LAN).where(LAN.id == lan_id) + ) + lan = result.scalar_one_or_none() + if lan is None: + raise ValueError(f"lan {lan_id!r} not found") + await self._check_and_bump_version( + session, lan.topology_id, expected_version + ) await session.execute( update(LAN).where(LAN.id == lan_id).values(**fields) ) @@ -1053,9 +1111,17 @@ class SQLModelRepository(BaseRepository): ) return [r.model_dump(mode="json") for r in result.scalars().all()] - async def add_topology_decky(self, data: dict[str, Any]) -> str: + async def add_topology_decky( + self, + data: dict[str, Any], + *, + expected_version: Optional[int] = None, + ) -> str: payload = self._serialize_json_fields(data, ("services", "decky_config")) async with self._session() as session: + await self._check_and_bump_version( + session, data["topology_id"], expected_version + ) row = TopologyDecky(**payload) session.add(row) await session.commit() @@ -1063,13 +1129,27 @@ class SQLModelRepository(BaseRepository): return row.uuid async def update_topology_decky( - self, decky_uuid: str, fields: dict[str, Any] + self, + decky_uuid: str, + fields: dict[str, Any], + *, + expected_version: Optional[int] = None, ) -> None: if not fields: return payload = self._serialize_json_fields(fields, ("services", "decky_config")) payload.setdefault("updated_at", datetime.now(timezone.utc)) async with self._session() as session: + if expected_version is not None: + result = await session.execute( + select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid) + ) + d = result.scalar_one_or_none() + if d is None: + raise ValueError(f"decky {decky_uuid!r} not found") + await self._check_and_bump_version( + session, d.topology_id, expected_version + ) await session.execute( update(TopologyDecky) .where(TopologyDecky.uuid == decky_uuid) @@ -1093,8 +1173,16 @@ class SQLModelRepository(BaseRepository): for r in result.scalars().all() ] - async def add_topology_edge(self, data: dict[str, Any]) -> str: + async def add_topology_edge( + self, + data: dict[str, Any], + *, + expected_version: Optional[int] = None, + ) -> str: async with self._session() as session: + await self._check_and_bump_version( + session, data["topology_id"], expected_version + ) row = TopologyEdge(**data) session.add(row) await session.commit() diff --git a/tests/topology/test_concurrency.py b/tests/topology/test_concurrency.py new file mode 100644 index 00000000..af6bf77a --- /dev/null +++ b/tests/topology/test_concurrency.py @@ -0,0 +1,118 @@ +"""Optimistic-concurrency (version) checks on topology child mutations.""" +from __future__ import annotations + +import pytest + +from decnet.topology.config import TopologyConfig +from decnet.topology.generator import generate +from decnet.topology.persistence import persist +from decnet.topology.status import VersionConflict +from decnet.web.db.factory import get_repository + + +def _cfg(**kw) -> TopologyConfig: + base = dict( + name="ver", + depth=1, + branching_factor=1, + deckies_per_lan_min=1, + deckies_per_lan_max=1, + cross_edge_probability=0.0, + randomize_services=False, + services_explicit=["ssh"], + seed=2, + ) + base.update(kw) + return TopologyConfig(**base) + + +@pytest.fixture +async def repo(tmp_path): + r = get_repository(db_path=str(tmp_path / "ver.db")) + await r.initialize() + return r + + +@pytest.mark.anyio +async def test_version_starts_at_one_after_persist(repo): + plan = generate(_cfg()) + # persist() adds LANs/deckies/edges without expected_version, so + # the version token stays at 1. + tid = await persist(repo, plan) + topo = await repo.get_topology(tid) + assert topo["version"] == 1 + + +@pytest.mark.anyio +async def test_happy_path_two_sequential_writes(repo): + plan = generate(_cfg()) + tid = await persist(repo, plan) + + await repo.add_lan( + {"topology_id": tid, "name": "LAN-A", "subnet": "10.9.0.0/24", "is_dmz": False}, + expected_version=1, + ) + assert (await repo.get_topology(tid))["version"] == 2 + + await repo.add_lan( + {"topology_id": tid, "name": "LAN-B", "subnet": "10.9.1.0/24", "is_dmz": False}, + expected_version=2, + ) + assert (await repo.get_topology(tid))["version"] == 3 + + +@pytest.mark.anyio +async def test_stale_expected_version_raises(repo): + plan = generate(_cfg()) + tid = await persist(repo, plan) + + await repo.add_lan( + {"topology_id": tid, "name": "LAN-A", "subnet": "10.8.0.0/24", "is_dmz": False}, + expected_version=1, + ) + with pytest.raises(VersionConflict) as ei: + await repo.add_lan( + {"topology_id": tid, "name": "LAN-B", "subnet": "10.8.1.0/24", "is_dmz": False}, + expected_version=1, # stale + ) + assert ei.value.current == 2 + assert ei.value.expected == 1 + + +@pytest.mark.anyio +async def test_no_expected_version_skips_check(repo): + """Existing callers (persist) don't pass expected_version and must + continue to work without version bumps.""" + plan = generate(_cfg()) + tid = await persist(repo, plan) + before = (await repo.get_topology(tid))["version"] + await repo.add_lan( + {"topology_id": tid, "name": "LAN-X", "subnet": "10.7.0.0/24", "is_dmz": False} + ) + after = (await repo.get_topology(tid))["version"] + assert before == after # no bump when version not asserted + + +@pytest.mark.anyio +async def test_update_topology_decky_bumps_version(repo): + plan = generate(_cfg()) + tid = await persist(repo, plan) + decky = (await repo.list_topology_deckies(tid))[0] + await repo.update_topology_decky( + decky["uuid"], + {"decky_config": {"name": decky["name"], "services": ["ssh"], + "ips_by_lan": decky["decky_config"]["ips_by_lan"], + "forwards_l3": False, + "service_config": {"ssh": {"password": "x"}}}}, + expected_version=1, + ) + assert (await repo.get_topology(tid))["version"] == 2 + + +@pytest.mark.anyio +async def test_update_lan_bumps_version(repo): + plan = generate(_cfg()) + tid = await persist(repo, plan) + lan = (await repo.list_lans_for_topology(tid))[0] + await repo.update_lan(lan["id"], {"name": "LAN-RENAMED"}, expected_version=1) + assert (await repo.get_topology(tid))["version"] == 2