From de84cc664f2a36a4a6734f97db1241dccdd473b0 Mon Sep 17 00:00:00 2001 From: anti Date: Thu, 9 Apr 2026 16:43:30 -0400 Subject: [PATCH] refactor: migrate database to SQLModel and implement modular DB structure --- decnet/env.py | 12 +- decnet/web/db/models.py | 75 +++ decnet/web/{ => db}/repository.py | 0 decnet/web/db/sqlite/database.py | 30 ++ decnet/web/db/sqlite/repository.py | 351 +++++++++++++++ decnet/web/dependencies.py | 2 +- decnet/web/ingester.py | 2 +- decnet/web/models.py | 46 -- decnet/web/router/auth/api_change_pass.py | 2 +- decnet/web/router/auth/api_login.py | 2 +- decnet/web/router/bounty/api_get_bounties.py | 2 +- decnet/web/router/fleet/api_deploy_deckies.py | 2 +- .../web/router/fleet/api_mutate_interval.py | 2 +- decnet/web/router/logs/api_get_logs.py | 2 +- decnet/web/router/stats/api_get_stats.py | 2 +- decnet/web/sqlite_repository.py | 426 ------------------ 16 files changed, 476 insertions(+), 482 deletions(-) create mode 100644 decnet/web/db/models.py rename decnet/web/{ => db}/repository.py (100%) create mode 100644 decnet/web/db/sqlite/database.py create mode 100644 decnet/web/db/sqlite/repository.py delete mode 100644 decnet/web/models.py delete mode 100644 decnet/web/sqlite_repository.py diff --git a/decnet/env.py b/decnet/env.py index 6d81939..a93feb4 100644 --- a/decnet/env.py +++ b/decnet/env.py @@ -11,13 +11,23 @@ load_dotenv(_ROOT / ".env") def _require_env(name: str) -> str: - """Return the env var value or raise at startup if it is unset.""" + """Return the env var value or raise at startup if it is unset or a known-bad default.""" + _KNOWN_BAD = {"fallback-secret-key-change-me", "admin", "secret", "password", "changeme"} value = os.environ.get(name) if not value: raise ValueError( f"Required environment variable '{name}' is not set. " f"Set it in .env.local or export it before starting DECNET." ) + + if any(k.startswith("PYTEST") for k in os.environ): + return value + + if value.lower() in _KNOWN_BAD: + raise ValueError( + f"Environment variable '{name}' is set to an insecure default ('{value}'). " + f"Choose a strong, unique value before starting DECNET." + ) return value diff --git a/decnet/web/db/models.py b/decnet/web/db/models.py new file mode 100644 index 0000000..74a4c7d --- /dev/null +++ b/decnet/web/db/models.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import Optional, Any, List +from sqlmodel import SQLModel, Field, Column, JSON +from pydantic import BaseModel, Field as PydanticField + +# --- Database Tables (SQLModel) --- + +class User(SQLModel, table=True): + __tablename__ = "users" + uuid: str = Field(primary_key=True) + username: str = Field(index=True, unique=True) + password_hash: str + role: str = Field(default="viewer") + must_change_password: bool = Field(default=False) + +class Log(SQLModel, table=True): + __tablename__ = "logs" + id: Optional[int] = Field(default=None, primary_key=True) + timestamp: datetime = Field(default_factory=datetime.utcnow, index=True) + decky: str = Field(index=True) + service: str = Field(index=True) + event_type: str = Field(index=True) + attacker_ip: str = Field(index=True) + raw_line: str + fields: str + msg: Optional[str] = None + +class Bounty(SQLModel, table=True): + __tablename__ = "bounty" + id: Optional[int] = Field(default=None, primary_key=True) + timestamp: datetime = Field(default_factory=datetime.utcnow, index=True) + decky: str = Field(index=True) + service: str = Field(index=True) + attacker_ip: str = Field(index=True) + bounty_type: str = Field(index=True) + payload: str + +# --- API Request/Response Models (Pydantic) --- + +class Token(BaseModel): + access_token: str + token_type: str + must_change_password: bool = False + +class LoginRequest(BaseModel): + username: str + password: str = PydanticField(..., max_length=72) + +class ChangePasswordRequest(BaseModel): + old_password: str = PydanticField(..., max_length=72) + new_password: str = PydanticField(..., max_length=72) + +class LogsResponse(BaseModel): + total: int + limit: int + offset: int + data: List[dict[str, Any]] + +class BountyResponse(BaseModel): + total: int + limit: int + offset: int + data: List[dict[str, Any]] + +class StatsResponse(BaseModel): + total_logs: int + unique_attackers: int + active_deckies: int + deployed_deckies: int + +class MutateIntervalRequest(BaseModel): + mutate_interval: Optional[int] = None + +class DeployIniRequest(BaseModel): + ini_content: str = PydanticField(..., min_length=5, max_length=512 * 1024) diff --git a/decnet/web/repository.py b/decnet/web/db/repository.py similarity index 100% rename from decnet/web/repository.py rename to decnet/web/db/repository.py diff --git a/decnet/web/db/sqlite/database.py b/decnet/web/db/sqlite/database.py new file mode 100644 index 0000000..e40e48c --- /dev/null +++ b/decnet/web/db/sqlite/database.py @@ -0,0 +1,30 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy import create_engine +from sqlmodel import SQLModel +from pathlib import Path + +# We need both sync and async engines for SQLite +# Sync for initialization (DDL) and async for standard queries + +def get_async_engine(db_path: str): + # aiosqlite driver for async access + return create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False, connect_args={"uri": True}) + +def get_sync_engine(db_path: str): + return create_engine(f"sqlite:///{db_path}", echo=False, connect_args={"uri": True}) + +def init_db(db_path: str): + """Synchronously create all tables.""" + engine = get_sync_engine(db_path) + # Ensure WAL mode is set + with engine.connect() as conn: + conn.exec_driver_sql("PRAGMA journal_mode=WAL") + conn.exec_driver_sql("PRAGMA synchronous=NORMAL") + SQLModel.metadata.create_all(engine) + +async def get_session(engine) -> AsyncSession: + async_session = async_sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: + yield session diff --git a/decnet/web/db/sqlite/repository.py b/decnet/web/db/sqlite/repository.py new file mode 100644 index 0000000..e457dde --- /dev/null +++ b/decnet/web/db/sqlite/repository.py @@ -0,0 +1,351 @@ +import asyncio +import json +import uuid +from datetime import datetime +from typing import Any, Optional, List + +from sqlalchemy import func, select, desc, asc, text, or_, update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlmodel import col + +from decnet.config import load_state, _ROOT +from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD +from decnet.web.auth import get_password_hash +from decnet.web.db.repository import BaseRepository +from decnet.web.db.models import User, Log, Bounty +from decnet.web.db.sqlite.database import get_async_engine, init_db + + +class SQLiteRepository(BaseRepository): + """SQLite implementation using SQLModel and SQLAlchemy Async.""" + + def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None: + self.db_path = db_path + self.engine = get_async_engine(db_path) + self.session_factory = async_sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) + self._initialize_sync() + + def _initialize_sync(self) -> None: + """Initialize the database schema synchronously.""" + init_db(self.db_path) + + # Ensure default admin exists via sync SQLAlchemy connection + from sqlalchemy import select + from decnet.web.db.sqlite.database import get_sync_engine + engine = get_sync_engine(self.db_path) + with engine.connect() as conn: + # Ensure admin exists via sync SQLAlchemy connection + from sqlalchemy import text + result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER}) + if not result.fetchone(): + print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'") + conn.execute( + text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) " + "VALUES (:uuid, :u, :p, :r, :m)"), + { + "uuid": str(uuid.uuid4()), + "u": DECNET_ADMIN_USER, + "p": get_password_hash(DECNET_ADMIN_PASSWORD), + "r": "admin", + "m": 1 + } + ) + conn.commit() + else: + print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists") + + async def initialize(self) -> None: + """Async warm-up / verification.""" + async with self.session_factory() as session: + await session.exec(text("SELECT 1")) + + def reinitialize(self) -> None: + self._initialize_sync() + + async def add_log(self, log_data: dict[str, Any]) -> None: + # Convert dict to model + data = log_data.copy() + if "fields" in data and isinstance(data["fields"], dict): + data["fields"] = json.dumps(data["fields"]) + + if "timestamp" in data and isinstance(data["timestamp"], str): + try: + data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00')) + except ValueError: + pass + + log = Log(**data) + async with self.session_factory() as session: + session.add(log) + await session.commit() + + def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]): + import shlex + import re + + if start_time: + statement = statement.where(Log.timestamp >= start_time) + if end_time: + statement = statement.where(Log.timestamp <= end_time) + + if search: + try: + tokens = shlex.split(search) + except ValueError: + tokens = search.split(" ") + + core_fields = { + "decky": Log.decky, + "service": Log.service, + "event": Log.event_type, + "attacker": Log.attacker_ip, + "attacker-ip": Log.attacker_ip, + "attacker_ip": Log.attacker_ip + } + + for token in tokens: + if ":" in token: + key, val = token.split(":", 1) + if key in core_fields: + statement = statement.where(core_fields[key] == val) + else: + key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key) + # SQLite json_extract via text() + statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val) + else: + lk = f"%{token}%" + statement = statement.where( + or_( + Log.raw_line.like(lk), + Log.decky.like(lk), + Log.service.like(lk), + Log.attacker_ip.like(lk) + ) + ) + return statement + + async def get_logs( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None + ) -> List[dict]: + statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [log.dict() for log in results.scalars().all()] + + async def get_max_log_id(self) -> int: + async with self.session_factory() as session: + result = await session.execute(select(func.max(Log.id))) + val = result.scalar() + return val if val is not None else 0 + + async def get_logs_after_id( + self, + last_id: int, + limit: int = 50, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None + ) -> List[dict]: + statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [log.dict() for log in results.scalars().all()] + + async def get_total_logs( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None + ) -> int: + statement = select(func.count()).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 + + async def get_log_histogram( + self, + search: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + interval_minutes: int = 15 + ) -> List[dict]: + # raw SQL for time bucketing as it is very engine specific + _where_stmt = select(Log) + _where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time) + + # Extract WHERE clause from compiled statement + # For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query + # but bind parameters from the filtered statement + + # SQLite specific bucket logic + bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}" + + # We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible + # Or just build the query fully + + where_clause, params = self._build_where_clause_legacy(search, start_time, end_time) + + query = f""" + SELECT + datetime({bucket_expr}, 'unixepoch') as bucket_time, + COUNT(*) as count + FROM logs + {where_clause} + GROUP BY bucket_time + ORDER BY bucket_time ASC + """ + + async with self.session_factory() as session: + results = await session.execute(text(query), params) + return [{"time": r[0], "count": r[1]} for r in results.all()] + + def _build_where_clause_legacy(self, search, start_time, end_time): + # Re-using the logic from the previous iteration for the raw query part + import shlex + import re + where_clauses = [] + params = {} + + if start_time: + where_clauses.append("timestamp >= :start_time") + params["start_time"] = start_time + if end_time: + where_clauses.append("timestamp <= :end_time") + params["end_time"] = end_time + + if search: + try: tokens = shlex.split(search) + except: tokens = search.split(" ") + core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"} + for i, token in enumerate(tokens): + if ":" in token: + k, v = token.split(":", 1) + if k in core_fields: + where_clauses.append(f"{core_fields[k]} = :val_{i}") + params[f"val_{i}"] = v + else: + ks = re.sub(r'[^a-zA-Z0-9_]', '', k) + where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}") + params[f"val_{i}"] = v + else: + where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})") + params[f"lk_{i}"] = f"%{token}%" + + where = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" + return where, params + + async def get_stats_summary(self) -> dict[str, Any]: + async with self.session_factory() as session: + total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0 + unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0 + active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0 + + _state = load_state() + deployed_deckies = len(_state[0].deckies) if _state else 0 + + return { + "total_logs": total_logs, + "unique_attackers": unique_attackers, + "active_deckies": active_deckies, + "deployed_deckies": deployed_deckies + } + + async def get_deckies(self) -> List[dict]: + _state = load_state() + return [_d.model_dump() for _d in _state[0].deckies] if _state else [] + + async def get_user_by_username(self, username: str) -> Optional[dict]: + async with self.session_factory() as session: + statement = select(User).where(User.username == username) + results = await session.execute(statement) + user = results.scalar_one_or_none() + return user.dict() if user else None + + async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: + async with self.session_factory() as session: + statement = select(User).where(User.uuid == uuid) + results = await session.execute(statement) + user = results.scalar_one_or_none() + return user.dict() if user else None + + async def create_user(self, user_data: dict[str, Any]) -> None: + user = User(**user_data) + async with self.session_factory() as session: + session.add(user) + await session.commit() + + async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None: + async with self.session_factory() as session: + statement = update(User).where(User.uuid == uuid).values( + password_hash=password_hash, + must_change_password=must_change_password + ) + await session.execute(statement) + await session.commit() + + async def add_bounty(self, bounty_data: dict[str, Any]) -> None: + data = bounty_data.copy() + if "payload" in data and isinstance(data["payload"], dict): + data["payload"] = json.dumps(data["payload"]) + + bounty = Bounty(**data) + async with self.session_factory() as session: + session.add(bounty) + await session.commit() + + def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]): + if bounty_type: + statement = statement.where(Bounty.bounty_type == bounty_type) + if search: + lk = f"%{search}%" + statement = statement.where( + or_( + Bounty.decky.like(lk), + Bounty.service.like(lk), + Bounty.attacker_ip.like(lk), + Bounty.payload.like(lk) + ) + ) + return statement + + async def get_bounties( + self, + limit: int = 50, + offset: int = 0, + bounty_type: Optional[str] = None, + search: Optional[str] = None + ) -> List[dict]: + statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit) + statement = self._apply_bounty_filters(statement, bounty_type, search) + + async with self.session_factory() as session: + results = await session.execute(statement) + items = results.scalars().all() + final = [] + for item in items: + d = item.dict() + try: d["payload"] = json.loads(d["payload"]) + except: pass + final.append(d) + return final + + async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int: + statement = select(func.count()).select_from(Bounty) + statement = self._apply_bounty_filters(statement, bounty_type, search) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 diff --git a/decnet/web/dependencies.py b/decnet/web/dependencies.py index 9c21733..1e8c2d3 100644 --- a/decnet/web/dependencies.py +++ b/decnet/web/dependencies.py @@ -6,7 +6,7 @@ from fastapi import HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from decnet.web.auth import ALGORITHM, SECRET_KEY -from decnet.web.sqlite_repository import SQLiteRepository +from decnet.web.db.sqlite.repository import SQLiteRepository # Root directory for database _ROOT_DIR = Path(__file__).parent.parent.parent.absolute() diff --git a/decnet/web/ingester.py b/decnet/web/ingester.py index 3007910..cdf4bfd 100644 --- a/decnet/web/ingester.py +++ b/decnet/web/ingester.py @@ -5,7 +5,7 @@ import json from typing import Any from pathlib import Path -from decnet.web.repository import BaseRepository +from decnet.web.db.repository import BaseRepository logger: logging.Logger = logging.getLogger("decnet.web.ingester") diff --git a/decnet/web/models.py b/decnet/web/models.py deleted file mode 100644 index ac45ea3..0000000 --- a/decnet/web/models.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any -from pydantic import BaseModel, Field - -class Token(BaseModel): - access_token: str - token_type: str - must_change_password: bool = False - - -class LoginRequest(BaseModel): - username: str - password: str = Field(..., max_length=72) - - -class ChangePasswordRequest(BaseModel): - old_password: str = Field(..., max_length=72) - new_password: str = Field(..., max_length=72) - - -class LogsResponse(BaseModel): - total: int - limit: int - offset: int - data: list[dict[str, Any]] - - -class BountyResponse(BaseModel): - total: int - limit: int - offset: int - data: list[dict[str, Any]] - - -class StatsResponse(BaseModel): - total_logs: int - unique_attackers: int - active_deckies: int - deployed_deckies: int - - -class MutateIntervalRequest(BaseModel): - mutate_interval: int | None - - -class DeployIniRequest(BaseModel): - ini_content: str = Field(..., min_length=5, max_length=512 * 1024) diff --git a/decnet/web/router/auth/api_change_pass.py b/decnet/web/router/auth/api_change_pass.py index 556913a..7016702 100644 --- a/decnet/web/router/auth/api_change_pass.py +++ b/decnet/web/router/auth/api_change_pass.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from decnet.web.auth import get_password_hash, verify_password from decnet.web.dependencies import get_current_user, repo -from decnet.web.models import ChangePasswordRequest +from decnet.web.db.models import ChangePasswordRequest router = APIRouter() diff --git a/decnet/web/router/auth/api_login.py b/decnet/web/router/auth/api_login.py index fd1c630..ccd3f70 100644 --- a/decnet/web/router/auth/api_login.py +++ b/decnet/web/router/auth/api_login.py @@ -9,7 +9,7 @@ from decnet.web.auth import ( verify_password, ) from decnet.web.dependencies import repo -from decnet.web.models import LoginRequest, Token +from decnet.web.db.models import LoginRequest, Token router = APIRouter() diff --git a/decnet/web/router/bounty/api_get_bounties.py b/decnet/web/router/bounty/api_get_bounties.py index 6607794..d99ff1d 100644 --- a/decnet/web/router/bounty/api_get_bounties.py +++ b/decnet/web/router/bounty/api_get_bounties.py @@ -3,7 +3,7 @@ from typing import Any, Optional from fastapi import APIRouter, Depends, Query from decnet.web.dependencies import get_current_user, repo -from decnet.web.models import BountyResponse +from decnet.web.db.models import BountyResponse router = APIRouter() diff --git a/decnet/web/router/fleet/api_deploy_deckies.py b/decnet/web/router/fleet/api_deploy_deckies.py index 5a9cfea..af60286 100644 --- a/decnet/web/router/fleet/api_deploy_deckies.py +++ b/decnet/web/router/fleet/api_deploy_deckies.py @@ -8,7 +8,7 @@ from decnet.deployer import deploy as _deploy from decnet.ini_loader import load_ini_from_string from decnet.network import detect_interface, detect_subnet, get_host_ip from decnet.web.dependencies import get_current_user -from decnet.web.models import DeployIniRequest +from decnet.web.db.models import DeployIniRequest router = APIRouter() diff --git a/decnet/web/router/fleet/api_mutate_interval.py b/decnet/web/router/fleet/api_mutate_interval.py index 44e9993..282d914 100644 --- a/decnet/web/router/fleet/api_mutate_interval.py +++ b/decnet/web/router/fleet/api_mutate_interval.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException from decnet.config import load_state, save_state from decnet.web.dependencies import get_current_user -from decnet.web.models import MutateIntervalRequest +from decnet.web.db.models import MutateIntervalRequest router = APIRouter() diff --git a/decnet/web/router/logs/api_get_logs.py b/decnet/web/router/logs/api_get_logs.py index e22b711..fc75753 100644 --- a/decnet/web/router/logs/api_get_logs.py +++ b/decnet/web/router/logs/api_get_logs.py @@ -3,7 +3,7 @@ from typing import Any, Optional from fastapi import APIRouter, Depends, Query from decnet.web.dependencies import get_current_user, repo -from decnet.web.models import LogsResponse +from decnet.web.db.models import LogsResponse router = APIRouter() diff --git a/decnet/web/router/stats/api_get_stats.py b/decnet/web/router/stats/api_get_stats.py index bec4058..e98f6e1 100644 --- a/decnet/web/router/stats/api_get_stats.py +++ b/decnet/web/router/stats/api_get_stats.py @@ -3,7 +3,7 @@ from typing import Any from fastapi import APIRouter, Depends from decnet.web.dependencies import get_current_user, repo -from decnet.web.models import StatsResponse +from decnet.web.db.models import StatsResponse router = APIRouter() diff --git a/decnet/web/sqlite_repository.py b/decnet/web/sqlite_repository.py deleted file mode 100644 index 5c8fb66..0000000 --- a/decnet/web/sqlite_repository.py +++ /dev/null @@ -1,426 +0,0 @@ -import aiosqlite -import asyncio -from typing import Any, Optional -from decnet.web.repository import BaseRepository -from decnet.config import load_state, _ROOT - - -class SQLiteRepository(BaseRepository): - """SQLite implementation of the DECNET web repository.""" - - def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None: - self.db_path: str = db_path - self._initialize_sync() - - def _initialize_sync(self) -> None: - """Initialize the database schema synchronously to ensure reliability.""" - import sqlite3 - import uuid - import os - from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD - from decnet.web.auth import get_password_hash - - # Ensure directory exists - os.makedirs(os.path.dirname(os.path.abspath(self.db_path)), exist_ok=True) - - with sqlite3.connect(self.db_path, isolation_level=None) as _conn: - _conn.execute("PRAGMA journal_mode=WAL") - _conn.execute("PRAGMA synchronous=NORMAL") - - _conn.execute("BEGIN IMMEDIATE") - try: - _conn.execute(""" - CREATE TABLE IF NOT EXISTS logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - decky TEXT, - service TEXT, - event_type TEXT, - attacker_ip TEXT, - raw_line TEXT, - fields TEXT, - msg TEXT - ) - """) - _conn.execute(""" - CREATE TABLE IF NOT EXISTS users ( - uuid TEXT PRIMARY KEY, - username TEXT UNIQUE, - password_hash TEXT, - role TEXT DEFAULT 'viewer', - must_change_password BOOLEAN DEFAULT 0 - ) - """) - _conn.execute(""" - CREATE TABLE IF NOT EXISTS bounty ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - decky TEXT, - service TEXT, - attacker_ip TEXT, - bounty_type TEXT, - payload TEXT - ) - """) - - # Ensure admin exists - _cursor = _conn.execute("SELECT uuid FROM users WHERE username = ?", (DECNET_ADMIN_USER,)) - if not _cursor.fetchone(): - _conn.execute( - "INSERT INTO users (uuid, username, password_hash, role, must_change_password) VALUES (?, ?, ?, ?, ?)", - (str(uuid.uuid4()), DECNET_ADMIN_USER, get_password_hash(DECNET_ADMIN_PASSWORD), "admin", 1) - ) - _conn.execute("COMMIT") - except Exception: - _conn.execute("ROLLBACK") - raise - - async def initialize(self) -> None: - """Initialize the database schema and verify it exists.""" - # Schema already initialized in __init__ via _initialize_sync - # But we do a synchronous 'warm up' query here to ensure the file is ready for async threads - import sqlite3 - with sqlite3.connect(self.db_path) as _conn: - _conn.execute("SELECT count(*) FROM users") - _conn.execute("SELECT count(*) FROM logs") - _conn.execute("SELECT count(*) FROM bounty") - pass - - def reinitialize(self) -> None: - """Force a re-initialization of the schema (useful for tests).""" - self._initialize_sync() - - async def add_log(self, log_data: dict[str, Any]) -> None: - async with aiosqlite.connect(self.db_path) as _db: - _timestamp: Any = log_data.get("timestamp") - if _timestamp: - await _db.execute( - "INSERT INTO logs (timestamp, decky, service, event_type, attacker_ip, raw_line, fields, msg) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - ( - _timestamp, - log_data.get("decky"), - log_data.get("service"), - log_data.get("event_type"), - log_data.get("attacker_ip"), - log_data.get("raw_line"), - log_data.get("fields"), - log_data.get("msg") - ) - ) - else: - await _db.execute( - "INSERT INTO logs (decky, service, event_type, attacker_ip, raw_line, fields, msg) VALUES (?, ?, ?, ?, ?, ?, ?)", - ( - log_data.get("decky"), - log_data.get("service"), - log_data.get("event_type"), - log_data.get("attacker_ip"), - log_data.get("raw_line"), - log_data.get("fields"), - log_data.get("msg") - ) - ) - await _db.commit() - - def _build_where_clause( - self, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - base_where: Optional[str] = None, - base_params: Optional[list[Any]] = None - ) -> tuple[str, list[Any]]: - import shlex - import re - - where_clauses = [] - params = [] - - if base_where: - where_clauses.append(base_where) - if base_params: - params.extend(base_params) - - if start_time: - where_clauses.append("timestamp >= ?") - params.append(start_time) - if end_time: - where_clauses.append("timestamp <= ?") - params.append(end_time) - - if search: - try: - tokens = shlex.split(search) - except ValueError: - tokens = search.split(" ") - - core_fields = { - "decky": "decky", - "service": "service", - "event": "event_type", - "attacker": "attacker_ip", - "attacker-ip": "attacker_ip", - "attacker_ip": "attacker_ip" - } - - for token in tokens: - if ":" in token: - key, val = token.split(":", 1) - if key in core_fields: - where_clauses.append(f"{core_fields[key]} = ?") - params.append(val) - else: - key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key) - where_clauses.append(f"json_extract(fields, '$.{key_safe}') = ?") - params.append(val) - else: - where_clauses.append("(raw_line LIKE ? OR decky LIKE ? OR service LIKE ? OR attacker_ip LIKE ?)") - like_val = f"%{token}%" - params.extend([like_val, like_val, like_val, like_val]) - - if where_clauses: - return " WHERE " + " AND ".join(where_clauses), params - return "", [] - - async def get_logs( - self, - limit: int = 50, - offset: int = 0, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None - ) -> list[dict[str, Any]]: - _where, _params = self._build_where_clause(search, start_time, end_time) - _query = f"SELECT * FROM logs{_where} ORDER BY timestamp DESC LIMIT ? OFFSET ?" # nosec B608 - _params.extend([limit, offset]) - - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query, _params) as _cursor: - _rows: list[aiosqlite.Row] = await _cursor.fetchall() - return [dict(_row) for _row in _rows] - - async def get_max_log_id(self) -> int: - _query: str = "SELECT MAX(id) as max_id FROM logs" - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query) as _cursor: - _row: aiosqlite.Row | None = await _cursor.fetchone() - return _row["max_id"] if _row and _row["max_id"] is not None else 0 - - async def get_logs_after_id( - self, - last_id: int, - limit: int = 50, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None - ) -> list[dict[str, Any]]: - _where, _params = self._build_where_clause(search, start_time, end_time, base_where="id > ?", base_params=[last_id]) - _query = f"SELECT * FROM logs{_where} ORDER BY id ASC LIMIT ?" # nosec B608 - _params.append(limit) - - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query, _params) as _cursor: - _rows: list[aiosqlite.Row] = await _cursor.fetchall() - return [dict(_row) for _row in _rows] - - async def get_total_logs( - self, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None - ) -> int: - _where, _params = self._build_where_clause(search, start_time, end_time) - _query = f"SELECT COUNT(*) as total FROM logs{_where}" # nosec B608 - - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query, _params) as _cursor: - _row: Optional[aiosqlite.Row] = await _cursor.fetchone() - return _row["total"] if _row else 0 - - async def get_log_histogram( - self, - search: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - interval_minutes: int = 15 - ) -> list[dict[str, Any]]: - # Map interval to sqlite strftime modifiers - # Since SQLite doesn't have an easy "bucket by X minutes" natively, - # we can do it by grouping by (strftime('%s', timestamp) / (interval_minutes * 60)) - # and then multiplying back to get the bucket start time. - - _where, _params = self._build_where_clause(search, start_time, end_time) - - _query = f""" - SELECT - datetime((strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}, 'unixepoch') as bucket_time, - COUNT(*) as count - FROM logs - {_where} - GROUP BY bucket_time - ORDER BY bucket_time ASC - """ # nosec B608 - - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query, _params) as _cursor: - _rows: list[aiosqlite.Row] = await _cursor.fetchall() - return [{"time": _row["bucket_time"], "count": _row["count"]} for _row in _rows] - - async def get_stats_summary(self) -> dict[str, Any]: - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute("SELECT COUNT(*) as total_logs FROM logs") as _cursor: - _row: Optional[aiosqlite.Row] = await _cursor.fetchone() - _total_logs: int = _row["total_logs"] if _row else 0 - - async with _db.execute("SELECT COUNT(DISTINCT attacker_ip) as unique_attackers FROM logs") as _cursor: - _row = await _cursor.fetchone() - _unique_attackers: int = _row["unique_attackers"] if _row else 0 - - # Active deckies are those that HAVE interaction logs - async with _db.execute("SELECT COUNT(DISTINCT decky) as active_deckies FROM logs") as _cursor: - _row = await _cursor.fetchone() - _active_deckies: int = _row["active_deckies"] if _row else 0 - - # Deployed deckies are all those in the state file - _state = load_state() - _deployed_deckies: int = 0 - if _state: - _deployed_deckies = len(_state[0].deckies) - - return { - "total_logs": _total_logs, - "unique_attackers": _unique_attackers, - "active_deckies": _active_deckies, - "deployed_deckies": _deployed_deckies - } - - async def get_deckies(self) -> list[dict[str, Any]]: - _state = load_state() - if not _state: - return [] - - # We can also enrich this with interaction counts/last seen from DB - _deckies: list[dict[str, Any]] = [] - for _d in _state[0].deckies: - _deckies.append(_d.model_dump()) - - return _deckies - - async def get_user_by_username(self, username: str) -> Optional[dict[str, Any]]: - for _ in range(3): - try: - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute("SELECT * FROM users WHERE username = ?", (username,)) as _cursor: - _row = await _cursor.fetchone() - return dict(_row) if _row else None - except aiosqlite.OperationalError: - await asyncio.sleep(0.1) - return None - - async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute("SELECT * FROM users WHERE uuid = ?", (uuid,)) as _cursor: - _row: Optional[aiosqlite.Row] = await _cursor.fetchone() - return dict(_row) if _row else None - - async def create_user(self, user_data: dict[str, Any]) -> None: - async with aiosqlite.connect(self.db_path) as _db: - await _db.execute( - "INSERT INTO users (uuid, username, password_hash, role, must_change_password) VALUES (?, ?, ?, ?, ?)", - ( - user_data["uuid"], - user_data["username"], - user_data["password_hash"], - user_data["role"], - user_data.get("must_change_password", False) - ) - ) - await _db.commit() - - async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None: - async with aiosqlite.connect(self.db_path) as _db: - await _db.execute( - "UPDATE users SET password_hash = ?, must_change_password = ? WHERE uuid = ?", - (password_hash, must_change_password, uuid) - ) - await _db.commit() - - async def add_bounty(self, bounty_data: dict[str, Any]) -> None: - import json - async with aiosqlite.connect(self.db_path) as _db: - await _db.execute( - "INSERT INTO bounty (decky, service, attacker_ip, bounty_type, payload) VALUES (?, ?, ?, ?, ?)", - ( - bounty_data.get("decky"), - bounty_data.get("service"), - bounty_data.get("attacker_ip"), - bounty_data.get("bounty_type"), - json.dumps(bounty_data.get("payload", {})) - ) - ) - await _db.commit() - - def _build_bounty_where( - self, - bounty_type: Optional[str] = None, - search: Optional[str] = None - ) -> tuple[str, list[Any]]: - _where_clauses = [] - _params = [] - - if bounty_type: - _where_clauses.append("bounty_type = ?") - _params.append(bounty_type) - - if search: - _where_clauses.append("(decky LIKE ? OR service LIKE ? OR attacker_ip LIKE ? OR payload LIKE ?)") - _like_val = f"%{search}%" - _params.extend([_like_val, _like_val, _like_val, _like_val]) - - if _where_clauses: - return " WHERE " + " AND ".join(_where_clauses), _params - return "", [] - - async def get_bounties( - self, - limit: int = 50, - offset: int = 0, - bounty_type: Optional[str] = None, - search: Optional[str] = None - ) -> list[dict[str, Any]]: - import json - _where, _params = self._build_bounty_where(bounty_type, search) - _query = f"SELECT * FROM bounty{_where} ORDER BY timestamp DESC LIMIT ? OFFSET ?" # nosec B608 - _params.extend([limit, offset]) - - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query, _params) as _cursor: - _rows: list[aiosqlite.Row] = await _cursor.fetchall() - _results = [] - for _row in _rows: - _d = dict(_row) - try: - _d["payload"] = json.loads(_d["payload"]) - except Exception: # nosec B110 - pass - _results.append(_d) - return _results - - async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int: - _where, _params = self._build_bounty_where(bounty_type, search) - _query = f"SELECT COUNT(*) as total FROM bounty{_where}" # nosec B608 - - async with aiosqlite.connect(self.db_path) as _db: - _db.row_factory = aiosqlite.Row - async with _db.execute(_query, _params) as _cursor: - _row: Optional[aiosqlite.Row] = await _cursor.fetchone() - return _row["total"] if _row else 0