diff --git a/decnet/web/db/mysql/__init__.py b/decnet/web/db/mysql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc b/decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000..9e6b21b Binary files /dev/null and b/decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc differ diff --git a/decnet/web/db/mysql/__pycache__/database.cpython-314.pyc b/decnet/web/db/mysql/__pycache__/database.cpython-314.pyc new file mode 100644 index 0000000..e7ecf39 Binary files /dev/null and b/decnet/web/db/mysql/__pycache__/database.cpython-314.pyc differ diff --git a/decnet/web/db/mysql/__pycache__/repository.cpython-314.pyc b/decnet/web/db/mysql/__pycache__/repository.cpython-314.pyc new file mode 100644 index 0000000..704b501 Binary files /dev/null and b/decnet/web/db/mysql/__pycache__/repository.cpython-314.pyc differ diff --git a/decnet/web/db/mysql/database.py b/decnet/web/db/mysql/database.py new file mode 100644 index 0000000..73a4185 --- /dev/null +++ b/decnet/web/db/mysql/database.py @@ -0,0 +1,98 @@ +""" +MySQL async engine factory. + +Builds a SQLAlchemy AsyncEngine against MySQL using the ``aiomysql`` driver. + +Connection info is resolved (in order of precedence): + +1. An explicit ``url`` argument passed to :func:`get_async_engine` +2. ``DECNET_DB_URL`` — full SQLAlchemy URL +3. Component env vars: + ``DECNET_DB_HOST`` (default ``localhost``) + ``DECNET_DB_PORT`` (default ``3306``) + ``DECNET_DB_NAME`` (default ``decnet``) + ``DECNET_DB_USER`` (default ``decnet``) + ``DECNET_DB_PASSWORD`` (default empty — raises unless pytest is running) +""" +from __future__ import annotations + +import os +from typing import Optional +from urllib.parse import quote_plus + +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + + +DEFAULT_POOL_SIZE = 10 +DEFAULT_MAX_OVERFLOW = 20 +DEFAULT_POOL_RECYCLE = 3600 # seconds — avoid MySQL ``wait_timeout`` disconnects +DEFAULT_POOL_PRE_PING = True + + +def build_mysql_url( + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, +) -> str: + """Compose an async SQLAlchemy URL for MySQL using the aiomysql driver. + + Component args override env vars. Password is percent-encoded so special + characters (``@``, ``:``, ``/``…) don't break URL parsing. + """ + host = host or os.environ.get("DECNET_DB_HOST", "localhost") + port = port or int(os.environ.get("DECNET_DB_PORT", "3306")) + database = database or os.environ.get("DECNET_DB_NAME", "decnet") + user = user or os.environ.get("DECNET_DB_USER", "decnet") + + if password is None: + password = os.environ.get("DECNET_DB_PASSWORD", "") + + # Allow empty passwords during tests (pytest sets PYTEST_* env vars). + # Outside tests, an empty MySQL password is almost never intentional. + if not password and not any(k.startswith("PYTEST") for k in os.environ): + raise ValueError( + "DECNET_DB_PASSWORD is not set. Either export it, set DECNET_DB_URL, " + "or run under pytest for an empty-password default." + ) + + pw_enc = quote_plus(password) + user_enc = quote_plus(user) + return f"mysql+aiomysql://{user_enc}:{pw_enc}@{host}:{port}/{database}" + + +def resolve_url(url: Optional[str] = None) -> str: + """Pick a connection URL: explicit arg → DECNET_DB_URL env → built from components.""" + if url: + return url + env_url = os.environ.get("DECNET_DB_URL") + if env_url: + return env_url + return build_mysql_url() + + +def get_async_engine( + url: Optional[str] = None, + *, + pool_size: int = DEFAULT_POOL_SIZE, + max_overflow: int = DEFAULT_MAX_OVERFLOW, + pool_recycle: int = DEFAULT_POOL_RECYCLE, + pool_pre_ping: bool = DEFAULT_POOL_PRE_PING, + echo: bool = False, +) -> AsyncEngine: + """Create an AsyncEngine for MySQL. + + Defaults tuned for a dashboard workload: a modest pool, hourly recycle + to sidestep MySQL's idle-connection reaper, and pre-ping to fail fast + if a pooled connection has been killed server-side. + """ + dsn = resolve_url(url) + return create_async_engine( + dsn, + echo=echo, + pool_size=pool_size, + max_overflow=max_overflow, + pool_recycle=pool_recycle, + pool_pre_ping=pool_pre_ping, + ) diff --git a/decnet/web/db/mysql/repository.py b/decnet/web/db/mysql/repository.py new file mode 100644 index 0000000..533b061 --- /dev/null +++ b/decnet/web/db/mysql/repository.py @@ -0,0 +1,87 @@ +""" +MySQL implementation of :class:`BaseRepository`. + +Inherits the portable SQLModel query code from :class:`SQLModelRepository` +and only overrides the two places where MySQL's SQL dialect differs from +SQLite's: + +* :meth:`_migrate_attackers_table` — uses ``information_schema`` (MySQL + has no ``PRAGMA``). +* :meth:`get_log_histogram` — uses ``FROM_UNIXTIME`` / + ``UNIX_TIMESTAMP`` + integer division for bucketing. +""" +from __future__ import annotations + +from typing import List, Optional + +from sqlalchemy import func, select, text, literal_column +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlmodel.sql.expression import SelectOfScalar + +from decnet.web.db.models import Log +from decnet.web.db.mysql.database import get_async_engine +from decnet.web.db.sqlmodel_repo import SQLModelRepository + + +class MySQLRepository(SQLModelRepository): + """MySQL backend — uses ``aiomysql``.""" + + def __init__(self, url: Optional[str] = None, **engine_kwargs) -> None: + self.engine = get_async_engine(url=url, **engine_kwargs) + self.session_factory = async_sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) + + async def _migrate_attackers_table(self) -> None: + """Drop the legacy (pre-UUID) ``attackers`` table if it exists without a ``uuid`` column. + + MySQL exposes column metadata via ``information_schema.COLUMNS``. + ``DATABASE()`` scopes the lookup to the currently connected schema. + """ + async with self.engine.begin() as conn: + rows = (await conn.execute(text( + "SELECT COLUMN_NAME FROM information_schema.COLUMNS " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'attackers'" + ))).fetchall() + if rows and not any(r[0] == "uuid" for r in rows): + await conn.execute(text("DROP TABLE attackers")) + + def _json_field_equals(self, key: str): + # MySQL 5.7+ exposes JSON_EXTRACT; quoted string result returned for + # TEXT-stored JSON, same behavior we rely on in SQLite. + return text(f"JSON_UNQUOTE(JSON_EXTRACT(fields, '$.{key}')) = :val") + + async def get_log_histogram( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + interval_minutes: int = 15, + ) -> List[dict]: + bucket_seconds = max(interval_minutes, 1) * 60 + # Truncate each timestamp to the start of its bucket: + # FROM_UNIXTIME( (UNIX_TIMESTAMP(timestamp) DIV N) * N ) + # DIV is MySQL's integer division operator. + bucket_expr = literal_column( + f"FROM_UNIXTIME((UNIX_TIMESTAMP(timestamp) DIV {bucket_seconds}) * {bucket_seconds})" + ).label("bucket_time") + + statement: SelectOfScalar = select(bucket_expr, func.count().label("count")).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + statement = statement.group_by(literal_column("bucket_time")).order_by( + literal_column("bucket_time") + ) + + async with self.session_factory() as session: + results = await session.execute(statement) + # Normalize to ISO string for API parity with the SQLite backend + # (SQLite's datetime() returns a string already; FROM_UNIXTIME + # returns a datetime). + out: List[dict] = [] + for r in results.all(): + ts = r[0] + out.append({ + "time": ts.isoformat(sep=" ") if hasattr(ts, "isoformat") else ts, + "count": r[1], + }) + return out diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py new file mode 100644 index 0000000..e50b652 --- /dev/null +++ b/decnet/web/db/sqlmodel_repo.py @@ -0,0 +1,637 @@ +""" +Shared SQLModel-based repository implementation. + +Contains all dialect-portable query code used by the SQLite and MySQL +backends. Dialect-specific behavior lives in subclasses: + +* engine/session construction (``__init__``) +* ``_migrate_attackers_table`` (legacy schema check; DDL introspection + is not portable) +* ``get_log_histogram`` (date-bucket expression differs per dialect) +""" +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from typing import Any, Optional, List + +from sqlalchemy import func, select, desc, asc, text, or_, update +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker +from sqlmodel.sql.expression import SelectOfScalar + +from decnet.config import load_state +from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD +from decnet.web.auth import get_password_hash +from decnet.web.db.repository import BaseRepository +from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior + + +class SQLModelRepository(BaseRepository): + """Concrete SQLModel/SQLAlchemy-async repository. + + Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory`` + in ``__init__``, and override the few dialect-specific helpers. + """ + + engine: AsyncEngine + session_factory: async_sessionmaker[AsyncSession] + + # ------------------------------------------------------------ lifecycle + + async def initialize(self) -> None: + """Create tables if absent and seed the admin user.""" + from sqlmodel import SQLModel + await self._migrate_attackers_table() + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await self._ensure_admin_user() + + async def reinitialize(self) -> None: + """Re-create schema (for tests / reset flows). Does NOT drop existing tables.""" + from sqlmodel import SQLModel + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await self._ensure_admin_user() + + async def _ensure_admin_user(self) -> None: + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.username == DECNET_ADMIN_USER) + ) + if not result.scalar_one_or_none(): + session.add(User( + uuid=str(uuid.uuid4()), + username=DECNET_ADMIN_USER, + password_hash=get_password_hash(DECNET_ADMIN_PASSWORD), + role="admin", + must_change_password=True, + )) + await session.commit() + + async def _migrate_attackers_table(self) -> None: + """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" + return None + + # ---------------------------------------------------------------- logs + + async def add_log(self, log_data: dict[str, Any]) -> None: + data = log_data.copy() + if "fields" in data and isinstance(data["fields"], dict): + data["fields"] = json.dumps(data["fields"]) + if "timestamp" in data and isinstance(data["timestamp"], str): + try: + data["timestamp"] = datetime.fromisoformat( + data["timestamp"].replace("Z", "+00:00") + ) + except ValueError: + pass + + async with self.session_factory() as session: + session.add(Log(**data)) + await session.commit() + + def _apply_filters( + self, + statement: SelectOfScalar, + search: Optional[str], + start_time: Optional[str], + end_time: Optional[str], + ) -> SelectOfScalar: + import re + import shlex + + if start_time: + statement = statement.where(Log.timestamp >= start_time) + if end_time: + statement = statement.where(Log.timestamp <= end_time) + + if search: + try: + tokens = shlex.split(search) + except ValueError: + tokens = search.split() + + core_fields = { + "decky": Log.decky, + "service": Log.service, + "event": Log.event_type, + "attacker": Log.attacker_ip, + "attacker-ip": Log.attacker_ip, + "attacker_ip": Log.attacker_ip, + } + + for token in tokens: + if ":" in token: + key, val = token.split(":", 1) + if key in core_fields: + statement = statement.where(core_fields[key] == val) + else: + key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) + if key_safe: + statement = statement.where( + self._json_field_equals(key_safe) + ).params(val=val) + else: + lk = f"%{token}%" + statement = statement.where( + or_( + Log.raw_line.like(lk), + Log.decky.like(lk), + Log.service.like(lk), + Log.attacker_ip.like(lk), + ) + ) + return statement + + def _json_field_equals(self, key: str): + """Return a text() predicate that matches rows where fields->key == :val. + + Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also + exposes the same function under ``json_extract`` (case-insensitive). + The ``:val`` parameter is bound separately and must be supplied with + ``.params(val=...)`` by the caller, which keeps us safe from injection. + """ + return text(f"JSON_EXTRACT(fields, '$.{key}') = :val") + + async def get_logs( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Log) + .order_by(desc(Log.timestamp)) + .offset(offset) + .limit(limit) + ) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [log.model_dump(mode="json") for log in results.scalars().all()] + + async def get_max_log_id(self) -> int: + async with self.session_factory() as session: + result = await session.execute(select(func.max(Log.id))) + val = result.scalar() + return val if val is not None else 0 + + async def get_logs_after_id( + self, + last_id: int, + limit: int = 50, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) + ) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [log.model_dump(mode="json") for log in results.scalars().all()] + + async def get_total_logs( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> int: + statement = select(func.count()).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_log_histogram( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + interval_minutes: int = 15, + ) -> List[dict]: + """Dialect-specific — override per backend.""" + raise NotImplementedError + + async def get_stats_summary(self) -> dict[str, Any]: + async with self.session_factory() as session: + total_logs = ( + await session.execute(select(func.count()).select_from(Log)) + ).scalar() or 0 + unique_attackers = ( + await session.execute( + select(func.count(func.distinct(Log.attacker_ip))) + ) + ).scalar() or 0 + + _state = await asyncio.to_thread(load_state) + deployed_deckies = len(_state[0].deckies) if _state else 0 + + return { + "total_logs": total_logs, + "unique_attackers": unique_attackers, + "active_deckies": deployed_deckies, + "deployed_deckies": deployed_deckies, + } + + async def get_deckies(self) -> List[dict]: + _state = await asyncio.to_thread(load_state) + return [_d.model_dump() for _d in _state[0].deckies] if _state else [] + + # --------------------------------------------------------------- users + + async def get_user_by_username(self, username: str) -> Optional[dict]: + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + return user.model_dump() if user else None + + async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.uuid == uuid) + ) + user = result.scalar_one_or_none() + return user.model_dump() if user else None + + async def create_user(self, user_data: dict[str, Any]) -> None: + async with self.session_factory() as session: + session.add(User(**user_data)) + await session.commit() + + async def update_user_password( + self, uuid: str, password_hash: str, must_change_password: bool = False + ) -> None: + async with self.session_factory() as session: + await session.execute( + update(User) + .where(User.uuid == uuid) + .values( + password_hash=password_hash, + must_change_password=must_change_password, + ) + ) + await session.commit() + + async def list_users(self) -> list[dict]: + async with self.session_factory() as session: + result = await session.execute(select(User)) + return [u.model_dump() for u in result.scalars().all()] + + async def delete_user(self, uuid: str) -> bool: + async with self.session_factory() as session: + result = await session.execute(select(User).where(User.uuid == uuid)) + user = result.scalar_one_or_none() + if not user: + return False + await session.delete(user) + await session.commit() + return True + + async def update_user_role(self, uuid: str, role: str) -> None: + async with self.session_factory() as session: + await session.execute( + update(User).where(User.uuid == uuid).values(role=role) + ) + await session.commit() + + async def purge_logs_and_bounties(self) -> dict[str, int]: + async with self.session_factory() as session: + logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount + bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount + # attacker_behavior has FK → attackers.uuid; delete children first. + await session.execute(text("DELETE FROM attacker_behavior")) + attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount + await session.commit() + return { + "logs": logs_deleted, + "bounties": bounties_deleted, + "attackers": attackers_deleted, + } + + # ------------------------------------------------------------ bounties + + async def add_bounty(self, bounty_data: dict[str, Any]) -> None: + data = bounty_data.copy() + if "payload" in data and isinstance(data["payload"], dict): + data["payload"] = json.dumps(data["payload"]) + + async with self.session_factory() as session: + session.add(Bounty(**data)) + await session.commit() + + def _apply_bounty_filters( + self, + statement: SelectOfScalar, + bounty_type: Optional[str], + search: Optional[str], + ) -> SelectOfScalar: + if bounty_type: + statement = statement.where(Bounty.bounty_type == bounty_type) + if search: + lk = f"%{search}%" + statement = statement.where( + or_( + Bounty.decky.like(lk), + Bounty.service.like(lk), + Bounty.attacker_ip.like(lk), + Bounty.payload.like(lk), + ) + ) + return statement + + async def get_bounties( + self, + limit: int = 50, + offset: int = 0, + bounty_type: Optional[str] = None, + search: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Bounty) + .order_by(desc(Bounty.timestamp)) + .offset(offset) + .limit(limit) + ) + statement = self._apply_bounty_filters(statement, bounty_type, search) + + async with self.session_factory() as session: + results = await session.execute(statement) + final = [] + for item in results.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + final.append(d) + return final + + async def get_total_bounties( + self, bounty_type: Optional[str] = None, search: Optional[str] = None + ) -> int: + statement = select(func.count()).select_from(Bounty) + statement = self._apply_bounty_filters(statement, bounty_type, search) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_state(self, key: str) -> Optional[dict[str, Any]]: + async with self.session_factory() as session: + statement = select(State).where(State.key == key) + result = await session.execute(statement) + state = result.scalar_one_or_none() + if state: + return json.loads(state.value) + return None + + async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 + async with self.session_factory() as session: + statement = select(State).where(State.key == key) + result = await session.execute(statement) + state = result.scalar_one_or_none() + + value_json = json.dumps(value) + if state: + state.value = value_json + session.add(state) + else: + session.add(State(key=key, value=value_json)) + + await session.commit() + + # ----------------------------------------------------------- attackers + + async def get_all_logs_raw(self) -> List[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select( + Log.id, + Log.raw_line, + Log.attacker_ip, + Log.service, + Log.event_type, + Log.decky, + Log.timestamp, + Log.fields, + ) + ) + return [ + { + "id": r.id, + "raw_line": r.raw_line, + "attacker_ip": r.attacker_ip, + "service": r.service, + "event_type": r.event_type, + "decky": r.decky, + "timestamp": r.timestamp, + "fields": r.fields, + } + for r in result.all() + ] + + async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: + from collections import defaultdict + async with self.session_factory() as session: + result = await session.execute( + select(Bounty).order_by(asc(Bounty.timestamp)) + ) + grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) + for item in result.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + grouped[item.attacker_ip].append(d) + return dict(grouped) + + async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: + from collections import defaultdict + async with self.session_factory() as session: + result = await session.execute( + select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) + ) + grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) + for item in result.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + grouped[item.attacker_ip].append(d) + return dict(grouped) + + async def upsert_attacker(self, data: dict[str, Any]) -> str: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker).where(Attacker.ip == data["ip"]) + ) + existing = result.scalar_one_or_none() + if existing: + for k, v in data.items(): + setattr(existing, k, v) + session.add(existing) + row_uuid = existing.uuid + else: + row_uuid = str(uuid.uuid4()) + data = {**data, "uuid": row_uuid} + session.add(Attacker(**data)) + await session.commit() + return row_uuid + + async def upsert_attacker_behavior( + self, + attacker_uuid: str, + data: dict[str, Any], + ) -> None: + async with self.session_factory() as session: + result = await session.execute( + select(AttackerBehavior).where( + AttackerBehavior.attacker_uuid == attacker_uuid + ) + ) + existing = result.scalar_one_or_none() + payload = {**data, "updated_at": datetime.now(timezone.utc)} + if existing: + for k, v in payload.items(): + setattr(existing, k, v) + session.add(existing) + else: + session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload)) + await session.commit() + + async def get_attacker_behavior( + self, + attacker_uuid: str, + ) -> Optional[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select(AttackerBehavior).where( + AttackerBehavior.attacker_uuid == attacker_uuid + ) + ) + row = result.scalar_one_or_none() + if not row: + return None + return self._deserialize_behavior(row.model_dump(mode="json")) + + async def get_behaviors_for_ips( + self, + ips: set[str], + ) -> dict[str, dict[str, Any]]: + if not ips: + return {} + async with self.session_factory() as session: + result = await session.execute( + select(Attacker.ip, AttackerBehavior) + .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) + .where(Attacker.ip.in_(ips)) + ) + out: dict[str, dict[str, Any]] = {} + for ip, row in result.all(): + out[ip] = self._deserialize_behavior(row.model_dump(mode="json")) + return out + + @staticmethod + def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]: + for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"): + if isinstance(d.get(key), str): + try: + d[key] = json.loads(d[key]) + except (json.JSONDecodeError, TypeError): + pass + return d + + @staticmethod + def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]: + for key in ("services", "deckies", "fingerprints", "commands"): + if isinstance(d.get(key), str): + try: + d[key] = json.loads(d[key]) + except (json.JSONDecodeError, TypeError): + pass + return d + + async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker).where(Attacker.uuid == uuid) + ) + attacker = result.scalar_one_or_none() + if not attacker: + return None + return self._deserialize_attacker(attacker.model_dump(mode="json")) + + async def get_attackers( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + sort_by: str = "recent", + service: Optional[str] = None, + ) -> List[dict[str, Any]]: + order = { + "active": desc(Attacker.event_count), + "traversals": desc(Attacker.is_traversal), + }.get(sort_by, desc(Attacker.last_seen)) + + statement = select(Attacker).order_by(order).offset(offset).limit(limit) + if search: + statement = statement.where(Attacker.ip.like(f"%{search}%")) + if service: + statement = statement.where(Attacker.services.like(f'%"{service}"%')) + + async with self.session_factory() as session: + result = await session.execute(statement) + return [ + self._deserialize_attacker(a.model_dump(mode="json")) + for a in result.scalars().all() + ] + + async def get_total_attackers( + self, search: Optional[str] = None, service: Optional[str] = None + ) -> int: + statement = select(func.count()).select_from(Attacker) + if search: + statement = statement.where(Attacker.ip.like(f"%{search}%")) + if service: + statement = statement.where(Attacker.services.like(f'%"{service}"%')) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_attacker_commands( + self, + uuid: str, + limit: int = 50, + offset: int = 0, + service: Optional[str] = None, + ) -> dict[str, Any]: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker.commands).where(Attacker.uuid == uuid) + ) + raw = result.scalar_one_or_none() + if raw is None: + return {"total": 0, "data": []} + + commands: list = json.loads(raw) if isinstance(raw, str) else raw + if service: + commands = [c for c in commands if c.get("service") == service] + + total = len(commands) + page = commands[offset: offset + limit] + return {"total": total, "data": page}