delete_topology_cascade manually deletes status_events, edges, deckies and lans but overlooked topology_mutations, so deleting any topology that ever had a mutation enqueued (i.e. edits while active|degraded) failed with an FK IntegrityError. Add the missing DELETE and extend the cascade test to seed a mutation row.
1600 lines
61 KiB
Python
1600 lines
61 KiB
Python
"""
|
|
Shared SQLModel-based repository implementation.
|
|
|
|
Contains all dialect-portable query code used by the SQLite and MySQL
|
|
backends. Dialect-specific behavior lives in subclasses:
|
|
|
|
* engine/session construction (``__init__``)
|
|
* ``_migrate_attackers_table`` (legacy schema check; DDL introspection
|
|
is not portable)
|
|
* ``get_log_histogram`` (date-bucket expression differs per dialect)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
|
|
import orjson
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional, List
|
|
|
|
from sqlalchemy import func, select, desc, asc, text, or_, update
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
from sqlmodel.sql.expression import SelectOfScalar
|
|
|
|
from decnet.config import load_state
|
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
|
from decnet.web.auth import get_password_hash
|
|
from decnet.web.db.repository import BaseRepository
|
|
from decnet.web.db.models import (
|
|
User,
|
|
Log,
|
|
Bounty,
|
|
State,
|
|
Attacker,
|
|
AttackerBehavior,
|
|
SwarmHost,
|
|
DeckyShard,
|
|
Topology,
|
|
LAN,
|
|
TopologyDecky,
|
|
TopologyEdge,
|
|
TopologyStatusEvent,
|
|
TopologyMutation,
|
|
)
|
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
from decnet.logging import get_logger
|
|
|
|
_log = get_logger("db.pool")
|
|
|
|
# Hold strong refs to in-flight cleanup tasks so they aren't GC'd mid-run.
|
|
_cleanup_tasks: set[asyncio.Task] = set()
|
|
|
|
|
|
def _detach_close(session: AsyncSession) -> None:
|
|
"""Hand session cleanup to a fresh task so the caller's cancellation
|
|
doesn't interrupt it.
|
|
|
|
``asyncio.shield`` doesn't help on the exception path: shield prevents
|
|
*other* tasks from cancelling the inner coroutine, but if the *current*
|
|
task is already cancelled, its next ``await`` re-raises
|
|
``CancelledError`` as soon as the inner coroutine yields. That's what
|
|
happens when uvicorn cancels a request mid-query — the rollback inside
|
|
``session.close()`` can't complete, and the aiomysql connection is
|
|
orphaned (pool logs "non-checked-in connection" on GC).
|
|
|
|
A fresh task isn't subject to the caller's pending cancellation, so
|
|
``close()`` (or the ``invalidate()`` fallback for a dead connection)
|
|
runs to completion and the pool reclaims the connection promptly.
|
|
|
|
Fire-and-forget on purpose: the caller is already unwinding and must
|
|
not wait on cleanup.
|
|
"""
|
|
async def _cleanup() -> None:
|
|
try:
|
|
await session.close()
|
|
except BaseException:
|
|
try:
|
|
session.sync_session.invalidate()
|
|
except BaseException:
|
|
_log.debug("detach-close: invalidate failed", exc_info=True)
|
|
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
# No running loop (shutdown path) — best-effort sync invalidate.
|
|
try:
|
|
session.sync_session.invalidate()
|
|
except BaseException:
|
|
_log.debug("detach-close: no-loop invalidate failed", exc_info=True)
|
|
return
|
|
task = loop.create_task(_cleanup())
|
|
_cleanup_tasks.add(task)
|
|
# Consume any exception to silence "Task exception was never retrieved".
|
|
task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception()))
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _safe_session(factory: async_sessionmaker[AsyncSession]):
|
|
"""Session context manager that keeps close() reliable under cancellation.
|
|
|
|
Success path: await close() inline so the caller observes cleanup
|
|
(commit visibility, connection release) before proceeding.
|
|
|
|
Exception path (includes CancelledError from client disconnects):
|
|
detach close() to a fresh task. The caller is unwinding and its
|
|
own cancellation would abort an inline close mid-rollback, leaving
|
|
the aiomysql connection orphaned.
|
|
"""
|
|
session = factory()
|
|
try:
|
|
yield session
|
|
except BaseException:
|
|
_detach_close(session)
|
|
raise
|
|
else:
|
|
await session.close()
|
|
|
|
|
|
class SQLModelRepository(BaseRepository):
|
|
"""Concrete SQLModel/SQLAlchemy-async repository.
|
|
|
|
Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory``
|
|
in ``__init__``, and override the few dialect-specific helpers.
|
|
"""
|
|
|
|
engine: AsyncEngine
|
|
session_factory: async_sessionmaker[AsyncSession]
|
|
|
|
def _session(self):
|
|
"""Return a cancellation-safe session context manager."""
|
|
return _safe_session(self.session_factory)
|
|
|
|
# ------------------------------------------------------------ lifecycle
|
|
|
|
async def initialize(self) -> None:
|
|
"""Create tables if absent and seed the admin user."""
|
|
from sqlmodel import SQLModel
|
|
await self._migrate_attackers_table()
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
await self._ensure_admin_user()
|
|
|
|
async def reinitialize(self) -> None:
|
|
"""Re-create schema (for tests / reset flows). Does NOT drop existing tables."""
|
|
from sqlmodel import SQLModel
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
await self._ensure_admin_user()
|
|
|
|
async def _ensure_admin_user(self) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(User).where(User.username == DECNET_ADMIN_USER)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is None:
|
|
session.add(User(
|
|
uuid=str(uuid.uuid4()),
|
|
username=DECNET_ADMIN_USER,
|
|
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
|
|
role="admin",
|
|
must_change_password=True,
|
|
))
|
|
await session.commit()
|
|
return
|
|
# Self-heal env drift: if admin never finalized their password,
|
|
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
|
|
# the user's chosen password alone.
|
|
if existing.must_change_password:
|
|
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
|
|
session.add(existing)
|
|
await session.commit()
|
|
|
|
async def _migrate_attackers_table(self) -> None:
|
|
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
|
return None
|
|
|
|
# ---------------------------------------------------------------- logs
|
|
|
|
@staticmethod
|
|
def _normalize_log_row(log_data: dict[str, Any]) -> dict[str, Any]:
|
|
data = log_data.copy()
|
|
if "fields" in data and isinstance(data["fields"], dict):
|
|
data["fields"] = orjson.dumps(data["fields"]).decode()
|
|
if "timestamp" in data and isinstance(data["timestamp"], str):
|
|
try:
|
|
data["timestamp"] = datetime.fromisoformat(
|
|
data["timestamp"].replace("Z", "+00:00")
|
|
)
|
|
except ValueError:
|
|
pass
|
|
return data
|
|
|
|
async def add_log(self, log_data: dict[str, Any]) -> None:
|
|
data = self._normalize_log_row(log_data)
|
|
async with self._session() as session:
|
|
session.add(Log(**data))
|
|
await session.commit()
|
|
|
|
async def add_logs(self, log_entries: list[dict[str, Any]]) -> None:
|
|
"""Bulk insert — one session, one commit for the whole batch."""
|
|
if not log_entries:
|
|
return
|
|
_rows = [Log(**self._normalize_log_row(e)) for e in log_entries]
|
|
async with self._session() as session:
|
|
session.add_all(_rows)
|
|
await session.commit()
|
|
|
|
def _apply_filters(
|
|
self,
|
|
statement: SelectOfScalar,
|
|
search: Optional[str],
|
|
start_time: Optional[str],
|
|
end_time: Optional[str],
|
|
) -> SelectOfScalar:
|
|
import re
|
|
import shlex
|
|
|
|
if start_time:
|
|
statement = statement.where(Log.timestamp >= start_time)
|
|
if end_time:
|
|
statement = statement.where(Log.timestamp <= end_time)
|
|
|
|
if search:
|
|
try:
|
|
tokens = shlex.split(search)
|
|
except ValueError:
|
|
tokens = search.split()
|
|
|
|
core_fields = {
|
|
"decky": Log.decky,
|
|
"service": Log.service,
|
|
"event": Log.event_type,
|
|
"attacker": Log.attacker_ip,
|
|
"attacker-ip": Log.attacker_ip,
|
|
"attacker_ip": Log.attacker_ip,
|
|
}
|
|
|
|
for token in tokens:
|
|
if ":" in token:
|
|
key, val = token.split(":", 1)
|
|
if key in core_fields:
|
|
statement = statement.where(core_fields[key] == val)
|
|
else:
|
|
key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key)
|
|
if key_safe:
|
|
statement = statement.where(
|
|
self._json_field_equals(key_safe)
|
|
).params(val=val)
|
|
else:
|
|
lk = f"%{token}%"
|
|
statement = statement.where(
|
|
or_(
|
|
Log.raw_line.like(lk),
|
|
Log.decky.like(lk),
|
|
Log.service.like(lk),
|
|
Log.attacker_ip.like(lk),
|
|
)
|
|
)
|
|
return statement
|
|
|
|
def _json_field_equals(self, key: str):
|
|
"""Return a text() predicate that matches rows where fields->key == :val.
|
|
|
|
Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also
|
|
exposes the same function under ``json_extract`` (case-insensitive).
|
|
The ``:val`` parameter is bound separately and must be supplied with
|
|
``.params(val=...)`` by the caller, which keeps us safe from injection.
|
|
"""
|
|
return text(f"JSON_EXTRACT(fields, '$.{key}') = :val")
|
|
|
|
async def get_logs(
|
|
self,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
search: Optional[str] = None,
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
) -> List[dict]:
|
|
statement = (
|
|
select(Log)
|
|
.order_by(desc(Log.timestamp))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
|
|
|
async with self._session() as session:
|
|
results = await session.execute(statement)
|
|
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
|
|
|
async def get_max_log_id(self) -> int:
|
|
async with self._session() as session:
|
|
result = await session.execute(select(func.max(Log.id)))
|
|
val = result.scalar()
|
|
return val if val is not None else 0
|
|
|
|
async def get_logs_after_id(
|
|
self,
|
|
last_id: int,
|
|
limit: int = 50,
|
|
search: Optional[str] = None,
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
) -> List[dict]:
|
|
statement = (
|
|
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
|
)
|
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
|
|
|
async with self._session() as session:
|
|
results = await session.execute(statement)
|
|
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
|
|
|
async def get_total_logs(
|
|
self,
|
|
search: Optional[str] = None,
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
) -> int:
|
|
statement = select(func.count()).select_from(Log)
|
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
|
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
async def get_log_histogram(
|
|
self,
|
|
search: Optional[str] = None,
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
interval_minutes: int = 15,
|
|
) -> List[dict]:
|
|
"""Dialect-specific — override per backend."""
|
|
raise NotImplementedError
|
|
|
|
async def get_stats_summary(self) -> dict[str, Any]:
|
|
async with self._session() as session:
|
|
total_logs = (
|
|
await session.execute(select(func.count()).select_from(Log))
|
|
).scalar() or 0
|
|
unique_attackers = (
|
|
await session.execute(
|
|
select(func.count(func.distinct(Log.attacker_ip)))
|
|
)
|
|
).scalar() or 0
|
|
topo_total = (
|
|
await session.execute(select(func.count()).select_from(TopologyDecky))
|
|
).scalar() or 0
|
|
topo_running = (
|
|
await session.execute(
|
|
select(func.count())
|
|
.select_from(TopologyDecky)
|
|
.where(TopologyDecky.state == "running")
|
|
)
|
|
).scalar() or 0
|
|
|
|
_state = await asyncio.to_thread(load_state)
|
|
fleet_deckies = len(_state[0].deckies) if _state else 0
|
|
|
|
return {
|
|
"total_logs": total_logs,
|
|
"unique_attackers": unique_attackers,
|
|
# Fleet state file doesn't track per-decky runtime; treat all
|
|
# fleet rows as active and add MazeNET running rows on top.
|
|
"active_deckies": fleet_deckies + topo_running,
|
|
"deployed_deckies": fleet_deckies + topo_total,
|
|
}
|
|
|
|
async def get_deckies(self) -> List[dict]:
|
|
_state = await asyncio.to_thread(load_state)
|
|
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
|
|
|
# --------------------------------------------------------------- users
|
|
|
|
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(User).where(User.username == username)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
return user.model_dump() if user else None
|
|
|
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(User).where(User.uuid == uuid)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
return user.model_dump() if user else None
|
|
|
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
|
async with self._session() as session:
|
|
session.add(User(**user_data))
|
|
await session.commit()
|
|
|
|
async def update_user_password(
|
|
self, uuid: str, password_hash: str, must_change_password: bool = False
|
|
) -> None:
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
update(User)
|
|
.where(User.uuid == uuid)
|
|
.values(
|
|
password_hash=password_hash,
|
|
must_change_password=must_change_password,
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_users(self) -> list[dict]:
|
|
async with self._session() as session:
|
|
result = await session.execute(select(User))
|
|
return [u.model_dump() for u in result.scalars().all()]
|
|
|
|
async def delete_user(self, uuid: str) -> bool:
|
|
async with self._session() as session:
|
|
result = await session.execute(select(User).where(User.uuid == uuid))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
return False
|
|
await session.delete(user)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def update_user_role(self, uuid: str, role: str) -> None:
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
update(User).where(User.uuid == uuid).values(role=role)
|
|
)
|
|
await session.commit()
|
|
|
|
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
|
async with self._session() as session:
|
|
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
|
|
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
|
|
# attacker_behavior has FK → attackers.uuid; delete children first.
|
|
await session.execute(text("DELETE FROM attacker_behavior"))
|
|
attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount
|
|
await session.commit()
|
|
return {
|
|
"logs": logs_deleted,
|
|
"bounties": bounties_deleted,
|
|
"attackers": attackers_deleted,
|
|
}
|
|
|
|
# ------------------------------------------------------------ bounties
|
|
|
|
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
|
data = bounty_data.copy()
|
|
if "payload" in data and isinstance(data["payload"], dict):
|
|
data["payload"] = orjson.dumps(data["payload"]).decode()
|
|
|
|
async with self._session() as session:
|
|
dup = await session.execute(
|
|
select(Bounty.id).where(
|
|
Bounty.bounty_type == data.get("bounty_type"),
|
|
Bounty.attacker_ip == data.get("attacker_ip"),
|
|
Bounty.payload == data.get("payload"),
|
|
).limit(1)
|
|
)
|
|
if dup.first() is not None:
|
|
return
|
|
session.add(Bounty(**data))
|
|
await session.commit()
|
|
|
|
def _apply_bounty_filters(
|
|
self,
|
|
statement: SelectOfScalar,
|
|
bounty_type: Optional[str],
|
|
search: Optional[str],
|
|
) -> SelectOfScalar:
|
|
if bounty_type:
|
|
statement = statement.where(Bounty.bounty_type == bounty_type)
|
|
if search:
|
|
lk = f"%{search}%"
|
|
statement = statement.where(
|
|
or_(
|
|
Bounty.decky.like(lk),
|
|
Bounty.service.like(lk),
|
|
Bounty.attacker_ip.like(lk),
|
|
Bounty.payload.like(lk),
|
|
)
|
|
)
|
|
return statement
|
|
|
|
async def get_bounties(
|
|
self,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
bounty_type: Optional[str] = None,
|
|
search: Optional[str] = None,
|
|
) -> List[dict]:
|
|
statement = (
|
|
select(Bounty)
|
|
.order_by(desc(Bounty.timestamp))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
|
|
|
async with self._session() as session:
|
|
results = await session.execute(statement)
|
|
final = []
|
|
for item in results.scalars().all():
|
|
d = item.model_dump(mode="json")
|
|
try:
|
|
d["payload"] = json.loads(d["payload"])
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
final.append(d)
|
|
return final
|
|
|
|
async def get_total_bounties(
|
|
self, bounty_type: Optional[str] = None, search: Optional[str] = None
|
|
) -> int:
|
|
statement = select(func.count()).select_from(Bounty)
|
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
|
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
statement = select(State).where(State.key == key)
|
|
result = await session.execute(statement)
|
|
state = result.scalar_one_or_none()
|
|
if state:
|
|
return json.loads(state.value)
|
|
return None
|
|
|
|
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
|
async with self._session() as session:
|
|
statement = select(State).where(State.key == key)
|
|
result = await session.execute(statement)
|
|
state = result.scalar_one_or_none()
|
|
|
|
value_json = orjson.dumps(value).decode()
|
|
if state:
|
|
state.value = value_json
|
|
session.add(state)
|
|
else:
|
|
session.add(State(key=key, value=value_json))
|
|
|
|
await session.commit()
|
|
|
|
# ----------------------------------------------------------- attackers
|
|
|
|
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
|
from collections import defaultdict
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Bounty).order_by(asc(Bounty.timestamp))
|
|
)
|
|
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
|
for item in result.scalars().all():
|
|
d = item.model_dump(mode="json")
|
|
try:
|
|
d["payload"] = json.loads(d["payload"])
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
grouped[item.attacker_ip].append(d)
|
|
return dict(grouped)
|
|
|
|
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
|
from collections import defaultdict
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
|
)
|
|
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
|
for item in result.scalars().all():
|
|
d = item.model_dump(mode="json")
|
|
try:
|
|
d["payload"] = json.loads(d["payload"])
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
grouped[item.attacker_ip].append(d)
|
|
return dict(grouped)
|
|
|
|
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Attacker).where(Attacker.ip == data["ip"])
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing:
|
|
for k, v in data.items():
|
|
setattr(existing, k, v)
|
|
session.add(existing)
|
|
row_uuid = existing.uuid
|
|
else:
|
|
row_uuid = str(uuid.uuid4())
|
|
data = {**data, "uuid": row_uuid}
|
|
session.add(Attacker(**data))
|
|
await session.commit()
|
|
return row_uuid
|
|
|
|
async def upsert_attacker_behavior(
|
|
self,
|
|
attacker_uuid: str,
|
|
data: dict[str, Any],
|
|
) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(AttackerBehavior).where(
|
|
AttackerBehavior.attacker_uuid == attacker_uuid
|
|
)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
payload = {**data, "updated_at": datetime.now(timezone.utc)}
|
|
if existing:
|
|
for k, v in payload.items():
|
|
setattr(existing, k, v)
|
|
session.add(existing)
|
|
else:
|
|
session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload))
|
|
await session.commit()
|
|
|
|
async def get_attacker_behavior(
|
|
self,
|
|
attacker_uuid: str,
|
|
) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(AttackerBehavior).where(
|
|
AttackerBehavior.attacker_uuid == attacker_uuid
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if not row:
|
|
return None
|
|
return self._deserialize_behavior(row.model_dump(mode="json"))
|
|
|
|
async def get_behaviors_for_ips(
|
|
self,
|
|
ips: set[str],
|
|
) -> dict[str, dict[str, Any]]:
|
|
if not ips:
|
|
return {}
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Attacker.ip, AttackerBehavior)
|
|
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
|
.where(Attacker.ip.in_(ips))
|
|
)
|
|
out: dict[str, dict[str, Any]] = {}
|
|
for ip, row in result.all():
|
|
out[ip] = self._deserialize_behavior(row.model_dump(mode="json"))
|
|
return out
|
|
|
|
@staticmethod
|
|
def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]:
|
|
for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"):
|
|
if isinstance(d.get(key), str):
|
|
try:
|
|
d[key] = json.loads(d[key])
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
# Deserialize tool_guesses JSON array; normalise None → [].
|
|
raw = d.get("tool_guesses")
|
|
if isinstance(raw, str):
|
|
try:
|
|
parsed = json.loads(raw)
|
|
d["tool_guesses"] = parsed if isinstance(parsed, list) else [parsed]
|
|
except (json.JSONDecodeError, TypeError):
|
|
d["tool_guesses"] = []
|
|
elif raw is None:
|
|
d["tool_guesses"] = []
|
|
return d
|
|
|
|
@staticmethod
|
|
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
|
|
for key in ("services", "deckies", "fingerprints", "commands"):
|
|
if isinstance(d.get(key), str):
|
|
try:
|
|
d[key] = json.loads(d[key])
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
return d
|
|
|
|
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Attacker).where(Attacker.uuid == uuid)
|
|
)
|
|
attacker = result.scalar_one_or_none()
|
|
if not attacker:
|
|
return None
|
|
return self._deserialize_attacker(attacker.model_dump(mode="json"))
|
|
|
|
async def get_attackers(
|
|
self,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
search: Optional[str] = None,
|
|
sort_by: str = "recent",
|
|
service: Optional[str] = None,
|
|
) -> List[dict[str, Any]]:
|
|
order = {
|
|
"active": desc(Attacker.event_count),
|
|
"traversals": desc(Attacker.is_traversal),
|
|
}.get(sort_by, desc(Attacker.last_seen))
|
|
|
|
statement = select(Attacker).order_by(order).offset(offset).limit(limit)
|
|
if search:
|
|
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
|
if service:
|
|
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
|
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [
|
|
self._deserialize_attacker(a.model_dump(mode="json"))
|
|
for a in result.scalars().all()
|
|
]
|
|
|
|
async def get_total_attackers(
|
|
self, search: Optional[str] = None, service: Optional[str] = None
|
|
) -> int:
|
|
statement = select(func.count()).select_from(Attacker)
|
|
if search:
|
|
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
|
if service:
|
|
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
|
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return result.scalar() or 0
|
|
|
|
async def get_attacker_commands(
|
|
self,
|
|
uuid: str,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
service: Optional[str] = None,
|
|
) -> dict[str, Any]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Attacker.commands).where(Attacker.uuid == uuid)
|
|
)
|
|
raw = result.scalar_one_or_none()
|
|
if raw is None:
|
|
return {"total": 0, "data": []}
|
|
|
|
commands: list = json.loads(raw) if isinstance(raw, str) else raw
|
|
if service:
|
|
commands = [c for c in commands if c.get("service") == service]
|
|
|
|
total = len(commands)
|
|
page = commands[offset: offset + limit]
|
|
return {"total": total, "data": page}
|
|
|
|
async def get_attacker_artifacts(self, uuid: str) -> list[dict[str, Any]]:
|
|
"""Return `file_captured` logs for the attacker identified by UUID.
|
|
|
|
Resolves the attacker's IP first, then queries the logs table on two
|
|
indexed columns (``attacker_ip`` and ``event_type``). No JSON extract
|
|
needed — the decky/stored_as are already decoded into ``fields`` by
|
|
the ingester and returned to the frontend for drawer rendering.
|
|
"""
|
|
async with self._session() as session:
|
|
ip_res = await session.execute(
|
|
select(Attacker.ip).where(Attacker.uuid == uuid)
|
|
)
|
|
ip = ip_res.scalar_one_or_none()
|
|
if not ip:
|
|
return []
|
|
rows = await session.execute(
|
|
select(Log)
|
|
.where(Log.attacker_ip == ip)
|
|
.where(Log.event_type == "file_captured")
|
|
.order_by(desc(Log.timestamp))
|
|
.limit(200)
|
|
)
|
|
return [r.model_dump(mode="json") for r in rows.scalars().all()]
|
|
|
|
async def get_session_log(self, sid: str) -> Optional[dict[str, Any]]:
|
|
"""Look up the `session_recorded` Log row that owns a given sid.
|
|
|
|
sid is a v4 UUID embedded in the row's ``fields`` JSON blob. Matched
|
|
with LIKE on the textual sid substring — cheap given the bounded
|
|
cardinality of session_recorded rows vs. the full logs table.
|
|
"""
|
|
needle = f'"sid":"{sid}"'
|
|
async with self._session() as session:
|
|
rows = await session.execute(
|
|
select(Log)
|
|
.where(Log.event_type == "session_recorded")
|
|
.where(Log.fields.contains(needle))
|
|
.limit(1)
|
|
)
|
|
row = rows.scalars().first()
|
|
return row.model_dump(mode="json") if row else None
|
|
|
|
async def get_attacker_transcripts(self, uuid: str) -> list[dict[str, Any]]:
|
|
"""Return `session_recorded` logs for the attacker identified by UUID.
|
|
|
|
Mirror of :meth:`get_attacker_artifacts` — sessions ride in the same
|
|
Log table with event_type=session_recorded; the ingester decodes the
|
|
RFC 5424 SD fields (sid, service, decky, src_ip, duration_s, bytes,
|
|
truncated, shard_path) into the returned ``fields`` blob.
|
|
"""
|
|
async with self._session() as session:
|
|
ip_res = await session.execute(
|
|
select(Attacker.ip).where(Attacker.uuid == uuid)
|
|
)
|
|
ip = ip_res.scalar_one_or_none()
|
|
if not ip:
|
|
return []
|
|
rows = await session.execute(
|
|
select(Log)
|
|
.where(Log.attacker_ip == ip)
|
|
.where(Log.event_type == "session_recorded")
|
|
.order_by(desc(Log.timestamp))
|
|
.limit(200)
|
|
)
|
|
return [r.model_dump(mode="json") for r in rows.scalars().all()]
|
|
|
|
# ------------------------------------------------------------- swarm
|
|
|
|
async def add_swarm_host(self, data: dict[str, Any]) -> None:
|
|
async with self._session() as session:
|
|
session.add(SwarmHost(**data))
|
|
await session.commit()
|
|
|
|
async def get_swarm_host_by_name(self, name: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(select(SwarmHost).where(SwarmHost.name == name))
|
|
row = result.scalar_one_or_none()
|
|
return row.model_dump(mode="json") if row else None
|
|
|
|
async def get_swarm_host_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(select(SwarmHost).where(SwarmHost.uuid == uuid))
|
|
row = result.scalar_one_or_none()
|
|
return row.model_dump(mode="json") if row else None
|
|
|
|
async def get_swarm_host_by_fingerprint(self, fingerprint: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(SwarmHost).where(SwarmHost.client_cert_fingerprint == fingerprint)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
return row.model_dump(mode="json") if row else None
|
|
|
|
async def list_swarm_hosts(self, status: Optional[str] = None) -> list[dict[str, Any]]:
|
|
statement = select(SwarmHost).order_by(asc(SwarmHost.name))
|
|
if status:
|
|
statement = statement.where(SwarmHost.status == status)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def update_swarm_host(self, uuid: str, fields: dict[str, Any]) -> None:
|
|
if not fields:
|
|
return
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
update(SwarmHost).where(SwarmHost.uuid == uuid).values(**fields)
|
|
)
|
|
await session.commit()
|
|
|
|
async def delete_swarm_host(self, uuid: str) -> bool:
|
|
async with self._session() as session:
|
|
# Clean up child shards first (no ON DELETE CASCADE portable across dialects).
|
|
await session.execute(
|
|
text("DELETE FROM decky_shards WHERE host_uuid = :u"), {"u": uuid}
|
|
)
|
|
result = await session.execute(
|
|
select(SwarmHost).where(SwarmHost.uuid == uuid)
|
|
)
|
|
host = result.scalar_one_or_none()
|
|
if not host:
|
|
await session.commit()
|
|
return False
|
|
await session.delete(host)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def upsert_decky_shard(self, data: dict[str, Any]) -> None:
|
|
payload = {**data, "updated_at": datetime.now(timezone.utc)}
|
|
if isinstance(payload.get("services"), list):
|
|
payload["services"] = orjson.dumps(payload["services"]).decode()
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(DeckyShard).where(DeckyShard.decky_name == payload["decky_name"])
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing:
|
|
for k, v in payload.items():
|
|
setattr(existing, k, v)
|
|
session.add(existing)
|
|
else:
|
|
session.add(DeckyShard(**payload))
|
|
await session.commit()
|
|
|
|
async def list_decky_shards(
|
|
self, host_uuid: Optional[str] = None
|
|
) -> list[dict[str, Any]]:
|
|
statement = select(DeckyShard).order_by(asc(DeckyShard.decky_name))
|
|
if host_uuid:
|
|
statement = statement.where(DeckyShard.host_uuid == host_uuid)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
out: list[dict[str, Any]] = []
|
|
for r in result.scalars().all():
|
|
d = r.model_dump(mode="json")
|
|
raw = d.get("services")
|
|
if isinstance(raw, str):
|
|
try:
|
|
d["services"] = json.loads(raw)
|
|
except (json.JSONDecodeError, TypeError):
|
|
d["services"] = []
|
|
# Flatten the stored DeckyConfig snapshot into the row so
|
|
# routers can hand it to DeckyShardView without re-parsing.
|
|
# Rows predating the migration have decky_config=NULL and
|
|
# fall through with the default (None/{}) view values.
|
|
cfg_raw = d.get("decky_config")
|
|
if isinstance(cfg_raw, str):
|
|
try:
|
|
cfg = json.loads(cfg_raw)
|
|
except (json.JSONDecodeError, TypeError):
|
|
cfg = {}
|
|
if isinstance(cfg, dict):
|
|
for k in ("hostname", "distro", "archetype",
|
|
"service_config", "mutate_interval",
|
|
"last_mutated"):
|
|
if k in cfg and d.get(k) is None:
|
|
d[k] = cfg[k]
|
|
# Keep decky_ip authoritative from the column (newer
|
|
# heartbeats overwrite it) but fall back to the
|
|
# snapshot if the column is still NULL.
|
|
if not d.get("decky_ip") and cfg.get("ip"):
|
|
d["decky_ip"] = cfg["ip"]
|
|
out.append(d)
|
|
return out
|
|
|
|
async def delete_decky_shards_for_host(self, host_uuid: str) -> int:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
text("DELETE FROM decky_shards WHERE host_uuid = :u"),
|
|
{"u": host_uuid},
|
|
)
|
|
await session.commit()
|
|
return result.rowcount or 0
|
|
|
|
async def delete_decky_shard(self, decky_name: str) -> bool:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
text("DELETE FROM decky_shards WHERE decky_name = :n"),
|
|
{"n": decky_name},
|
|
)
|
|
await session.commit()
|
|
return bool(result.rowcount)
|
|
|
|
# ------------------------------------------------------------ mazenet
|
|
|
|
@staticmethod
|
|
def _serialize_json_fields(data: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
|
|
out = dict(data)
|
|
for k in keys:
|
|
v = out.get(k)
|
|
if v is not None and not isinstance(v, str):
|
|
out[k] = orjson.dumps(v).decode()
|
|
return out
|
|
|
|
@staticmethod
|
|
def _deserialize_json_fields(d: dict[str, Any], keys: tuple[str, ...]) -> dict[str, Any]:
|
|
for k in keys:
|
|
v = d.get(k)
|
|
if isinstance(v, str):
|
|
try:
|
|
d[k] = json.loads(v)
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
return d
|
|
|
|
async def create_topology(self, data: dict[str, Any]) -> str:
|
|
payload = self._serialize_json_fields(data, ("config_snapshot",))
|
|
async with self._session() as session:
|
|
row = Topology(**payload)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def get_topology(self, topology_id: str) -> Optional[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if not row:
|
|
return None
|
|
d = row.model_dump(mode="json")
|
|
return self._deserialize_json_fields(d, ("config_snapshot",))
|
|
|
|
async def list_topologies(
|
|
self,
|
|
status: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
) -> list[dict[str, Any]]:
|
|
statement = select(Topology).order_by(desc(Topology.created_at))
|
|
if status:
|
|
statement = statement.where(Topology.status == status)
|
|
if offset is not None:
|
|
statement = statement.offset(offset)
|
|
if limit is not None:
|
|
statement = statement.limit(limit)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return [
|
|
self._deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("config_snapshot",)
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|
|
async def count_topologies(self, status: Optional[str] = None) -> int:
|
|
from sqlalchemy import func
|
|
statement = select(func.count(Topology.id))
|
|
if status:
|
|
statement = statement.where(Topology.status == status)
|
|
async with self._session() as session:
|
|
result = await session.execute(statement)
|
|
return int(result.scalar_one() or 0)
|
|
|
|
async def update_topology_status(
|
|
self,
|
|
topology_id: str,
|
|
new_status: str,
|
|
reason: Optional[str] = None,
|
|
) -> None:
|
|
"""Update topology.status and append a TopologyStatusEvent atomically.
|
|
|
|
Transition legality is enforced in ``decnet.topology.status``; this
|
|
method trusts the caller.
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
return
|
|
from_status = topo.status
|
|
topo.status = new_status
|
|
topo.status_changed_at = now
|
|
session.add(topo)
|
|
session.add(
|
|
TopologyStatusEvent(
|
|
topology_id=topology_id,
|
|
from_status=from_status,
|
|
to_status=new_status,
|
|
at=now,
|
|
reason=reason,
|
|
)
|
|
)
|
|
await session.commit()
|
|
|
|
async def set_topology_resync(self, topology_id: str, value: bool) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
return
|
|
topo.needs_resync = bool(value)
|
|
session.add(topo)
|
|
await session.commit()
|
|
|
|
async def list_topologies_needing_resync(self) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.needs_resync == True) # noqa: E712
|
|
)
|
|
return [
|
|
self._deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("config_snapshot",)
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|
|
async def delete_topology_cascade(self, topology_id: str) -> bool:
|
|
"""Delete topology and all children. No portable ON DELETE CASCADE."""
|
|
async with self._session() as session:
|
|
params = {"t": topology_id}
|
|
await session.execute(
|
|
text("DELETE FROM topology_status_events WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_deckies WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM lans WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_mutations WHERE topology_id = :t"),
|
|
params,
|
|
)
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if not topo:
|
|
await session.commit()
|
|
return False
|
|
await session.delete(topo)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def _assert_pending(self, session, topology_id: str) -> None:
|
|
"""Pre-deploy edits are pending-only. Raises TopologyNotEditable."""
|
|
from decnet.topology.status import TopologyNotEditable, TopologyStatus
|
|
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
raise ValueError(f"topology {topology_id!r} not found")
|
|
if topo.status != TopologyStatus.PENDING:
|
|
raise TopologyNotEditable(
|
|
status=topo.status,
|
|
reason="free-form edits are pending-only; use the "
|
|
"mutator (topology_mutations) after deploy",
|
|
)
|
|
|
|
async def _check_and_bump_version(
|
|
self,
|
|
session,
|
|
topology_id: str,
|
|
expected_version: Optional[int],
|
|
) -> None:
|
|
"""Optimistic-concurrency guard used by child-row mutators.
|
|
|
|
If ``expected_version`` is None, no check happens (backward-compat
|
|
for internal callers that don't need concurrency protection).
|
|
|
|
If supplied, loads the Topology row in the same session,
|
|
compares ``version == expected_version``, raises VersionConflict
|
|
on mismatch, otherwise bumps ``version += 1``. The caller must
|
|
commit the enclosing session.
|
|
"""
|
|
from decnet.topology.status import VersionConflict
|
|
|
|
if expected_version is None:
|
|
return
|
|
result = await session.execute(
|
|
select(Topology).where(Topology.id == topology_id)
|
|
)
|
|
topo = result.scalar_one_or_none()
|
|
if topo is None:
|
|
raise ValueError(f"topology {topology_id!r} not found")
|
|
if topo.version != expected_version:
|
|
raise VersionConflict(
|
|
current=topo.version, expected=expected_version
|
|
)
|
|
topo.version = topo.version + 1
|
|
session.add(topo)
|
|
|
|
async def add_lan(
|
|
self,
|
|
data: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, data["topology_id"], expected_version
|
|
)
|
|
row = LAN(**data)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def update_lan(
|
|
self,
|
|
lan_id: str,
|
|
fields: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
enforce_pending: bool = False,
|
|
) -> None:
|
|
if not fields:
|
|
return
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(LAN).where(LAN.id == lan_id)
|
|
)
|
|
lan = result.scalar_one_or_none()
|
|
if lan is None:
|
|
raise ValueError(f"lan {lan_id!r} not found")
|
|
if enforce_pending:
|
|
await self._assert_pending(session, lan.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, lan.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
update(LAN).where(LAN.id == lan_id).values(**fields)
|
|
)
|
|
await session.commit()
|
|
|
|
async def delete_lan(
|
|
self,
|
|
lan_id: str,
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> None:
|
|
"""Cascade-delete a LAN from a pending topology.
|
|
|
|
Rejects if any decky declares this LAN as its home (i.e. has a
|
|
non-bridge edge to it — the only LAN that decky lives in). The
|
|
caller must delete or reassign the home-deckies first.
|
|
"""
|
|
from decnet.topology.status import TopologyNotEditable # noqa: F401
|
|
|
|
async with self._session() as session:
|
|
result = await session.execute(select(LAN).where(LAN.id == lan_id))
|
|
lan = result.scalar_one_or_none()
|
|
if lan is None:
|
|
return
|
|
await self._assert_pending(session, lan.topology_id)
|
|
|
|
# Home-decky check: any decky whose only edge lands here?
|
|
edges_result = await session.execute(
|
|
select(TopologyEdge).where(TopologyEdge.lan_id == lan_id)
|
|
)
|
|
edges_here = edges_result.scalars().all()
|
|
decky_uuids_on_this_lan = {e.decky_uuid for e in edges_here}
|
|
for decky_uuid in decky_uuids_on_this_lan:
|
|
other = await session.execute(
|
|
select(TopologyEdge).where(
|
|
TopologyEdge.decky_uuid == decky_uuid,
|
|
TopologyEdge.lan_id != lan_id,
|
|
)
|
|
)
|
|
if other.scalars().first() is None:
|
|
raise ValueError(
|
|
f"cannot delete LAN {lan.name!r}: decky "
|
|
f"{decky_uuid} has no other LAN (would be orphaned)"
|
|
)
|
|
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, lan.topology_id, expected_version
|
|
)
|
|
# Cascade edges → LAN.
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE lan_id = :l"),
|
|
{"l": lan_id},
|
|
)
|
|
await session.execute(text("DELETE FROM lans WHERE id = :l"), {"l": lan_id})
|
|
await session.commit()
|
|
|
|
async def list_lans_for_topology(
|
|
self, topology_id: str
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(LAN).where(LAN.topology_id == topology_id).order_by(asc(LAN.name))
|
|
)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def add_topology_decky(
|
|
self,
|
|
data: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
payload = self._serialize_json_fields(data, ("services", "decky_config"))
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, data["topology_id"], expected_version
|
|
)
|
|
row = TopologyDecky(**payload)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.uuid
|
|
|
|
async def update_topology_decky(
|
|
self,
|
|
decky_uuid: str,
|
|
fields: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
enforce_pending: bool = False,
|
|
) -> None:
|
|
if not fields:
|
|
return
|
|
payload = self._serialize_json_fields(fields, ("services", "decky_config"))
|
|
payload.setdefault("updated_at", datetime.now(timezone.utc))
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
|
)
|
|
d = result.scalar_one_or_none()
|
|
if d is None:
|
|
raise ValueError(f"decky {decky_uuid!r} not found")
|
|
if enforce_pending:
|
|
await self._assert_pending(session, d.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, d.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
update(TopologyDecky)
|
|
.where(TopologyDecky.uuid == decky_uuid)
|
|
.values(**payload)
|
|
)
|
|
await session.commit()
|
|
|
|
async def delete_topology_decky(
|
|
self,
|
|
decky_uuid: str,
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> None:
|
|
"""Cascade-delete a decky + all its edges from a pending topology."""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky).where(TopologyDecky.uuid == decky_uuid)
|
|
)
|
|
d = result.scalar_one_or_none()
|
|
if d is None:
|
|
return
|
|
await self._assert_pending(session, d.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, d.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE decky_uuid = :u"),
|
|
{"u": decky_uuid},
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_deckies WHERE uuid = :u"),
|
|
{"u": decky_uuid},
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_topology_deckies(
|
|
self, topology_id: str
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyDecky)
|
|
.where(TopologyDecky.topology_id == topology_id)
|
|
.order_by(asc(TopologyDecky.name))
|
|
)
|
|
return [
|
|
self._deserialize_json_fields(
|
|
r.model_dump(mode="json"), ("services", "decky_config")
|
|
)
|
|
for r in result.scalars().all()
|
|
]
|
|
|
|
async def add_topology_edge(
|
|
self,
|
|
data: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, data["topology_id"], expected_version
|
|
)
|
|
row = TopologyEdge(**data)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def delete_topology_edge(
|
|
self,
|
|
edge_id: str,
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> None:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyEdge).where(TopologyEdge.id == edge_id)
|
|
)
|
|
edge = result.scalar_one_or_none()
|
|
if edge is None:
|
|
return
|
|
await self._assert_pending(session, edge.topology_id)
|
|
if expected_version is not None:
|
|
await self._check_and_bump_version(
|
|
session, edge.topology_id, expected_version
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM topology_edges WHERE id = :e"),
|
|
{"e": edge_id},
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_topology_edges(
|
|
self, topology_id: str
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyEdge).where(TopologyEdge.topology_id == topology_id)
|
|
)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def list_topology_status_events(
|
|
self, topology_id: str, limit: int = 100
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(TopologyStatusEvent)
|
|
.where(TopologyStatusEvent.topology_id == topology_id)
|
|
.order_by(desc(TopologyStatusEvent.at))
|
|
.limit(limit)
|
|
)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
# ---------------- topology_mutations (live reconciler queue) ----------------
|
|
|
|
async def enqueue_topology_mutation(
|
|
self,
|
|
topology_id: str,
|
|
op: str,
|
|
payload: dict[str, Any],
|
|
*,
|
|
expected_version: Optional[int] = None,
|
|
) -> str:
|
|
"""Append a pending mutation row and bump the topology version.
|
|
|
|
Intended for use while the topology is ``active|degraded``; the
|
|
reconciler picks these rows up on its next tick.
|
|
"""
|
|
async with self._session() as session:
|
|
await self._check_and_bump_version(
|
|
session, topology_id, expected_version
|
|
)
|
|
row = TopologyMutation(
|
|
topology_id=topology_id,
|
|
op=op,
|
|
payload=orjson.dumps(payload).decode(),
|
|
)
|
|
session.add(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return row.id
|
|
|
|
async def claim_next_mutation(
|
|
self, topology_id: str
|
|
) -> Optional[dict[str, Any]]:
|
|
"""Atomically claim the oldest pending mutation for ``topology_id``.
|
|
|
|
Correctness-critical: this is ONE SQL statement. Splitting it
|
|
into SELECT-then-UPDATE would let two racing watch-loops both
|
|
see the same ``pending`` row and both transition it to
|
|
``applying`` — double-executing the op. With the single
|
|
``UPDATE ... WHERE id = (SELECT ... LIMIT 1) AND state='pending'``
|
|
pattern the loser's UPDATE matches zero rows and returns
|
|
``None`` — that is the expected, non-error outcome under
|
|
contention.
|
|
"""
|
|
async with self._session() as session:
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
# Single-statement atomic claim. The inner SELECT picks the
|
|
# oldest pending row; the outer UPDATE re-checks state so a
|
|
# second racer that also saw that id finds state='applying'
|
|
# and matches zero rows.
|
|
sql = text(
|
|
"""
|
|
UPDATE topology_mutations
|
|
SET state = 'applying'
|
|
WHERE id = (
|
|
SELECT id FROM topology_mutations
|
|
WHERE topology_id = :t AND state = 'pending'
|
|
ORDER BY requested_at ASC
|
|
LIMIT 1
|
|
)
|
|
AND state = 'pending'
|
|
"""
|
|
)
|
|
result = await session.execute(sql, {"t": topology_id})
|
|
if result.rowcount == 0:
|
|
await session.commit()
|
|
return None
|
|
# Re-read the row we just claimed. The post-UPDATE SELECT is
|
|
# safe: no racer can now transition an ``applying`` row back
|
|
# to ``pending``.
|
|
sel = await session.execute(
|
|
select(TopologyMutation)
|
|
.where(TopologyMutation.topology_id == topology_id)
|
|
.where(TopologyMutation.state == "applying")
|
|
.order_by(asc(TopologyMutation.requested_at))
|
|
.limit(1)
|
|
)
|
|
row = sel.scalar_one_or_none()
|
|
await session.commit()
|
|
_ = now
|
|
if row is None:
|
|
return None
|
|
return row.model_dump(mode="json")
|
|
|
|
async def mark_mutation_applied(self, mutation_id: str) -> None:
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
text(
|
|
"UPDATE topology_mutations "
|
|
"SET state = 'applied', applied_at = :at "
|
|
"WHERE id = :i"
|
|
),
|
|
{
|
|
"at": datetime.now(timezone.utc).isoformat(),
|
|
"i": mutation_id,
|
|
},
|
|
)
|
|
await session.commit()
|
|
|
|
async def mark_mutation_failed(
|
|
self, mutation_id: str, reason: str
|
|
) -> None:
|
|
async with self._session() as session:
|
|
await session.execute(
|
|
text(
|
|
"UPDATE topology_mutations "
|
|
"SET state = 'failed', applied_at = :at, reason = :r "
|
|
"WHERE id = :i"
|
|
),
|
|
{
|
|
"at": datetime.now(timezone.utc).isoformat(),
|
|
"r": reason,
|
|
"i": mutation_id,
|
|
},
|
|
)
|
|
await session.commit()
|
|
|
|
async def list_topology_mutations(
|
|
self,
|
|
topology_id: str,
|
|
state: Optional[str] = None,
|
|
) -> list[dict[str, Any]]:
|
|
async with self._session() as session:
|
|
stmt = (
|
|
select(TopologyMutation)
|
|
.where(TopologyMutation.topology_id == topology_id)
|
|
.order_by(desc(TopologyMutation.requested_at))
|
|
)
|
|
if state is not None:
|
|
stmt = stmt.where(TopologyMutation.state == state)
|
|
result = await session.execute(stmt)
|
|
return [r.model_dump(mode="json") for r in result.scalars().all()]
|
|
|
|
async def has_pending_topology_mutation(self) -> bool:
|
|
"""Cheap watch-loop guard: any pending mutation on a live topology?
|
|
|
|
Uses the ``ix_topology_mutations_state_topology`` composite index
|
|
to keep the join cheap at scale. Returns False as soon as the
|
|
reconciler path should be skipped.
|
|
"""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
text(
|
|
"SELECT 1 FROM topology_mutations "
|
|
"WHERE state = 'pending' "
|
|
"AND topology_id IN ("
|
|
" SELECT id FROM topologies "
|
|
" WHERE status IN ('active', 'degraded')"
|
|
") LIMIT 1"
|
|
)
|
|
)
|
|
return result.first() is not None
|
|
|
|
async def list_live_topology_ids(self) -> list[str]:
|
|
"""Return ids of topologies currently in ``active|degraded``."""
|
|
async with self._session() as session:
|
|
result = await session.execute(
|
|
select(Topology.id).where(
|
|
Topology.status.in_(["active", "degraded"])
|
|
)
|
|
)
|
|
return [r for r in result.scalars().all()]
|