diff --git a/decnet/web/db/sqlmodel_repo/__init__.py b/decnet/web/db/sqlmodel_repo/__init__.py index beac5e96..bf429892 100644 --- a/decnet/web/db/sqlmodel_repo/__init__.py +++ b/decnet/web/db/sqlmodel_repo/__init__.py @@ -68,6 +68,7 @@ from decnet.web.db.sqlmodel_repo.auth import AuthMixin from decnet.web.db.sqlmodel_repo.bounties import BountiesMixin from decnet.web.db.sqlmodel_repo.deckies import DeckiesMixin from decnet.web.db.sqlmodel_repo.fleet import FleetMixin +from decnet.web.db.sqlmodel_repo.logs import LogsMixin from decnet.web.db.sqlmodel_repo.swarm import SwarmMixin from decnet.web.db.sqlmodel_repo.webhooks import WebhooksMixin @@ -78,6 +79,7 @@ class SQLModelRepository( BountiesMixin, DeckiesMixin, FleetMixin, + LogsMixin, SwarmMixin, WebhooksMixin, BaseRepository, @@ -146,199 +148,6 @@ class SQLModelRepository( rows. Override per dialect — DDL introspection is non-portable.""" return None - # ---------------------------------------------------------------- logs - - @staticmethod - def _normalize_log_row(log_data: dict[str, Any]) -> dict[str, Any]: - data = log_data.copy() - if "fields" in data and isinstance(data["fields"], dict): - data["fields"] = orjson.dumps(data["fields"]).decode() - if "timestamp" in data and isinstance(data["timestamp"], str): - try: - data["timestamp"] = datetime.fromisoformat( - data["timestamp"].replace("Z", "+00:00") - ) - except ValueError: - pass - return data - - async def add_log(self, log_data: dict[str, Any]) -> None: - data = self._normalize_log_row(log_data) - async with self._session() as session: - session.add(Log(**data)) - await session.commit() - - async def add_logs(self, log_entries: list[dict[str, Any]]) -> None: - """Bulk insert — one session, one commit for the whole batch.""" - if not log_entries: - return - _rows = [Log(**self._normalize_log_row(e)) for e in log_entries] - async with self._session() as session: - session.add_all(_rows) - await session.commit() - - def _apply_filters( - self, - statement: SelectOfScalar, - search: Optional[str], - start_time: Optional[str], - end_time: Optional[str], - ) -> SelectOfScalar: - import re - import shlex - - if start_time: - statement = statement.where(Log.timestamp >= start_time) - if end_time: - statement = statement.where(Log.timestamp <= end_time) - - if search: - try: - tokens = shlex.split(search) - except ValueError: - tokens = search.split() - - core_fields = { - "decky": Log.decky, - "service": Log.service, - "event": Log.event_type, - "attacker": Log.attacker_ip, - "attacker-ip": Log.attacker_ip, - "attacker_ip": Log.attacker_ip, - } - - for token in tokens: - if ":" in token: - key, val = token.split(":", 1) - if key in core_fields: - statement = statement.where(core_fields[key] == val) - else: - key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) - if key_safe: - statement = statement.where( - self._json_field_equals(key_safe) - ).params(val=val) - else: - lk = f"%{token}%" - statement = statement.where( - or_( - Log.raw_line.like(lk), - Log.decky.like(lk), - Log.service.like(lk), - Log.attacker_ip.like(lk), - ) - ) - return statement - - def _json_field_equals(self, key: str): - """Return a text() predicate that matches rows where fields->key == :val. - - Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also - exposes the same function under ``json_extract`` (case-insensitive). - The ``:val`` parameter is bound separately and must be supplied with - ``.params(val=...)`` by the caller, which keeps us safe from injection. - """ - return text(f"JSON_EXTRACT(fields, '$.{key}') = :val") - - async def get_logs( - self, - limit: int = 50, - offset: int = 0, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - ) -> List[dict]: - statement = ( - select(Log) - .order_by(desc(Log.timestamp)) - .offset(offset) - .limit(limit) - ) - statement = self._apply_filters(statement, search, start_time, end_time) - - async with self._session() as session: - results = await session.execute(statement) - return [log.model_dump(mode="json") for log in results.scalars().all()] - - async def get_max_log_id(self) -> int: - async with self._session() as session: - result = await session.execute(select(func.max(Log.id))) - val = result.scalar() - return val if val is not None else 0 - - async def get_logs_after_id( - self, - last_id: int, - limit: int = 50, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - ) -> List[dict]: - statement = ( - select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) - ) - statement = self._apply_filters(statement, search, start_time, end_time) - - async with self._session() as session: - results = await session.execute(statement) - return [log.model_dump(mode="json") for log in results.scalars().all()] - - async def get_total_logs( - self, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - ) -> int: - statement = select(func.count()).select_from(Log) - statement = self._apply_filters(statement, search, start_time, end_time) - - async with self._session() as session: - result = await session.execute(statement) - return result.scalar() or 0 - - async def get_log_histogram( - self, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - interval_minutes: int = 15, - ) -> List[dict]: - """Dialect-specific — override per backend.""" - raise NotImplementedError - - async def get_stats_summary(self) -> dict[str, Any]: - async with self._session() as session: - total_logs = ( - await session.execute(select(func.count()).select_from(Log)) - ).scalar() or 0 - unique_attackers = ( - await session.execute( - select(func.count(func.distinct(Log.attacker_ip))) - ) - ).scalar() or 0 - topo_total = ( - await session.execute(select(func.count()).select_from(TopologyDecky)) - ).scalar() or 0 - topo_running = ( - await session.execute( - select(func.count()) - .select_from(TopologyDecky) - .where(TopologyDecky.state == "running") - ) - ).scalar() or 0 - - _state = await asyncio.to_thread(load_state) - fleet_deckies = len(_state[0].deckies) if _state else 0 - - return { - "total_logs": total_logs, - "unique_attackers": unique_attackers, - # Fleet state file doesn't track per-decky runtime; treat all - # fleet rows as active and add MazeNET running rows on top. - "active_deckies": fleet_deckies + topo_running, - "deployed_deckies": fleet_deckies + topo_total, - } - async def get_deckies(self) -> List[dict]: _state = await asyncio.to_thread(load_state) return [_d.model_dump() for _d in _state[0].deckies] if _state else [] diff --git a/decnet/web/db/sqlmodel_repo/logs.py b/decnet/web/db/sqlmodel_repo/logs.py new file mode 100644 index 00000000..4d1ae1a4 --- /dev/null +++ b/decnet/web/db/sqlmodel_repo/logs.py @@ -0,0 +1,213 @@ +"""Log ingestion, query, and the stats summary endpoint. + +``get_log_histogram`` is the per-dialect override point; the abstract +default raises NotImplementedError. ``get_stats_summary`` joins log +counts, topology-decky counts, and the on-disk fleet state into a +single dashboard payload. +""" +from __future__ import annotations + +import asyncio +import re +import shlex +from datetime import datetime +from typing import Any, List, Optional + +import orjson +from sqlalchemy import asc, desc, func, or_, select, text +from sqlmodel.sql.expression import SelectOfScalar + +from decnet.config import load_state +from decnet.web.db.models import Log, TopologyDecky + + +class LogsMixin: + """Mixin: composed onto ``SQLModelRepository``.""" + + @staticmethod + def _normalize_log_row(log_data: dict[str, Any]) -> dict[str, Any]: + data = log_data.copy() + if "fields" in data and isinstance(data["fields"], dict): + data["fields"] = orjson.dumps(data["fields"]).decode() + if "timestamp" in data and isinstance(data["timestamp"], str): + try: + data["timestamp"] = datetime.fromisoformat( + data["timestamp"].replace("Z", "+00:00") + ) + except ValueError: + pass + return data + + async def add_log(self, log_data: dict[str, Any]) -> None: + data = self._normalize_log_row(log_data) + async with self._session() as session: + session.add(Log(**data)) + await session.commit() + + async def add_logs(self, log_entries: list[dict[str, Any]]) -> None: + """Bulk insert — one session, one commit for the whole batch.""" + if not log_entries: + return + _rows = [Log(**self._normalize_log_row(e)) for e in log_entries] + async with self._session() as session: + session.add_all(_rows) + await session.commit() + + def _apply_filters( + self, + statement: SelectOfScalar, + search: Optional[str], + start_time: Optional[str], + end_time: Optional[str], + ) -> SelectOfScalar: + if start_time: + statement = statement.where(Log.timestamp >= start_time) + if end_time: + statement = statement.where(Log.timestamp <= end_time) + + if search: + try: + tokens = shlex.split(search) + except ValueError: + tokens = search.split() + + core_fields = { + "decky": Log.decky, + "service": Log.service, + "event": Log.event_type, + "attacker": Log.attacker_ip, + "attacker-ip": Log.attacker_ip, + "attacker_ip": Log.attacker_ip, + } + + for token in tokens: + if ":" in token: + key, val = token.split(":", 1) + if key in core_fields: + statement = statement.where(core_fields[key] == val) + else: + key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) + if key_safe: + statement = statement.where( + self._json_field_equals(key_safe) + ).params(val=val) + else: + lk = f"%{token}%" + statement = statement.where( + or_( + Log.raw_line.like(lk), + Log.decky.like(lk), + Log.service.like(lk), + Log.attacker_ip.like(lk), + ) + ) + return statement + + def _json_field_equals(self, key: str): + """Return a text() predicate that matches rows where fields->key == :val. + + Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also + exposes the same function under ``json_extract`` (case-insensitive). + The ``:val`` parameter is bound separately and must be supplied with + ``.params(val=...)`` by the caller, which keeps us safe from injection. + """ + return text(f"JSON_EXTRACT(fields, '$.{key}') = :val") + + async def get_logs( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Log) + .order_by(desc(Log.timestamp)) + .offset(offset) + .limit(limit) + ) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self._session() as session: + results = await session.execute(statement) + return [log.model_dump(mode="json") for log in results.scalars().all()] + + async def get_max_log_id(self) -> int: + async with self._session() as session: + result = await session.execute(select(func.max(Log.id))) + val = result.scalar() + return val if val is not None else 0 + + async def get_logs_after_id( + self, + last_id: int, + limit: int = 50, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> List[dict]: + statement = ( + select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) + ) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self._session() as session: + results = await session.execute(statement) + return [log.model_dump(mode="json") for log in results.scalars().all()] + + async def get_total_logs( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + ) -> int: + statement = select(func.count()).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self._session() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_log_histogram( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + interval_minutes: int = 15, + ) -> List[dict]: + """Dialect-specific — override per backend.""" + raise NotImplementedError + + async def get_stats_summary(self) -> dict[str, Any]: + async with self._session() as session: + total_logs = ( + await session.execute(select(func.count()).select_from(Log)) + ).scalar() or 0 + unique_attackers = ( + await session.execute( + select(func.count(func.distinct(Log.attacker_ip))) + ) + ).scalar() or 0 + topo_total = ( + await session.execute(select(func.count()).select_from(TopologyDecky)) + ).scalar() or 0 + topo_running = ( + await session.execute( + select(func.count()) + .select_from(TopologyDecky) + .where(TopologyDecky.state == "running") + ) + ).scalar() or 0 + + _state = await asyncio.to_thread(load_state) + fleet_deckies = len(_state[0].deckies) if _state else 0 + + return { + "total_logs": total_logs, + "unique_attackers": unique_attackers, + # Fleet state file doesn't track per-decky runtime; treat all + # fleet rows as active and add MazeNET running rows on top. + "active_deckies": fleet_deckies + topo_running, + "deployed_deckies": fleet_deckies + topo_total, + }