""" Shared SQLModel-based repository implementation. Contains all dialect-portable query code used by the SQLite and MySQL backends. Dialect-specific behavior lives in subclasses: * engine/session construction (``__init__``) * ``_migrate_attackers_table`` (legacy schema check; DDL introspection is not portable) * ``get_log_histogram`` (date-bucket expression differs per dialect) """ from __future__ import annotations import asyncio import json import uuid from datetime import datetime, timezone from typing import Any, Optional, List from sqlalchemy import func, select, desc, asc, text, or_, update from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlmodel.sql.expression import SelectOfScalar from decnet.config import load_state from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD from decnet.web.auth import get_password_hash from decnet.web.db.repository import BaseRepository from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior class SQLModelRepository(BaseRepository): """Concrete SQLModel/SQLAlchemy-async repository. Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory`` in ``__init__``, and override the few dialect-specific helpers. """ engine: AsyncEngine session_factory: async_sessionmaker[AsyncSession] # ------------------------------------------------------------ lifecycle async def initialize(self) -> None: """Create tables if absent and seed the admin user.""" from sqlmodel import SQLModel await self._migrate_attackers_table() async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await self._ensure_admin_user() async def reinitialize(self) -> None: """Re-create schema (for tests / reset flows). Does NOT drop existing tables.""" from sqlmodel import SQLModel async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await self._ensure_admin_user() async def _ensure_admin_user(self) -> None: async with self.session_factory() as session: result = await session.execute( select(User).where(User.username == DECNET_ADMIN_USER) ) if not result.scalar_one_or_none(): session.add(User( uuid=str(uuid.uuid4()), username=DECNET_ADMIN_USER, password_hash=get_password_hash(DECNET_ADMIN_PASSWORD), role="admin", must_change_password=True, )) await session.commit() async def _migrate_attackers_table(self) -> None: """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" return None # ---------------------------------------------------------------- logs async def add_log(self, log_data: dict[str, Any]) -> None: data = log_data.copy() if "fields" in data and isinstance(data["fields"], dict): data["fields"] = json.dumps(data["fields"]) if "timestamp" in data and isinstance(data["timestamp"], str): try: data["timestamp"] = datetime.fromisoformat( data["timestamp"].replace("Z", "+00:00") ) except ValueError: pass async with self.session_factory() as session: session.add(Log(**data)) await session.commit() def _apply_filters( self, statement: SelectOfScalar, search: Optional[str], start_time: Optional[str], end_time: Optional[str], ) -> SelectOfScalar: import re import shlex if start_time: statement = statement.where(Log.timestamp >= start_time) if end_time: statement = statement.where(Log.timestamp <= end_time) if search: try: tokens = shlex.split(search) except ValueError: tokens = search.split() core_fields = { "decky": Log.decky, "service": Log.service, "event": Log.event_type, "attacker": Log.attacker_ip, "attacker-ip": Log.attacker_ip, "attacker_ip": Log.attacker_ip, } for token in tokens: if ":" in token: key, val = token.split(":", 1) if key in core_fields: statement = statement.where(core_fields[key] == val) else: key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) if key_safe: statement = statement.where( self._json_field_equals(key_safe) ).params(val=val) else: lk = f"%{token}%" statement = statement.where( or_( Log.raw_line.like(lk), Log.decky.like(lk), Log.service.like(lk), Log.attacker_ip.like(lk), ) ) return statement def _json_field_equals(self, key: str): """Return a text() predicate that matches rows where fields->key == :val. Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also exposes the same function under ``json_extract`` (case-insensitive). The ``:val`` parameter is bound separately and must be supplied with ``.params(val=...)`` by the caller, which keeps us safe from injection. """ return text(f"JSON_EXTRACT(fields, '$.{key}') = :val") async def get_logs( self, limit: int = 50, offset: int = 0, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, ) -> List[dict]: statement = ( select(Log) .order_by(desc(Log.timestamp)) .offset(offset) .limit(limit) ) statement = self._apply_filters(statement, search, start_time, end_time) async with self.session_factory() as session: results = await session.execute(statement) return [log.model_dump(mode="json") for log in results.scalars().all()] async def get_max_log_id(self) -> int: async with self.session_factory() as session: result = await session.execute(select(func.max(Log.id))) val = result.scalar() return val if val is not None else 0 async def get_logs_after_id( self, last_id: int, limit: int = 50, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, ) -> List[dict]: statement = ( select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) ) statement = self._apply_filters(statement, search, start_time, end_time) async with self.session_factory() as session: results = await session.execute(statement) return [log.model_dump(mode="json") for log in results.scalars().all()] async def get_total_logs( self, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, ) -> int: statement = select(func.count()).select_from(Log) statement = self._apply_filters(statement, search, start_time, end_time) async with self.session_factory() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_log_histogram( self, search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, interval_minutes: int = 15, ) -> List[dict]: """Dialect-specific — override per backend.""" raise NotImplementedError async def get_stats_summary(self) -> dict[str, Any]: async with self.session_factory() as session: total_logs = ( await session.execute(select(func.count()).select_from(Log)) ).scalar() or 0 unique_attackers = ( await session.execute( select(func.count(func.distinct(Log.attacker_ip))) ) ).scalar() or 0 _state = await asyncio.to_thread(load_state) deployed_deckies = len(_state[0].deckies) if _state else 0 return { "total_logs": total_logs, "unique_attackers": unique_attackers, "active_deckies": deployed_deckies, "deployed_deckies": deployed_deckies, } async def get_deckies(self) -> List[dict]: _state = await asyncio.to_thread(load_state) return [_d.model_dump() for _d in _state[0].deckies] if _state else [] # --------------------------------------------------------------- users async def get_user_by_username(self, username: str) -> Optional[dict]: async with self.session_factory() as session: result = await session.execute( select(User).where(User.username == username) ) user = result.scalar_one_or_none() return user.model_dump() if user else None async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: async with self.session_factory() as session: result = await session.execute( select(User).where(User.uuid == uuid) ) user = result.scalar_one_or_none() return user.model_dump() if user else None async def create_user(self, user_data: dict[str, Any]) -> None: async with self.session_factory() as session: session.add(User(**user_data)) await session.commit() async def update_user_password( self, uuid: str, password_hash: str, must_change_password: bool = False ) -> None: async with self.session_factory() as session: await session.execute( update(User) .where(User.uuid == uuid) .values( password_hash=password_hash, must_change_password=must_change_password, ) ) await session.commit() async def list_users(self) -> list[dict]: async with self.session_factory() as session: result = await session.execute(select(User)) return [u.model_dump() for u in result.scalars().all()] async def delete_user(self, uuid: str) -> bool: async with self.session_factory() as session: result = await session.execute(select(User).where(User.uuid == uuid)) user = result.scalar_one_or_none() if not user: return False await session.delete(user) await session.commit() return True async def update_user_role(self, uuid: str, role: str) -> None: async with self.session_factory() as session: await session.execute( update(User).where(User.uuid == uuid).values(role=role) ) await session.commit() async def purge_logs_and_bounties(self) -> dict[str, int]: async with self.session_factory() as session: logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount # attacker_behavior has FK → attackers.uuid; delete children first. await session.execute(text("DELETE FROM attacker_behavior")) attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount await session.commit() return { "logs": logs_deleted, "bounties": bounties_deleted, "attackers": attackers_deleted, } # ------------------------------------------------------------ bounties async def add_bounty(self, bounty_data: dict[str, Any]) -> None: data = bounty_data.copy() if "payload" in data and isinstance(data["payload"], dict): data["payload"] = json.dumps(data["payload"]) async with self.session_factory() as session: session.add(Bounty(**data)) await session.commit() def _apply_bounty_filters( self, statement: SelectOfScalar, bounty_type: Optional[str], search: Optional[str], ) -> SelectOfScalar: if bounty_type: statement = statement.where(Bounty.bounty_type == bounty_type) if search: lk = f"%{search}%" statement = statement.where( or_( Bounty.decky.like(lk), Bounty.service.like(lk), Bounty.attacker_ip.like(lk), Bounty.payload.like(lk), ) ) return statement async def get_bounties( self, limit: int = 50, offset: int = 0, bounty_type: Optional[str] = None, search: Optional[str] = None, ) -> List[dict]: statement = ( select(Bounty) .order_by(desc(Bounty.timestamp)) .offset(offset) .limit(limit) ) statement = self._apply_bounty_filters(statement, bounty_type, search) async with self.session_factory() as session: results = await session.execute(statement) final = [] for item in results.scalars().all(): d = item.model_dump(mode="json") try: d["payload"] = json.loads(d["payload"]) except (json.JSONDecodeError, TypeError): pass final.append(d) return final async def get_total_bounties( self, bounty_type: Optional[str] = None, search: Optional[str] = None ) -> int: statement = select(func.count()).select_from(Bounty) statement = self._apply_bounty_filters(statement, bounty_type, search) async with self.session_factory() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_state(self, key: str) -> Optional[dict[str, Any]]: async with self.session_factory() as session: statement = select(State).where(State.key == key) result = await session.execute(statement) state = result.scalar_one_or_none() if state: return json.loads(state.value) return None async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 async with self.session_factory() as session: statement = select(State).where(State.key == key) result = await session.execute(statement) state = result.scalar_one_or_none() value_json = json.dumps(value) if state: state.value = value_json session.add(state) else: session.add(State(key=key, value=value_json)) await session.commit() # ----------------------------------------------------------- attackers async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict async with self.session_factory() as session: result = await session.execute( select(Bounty).order_by(asc(Bounty.timestamp)) ) grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) for item in result.scalars().all(): d = item.model_dump(mode="json") try: d["payload"] = json.loads(d["payload"]) except (json.JSONDecodeError, TypeError): pass grouped[item.attacker_ip].append(d) return dict(grouped) async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict async with self.session_factory() as session: result = await session.execute( select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) ) grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) for item in result.scalars().all(): d = item.model_dump(mode="json") try: d["payload"] = json.loads(d["payload"]) except (json.JSONDecodeError, TypeError): pass grouped[item.attacker_ip].append(d) return dict(grouped) async def upsert_attacker(self, data: dict[str, Any]) -> str: async with self.session_factory() as session: result = await session.execute( select(Attacker).where(Attacker.ip == data["ip"]) ) existing = result.scalar_one_or_none() if existing: for k, v in data.items(): setattr(existing, k, v) session.add(existing) row_uuid = existing.uuid else: row_uuid = str(uuid.uuid4()) data = {**data, "uuid": row_uuid} session.add(Attacker(**data)) await session.commit() return row_uuid async def upsert_attacker_behavior( self, attacker_uuid: str, data: dict[str, Any], ) -> None: async with self.session_factory() as session: result = await session.execute( select(AttackerBehavior).where( AttackerBehavior.attacker_uuid == attacker_uuid ) ) existing = result.scalar_one_or_none() payload = {**data, "updated_at": datetime.now(timezone.utc)} if existing: for k, v in payload.items(): setattr(existing, k, v) session.add(existing) else: session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload)) await session.commit() async def get_attacker_behavior( self, attacker_uuid: str, ) -> Optional[dict[str, Any]]: async with self.session_factory() as session: result = await session.execute( select(AttackerBehavior).where( AttackerBehavior.attacker_uuid == attacker_uuid ) ) row = result.scalar_one_or_none() if not row: return None return self._deserialize_behavior(row.model_dump(mode="json")) async def get_behaviors_for_ips( self, ips: set[str], ) -> dict[str, dict[str, Any]]: if not ips: return {} async with self.session_factory() as session: result = await session.execute( select(Attacker.ip, AttackerBehavior) .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) .where(Attacker.ip.in_(ips)) ) out: dict[str, dict[str, Any]] = {} for ip, row in result.all(): out[ip] = self._deserialize_behavior(row.model_dump(mode="json")) return out @staticmethod def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]: for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"): if isinstance(d.get(key), str): try: d[key] = json.loads(d[key]) except (json.JSONDecodeError, TypeError): pass return d @staticmethod def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]: for key in ("services", "deckies", "fingerprints", "commands"): if isinstance(d.get(key), str): try: d[key] = json.loads(d[key]) except (json.JSONDecodeError, TypeError): pass return d async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: async with self.session_factory() as session: result = await session.execute( select(Attacker).where(Attacker.uuid == uuid) ) attacker = result.scalar_one_or_none() if not attacker: return None return self._deserialize_attacker(attacker.model_dump(mode="json")) async def get_attackers( self, limit: int = 50, offset: int = 0, search: Optional[str] = None, sort_by: str = "recent", service: Optional[str] = None, ) -> List[dict[str, Any]]: order = { "active": desc(Attacker.event_count), "traversals": desc(Attacker.is_traversal), }.get(sort_by, desc(Attacker.last_seen)) statement = select(Attacker).order_by(order).offset(offset).limit(limit) if search: statement = statement.where(Attacker.ip.like(f"%{search}%")) if service: statement = statement.where(Attacker.services.like(f'%"{service}"%')) async with self.session_factory() as session: result = await session.execute(statement) return [ self._deserialize_attacker(a.model_dump(mode="json")) for a in result.scalars().all() ] async def get_total_attackers( self, search: Optional[str] = None, service: Optional[str] = None ) -> int: statement = select(func.count()).select_from(Attacker) if search: statement = statement.where(Attacker.ip.like(f"%{search}%")) if service: statement = statement.where(Attacker.services.like(f'%"{service}"%')) async with self.session_factory() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_attacker_commands( self, uuid: str, limit: int = 50, offset: int = 0, service: Optional[str] = None, ) -> dict[str, Any]: async with self.session_factory() as session: result = await session.execute( select(Attacker.commands).where(Attacker.uuid == uuid) ) raw = result.scalar_one_or_none() if raw is None: return {"total": 0, "data": []} commands: list = json.loads(raw) if isinstance(raw, str) else raw if service: commands = [c for c in commands if c.get("service") == service] total = len(commands) page = commands[offset: offset + limit] return {"total": total, "data": page}