diff --git a/decnet/web/db/models.py b/decnet/web/db/models.py index 74a4c7d..4a758d1 100644 --- a/decnet/web/db/models.py +++ b/decnet/web/db/models.py @@ -1,6 +1,6 @@ -from datetime import datetime +from datetime import datetime, timezone from typing import Optional, Any, List -from sqlmodel import SQLModel, Field, Column, JSON +from sqlmodel import SQLModel, Field from pydantic import BaseModel, Field as PydanticField # --- Database Tables (SQLModel) --- @@ -16,7 +16,7 @@ class User(SQLModel, table=True): class Log(SQLModel, table=True): __tablename__ = "logs" id: Optional[int] = Field(default=None, primary_key=True) - timestamp: datetime = Field(default_factory=datetime.utcnow, index=True) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True) decky: str = Field(index=True) service: str = Field(index=True) event_type: str = Field(index=True) @@ -28,7 +28,7 @@ class Log(SQLModel, table=True): class Bounty(SQLModel, table=True): __tablename__ = "bounty" id: Optional[int] = Field(default=None, primary_key=True) - timestamp: datetime = Field(default_factory=datetime.utcnow, index=True) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True) decky: str = Field(index=True) service: str = Field(index=True) attacker_ip: str = Field(index=True) diff --git a/decnet/web/db/sqlite/repository.py b/decnet/web/db/sqlite/repository.py index e457dde..f2a7b79 100644 --- a/decnet/web/db/sqlite/repository.py +++ b/decnet/web/db/sqlite/repository.py @@ -4,9 +4,8 @@ import uuid from datetime import datetime from typing import Any, Optional, List -from sqlalchemy import func, select, desc, asc, text, or_, update +from sqlalchemy import func, select, desc, asc, text, or_, update, literal_column from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from sqlmodel import col from decnet.config import load_state, _ROOT from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD @@ -30,60 +29,82 @@ class SQLiteRepository(BaseRepository): def _initialize_sync(self) -> None: """Initialize the database schema synchronously.""" init_db(self.db_path) - - # Ensure default admin exists via sync SQLAlchemy connection - from sqlalchemy import select + from decnet.web.db.sqlite.database import get_sync_engine engine = get_sync_engine(self.db_path) with engine.connect() as conn: - # Ensure admin exists via sync SQLAlchemy connection - from sqlalchemy import text - result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER}) + result = conn.execute( + text("SELECT uuid FROM users WHERE username = :u"), + {"u": DECNET_ADMIN_USER}, + ) if not result.fetchone(): - print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'") conn.execute( - text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) " - "VALUES (:uuid, :u, :p, :r, :m)"), + text( + "INSERT INTO users (uuid, username, password_hash, role, must_change_password) " + "VALUES (:uuid, :u, :p, :r, :m)" + ), { "uuid": str(uuid.uuid4()), "u": DECNET_ADMIN_USER, "p": get_password_hash(DECNET_ADMIN_PASSWORD), "r": "admin", - "m": 1 - } + "m": 1, + }, ) conn.commit() - else: - print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists") async def initialize(self) -> None: """Async warm-up / verification.""" async with self.session_factory() as session: - await session.exec(text("SELECT 1")) + await session.execute(text("SELECT 1")) - def reinitialize(self) -> None: - self._initialize_sync() + async def reinitialize(self) -> None: + """Initialize the database schema asynchronously (useful for tests).""" + from sqlmodel import SQLModel + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + async with self.session_factory() as session: + result = await session.execute( + select(User).where(User.username == DECNET_ADMIN_USER) + ) + if not result.scalar_one_or_none(): + session.add(User( + uuid=str(uuid.uuid4()), + username=DECNET_ADMIN_USER, + password_hash=get_password_hash(DECNET_ADMIN_PASSWORD), + role="admin", + must_change_password=True, + )) + await session.commit() + + # ------------------------------------------------------------------ logs async def add_log(self, log_data: dict[str, Any]) -> None: - # Convert dict to model data = log_data.copy() if "fields" in data and isinstance(data["fields"], dict): data["fields"] = json.dumps(data["fields"]) - if "timestamp" in data and isinstance(data["timestamp"], str): try: - data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00')) + data["timestamp"] = datetime.fromisoformat( + data["timestamp"].replace("Z", "+00:00") + ) except ValueError: pass - log = Log(**data) async with self.session_factory() as session: - session.add(log) + session.add(Log(**data)) await session.commit() - def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]): - import shlex + def _apply_filters( + self, + statement, + search: Optional[str], + start_time: Optional[str], + end_time: Optional[str], + ): import re + import shlex if start_time: statement = statement.where(Log.timestamp >= start_time) @@ -94,7 +115,7 @@ class SQLiteRepository(BaseRepository): try: tokens = shlex.split(search) except ValueError: - tokens = search.split(" ") + tokens = search.split() core_fields = { "decky": Log.decky, @@ -102,7 +123,7 @@ class SQLiteRepository(BaseRepository): "event": Log.event_type, "attacker": Log.attacker_ip, "attacker-ip": Log.attacker_ip, - "attacker_ip": Log.attacker_ip + "attacker_ip": Log.attacker_ip, } for token in tokens: @@ -111,9 +132,10 @@ class SQLiteRepository(BaseRepository): if key in core_fields: statement = statement.where(core_fields[key] == val) else: - key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key) - # SQLite json_extract via text() - statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val) + key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key) + statement = statement.where( + text(f"json_extract(fields, '$.{key_safe}') = :val") + ).params(val=val) else: lk = f"%{token}%" statement = statement.where( @@ -121,25 +143,30 @@ class SQLiteRepository(BaseRepository): Log.raw_line.like(lk), Log.decky.like(lk), Log.service.like(lk), - Log.attacker_ip.like(lk) + Log.attacker_ip.like(lk), ) ) return statement async def get_logs( - self, - limit: int = 50, - offset: int = 0, + self, + limit: int = 50, + offset: int = 0, search: Optional[str] = None, start_time: Optional[str] = None, - end_time: Optional[str] = None + end_time: Optional[str] = None, ) -> List[dict]: - statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit) + statement = ( + select(Log) + .order_by(desc(Log.timestamp)) + .offset(offset) + .limit(limit) + ) statement = self._apply_filters(statement, search, start_time, end_time) async with self.session_factory() as session: results = await session.execute(statement) - return [log.dict() for log in results.scalars().all()] + return [log.model_dump() for log in results.scalars().all()] async def get_max_log_id(self) -> int: async with self.session_factory() as session: @@ -148,25 +175,27 @@ class SQLiteRepository(BaseRepository): return val if val is not None else 0 async def get_logs_after_id( - self, - last_id: int, - limit: int = 50, + self, + last_id: int, + limit: int = 50, search: Optional[str] = None, start_time: Optional[str] = None, - end_time: Optional[str] = None + end_time: Optional[str] = None, ) -> List[dict]: - statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) + statement = ( + select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) + ) statement = self._apply_filters(statement, search, start_time, end_time) async with self.session_factory() as session: results = await session.execute(statement) - return [log.dict() for log in results.scalars().all()] + return [log.model_dump() for log in results.scalars().all()] async def get_total_logs( - self, + self, search: Optional[str] = None, start_time: Optional[str] = None, - end_time: Optional[str] = None + end_time: Optional[str] = None, ) -> int: statement = select(func.count()).select_from(Log) statement = self._apply_filters(statement, search, start_time, end_time) @@ -180,130 +209,99 @@ class SQLiteRepository(BaseRepository): search: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, - interval_minutes: int = 15 + interval_minutes: int = 15, ) -> List[dict]: - # raw SQL for time bucketing as it is very engine specific - _where_stmt = select(Log) - _where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time) - - # Extract WHERE clause from compiled statement - # For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query - # but bind parameters from the filtered statement - - # SQLite specific bucket logic - bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}" - - # We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible - # Or just build the query fully - - where_clause, params = self._build_where_clause_legacy(search, start_time, end_time) - - query = f""" - SELECT - datetime({bucket_expr}, 'unixepoch') as bucket_time, - COUNT(*) as count - FROM logs - {where_clause} - GROUP BY bucket_time - ORDER BY bucket_time ASC - """ - - async with self.session_factory() as session: - results = await session.execute(text(query), params) - return [{"time": r[0], "count": r[1]} for r in results.all()] + bucket_seconds = interval_minutes * 60 + bucket_expr = literal_column( + f"datetime((strftime('%s', timestamp) / {bucket_seconds}) * {bucket_seconds}, 'unixepoch')" + ).label("bucket_time") - def _build_where_clause_legacy(self, search, start_time, end_time): - # Re-using the logic from the previous iteration for the raw query part - import shlex - import re - where_clauses = [] - params = {} - - if start_time: - where_clauses.append("timestamp >= :start_time") - params["start_time"] = start_time - if end_time: - where_clauses.append("timestamp <= :end_time") - params["end_time"] = end_time - - if search: - try: tokens = shlex.split(search) - except: tokens = search.split(" ") - core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"} - for i, token in enumerate(tokens): - if ":" in token: - k, v = token.split(":", 1) - if k in core_fields: - where_clauses.append(f"{core_fields[k]} = :val_{i}") - params[f"val_{i}"] = v - else: - ks = re.sub(r'[^a-zA-Z0-9_]', '', k) - where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}") - params[f"val_{i}"] = v - else: - where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})") - params[f"lk_{i}"] = f"%{token}%" - - where = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" - return where, params + statement = select(bucket_expr, func.count().label("count")).select_from(Log) + statement = self._apply_filters(statement, search, start_time, end_time) + statement = statement.group_by(literal_column("bucket_time")).order_by( + literal_column("bucket_time") + ) + + async with self.session_factory() as session: + results = await session.execute(statement) + return [{"time": r[0], "count": r[1]} for r in results.all()] async def get_stats_summary(self) -> dict[str, Any]: async with self.session_factory() as session: - total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0 - unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0 - active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0 - - _state = load_state() - deployed_deckies = len(_state[0].deckies) if _state else 0 + total_logs = ( + await session.execute(select(func.count()).select_from(Log)) + ).scalar() or 0 + unique_attackers = ( + await session.execute( + select(func.count(func.distinct(Log.attacker_ip))) + ) + ).scalar() or 0 + active_deckies = ( + await session.execute( + select(func.count(func.distinct(Log.decky))) + ) + ).scalar() or 0 - return { - "total_logs": total_logs, - "unique_attackers": unique_attackers, - "active_deckies": active_deckies, - "deployed_deckies": deployed_deckies - } + _state = await asyncio.to_thread(load_state) + deployed_deckies = len(_state[0].deckies) if _state else 0 + + return { + "total_logs": total_logs, + "unique_attackers": unique_attackers, + "active_deckies": active_deckies, + "deployed_deckies": deployed_deckies, + } async def get_deckies(self) -> List[dict]: - _state = load_state() + _state = await asyncio.to_thread(load_state) return [_d.model_dump() for _d in _state[0].deckies] if _state else [] + # ------------------------------------------------------------------ users + async def get_user_by_username(self, username: str) -> Optional[dict]: async with self.session_factory() as session: - statement = select(User).where(User.username == username) - results = await session.execute(statement) - user = results.scalar_one_or_none() - return user.dict() if user else None + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + return user.model_dump() if user else None async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: async with self.session_factory() as session: - statement = select(User).where(User.uuid == uuid) - results = await session.execute(statement) - user = results.scalar_one_or_none() - return user.dict() if user else None + result = await session.execute( + select(User).where(User.uuid == uuid) + ) + user = result.scalar_one_or_none() + return user.model_dump() if user else None async def create_user(self, user_data: dict[str, Any]) -> None: - user = User(**user_data) async with self.session_factory() as session: - session.add(user) + session.add(User(**user_data)) await session.commit() - async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None: + async def update_user_password( + self, uuid: str, password_hash: str, must_change_password: bool = False + ) -> None: async with self.session_factory() as session: - statement = update(User).where(User.uuid == uuid).values( - password_hash=password_hash, - must_change_password=must_change_password + await session.execute( + update(User) + .where(User.uuid == uuid) + .values( + password_hash=password_hash, + must_change_password=must_change_password, + ) ) - await session.execute(statement) await session.commit() + # ---------------------------------------------------------------- bounties + async def add_bounty(self, bounty_data: dict[str, Any]) -> None: data = bounty_data.copy() if "payload" in data and isinstance(data["payload"], dict): data["payload"] = json.dumps(data["payload"]) - - bounty = Bounty(**data) + async with self.session_factory() as session: - session.add(bounty) + session.add(Bounty(**data)) await session.commit() def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]): @@ -316,33 +314,41 @@ class SQLiteRepository(BaseRepository): Bounty.decky.like(lk), Bounty.service.like(lk), Bounty.attacker_ip.like(lk), - Bounty.payload.like(lk) + Bounty.payload.like(lk), ) ) return statement async def get_bounties( - self, - limit: int = 50, - offset: int = 0, + self, + limit: int = 50, + offset: int = 0, bounty_type: Optional[str] = None, - search: Optional[str] = None + search: Optional[str] = None, ) -> List[dict]: - statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit) + statement = ( + select(Bounty) + .order_by(desc(Bounty.timestamp)) + .offset(offset) + .limit(limit) + ) statement = self._apply_bounty_filters(statement, bounty_type, search) async with self.session_factory() as session: results = await session.execute(statement) - items = results.scalars().all() final = [] - for item in items: - d = item.dict() - try: d["payload"] = json.loads(d["payload"]) - except: pass + for item in results.scalars().all(): + d = item.model_dump() + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass final.append(d) return final - async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int: + async def get_total_bounties( + self, bounty_type: Optional[str] = None, search: Optional[str] = None + ) -> int: statement = select(func.count()).select_from(Bounty) statement = self._apply_bounty_filters(statement, bounty_type, search)