Moves the 19 attacker-domain methods (core CRUD, behavior, sessions, smtp targets, log-derived activity views) plus the _deserialize_attacker and _deserialize_behavior helpers into sqlmodel_repo/attackers.py.
1159 lines
44 KiB
Python
1159 lines
44 KiB
Python
"""
|
|
Shared SQLModel-based repository implementation.
|
|
|
|
Contains all dialect-portable query code used by the SQLite and MySQL
|
|
backends. Dialect-specific behavior lives in subclasses:
|
|
|
|
* engine/session construction (``__init__``)
|
|
* ``_migrate_attackers_table`` (legacy schema check; DDL introspection
|
|
is not portable)
|
|
* ``get_log_histogram`` (date-bucket expression differs per dialect)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
|
|
import orjson
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional, List
|
|
|
|
from sqlalchemy import func, select, desc, asc, text, update
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
|
|
from decnet.config import load_state
|
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
|
from decnet.web.auth import get_password_hash
|
|
from decnet.web.db.repository import BaseRepository
|
|
from decnet.web.db.models import (
|
|
User,
|
|
State,
|
|
Attacker,
|
|
AttackerIdentity,
|
|
Campaign,
|
|
Topology,
|
|
LAN,
|
|
TopologyDecky,
|
|
TopologyEdge,
|
|
TopologyStatusEvent,
|
|
TopologyMutation,
|
|
)
|
|
|
|
|
|
from decnet.web.db.sqlmodel_repo._helpers import ( # noqa: F401 (re-exported for tests/external)
|
|
_safe_session,
|
|
_detach_close,
|
|
_cleanup_tasks,
|
|
_serialize_json_fields,
|
|
_deserialize_json_fields,
|
|
)
|
|
from decnet.web.db.sqlmodel_repo.attacker_intel import AttackerIntelMixin
|
|
from decnet.web.db.sqlmodel_repo.attackers import AttackersMixin
|
|
from decnet.web.db.sqlmodel_repo.auth import AuthMixin
|
|
from decnet.web.db.sqlmodel_repo.bounties import BountiesMixin
|
|
from decnet.web.db.sqlmodel_repo.canary import CanaryMixin
|
|
from decnet.web.db.sqlmodel_repo.credentials import CredentialsMixin
|
|
from decnet.web.db.sqlmodel_repo.deckies import DeckiesMixin
|
|
from decnet.web.db.sqlmodel_repo.fleet import FleetMixin
|
|
from decnet.web.db.sqlmodel_repo.logs import LogsMixin
|
|
from decnet.web.db.sqlmodel_repo.orchestrator import OrchestratorMixin
|
|
from decnet.web.db.sqlmodel_repo.realism import RealismMixin
|
|
from decnet.web.db.sqlmodel_repo.swarm import SwarmMixin
|
|
from decnet.web.db.sqlmodel_repo.webhooks import WebhooksMixin
|
|
|
|
|
|
class SQLModelRepository(
|
|
AttackerIntelMixin,
|
|
AttackersMixin,
|
|
AuthMixin,
|
|
BountiesMixin,
|
|
CanaryMixin,
|
|
CredentialsMixin,
|
|
DeckiesMixin,
|
|
FleetMixin,
|
|
LogsMixin,
|
|
OrchestratorMixin,
|
|
RealismMixin,
|
|
SwarmMixin,
|
|
WebhooksMixin,
|
|
BaseRepository,
|
|
):
|
|
"""Concrete SQLModel/SQLAlchemy-async repository.
|
|
|
|
Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory``
|
|
in ``__init__``, and override the few dialect-specific helpers.
|
|
"""
|
|
|
|
engine: AsyncEngine
|
|
session_factory: async_sessionmaker[AsyncSession]
|
|
|
|
def _session(self):
|
|
"""Return a cancellation-safe session context manager."""
|
|
return _safe_session(self.session_factory)
|
|
|
|
# ------------------------------------------------------------ lifecycle
|
|
|
|
async def initialize(self) -> None:
|
|
"""Create tables if absent and seed the admin user."""
|
|
from sqlmodel import SQLModel
|
|
await self._migrate_attackers_table()
|
|
await self._migrate_session_profile_table()
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
await self._ensure_admin_user()
|
|
|
|
async def reinitialize(self) -> None:
|
|
"""Re-create schema (for tests / reset flows). Does NOT drop existing tables."""
|
|
from sqlmodel import SQLModel
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
await self._ensure_admin_user()
|
|
|
|
async def _ensure_admin_user(self) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(User).where(User.username == DECNET_ADMIN_USER)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is None:
|
|
session.add(User(
|
|
uuid=str(uuid.uuid4()),
|
|
username=DECNET_ADMIN_USER,
|
|
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
|
|
role="admin",
|
|
must_change_password=True,
|
|
))
|
|
await session.commit()
|
|
return
|
|
# Self-heal env drift: if admin never finalized their password,
|
|
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
|
|
# the user's chosen password alone.
|
|
if existing.must_change_password:
|
|
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
|
|
session.add(existing)
|
|
await session.commit()
|
|
|
|
async def _migrate_attackers_table(self) -> None:
|
|
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
|
return None
|
|
|
|
async def _migrate_session_profile_table(self) -> None:
|
|
"""Add DEBT-036 keystroke-dynamics columns to existing session_profile
|
|
rows. Override per dialect — DDL introspection is non-portable."""
|
|
return None
|
|
|
|
async def get_deckies(self) -> List[dict]:
|
|
_state = await asyncio.to_thread(load_state)
|
|
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
|
|
|
# --------------------------------------------------------------- users
|
|
|
|
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
statement = select(State).where(State.key == key)
|
|
result = await session.execute(statement)
|
|
state = result.scalar_one_or_none()
|
|
if state:
|
|
return json.loads(state.value)
|
|
return None
|
|
|
|
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
|
async with self._session() as session:
|
|
statement = select(State).where(State.key == key)
|
|
result = await session.execute(statement)
|
|
state = result.scalar_one_or_none()
|
|
|
|
value_json = orjson.dumps(value).decode()
|
|
if state:
|
|
state.value = value_json
|
|
session.add(state)
|
|
else:
|
|
session.add(State(key=key, value=value_json))
|
|
|
|
await session.commit()
|
|
|
|
# ----------------------------------------------------------- attackers
|
|
|
|
# ─── Identity resolution reads ────────────────────────────────────────
|
|
|
|
async def get_identity_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
|
# Follow merged_into_uuid up to the winner. Loop bounded by
|
|
# _MAX_MERGE_HOPS so a (hypothetically) corrupted ring can't
|
|
# spin the worker. Clusterer is responsible for never producing
|
|
# a cycle; this guard is belt-and-braces.
|
|
_MAX_MERGE_HOPS = 8
|
|
async with self._session() as session:
|
|
current_uuid = uuid
|
|
for _ in range(_MAX_MERGE_HOPS):
|
|
result = await session.execute(
|
|
select(AttackerIdentity).where(AttackerIdentity.uuid == current_uuid)
|
|
)
|
|
identity = result.scalar_one_or_none()
|
|
if identity is None:
|
|
return None
|
|
if identity.merged_into_uuid is None:
|
|
return identity.model_dump(mode="json")
|
|
current_uuid = identity.merged_into_uuid
|
|
# Hit the hop cap — surface what we have rather than recurse.
|
|
return identity.model_dump(mode="json")
|
|
|
|
async def list_identities(
|
|
self, limit: int = 50, offset: int = 0,
|
|
) -> list[dict[str, Any]]:
|
|
# Exclude merged-out rows so the list view is the de-duped truth.
|
|
# The history is still queryable per-uuid via get_identity_by_uuid
|
|
# and a future "merged into" endpoint when we need it.
|
|
statement = (
|
|
select(AttackerIdentity)
|
|
.where(AttackerIdentity.merged_into_uuid.is_(None))
|
|
.order_by(desc(AttackerIdentity.updated_at))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [i.model_dump(mode="json") for i in result.scalars().all()]
|
|
|
|
async def count_identities(self) -> int:
|
|
statement = (
|
|
select(func.count())
|
|
.select_from(AttackerIdentity)
|
|
.where(AttackerIdentity.merged_into_uuid.is_(None))
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
async def list_observations_for_identity(
|
|
self, identity_uuid: str, limit: int = 50, offset: int = 0,
|
|
) -> list[dict[str, Any]]:
|
|
statement = (
|
|
select(Attacker)
|
|
.where(Attacker.identity_id == identity_uuid)
|
|
.order_by(desc(Attacker.last_seen))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [
|
|
self._deserialize_attacker(a.model_dump(mode="json"))
|
|
for a in result.scalars().all()
|
|
]
|
|
|
|
async def count_observations_for_identity(self, identity_uuid: str) -> int:
|
|
statement = (
|
|
select(func.count())
|
|
.select_from(Attacker)
|
|
.where(Attacker.identity_id == identity_uuid)
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
# ─── Identity resolution writes (clusterer worker) ─────────────────────
|
|
|
|
async def list_attackers_for_clustering(
|
|
self, limit: Optional[int] = None,
|
|
) -> list[dict[str, Any]]:
|
|
# Project the columns the clusterer's similarity graph reads.
|
|
# Keep it narrow so future denormalised projections (payloads
|
|
# joined from logs, c2 endpoints aggregated from sessions) can
|
|
# land here without churning every caller. ``fingerprints`` is
|
|
# the raw JSON list — the clusterer parses for JA3 / HASSH.
|
|
statement = select(
|
|
Attacker.uuid, Attacker.asn, Attacker.identity_id, Attacker.fingerprints,
|
|
).order_by(Attacker.first_seen)
|
|
if limit is not None:
|
|
statement = statement.limit(limit)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [
|
|
{
|
|
"uuid": row.uuid,
|
|
"asn": row.asn,
|
|
"identity_id": row.identity_id,
|
|
"fingerprints": row.fingerprints,
|
|
}
|
|
for row in result.all()
|
|
]
|
|
|
|
async def create_attacker_identity(self, row: dict[str, Any]) -> str:
|
|
identity = AttackerIdentity(**row)
|
|
async with self._session() as session:
|
|
session.add(identity)
|
|
await session.commit()
|
|
return identity.uuid
|
|
|
|
async def set_attacker_identity_id(
|
|
self, attacker_uuid: str, identity_uuid: str,
|
|
) -> None:
|
|
statement = (
|
|
update(Attacker)
|
|
.where(Attacker.uuid == attacker_uuid)
|
|
.values(identity_id=identity_uuid)
|
|
)
|
|
async with self._session() as session:
|
|
await session.execute(statement)
|
|
await session.commit()
|
|
|
|
async def list_all_identities(self) -> list[dict[str, Any]]:
|
|
statement = select(AttackerIdentity).order_by(AttackerIdentity.created_at)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [i.model_dump(mode="json") for i in result.scalars().all()]
|
|
|
|
async def update_identity_merged_into(
|
|
self, identity_uuid: str, winner_uuid: Optional[str],
|
|
) -> None:
|
|
statement = (
|
|
update(AttackerIdentity)
|
|
.where(AttackerIdentity.uuid == identity_uuid)
|
|
.values(
|
|
merged_into_uuid=winner_uuid,
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
async with self._session() as session:
|
|
await session.execute(statement)
|
|
await session.commit()
|
|
|
|
async def update_identity_fingerprints(
|
|
self,
|
|
identity_uuid: str,
|
|
*,
|
|
ja3_hashes: Optional[str] = None,
|
|
hassh_hashes: Optional[str] = None,
|
|
tls_cert_sha256: Optional[str] = None,
|
|
) -> None:
|
|
statement = (
|
|
update(AttackerIdentity)
|
|
.where(AttackerIdentity.uuid == identity_uuid)
|
|
.values(
|
|
ja3_hashes=ja3_hashes,
|
|
hassh_hashes=hassh_hashes,
|
|
tls_cert_sha256=tls_cert_sha256,
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
async with self._session() as session:
|
|
await session.execute(statement)
|
|
await session.commit()
|
|
|
|
# ─── Campaign clustering reads ────────────────────────────────────────
|
|
|
|
async def get_campaign_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
|
# Same chain-walk as get_identity_by_uuid; bounded against
|
|
# corrupted rings.
|
|
_MAX_MERGE_HOPS = 8
|
|
async with self._session() as session:
|
|
current_uuid = uuid
|
|
for _ in range(_MAX_MERGE_HOPS):
|
|
result = await session.execute(
|
|
select(Campaign).where(Campaign.uuid == current_uuid)
|
|
)
|
|
campaign = result.scalar_one_or_none()
|
|
if campaign is None:
|
|
return None
|
|
if campaign.merged_into_uuid is None:
|
|
return campaign.model_dump(mode="json")
|
|
current_uuid = campaign.merged_into_uuid
|
|
return campaign.model_dump(mode="json")
|
|
|
|
async def list_campaigns(
|
|
self, limit: int = 50, offset: int = 0,
|
|
) -> list[dict[str, Any]]:
|
|
statement = (
|
|
select(Campaign)
|
|
.where(Campaign.merged_into_uuid.is_(None))
|
|
.order_by(desc(Campaign.updated_at))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [c.model_dump(mode="json") for c in result.scalars().all()]
|
|
|
|
async def count_campaigns(self) -> int:
|
|
statement = (
|
|
select(func.count())
|
|
.select_from(Campaign)
|
|
.where(Campaign.merged_into_uuid.is_(None))
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
async def list_identities_for_campaign(
|
|
self, campaign_uuid: str, limit: int = 50, offset: int = 0,
|
|
) -> list[dict[str, Any]]:
|
|
statement = (
|
|
select(AttackerIdentity)
|
|
.where(AttackerIdentity.campaign_id == campaign_uuid)
|
|
.order_by(desc(AttackerIdentity.updated_at))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [i.model_dump(mode="json") for i in result.scalars().all()]
|
|
|
|
async def count_identities_for_campaign(self, campaign_uuid: str) -> int:
|
|
statement = (
|
|
select(func.count())
|
|
.select_from(AttackerIdentity)
|
|
.where(AttackerIdentity.campaign_id == campaign_uuid)
|
|
)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
# ─── Campaign clustering writes (campaign-clusterer worker) ───────────
|
|
|
|
async def list_identities_for_clustering(
|
|
self, limit: Optional[int] = None,
|
|
) -> list[dict[str, Any]]:
|
|
# Project the columns the campaign clusterer's similarity
|
|
# graph reads. Narrow on purpose — future denormalised
|
|
# projections (commands_by_phase from log mining, decky-set
|
|
# aggregates) can land here without churning callers.
|
|
statement = select(
|
|
AttackerIdentity.uuid,
|
|
AttackerIdentity.campaign_id,
|
|
AttackerIdentity.merged_into_uuid,
|
|
AttackerIdentity.first_seen_at,
|
|
AttackerIdentity.last_seen_at,
|
|
AttackerIdentity.ja3_hashes,
|
|
AttackerIdentity.hassh_hashes,
|
|
AttackerIdentity.payload_simhashes,
|
|
AttackerIdentity.c2_endpoints,
|
|
).order_by(AttackerIdentity.created_at)
|
|
if limit is not None:
|
|
statement = statement.limit(limit)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [
|
|
{
|
|
"uuid": row.uuid,
|
|
"campaign_id": row.campaign_id,
|
|
"merged_into_uuid": row.merged_into_uuid,
|
|
"first_seen_at": (
|
|
row.first_seen_at.isoformat()
|
|
if row.first_seen_at is not None
|
|
else None
|
|
),
|
|
"last_seen_at": (
|
|
row.last_seen_at.isoformat()
|
|
if row.last_seen_at is not None
|
|
else None
|
|
),
|
|
"ja3_hashes": row.ja3_hashes,
|
|
"hassh_hashes": row.hassh_hashes,
|
|
"payload_simhashes": row.payload_simhashes,
|
|
"c2_endpoints": row.c2_endpoints,
|
|
}
|
|
for row in result.all()
|
|
]
|
|
|
|
async def create_campaign(self, row: dict[str, Any]) -> str:
|
|
campaign = Campaign(**row)
|
|
async with self._session() as session:
|
|
session.add(campaign)
|
|
await session.commit()
|
|
return campaign.uuid
|
|
|
|
async def set_identity_campaign_id(
|
|
self, identity_uuid: str, campaign_uuid: Optional[str],
|
|
) -> None:
|
|
statement = (
|
|
update(AttackerIdentity)
|
|
.where(AttackerIdentity.uuid == identity_uuid)
|
|
.values(
|
|
campaign_id=campaign_uuid,
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
async with self._session() as session:
|
|
await session.execute(statement)
|
|
await session.commit()
|
|
|
|
async def list_all_campaigns(self) -> list[dict[str, Any]]:
|
|
statement = select(Campaign).order_by(Campaign.created_at)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [c.model_dump(mode="json") for c in result.scalars().all()]
|
|
|
|
async def update_campaign_merged_into(
|
|
self, campaign_uuid: str, winner_uuid: Optional[str],
|
|
) -> None:
|
|
statement = (
|
|
update(Campaign)
|
|
.where(Campaign.uuid == campaign_uuid)
|
|
.values(
|
|
merged_into_uuid=winner_uuid,
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
async with self._session() as session:
|
|
await session.execute(statement)
|
|
await session.commit()
|
|
|
|
# ------------------------------------------------------------ mazenet
|
|
|
|
async def create_topology(self, data: dict[str, Any]) -> str:
|
|
payload = _serialize_json_fields(data, ("config_snapshot",))
|
|
async with self._session() as session:
|
|
row = Topology(**payload)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if not row:
|
|
return None
|
|
d = row.model_dump(mode="json")
|
|
return _deserialize_json_fields(d, ("config_snapshot",))
|
|
|
|
async def list_topologies(
|
|
self,
|
|
status: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
) -> list[dict[str, Any]]:
|
|
statement = select(Topology).order_by(desc(Topology.created_at))
|
|
if status:
|
|
statement = statement.where(Topology.status == status)
|
|
if offset is not None:
|
|
statement = statement.offset(offset)
|
|
if limit is not None:
|
|
statement = statement.limit(limit)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [
|
|
_deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("config_snapshot",)
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|
|
async def count_topologies(self, status: Optional[str] = None) -> int:
|
|
from sqlalchemy import func
|
|
statement = select(func.count(Topology.id))
|
|
if status:
|
|
statement = statement.where(Topology.status == status)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return int(result.scalar_one() or 0)
|
|
|
|
async def update_topology_status(
|
|
self,
|
|
topology_id: str,
|
|
new_status: str,
|
|
reason: Optional[str] = None,
|
|
) -> None:
|
|
"""Update topology.status and append a TopologyStatusEvent atomically.
|
|
|
|
Transition legality is enforced in ``decnet.topology.status``; this
|
|
method trusts the caller.
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
return
|
|
from_status = topo.status
|
|
topo.status = new_status
|
|
topo.status_changed_at = now
|
|
session.add(topo)
|
|
session.add(
|
|
TopologyStatusEvent(
|
|
topology_id=topology_id,
|
|
from_status=from_status,
|
|
to_status=new_status,
|
|
at=now,
|
|
reason=reason,
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def set_topology_resync(self, topology_id: str, value: bool) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
return
|
|
topo.needs_resync = bool(value)
|
|
session.add(topo)
|
|
await session.commit()
|
|
|
|
async def set_topology_email_personas(
|
|
self, topology_id: str, personas_json: str,
|
|
) -> bool:
|
|
"""Replace ``Topology.email_personas`` with the supplied JSON.
|
|
|
|
The string is stored as-is; validation/parsing is the caller's
|
|
job (and is repeated by the emailgen scheduler each tick anyway).
|
|
Returns True if a row was updated.
|
|
"""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
return False
|
|
topo.email_personas = personas_json
|
|
session.add(topo)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def list_topologies_needing_resync(self) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.needs_resync == True) # noqa: E712
|
|
)
|
|
return [
|
|
_deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("config_snapshot",)
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|
|
async def delete_topology_cascade(self, topology_id: str) -> bool:
|
|
"""Delete topology and all children. No portable ON DELETE CASCADE."""
|
|
async with self._session() as session:
|
|
params = {"t": topology_id}
|
|
await session.execute(
|
|
text("DELETE FROM topology_status_events WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_deckies WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM lans WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_mutations WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if not topo:
|
|
await session.commit()
|
|
return False
|
|
await session.delete(topo)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def _assert_pending(self, session, topology_id: str) -> None:
|
|
"""Pre-deploy edits are pending-only. Raises TopologyNotEditable."""
|
|
from decnet.topology.status import TopologyNotEditable, TopologyStatus
|
|
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
raise ValueError(f"topology {topology_id!r} not found")
|
|
if topo.status != TopologyStatus.PENDING:
|
|
raise TopologyNotEditable(
|
|
status=topo.status,
|
|
reason="free-form edits are pending-only; use the "
|
|
"mutator (topology_mutations) after deploy",
|
|
)
|
|
|
|
async def _check_and_bump_version(
|
|
self,
|
|
session,
|
|
topology_id: str,
|
|
expected_version: Optional[int],
|
|
) -> None:
|
|
"""Optimistic-concurrency guard used by child-row mutators.
|
|
|
|
If ``expected_version`` is None, no check happens (backward-compat
|
|
for internal callers that don't need concurrency protection).
|
|
|
|
If supplied, loads the Topology row in the same session,
|
|
compares ``version == expected_version``, raises VersionConflict
|
|
on mismatch, otherwise bumps ``version += 1``. The caller must
|
|
commit the enclosing session.
|
|
"""
|
|
from decnet.topology.status import VersionConflict
|
|
|
|
if expected_version is None:
|
|
return
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
raise ValueError(f"topology {topology_id!r} not found")
|
|
if topo.version != expected_version:
|
|
raise VersionConflict(
|
|
current=topo.version, expected=expected_version
|
|
)
|
|
topo.version = topo.version + 1
|
|
session.add(topo)
|
|
|
|
async def add_lan(
|
|
self,
|
|
data: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, data["topology_id"], expected_version
|
|
)
|
|
row = LAN(**data)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def update_lan(
|
|
self,
|
|
lan_id: str,
|
|
fields: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
enforce_pending: bool = False,
|
|
) -> None:
|
|
if not fields:
|
|
return
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(LAN).where(LAN.id == lan_id)
|
|
)
|
|
lan = result.scalar_one_or_none()
|
|
if lan is None:
|
|
raise ValueError(f"lan {lan_id!r} not found")
|
|
if enforce_pending:
|
|
await self._assert_pending(session, lan.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, lan.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
update(LAN).where(LAN.id == lan_id).values(**fields)
|
|
)
|
|
await session.commit()
|
|
|
|
async def delete_lan(
|
|
self,
|
|
lan_id: str,
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> None:
|
|
"""Cascade-delete a LAN from a pending topology.
|
|
|
|
Rejects if any decky declares this LAN as its home (i.e. has a
|
|
non-bridge edge to it — the only LAN that decky lives in). The
|
|
caller must delete or reassign the home-deckies first.
|
|
"""
|
|
from decnet.topology.status import TopologyNotEditable # noqa: F401
|
|
|
|
async with self._session() as session:
|
|
result = await session.execute(select(LAN).where(LAN.id == lan_id))
|
|
lan = result.scalar_one_or_none()
|
|
if lan is None:
|
|
return
|
|
await self._assert_pending(session, lan.topology_id)
|
|
|
|
# Home-decky check: any decky whose only edge lands here?
|
|
edges_result = await session.execute(
|
|
select(TopologyEdge).where(TopologyEdge.lan_id == lan_id)
|
|
)
|
|
edges_here = edges_result.scalars().all()
|
|
decky_uuids_on_this_lan = {e.decky_uuid for e in edges_here}
|
|
for decky_uuid in decky_uuids_on_this_lan:
|
|
other = await session.execute(
|
|
select(TopologyEdge).where(
|
|
TopologyEdge.decky_uuid == decky_uuid,
|
|
TopologyEdge.lan_id != lan_id,
|
|
)
|
|
)
|
|
if other.scalars().first() is None:
|
|
raise ValueError(
|
|
f"cannot delete LAN {lan.name!r}: decky "
|
|
f"{decky_uuid} has no other LAN (would be orphaned)"
|
|
)
|
|
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, lan.topology_id, expected_version
|
|
)
|
|
# Cascade edges → LAN.
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE lan_id = :l"),
|
|
{"l": lan_id},
|
|
)
|
|
await session.execute(text("DELETE FROM lans WHERE id = :l"), {"l": lan_id})
|
|
await session.commit()
|
|
|
|
async def list_lans_for_topology(
|
|
self, topology_id: str
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name))
|
|
)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def add_topology_decky(
|
|
self,
|
|
data: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
payload = _serialize_json_fields(data, ("services", "decky_config"))
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, data["topology_id"], expected_version
|
|
)
|
|
row = TopologyDecky(**payload)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.uuid
|
|
|
|
async def update_topology_decky(
|
|
self,
|
|
decky_uuid: str,
|
|
fields: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
enforce_pending: bool = False,
|
|
) -> None:
|
|
if not fields:
|
|
return
|
|
payload = _serialize_json_fields(fields, ("services", "decky_config"))
|
|
payload.setdefault("updated_at", datetime.now(timezone.utc))
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
|
)
|
|
d = result.scalar_one_or_none()
|
|
if d is None:
|
|
raise ValueError(f"decky {decky_uuid!r} not found")
|
|
if enforce_pending:
|
|
await self._assert_pending(session, d.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, d.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
update(TopologyDecky)
|
|
.where(TopologyDecky.uuid == decky_uuid)
|
|
.values(**payload)
|
|
)
|
|
await session.commit()
|
|
|
|
async def delete_topology_decky(
|
|
self,
|
|
decky_uuid: str,
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> None:
|
|
"""Cascade-delete a decky + all its edges from a pending topology."""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
|
)
|
|
d = result.scalar_one_or_none()
|
|
if d is None:
|
|
return
|
|
await self._assert_pending(session, d.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, d.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE decky_uuid = :u"),
|
|
{"u": decky_uuid},
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_deckies WHERE uuid = :u"),
|
|
{"u": decky_uuid},
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_topology_deckies(
|
|
self, topology_id: str
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky)
|
|
.where(TopologyDecky.topology_id == topology_id)
|
|
.order_by(asc(TopologyDecky.name))
|
|
)
|
|
return [
|
|
_deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("services", "decky_config")
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|
|
async def add_topology_edge(
|
|
self,
|
|
data: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, data["topology_id"], expected_version
|
|
)
|
|
row = TopologyEdge(**data)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def delete_topology_edge(
|
|
self,
|
|
edge_id: str,
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyEdge).where(TopologyEdge.id == edge_id)
|
|
)
|
|
edge = result.scalar_one_or_none()
|
|
if edge is None:
|
|
return
|
|
await self._assert_pending(session, edge.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, edge.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE id = :e"),
|
|
{"e": edge_id},
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_topology_edges(
|
|
self, topology_id: str
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyEdge).where(TopologyEdge.topology_id == topology_id)
|
|
)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def list_topology_status_events(
|
|
self, topology_id: str, limit: int = 100
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyStatusEvent)
|
|
.where(TopologyStatusEvent.topology_id == topology_id)
|
|
.order_by(desc(TopologyStatusEvent.at))
|
|
.limit(limit)
|
|
)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
# ---------------- topology_mutations (live reconciler queue) ----------------
|
|
|
|
async def enqueue_topology_mutation(
|
|
self,
|
|
topology_id: str,
|
|
op: str,
|
|
payload: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
"""Append a pending mutation row and bump the topology version.
|
|
|
|
Intended for use while the topology is ``active|degraded``; the
|
|
reconciler picks these rows up on its next tick.
|
|
"""
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, topology_id, expected_version
|
|
)
|
|
row = TopologyMutation(
|
|
topology_id=topology_id,
|
|
op=op,
|
|
payload=orjson.dumps(payload).decode(),
|
|
)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def claim_next_mutation(
|
|
self, topology_id: str
|
|
) -> Optional[dict[str, Any]]:
|
|
"""Atomically claim the oldest pending mutation for ``topology_id``.
|
|
|
|
Correctness-critical: this is ONE SQL statement. Splitting it
|
|
into SELECT-then-UPDATE would let two racing watch-loops both
|
|
see the same ``pending`` row and both transition it to
|
|
``applying`` — double-executing the op. With the single
|
|
``UPDATE ... WHERE id = (SELECT ... LIMIT 1) AND state='pending'``
|
|
pattern the loser's UPDATE matches zero rows and returns
|
|
``None`` — that is the expected, non-error outcome under
|
|
contention.
|
|
"""
|
|
async with self._session() as session:
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
# Single-statement atomic claim. The inner SELECT picks the
|
|
# oldest pending row; the outer UPDATE re-checks state so a
|
|
# second racer that also saw that id finds state='applying'
|
|
# and matches zero rows.
|
|
# MySQL forbids referencing the UPDATE target inside a
|
|
# subquery (ERROR 1093). Wrapping the inner SELECT in a
|
|
# derived table forces materialisation and sidesteps the
|
|
# rule. SQLite accepts both forms, so this stays portable.
|
|
sql = text(
|
|
"""
|
|
UPDATE topology_mutations
|
|
SET state = 'applying'
|
|
WHERE id = (
|
|
SELECT id FROM (
|
|
SELECT id FROM topology_mutations
|
|
WHERE topology_id = :t AND state = 'pending'
|
|
ORDER BY requested_at ASC
|
|
LIMIT 1
|
|
) AS _next
|
|
)
|
|
AND state = 'pending'
|
|
"""
|
|
)
|
|
result = await session.execute(sql, {"t": topology_id})
|
|
if result.rowcount == 0:
|
|
await session.commit()
|
|
return None
|
|
# Re-read the row we just claimed. The post-UPDATE SELECT is
|
|
# safe: no racer can now transition an ``applying`` row back
|
|
# to ``pending``.
|
|
sel = await session.execute(
|
|
select(TopologyMutation)
|
|
.where(TopologyMutation.topology_id == topology_id)
|
|
.where(TopologyMutation.state == "applying")
|
|
.order_by(asc(TopologyMutation.requested_at))
|
|
.limit(1)
|
|
)
|
|
row = sel.scalar_one_or_none()
|
|
await session.commit()
|
|
_ = now
|
|
if row is None:
|
|
return None
|
|
return row.model_dump(mode="json")
|
|
|
|
async def mark_mutation_applied(self, mutation_id: str) -> None:
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
text(
|
|
"UPDATE topology_mutations "
|
|
"SET state = 'applied', applied_at = :at "
|
|
"WHERE id = :i"
|
|
),
|
|
{
|
|
"at": datetime.now(timezone.utc).isoformat(),
|
|
"i": mutation_id,
|
|
},
|
|
)
|
|
await session.commit()
|
|
|
|
async def mark_mutation_failed(
|
|
self, mutation_id: str, reason: str
|
|
) -> None:
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
text(
|
|
"UPDATE topology_mutations "
|
|
"SET state = 'failed', applied_at = :at, reason = :r "
|
|
"WHERE id = :i"
|
|
),
|
|
{
|
|
"at": datetime.now(timezone.utc).isoformat(),
|
|
"r": reason,
|
|
"i": mutation_id,
|
|
},
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_topology_mutations(
|
|
self,
|
|
topology_id: str,
|
|
state: Optional[str] = None,
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
stmt = (
|
|
select(TopologyMutation)
|
|
.where(TopologyMutation.topology_id == topology_id)
|
|
.order_by(desc(TopologyMutation.requested_at))
|
|
)
|
|
if state is not None:
|
|
stmt = stmt.where(TopologyMutation.state == state)
|
|
result = await session.execute(stmt)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def has_pending_topology_mutation(self) -> bool:
|
|
"""Cheap watch-loop guard: any pending mutation on a live topology?
|
|
|
|
Uses the ``ix_topology_mutations_state_topology`` composite index
|
|
to keep the join cheap at scale. Returns False as soon as the
|
|
reconciler path should be skipped.
|
|
"""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
text(
|
|
"SELECT 1 FROM topology_mutations "
|
|
"WHERE state = 'pending' "
|
|
"AND topology_id IN ("
|
|
" SELECT id FROM topologies "
|
|
" WHERE status IN ('active', 'degraded')"
|
|
") LIMIT 1"
|
|
)
|
|
)
|
|
return result.first() is not None
|
|
|
|
async def list_live_topology_ids(self) -> list[str]:
|
|
"""Return ids of topologies currently in ``active|degraded``."""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology.id).where(
|
|
Topology.status.in_(["active", "degraded"])
|
|
)
|
|
)
|
|
return [r for r in result.scalars().all()]
|
|
|
|
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky).where(TopologyDecky.state == "running")
|
|
)
|
|
return [
|
|
_deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("services", "decky_config")
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|