From 47cd200e1d9364e18d6067023b31a098157e6969 Mon Sep 17 00:00:00 2001 From: anti Date: Mon, 20 Apr 2026 16:43:49 -0400 Subject: [PATCH] feat(mazenet): repo methods for topology/LAN/decky/edge/status events Adds topology CRUD to BaseRepository (NotImplementedError defaults) and implements them in SQLModelRepository: create/get/list/delete topologies, add/update/list LANs and TopologyDeckies, add/list edges, plus an atomic update_topology_status that appends a TopologyStatusEvent in the same transaction. Cascade delete sweeps children before the topology row. Covered by tests/topology/test_repo.py (roundtrip, per-topology name uniqueness, status event log, cascade delete, status filter) and an extension to tests/test_base_repo.py for the NotImplementedError surface. --- decnet/web/db/repository.py | 64 ++++++++++ decnet/web/db/sqlmodel_repo.py | 222 +++++++++++++++++++++++++++++++++ tests/test_base_repo.py | 14 +++ tests/topology/__init__.py | 0 tests/topology/test_repo.py | 166 ++++++++++++++++++++++++ 5 files changed, 466 insertions(+) create mode 100644 tests/topology/__init__.py create mode 100644 tests/topology/test_repo.py diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index d0513d4..d7f1f53 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -234,3 +234,67 @@ class BaseRepository(ABC): async def delete_decky_shard(self, decky_name: str) -> bool: raise NotImplementedError + + # ----------------------------------------------------------- mazenet + # MazeNET topology persistence. Default no-op / NotImplementedError so + # non-default backends stay functional; SQLModelRepository provides the + # real implementation used by SQLite and MySQL. + + async def create_topology(self, data: dict[str, Any]) -> str: + raise NotImplementedError + + async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]: + raise NotImplementedError + + async def list_topologies( + self, status: Optional[str] = None + ) -> list[dict[str, Any]]: + raise NotImplementedError + + async def update_topology_status( + self, + topology_id: str, + new_status: str, + reason: Optional[str] = None, + ) -> None: + raise NotImplementedError + + async def delete_topology_cascade(self, topology_id: str) -> bool: + raise NotImplementedError + + async def add_lan(self, data: dict[str, Any]) -> str: + raise NotImplementedError + + async def update_lan(self, lan_id: str, fields: dict[str, Any]) -> None: + raise NotImplementedError + + async def list_lans_for_topology( + self, topology_id: str + ) -> list[dict[str, Any]]: + raise NotImplementedError + + async def add_topology_decky(self, data: dict[str, Any]) -> str: + raise NotImplementedError + + async def update_topology_decky( + self, decky_uuid: str, fields: dict[str, Any] + ) -> None: + raise NotImplementedError + + async def list_topology_deckies( + self, topology_id: str + ) -> list[dict[str, Any]]: + raise NotImplementedError + + async def add_topology_edge(self, data: dict[str, Any]) -> str: + raise NotImplementedError + + async def list_topology_edges( + self, topology_id: str + ) -> list[dict[str, Any]]: + raise NotImplementedError + + async def list_topology_status_events( + self, topology_id: str, limit: int = 100 + ) -> list[dict[str, Any]]: + raise NotImplementedError diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py index b5f40f4..910a4d8 100644 --- a/decnet/web/db/sqlmodel_repo.py +++ b/decnet/web/db/sqlmodel_repo.py @@ -36,6 +36,11 @@ from decnet.web.db.models import ( AttackerBehavior, SwarmHost, DeckyShard, + Topology, + LAN, + TopologyDecky, + TopologyEdge, + TopologyStatusEvent, ) @@ -899,3 +904,220 @@ class SQLModelRepository(BaseRepository): ) await session.commit() return bool(result.rowcount) + + # ------------------------------------------------------------ mazenet + + @staticmethod + def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]: + out = dict(data) + for k in keys: + v = out.get(k) + if v is not None and not isinstance(v, str): + out[k] = orjson.dumps(v).decode() + return out + + @staticmethod + def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]: + for k in keys: + v = d.get(k) + if isinstance(v, str): + try: + d[k] = json.loads(v) + except (json.JSONDecodeError, TypeError): + pass + return d + + async def create_topology(self, data: dict[str, Any]) -> str: + payload = self._serialize_json_fields(data, ("config_snapshot",)) + async with self._session() as session: + row = Topology(**payload) + session.add(row) + await session.commit() + await session.refresh(row) + return row.id + + async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(Topology).where(Topology.id == topology_id) + ) + row = result.scalar_one_or_none() + if not row: + return None + d = row.model_dump(mode="json") + return self._deserialize_json_fields(d, ("config_snapshot",)) + + async def list_topologies( + self, status: Optional[str] = None + ) -> list[dict[str, Any]]: + statement = select(Topology).order_by(desc(Topology.created_at)) + if status: + statement = statement.where(Topology.status == status) + async with self._session() as session: + result = await session.execute(statement) + return [ + self._deserialize_json_fields( + r.model_dump(mode="json"), ("config_snapshot",) + ) + for r in result.scalars().all() + ] + + async def update_topology_status( + self, + topology_id: str, + new_status: str, + reason: Optional[str] = None, + ) -> None: + """Update topology.status and append a TopologyStatusEvent atomically. + + Transition legality is enforced in ``decnet.topology.status``; this + method trusts the caller. + """ + now = datetime.now(timezone.utc) + async with self._session() as session: + result = await session.execute( + select(Topology).where(Topology.id == topology_id) + ) + topo = result.scalar_one_or_none() + if topo is None: + return + from_status = topo.status + topo.status = new_status + topo.status_changed_at = now + session.add(topo) + session.add( + TopologyStatusEvent( + topology_id=topology_id, + from_status=from_status, + to_status=new_status, + at=now, + reason=reason, + ) + ) + await session.commit() + + async def delete_topology_cascade(self, topology_id: str) -> bool: + """Delete topology and all children. No portable ON DELETE CASCADE.""" + async with self._session() as session: + params = {"t": topology_id} + await session.execute( + text("DELETE FROM topology_status_events WHERE topology_id = :t"), + params, + ) + await session.execute( + text("DELETE FROM topology_edges WHERE topology_id = :t"), + params, + ) + await session.execute( + text("DELETE FROM topology_deckies WHERE topology_id = :t"), + params, + ) + await session.execute( + text("DELETE FROM lans WHERE topology_id = :t"), + params, + ) + result = await session.execute( + select(Topology).where(Topology.id == topology_id) + ) + topo = result.scalar_one_or_none() + if not topo: + await session.commit() + return False + await session.delete(topo) + await session.commit() + return True + + async def add_lan(self, data: dict[str, Any]) -> str: + async with self._session() as session: + row = LAN(**data) + session.add(row) + await session.commit() + await session.refresh(row) + return row.id + + async def update_lan(self, lan_id: str, fields: dict[str, Any]) -> None: + if not fields: + return + async with self._session() as session: + await session.execute( + update(LAN).where(LAN.id == lan_id).values(**fields) + ) + await session.commit() + + async def list_lans_for_topology( + self, topology_id: str + ) -> list[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name)) + ) + return [r.model_dump(mode="json") for r in result.scalars().all()] + + async def add_topology_decky(self, data: dict[str, Any]) -> str: + payload = self._serialize_json_fields(data, ("services", "decky_config")) + async with self._session() as session: + row = TopologyDecky(**payload) + session.add(row) + await session.commit() + await session.refresh(row) + return row.uuid + + async def update_topology_decky( + self, decky_uuid: str, fields: dict[str, Any] + ) -> None: + if not fields: + return + payload = self._serialize_json_fields(fields, ("services", "decky_config")) + payload.setdefault("updated_at", datetime.now(timezone.utc)) + async with self._session() as session: + await session.execute( + update(TopologyDecky) + .where(TopologyDecky.uuid == decky_uuid) + .values(**payload) + ) + await session.commit() + + async def list_topology_deckies( + self, topology_id: str + ) -> list[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(TopologyDecky) + .where(TopologyDecky.topology_id == topology_id) + .order_by(asc(TopologyDecky.name)) + ) + return [ + self._deserialize_json_fields( + r.model_dump(mode="json"), ("services", "decky_config") + ) + for r in result.scalars().all() + ] + + async def add_topology_edge(self, data: dict[str, Any]) -> str: + async with self._session() as session: + row = TopologyEdge(**data) + session.add(row) + await session.commit() + await session.refresh(row) + return row.id + + async def list_topology_edges( + self, topology_id: str + ) -> list[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(TopologyEdge).where(TopologyEdge.topology_id == topology_id) + ) + return [r.model_dump(mode="json") for r in result.scalars().all()] + + async def list_topology_status_events( + self, topology_id: str, limit: int = 100 + ) -> list[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(TopologyStatusEvent) + .where(TopologyStatusEvent.topology_id == topology_id) + .order_by(desc(TopologyStatusEvent.at)) + .limit(limit) + ) + return [r.model_dump(mode="json") for r in result.scalars().all()] diff --git a/tests/test_base_repo.py b/tests/test_base_repo.py index 7750f69..ac23fea 100644 --- a/tests/test_base_repo.py +++ b/tests/test_base_repo.py @@ -88,6 +88,20 @@ async def test_base_repo_coverage(): (dr.upsert_decky_shard, ({},)), (dr.list_decky_shards, ()), (dr.delete_decky_shards_for_host, ("u",)), + (dr.create_topology, ({},)), + (dr.get_topology, ("t",)), + (dr.list_topologies, ()), + (dr.update_topology_status, ("t", "active")), + (dr.delete_topology_cascade, ("t",)), + (dr.add_lan, ({},)), + (dr.update_lan, ("l", {})), + (dr.list_lans_for_topology, ("t",)), + (dr.add_topology_decky, ({},)), + (dr.update_topology_decky, ("d", {})), + (dr.list_topology_deckies, ("t",)), + (dr.add_topology_edge, ({},)), + (dr.list_topology_edges, ("t",)), + (dr.list_topology_status_events, ("t",)), ]: with pytest.raises(NotImplementedError): await coro(*args) diff --git a/tests/topology/__init__.py b/tests/topology/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/topology/test_repo.py b/tests/topology/test_repo.py new file mode 100644 index 0000000..f1fb138 --- /dev/null +++ b/tests/topology/test_repo.py @@ -0,0 +1,166 @@ +"""Direct async tests for MazeNET topology persistence. + +Exercises the repository layer without going through the HTTP stack or +the in-memory generator. The synthetic topology here is hand-built so +the test remains meaningful even if generator.py regresses. +""" +import pytest +from decnet.web.db.factory import get_repository + + +@pytest.fixture +async def repo(tmp_path): + r = get_repository(db_path=str(tmp_path / "mazenet.db")) + await r.initialize() + return r + + +@pytest.mark.anyio +async def test_topology_roundtrip(repo): + t_id = await repo.create_topology( + { + "name": "alpha", + "mode": "unihost", + "config_snapshot": {"depth": 3, "seed": 42}, + } + ) + assert t_id + t = await repo.get_topology(t_id) + assert t is not None + assert t["name"] == "alpha" + assert t["status"] == "pending" + # JSON field round-trips as a dict, not a string + assert t["config_snapshot"] == {"depth": 3, "seed": 42} + + +@pytest.mark.anyio +async def test_lan_add_update_list(repo): + t_id = await repo.create_topology( + {"name": "beta", "mode": "unihost", "config_snapshot": {}} + ) + lan_id = await repo.add_lan( + {"topology_id": t_id, "name": "DMZ", "subnet": "172.20.0.0/24", "is_dmz": True} + ) + await repo.add_lan( + {"topology_id": t_id, "name": "LAN-A", "subnet": "172.20.1.0/24"} + ) + await repo.update_lan(lan_id, {"docker_network_id": "abc123"}) + lans = await repo.list_lans_for_topology(t_id) + assert len(lans) == 2 + by_name = {lan["name"]: lan for lan in lans} + assert by_name["DMZ"]["docker_network_id"] == "abc123" + assert by_name["DMZ"]["is_dmz"] is True + assert by_name["LAN-A"]["is_dmz"] is False + + +@pytest.mark.anyio +async def test_topology_decky_json_roundtrip(repo): + t_id = await repo.create_topology( + {"name": "gamma", "mode": "unihost", "config_snapshot": {}} + ) + d_uuid = await repo.add_topology_decky( + { + "topology_id": t_id, + "name": "decky-01", + "services": ["ssh", "http"], + "decky_config": {"hostname": "bastion"}, + "ip": "172.20.0.10", + } + ) + assert d_uuid + deckies = await repo.list_topology_deckies(t_id) + assert len(deckies) == 1 + assert deckies[0]["services"] == ["ssh", "http"] + assert deckies[0]["decky_config"] == {"hostname": "bastion"} + assert deckies[0]["state"] == "pending" + + await repo.update_topology_decky(d_uuid, {"state": "running", "ip": "172.20.0.11"}) + deckies = await repo.list_topology_deckies(t_id) + assert deckies[0]["state"] == "running" + assert deckies[0]["ip"] == "172.20.0.11" + + +@pytest.mark.anyio +async def test_topology_decky_name_unique_within_topology(repo): + """Same decky name is legal across topologies, forbidden within one.""" + t1 = await repo.create_topology( + {"name": "one", "mode": "unihost", "config_snapshot": {}} + ) + t2 = await repo.create_topology( + {"name": "two", "mode": "unihost", "config_snapshot": {}} + ) + await repo.add_topology_decky( + {"topology_id": t1, "name": "decky-01", "services": []} + ) + # Same name, different topology — must succeed. + await repo.add_topology_decky( + {"topology_id": t2, "name": "decky-01", "services": []} + ) + # Same name, same topology — must fail at the DB level. + with pytest.raises(Exception): + await repo.add_topology_decky( + {"topology_id": t1, "name": "decky-01", "services": []} + ) + + +@pytest.mark.anyio +async def test_status_transition_writes_event(repo): + t_id = await repo.create_topology( + {"name": "delta", "mode": "unihost", "config_snapshot": {}} + ) + await repo.update_topology_status(t_id, "deploying", reason="kickoff") + await repo.update_topology_status(t_id, "active") + topo = await repo.get_topology(t_id) + assert topo["status"] == "active" + + events = await repo.list_topology_status_events(t_id) + assert len(events) == 2 + # Ordered desc by at — latest first + assert events[0]["to_status"] == "active" + assert events[0]["from_status"] == "deploying" + assert events[1]["to_status"] == "deploying" + assert events[1]["from_status"] == "pending" + assert events[1]["reason"] == "kickoff" + + +@pytest.mark.anyio +async def test_cascade_delete_clears_all_children(repo): + t_id = await repo.create_topology( + {"name": "eps", "mode": "unihost", "config_snapshot": {}} + ) + lan_id = await repo.add_lan( + {"topology_id": t_id, "name": "L", "subnet": "10.0.0.0/24"} + ) + d_uuid = await repo.add_topology_decky( + {"topology_id": t_id, "name": "d", "services": []} + ) + await repo.add_topology_edge( + {"topology_id": t_id, "decky_uuid": d_uuid, "lan_id": lan_id} + ) + await repo.update_topology_status(t_id, "deploying") + + assert await repo.delete_topology_cascade(t_id) is True + assert await repo.get_topology(t_id) is None + assert await repo.list_lans_for_topology(t_id) == [] + assert await repo.list_topology_deckies(t_id) == [] + assert await repo.list_topology_edges(t_id) == [] + assert await repo.list_topology_status_events(t_id) == [] + # Second delete on a missing row returns False, no raise + assert await repo.delete_topology_cascade(t_id) is False + + +@pytest.mark.anyio +async def test_list_topologies_filters_by_status(repo): + a = await repo.create_topology( + {"name": "a", "mode": "unihost", "config_snapshot": {}} + ) + b = await repo.create_topology( + {"name": "b", "mode": "unihost", "config_snapshot": {}} + ) + await repo.update_topology_status(b, "deploying") + pend = await repo.list_topologies(status="pending") + assert {t["id"] for t in pend} == {a} + dep = await repo.list_topologies(status="deploying") + assert {t["id"] for t in dep} == {b} + both = await repo.list_topologies() + assert {t["id"] for t in both} == {a, b}