""" Shared SQLModel-based repository implementation. Contains all dialect-portable query code used by the SQLite and MySQL backends. Dialect-specific behavior lives in subclasses: * engine/session construction (``__init__``) * ``_migrate_attackers_table`` (legacy schema check; DDL introspection is not portable) * ``get_log_histogram`` (date-bucket expression differs per dialect) """ from __future__ import annotations import asyncio import json import uuid from datetime import datetime, timezone from typing import Any, Optional, List from sqlalchemy import func, select, desc, asc, text, or_, update from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlmodel.sql.expression import SelectOfScalar from decnet.config import load_state from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD from decnet.web.auth import get_password_hash from decnet.web.db.repository import BaseRepository from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior from contextlib import asynccontextmanager from decnet.logging import get_logger _log = get_logger("db.pool") async def _force_close(session: AsyncSession) -> None: """Close a session, forcing connection invalidation if clean close fails. Shielded from cancellation and catches every exception class including CancelledError. If session.close() fails (corrupted connection), we invalidate the underlying connection so the pool discards it entirely rather than leaving it checked-out forever. """ try: await asyncio.shield(session.close()) except BaseException: # close() failed — connection is likely corrupted. # Try to invalidate the raw connection so the pool drops it. try: bind = session.get_bind() if hasattr(bind, "dispose"): pass # don't dispose the whole engine # The sync_session holds the connection record; invalidating # it tells the pool to discard rather than reuse. sync = session.sync_session if sync.is_active: sync.rollback() sync.close() except BaseException: _log.debug("force-close: fallback cleanup failed", exc_info=True) @asynccontextmanager async def _safe_session(factory: async_sessionmaker[AsyncSession]): """Session context manager that shields cleanup from cancellation. Under high concurrency, uvicorn cancels request tasks when clients disconnect. If a CancelledError hits during session.__aexit__, the underlying DB connection is orphaned — never returned to the pool. This wrapper ensures close() always completes, preventing the pool-drain death spiral. """ session = factory() try: yield session except BaseException: await _force_close(session) raise else: await _force_close(session) class SQLModelRepository(BaseRepository): """Concrete SQLModel/SQLAlchemy-async repository. Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory`` in ``__init__``, and override the few dialect-specific helpers. """ engine: AsyncEngine session_factory: async_sessionmaker[AsyncSession] def _session(self): """Return a cancellation-safe session context manager.""" return _safe_session(self.session_factory) # ------------------------------------------------------------ lifecycle async def initialize(self) -> None: """Create tables if absent and seed the admin user.""" from sqlmodel import SQLModel await self._migrate_attackers_table() async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await self._ensure_admin_user() async def reinitialize(self) -> None: """Re-create schema (for tests / reset flows). Does NOT drop existing tables.""" from sqlmodel import SQLModel async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await self._ensure_admin_user() async def _ensure_admin_user(self) -> None: async with self._session() as session: result = await session.execute( select(User).where(User.username == DECNET_ADMIN_USER) ) existing = result.scalar_one_or_none() if existing is None: session.add(User( uuid=str(uuid.uuid4()), username=DECNET_ADMIN_USER, password_hash=get_password_hash(DECNET_ADMIN_PASSWORD), role="admin", must_change_password=True, )) await session.commit() return # Self-heal env drift: if admin never finalized their password, # re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave # the user's chosen password alone. if existing.must_change_password: existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD) session.add(existing) await session.commit() async def _migrate_attackers_table(self) -> None: """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" return None # ---------------------------------------------------------------- logs async def add_log(self, log_data: dict[str, Any]) -> None: data = log_data.copy() if "fields" in data and isinstance(data["fields"], dict): data["fields"] = json.dumps(data["fields"]) if "timestamp" in data and isinstance(data["timestamp"], str): try: data["timestamp"] = datetime.fromisoformat( data["timestamp"].replace("Z", "+00:00") ) except ValueError: pass async with self._session() as session: session.add(Log(**data)) await session.commit() def _apply_filters( self, statement: SelectOfScalar, search: Optional[str], start_time: Optional[str], end_time: Optional[str], ) -> SelectOfScalar: import re import shlex if start_time: statement = statement.where(Log.timestamp >= start_time) if end_time: statement = statement.where(Log.timestamp <= end_time) if search: try: tokens = shlex.split(search) except ValueError: tokens = search.split() core_fields = { "decky": Log.decky, "service": Log.service, "event": Log.event_type, "attacker": Log.attacker_ip, "attacker-ip": Log.attacker_ip, "attacker_ip": Log.attacker_ip, } for token in tokens: if ":" in token: key, val = token.split(":", 1) if key in core_fields: statement = statement.where(core_fields[key] == val) else: key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) if key_safe: statement = statement.where( self._json_field_equals(key_safe) ).params(val=val) else: lk = f"%{token}%" statement = statement.where( or_( Log.raw_line.like(lk), Log.decky.like(lk), Log.service.like(lk), Log.attacker_ip.like(lk), ) ) return statement def _json_field_equals(self, key: str): """Return a text() predicate that matches rows where fields->key == :val. Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also exposes the same function under ``json_extract`` (case-insensitive). The ``:val`` parameter is bound separately and must be supplied with ``.params(val=...)`` by the caller, which keeps us safe from injection. """ return text(f"JSON_EXTRACT(fields, '$.{key}') = :val") async def get_logs( self, limit: int = 50, offset: int = 0, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, ) -> List[dict]: statement = ( select(Log) .order_by(desc(Log.timestamp)) .offset(offset) .limit(limit) ) statement = self._apply_filters(statement, search, start_time, end_time) async with self._session() as session: results = await session.execute(statement) return [log.model_dump(mode="json") for log in results.scalars().all()] async def get_max_log_id(self) -> int: async with self._session() as session: result = await session.execute(select(func.max(Log.id))) val = result.scalar() return val if val is not None else 0 async def get_logs_after_id( self, last_id: int, limit: int = 50, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, ) -> List[dict]: statement = ( select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) ) statement = self._apply_filters(statement, search, start_time, end_time) async with self._session() as session: results = await session.execute(statement) return [log.model_dump(mode="json") for log in results.scalars().all()] async def get_total_logs( self, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, ) -> int: statement = select(func.count()).select_from(Log) statement = self._apply_filters(statement, search, start_time, end_time) async with self._session() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_log_histogram( self, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, interval_minutes: int = 15, ) -> List[dict]: """Dialect-specific — override per backend.""" raise NotImplementedError async def get_stats_summary(self) -> dict[str, Any]: async with self._session() as session: total_logs = ( await session.execute(select(func.count()).select_from(Log)) ).scalar() or 0 unique_attackers = ( await session.execute( select(func.count(func.distinct(Log.attacker_ip))) ) ).scalar() or 0 _state = await asyncio.to_thread(load_state) deployed_deckies = len(_state[0].deckies) if _state else 0 return { "total_logs": total_logs, "unique_attackers": unique_attackers, "active_deckies": deployed_deckies, "deployed_deckies": deployed_deckies, } async def get_deckies(self) -> List[dict]: _state = await asyncio.to_thread(load_state) return [_d.model_dump() for _d in _state[0].deckies] if _state else [] # --------------------------------------------------------------- users async def get_user_by_username(self, username: str) -> Optional[dict]: async with self._session() as session: result = await session.execute( select(User).where(User.username == username) ) user = result.scalar_one_or_none() return user.model_dump() if user else None async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: async with self._session() as session: result = await session.execute( select(User).where(User.uuid == uuid) ) user = result.scalar_one_or_none() return user.model_dump() if user else None async def create_user(self, user_data: dict[str, Any]) -> None: async with self._session() as session: session.add(User(**user_data)) await session.commit() async def update_user_password( self, uuid: str, password_hash: str, must_change_password: bool = False ) -> None: async with self._session() as session: await session.execute( update(User) .where(User.uuid == uuid) .values( password_hash=password_hash, must_change_password=must_change_password, ) ) await session.commit() async def list_users(self) -> list[dict]: async with self._session() as session: result = await session.execute(select(User)) return [u.model_dump() for u in result.scalars().all()] async def delete_user(self, uuid: str) -> bool: async with self._session() as session: result = await session.execute(select(User).where(User.uuid == uuid)) user = result.scalar_one_or_none() if not user: return False await session.delete(user) await session.commit() return True async def update_user_role(self, uuid: str, role: str) -> None: async with self._session() as session: await session.execute( update(User).where(User.uuid == uuid).values(role=role) ) await session.commit() async def purge_logs_and_bounties(self) -> dict[str, int]: async with self._session() as session: logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount # attacker_behavior has FK → attackers.uuid; delete children first. await session.execute(text("DELETE FROM attacker_behavior")) attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount await session.commit() return { "logs": logs_deleted, "bounties": bounties_deleted, "attackers": attackers_deleted, } # ------------------------------------------------------------ bounties async def add_bounty(self, bounty_data: dict[str, Any]) -> None: data = bounty_data.copy() if "payload" in data and isinstance(data["payload"], dict): data["payload"] = json.dumps(data["payload"]) async with self._session() as session: dup = await session.execute( select(Bounty.id).where( Bounty.bounty_type == data.get("bounty_type"), Bounty.attacker_ip == data.get("attacker_ip"), Bounty.payload == data.get("payload"), ).limit(1) ) if dup.first() is not None: return session.add(Bounty(**data)) await session.commit() def _apply_bounty_filters( self, statement: SelectOfScalar, bounty_type: Optional[str], search: Optional[str], ) -> SelectOfScalar: if bounty_type: statement = statement.where(Bounty.bounty_type == bounty_type) if search: lk = f"%{search}%" statement = statement.where( or_( Bounty.decky.like(lk), Bounty.service.like(lk), Bounty.attacker_ip.like(lk), Bounty.payload.like(lk), ) ) return statement async def get_bounties( self, limit: int = 50, offset: int = 0, bounty_type: Optional[str] = None, search: Optional[str] = None, ) -> List[dict]: statement = ( select(Bounty) .order_by(desc(Bounty.timestamp)) .offset(offset) .limit(limit) ) statement = self._apply_bounty_filters(statement, bounty_type, search) async with self._session() as session: results = await session.execute(statement) final = [] for item in results.scalars().all(): d = item.model_dump(mode="json") try: d["payload"] = json.loads(d["payload"]) except (json.JSONDecodeError, TypeError): pass final.append(d) return final async def get_total_bounties( self, bounty_type: Optional[str] = None, search: Optional[str] = None ) -> int: statement = select(func.count()).select_from(Bounty) statement = self._apply_bounty_filters(statement, bounty_type, search) async with self._session() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_state(self, key: str) -> Optional[dict[str, Any]]: async with self._session() as session: statement = select(State).where(State.key == key) result = await session.execute(statement) state = result.scalar_one_or_none() if state: return json.loads(state.value) return None async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 async with self._session() as session: statement = select(State).where(State.key == key) result = await session.execute(statement) state = result.scalar_one_or_none() value_json = json.dumps(value) if state: state.value = value_json session.add(state) else: session.add(State(key=key, value=value_json)) await session.commit() # ----------------------------------------------------------- attackers async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict async with self._session() as session: result = await session.execute( select(Bounty).order_by(asc(Bounty.timestamp)) ) grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) for item in result.scalars().all(): d = item.model_dump(mode="json") try: d["payload"] = json.loads(d["payload"]) except (json.JSONDecodeError, TypeError): pass grouped[item.attacker_ip].append(d) return dict(grouped) async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict async with self._session() as session: result = await session.execute( select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) ) grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) for item in result.scalars().all(): d = item.model_dump(mode="json") try: d["payload"] = json.loads(d["payload"]) except (json.JSONDecodeError, TypeError): pass grouped[item.attacker_ip].append(d) return dict(grouped) async def upsert_attacker(self, data: dict[str, Any]) -> str: async with self._session() as session: result = await session.execute( select(Attacker).where(Attacker.ip == data["ip"]) ) existing = result.scalar_one_or_none() if existing: for k, v in data.items(): setattr(existing, k, v) session.add(existing) row_uuid = existing.uuid else: row_uuid = str(uuid.uuid4()) data = {**data, "uuid": row_uuid} session.add(Attacker(**data)) await session.commit() return row_uuid async def upsert_attacker_behavior( self, attacker_uuid: str, data: dict[str, Any], ) -> None: async with self._session() as session: result = await session.execute( select(AttackerBehavior).where( AttackerBehavior.attacker_uuid == attacker_uuid ) ) existing = result.scalar_one_or_none() payload = {**data, "updated_at": datetime.now(timezone.utc)} if existing: for k, v in payload.items(): setattr(existing, k, v) session.add(existing) else: session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload)) await session.commit() async def get_attacker_behavior( self, attacker_uuid: str, ) -> Optional[dict[str, Any]]: async with self._session() as session: result = await session.execute( select(AttackerBehavior).where( AttackerBehavior.attacker_uuid == attacker_uuid ) ) row = result.scalar_one_or_none() if not row: return None return self._deserialize_behavior(row.model_dump(mode="json")) async def get_behaviors_for_ips( self, ips: set[str], ) -> dict[str, dict[str, Any]]: if not ips: return {} async with self._session() as session: result = await session.execute( select(Attacker.ip, AttackerBehavior) .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) .where(Attacker.ip.in_(ips)) ) out: dict[str, dict[str, Any]] = {} for ip, row in result.all(): out[ip] = self._deserialize_behavior(row.model_dump(mode="json")) return out @staticmethod def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]: for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"): if isinstance(d.get(key), str): try: d[key] = json.loads(d[key]) except (json.JSONDecodeError, TypeError): pass # Deserialize tool_guesses JSON array; normalise None → []. raw = d.get("tool_guesses") if isinstance(raw, str): try: parsed = json.loads(raw) d["tool_guesses"] = parsed if isinstance(parsed, list) else [parsed] except (json.JSONDecodeError, TypeError): d["tool_guesses"] = [] elif raw is None: d["tool_guesses"] = [] return d @staticmethod def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]: for key in ("services", "deckies", "fingerprints", "commands"): if isinstance(d.get(key), str): try: d[key] = json.loads(d[key]) except (json.JSONDecodeError, TypeError): pass return d async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: async with self._session() as session: result = await session.execute( select(Attacker).where(Attacker.uuid == uuid) ) attacker = result.scalar_one_or_none() if not attacker: return None return self._deserialize_attacker(attacker.model_dump(mode="json")) async def get_attackers( self, limit: int = 50, offset: int = 0, search: Optional[str] = None, sort_by: str = "recent", service: Optional[str] = None, ) -> List[dict[str, Any]]: order = { "active": desc(Attacker.event_count), "traversals": desc(Attacker.is_traversal), }.get(sort_by, desc(Attacker.last_seen)) statement = select(Attacker).order_by(order).offset(offset).limit(limit) if search: statement = statement.where(Attacker.ip.like(f"%{search}%")) if service: statement = statement.where(Attacker.services.like(f'%"{service}"%')) async with self._session() as session: result = await session.execute(statement) return [ self._deserialize_attacker(a.model_dump(mode="json")) for a in result.scalars().all() ] async def get_total_attackers( self, search: Optional[str] = None, service: Optional[str] = None ) -> int: statement = select(func.count()).select_from(Attacker) if search: statement = statement.where(Attacker.ip.like(f"%{search}%")) if service: statement = statement.where(Attacker.services.like(f'%"{service}"%')) async with self._session() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_attacker_commands( self, uuid: str, limit: int = 50, offset: int = 0, service: Optional[str] = None, ) -> dict[str, Any]: async with self._session() as session: result = await session.execute( select(Attacker.commands).where(Attacker.uuid == uuid) ) raw = result.scalar_one_or_none() if raw is None: return {"total": 0, "data": []} commands: list = json.loads(raw) if isinstance(raw, str) else raw if service: commands = [c for c in commands if c.get("service") == service] total = len(commands) page = commands[offset: offset + limit] return {"total": total, "data": page}