fix: clean up db layer — model_dump, timezone-aware timestamps, unified histogram, async load_state

This commit is contained in:
2026-04-09 18:46:35 -04:00
parent dbf6d13b95
commit 0166d0d559
2 changed files with 162 additions and 156 deletions

View File

@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime, timezone
from typing import Optional, Any, List from typing import Optional, Any, List
from sqlmodel import SQLModel, Field, Column, JSON from sqlmodel import SQLModel, Field
from pydantic import BaseModel, Field as PydanticField from pydantic import BaseModel, Field as PydanticField
# --- Database Tables (SQLModel) --- # --- Database Tables (SQLModel) ---
@@ -16,7 +16,7 @@ class User(SQLModel, table=True):
class Log(SQLModel, table=True): class Log(SQLModel, table=True):
__tablename__ = "logs" __tablename__ = "logs"
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
decky: str = Field(index=True) decky: str = Field(index=True)
service: str = Field(index=True) service: str = Field(index=True)
event_type: str = Field(index=True) event_type: str = Field(index=True)
@@ -28,7 +28,7 @@ class Log(SQLModel, table=True):
class Bounty(SQLModel, table=True): class Bounty(SQLModel, table=True):
__tablename__ = "bounty" __tablename__ = "bounty"
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True) timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
decky: str = Field(index=True) decky: str = Field(index=True)
service: str = Field(index=True) service: str = Field(index=True)
attacker_ip: str = Field(index=True) attacker_ip: str = Field(index=True)

View File

@@ -4,9 +4,8 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Optional, List from typing import Any, Optional, List
from sqlalchemy import func, select, desc, asc, text, or_, update from sqlalchemy import func, select, desc, asc, text, or_, update, literal_column
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlmodel import col
from decnet.config import load_state, _ROOT from decnet.config import load_state, _ROOT
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
@@ -30,60 +29,82 @@ class SQLiteRepository(BaseRepository):
def _initialize_sync(self) -> None: def _initialize_sync(self) -> None:
"""Initialize the database schema synchronously.""" """Initialize the database schema synchronously."""
init_db(self.db_path) init_db(self.db_path)
# Ensure default admin exists via sync SQLAlchemy connection
from sqlalchemy import select
from decnet.web.db.sqlite.database import get_sync_engine from decnet.web.db.sqlite.database import get_sync_engine
engine = get_sync_engine(self.db_path) engine = get_sync_engine(self.db_path)
with engine.connect() as conn: with engine.connect() as conn:
# Ensure admin exists via sync SQLAlchemy connection result = conn.execute(
from sqlalchemy import text text("SELECT uuid FROM users WHERE username = :u"),
result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER}) {"u": DECNET_ADMIN_USER},
)
if not result.fetchone(): if not result.fetchone():
print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'")
conn.execute( conn.execute(
text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) " text(
"VALUES (:uuid, :u, :p, :r, :m)"), "INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
"VALUES (:uuid, :u, :p, :r, :m)"
),
{ {
"uuid": str(uuid.uuid4()), "uuid": str(uuid.uuid4()),
"u": DECNET_ADMIN_USER, "u": DECNET_ADMIN_USER,
"p": get_password_hash(DECNET_ADMIN_PASSWORD), "p": get_password_hash(DECNET_ADMIN_PASSWORD),
"r": "admin", "r": "admin",
"m": 1 "m": 1,
} },
) )
conn.commit() conn.commit()
else:
print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists")
async def initialize(self) -> None: async def initialize(self) -> None:
"""Async warm-up / verification.""" """Async warm-up / verification."""
async with self.session_factory() as session: async with self.session_factory() as session:
await session.exec(text("SELECT 1")) await session.execute(text("SELECT 1"))
def reinitialize(self) -> None: async def reinitialize(self) -> None:
self._initialize_sync() """Initialize the database schema asynchronously (useful for tests)."""
from sqlmodel import SQLModel
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with self.session_factory() as session:
result = await session.execute(
select(User).where(User.username == DECNET_ADMIN_USER)
)
if not result.scalar_one_or_none():
session.add(User(
uuid=str(uuid.uuid4()),
username=DECNET_ADMIN_USER,
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
role="admin",
must_change_password=True,
))
await session.commit()
# ------------------------------------------------------------------ logs
async def add_log(self, log_data: dict[str, Any]) -> None: async def add_log(self, log_data: dict[str, Any]) -> None:
# Convert dict to model
data = log_data.copy() data = log_data.copy()
if "fields" in data and isinstance(data["fields"], dict): if "fields" in data and isinstance(data["fields"], dict):
data["fields"] = json.dumps(data["fields"]) data["fields"] = json.dumps(data["fields"])
if "timestamp" in data and isinstance(data["timestamp"], str): if "timestamp" in data and isinstance(data["timestamp"], str):
try: try:
data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00')) data["timestamp"] = datetime.fromisoformat(
data["timestamp"].replace("Z", "+00:00")
)
except ValueError: except ValueError:
pass pass
log = Log(**data)
async with self.session_factory() as session: async with self.session_factory() as session:
session.add(log) session.add(Log(**data))
await session.commit() await session.commit()
def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]): def _apply_filters(
import shlex self,
statement,
search: Optional[str],
start_time: Optional[str],
end_time: Optional[str],
):
import re import re
import shlex
if start_time: if start_time:
statement = statement.where(Log.timestamp >= start_time) statement = statement.where(Log.timestamp >= start_time)
@@ -94,7 +115,7 @@ class SQLiteRepository(BaseRepository):
try: try:
tokens = shlex.split(search) tokens = shlex.split(search)
except ValueError: except ValueError:
tokens = search.split(" ") tokens = search.split()
core_fields = { core_fields = {
"decky": Log.decky, "decky": Log.decky,
@@ -102,7 +123,7 @@ class SQLiteRepository(BaseRepository):
"event": Log.event_type, "event": Log.event_type,
"attacker": Log.attacker_ip, "attacker": Log.attacker_ip,
"attacker-ip": Log.attacker_ip, "attacker-ip": Log.attacker_ip,
"attacker_ip": Log.attacker_ip "attacker_ip": Log.attacker_ip,
} }
for token in tokens: for token in tokens:
@@ -111,9 +132,10 @@ class SQLiteRepository(BaseRepository):
if key in core_fields: if key in core_fields:
statement = statement.where(core_fields[key] == val) statement = statement.where(core_fields[key] == val)
else: else:
key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key) key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key)
# SQLite json_extract via text() statement = statement.where(
statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val) text(f"json_extract(fields, '$.{key_safe}') = :val")
).params(val=val)
else: else:
lk = f"%{token}%" lk = f"%{token}%"
statement = statement.where( statement = statement.where(
@@ -121,25 +143,30 @@ class SQLiteRepository(BaseRepository):
Log.raw_line.like(lk), Log.raw_line.like(lk),
Log.decky.like(lk), Log.decky.like(lk),
Log.service.like(lk), Log.service.like(lk),
Log.attacker_ip.like(lk) Log.attacker_ip.like(lk),
) )
) )
return statement return statement
async def get_logs( async def get_logs(
self, self,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
search: Optional[str] = None, search: Optional[str] = None,
start_time: Optional[str] = None, start_time: Optional[str] = None,
end_time: Optional[str] = None end_time: Optional[str] = None,
) -> List[dict]: ) -> List[dict]:
statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit) statement = (
select(Log)
.order_by(desc(Log.timestamp))
.offset(offset)
.limit(limit)
)
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session: async with self.session_factory() as session:
results = await session.execute(statement) results = await session.execute(statement)
return [log.dict() for log in results.scalars().all()] return [log.model_dump() for log in results.scalars().all()]
async def get_max_log_id(self) -> int: async def get_max_log_id(self) -> int:
async with self.session_factory() as session: async with self.session_factory() as session:
@@ -148,25 +175,27 @@ class SQLiteRepository(BaseRepository):
return val if val is not None else 0 return val if val is not None else 0
async def get_logs_after_id( async def get_logs_after_id(
self, self,
last_id: int, last_id: int,
limit: int = 50, limit: int = 50,
search: Optional[str] = None, search: Optional[str] = None,
start_time: Optional[str] = None, start_time: Optional[str] = None,
end_time: Optional[str] = None end_time: Optional[str] = None,
) -> List[dict]: ) -> List[dict]:
statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) statement = (
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
)
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session: async with self.session_factory() as session:
results = await session.execute(statement) results = await session.execute(statement)
return [log.dict() for log in results.scalars().all()] return [log.model_dump() for log in results.scalars().all()]
async def get_total_logs( async def get_total_logs(
self, self,
search: Optional[str] = None, search: Optional[str] = None,
start_time: Optional[str] = None, start_time: Optional[str] = None,
end_time: Optional[str] = None end_time: Optional[str] = None,
) -> int: ) -> int:
statement = select(func.count()).select_from(Log) statement = select(func.count()).select_from(Log)
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)
@@ -180,130 +209,99 @@ class SQLiteRepository(BaseRepository):
search: Optional[str] = None, search: Optional[str] = None,
start_time: Optional[str] = None, start_time: Optional[str] = None,
end_time: Optional[str] = None, end_time: Optional[str] = None,
interval_minutes: int = 15 interval_minutes: int = 15,
) -> List[dict]: ) -> List[dict]:
# raw SQL for time bucketing as it is very engine specific bucket_seconds = interval_minutes * 60
_where_stmt = select(Log) bucket_expr = literal_column(
_where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time) f"datetime((strftime('%s', timestamp) / {bucket_seconds}) * {bucket_seconds}, 'unixepoch')"
).label("bucket_time")
# Extract WHERE clause from compiled statement
# For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query
# but bind parameters from the filtered statement
# SQLite specific bucket logic
bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}"
# We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible
# Or just build the query fully
where_clause, params = self._build_where_clause_legacy(search, start_time, end_time)
query = f"""
SELECT
datetime({bucket_expr}, 'unixepoch') as bucket_time,
COUNT(*) as count
FROM logs
{where_clause}
GROUP BY bucket_time
ORDER BY bucket_time ASC
"""
async with self.session_factory() as session:
results = await session.execute(text(query), params)
return [{"time": r[0], "count": r[1]} for r in results.all()]
def _build_where_clause_legacy(self, search, start_time, end_time): statement = select(bucket_expr, func.count().label("count")).select_from(Log)
# Re-using the logic from the previous iteration for the raw query part statement = self._apply_filters(statement, search, start_time, end_time)
import shlex statement = statement.group_by(literal_column("bucket_time")).order_by(
import re literal_column("bucket_time")
where_clauses = [] )
params = {}
async with self.session_factory() as session:
if start_time: results = await session.execute(statement)
where_clauses.append("timestamp >= :start_time") return [{"time": r[0], "count": r[1]} for r in results.all()]
params["start_time"] = start_time
if end_time:
where_clauses.append("timestamp <= :end_time")
params["end_time"] = end_time
if search:
try: tokens = shlex.split(search)
except: tokens = search.split(" ")
core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"}
for i, token in enumerate(tokens):
if ":" in token:
k, v = token.split(":", 1)
if k in core_fields:
where_clauses.append(f"{core_fields[k]} = :val_{i}")
params[f"val_{i}"] = v
else:
ks = re.sub(r'[^a-zA-Z0-9_]', '', k)
where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}")
params[f"val_{i}"] = v
else:
where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})")
params[f"lk_{i}"] = f"%{token}%"
where = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
return where, params
async def get_stats_summary(self) -> dict[str, Any]: async def get_stats_summary(self) -> dict[str, Any]:
async with self.session_factory() as session: async with self.session_factory() as session:
total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0 total_logs = (
unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0 await session.execute(select(func.count()).select_from(Log))
active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0 ).scalar() or 0
unique_attackers = (
_state = load_state() await session.execute(
deployed_deckies = len(_state[0].deckies) if _state else 0 select(func.count(func.distinct(Log.attacker_ip)))
)
).scalar() or 0
active_deckies = (
await session.execute(
select(func.count(func.distinct(Log.decky)))
)
).scalar() or 0
return { _state = await asyncio.to_thread(load_state)
"total_logs": total_logs, deployed_deckies = len(_state[0].deckies) if _state else 0
"unique_attackers": unique_attackers,
"active_deckies": active_deckies, return {
"deployed_deckies": deployed_deckies "total_logs": total_logs,
} "unique_attackers": unique_attackers,
"active_deckies": active_deckies,
"deployed_deckies": deployed_deckies,
}
async def get_deckies(self) -> List[dict]: async def get_deckies(self) -> List[dict]:
_state = load_state() _state = await asyncio.to_thread(load_state)
return [_d.model_dump() for _d in _state[0].deckies] if _state else [] return [_d.model_dump() for _d in _state[0].deckies] if _state else []
# ------------------------------------------------------------------ users
async def get_user_by_username(self, username: str) -> Optional[dict]: async def get_user_by_username(self, username: str) -> Optional[dict]:
async with self.session_factory() as session: async with self.session_factory() as session:
statement = select(User).where(User.username == username) result = await session.execute(
results = await session.execute(statement) select(User).where(User.username == username)
user = results.scalar_one_or_none() )
return user.dict() if user else None user = result.scalar_one_or_none()
return user.model_dump() if user else None
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
async with self.session_factory() as session: async with self.session_factory() as session:
statement = select(User).where(User.uuid == uuid) result = await session.execute(
results = await session.execute(statement) select(User).where(User.uuid == uuid)
user = results.scalar_one_or_none() )
return user.dict() if user else None user = result.scalar_one_or_none()
return user.model_dump() if user else None
async def create_user(self, user_data: dict[str, Any]) -> None: async def create_user(self, user_data: dict[str, Any]) -> None:
user = User(**user_data)
async with self.session_factory() as session: async with self.session_factory() as session:
session.add(user) session.add(User(**user_data))
await session.commit() await session.commit()
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None: async def update_user_password(
self, uuid: str, password_hash: str, must_change_password: bool = False
) -> None:
async with self.session_factory() as session: async with self.session_factory() as session:
statement = update(User).where(User.uuid == uuid).values( await session.execute(
password_hash=password_hash, update(User)
must_change_password=must_change_password .where(User.uuid == uuid)
.values(
password_hash=password_hash,
must_change_password=must_change_password,
)
) )
await session.execute(statement)
await session.commit() await session.commit()
# ---------------------------------------------------------------- bounties
async def add_bounty(self, bounty_data: dict[str, Any]) -> None: async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
data = bounty_data.copy() data = bounty_data.copy()
if "payload" in data and isinstance(data["payload"], dict): if "payload" in data and isinstance(data["payload"], dict):
data["payload"] = json.dumps(data["payload"]) data["payload"] = json.dumps(data["payload"])
bounty = Bounty(**data)
async with self.session_factory() as session: async with self.session_factory() as session:
session.add(bounty) session.add(Bounty(**data))
await session.commit() await session.commit()
def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]): def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]):
@@ -316,33 +314,41 @@ class SQLiteRepository(BaseRepository):
Bounty.decky.like(lk), Bounty.decky.like(lk),
Bounty.service.like(lk), Bounty.service.like(lk),
Bounty.attacker_ip.like(lk), Bounty.attacker_ip.like(lk),
Bounty.payload.like(lk) Bounty.payload.like(lk),
) )
) )
return statement return statement
async def get_bounties( async def get_bounties(
self, self,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
bounty_type: Optional[str] = None, bounty_type: Optional[str] = None,
search: Optional[str] = None search: Optional[str] = None,
) -> List[dict]: ) -> List[dict]:
statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit) statement = (
select(Bounty)
.order_by(desc(Bounty.timestamp))
.offset(offset)
.limit(limit)
)
statement = self._apply_bounty_filters(statement, bounty_type, search) statement = self._apply_bounty_filters(statement, bounty_type, search)
async with self.session_factory() as session: async with self.session_factory() as session:
results = await session.execute(statement) results = await session.execute(statement)
items = results.scalars().all()
final = [] final = []
for item in items: for item in results.scalars().all():
d = item.dict() d = item.model_dump()
try: d["payload"] = json.loads(d["payload"]) try:
except: pass d["payload"] = json.loads(d["payload"])
except (json.JSONDecodeError, TypeError):
pass
final.append(d) final.append(d)
return final return final
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int: async def get_total_bounties(
self, bounty_type: Optional[str] = None, search: Optional[str] = None
) -> int:
statement = select(func.count()).select_from(Bounty) statement = select(func.count()).select_from(Bounty)
statement = self._apply_bounty_filters(statement, bounty_type, search) statement = self._apply_bounty_filters(statement, bounty_type, search)