fix: clean up db layer — model_dump, timezone-aware timestamps, unified histogram, async load_state

This commit is contained in:
2026-04-09 18:46:35 -04:00
parent dbf6d13b95
commit 0166d0d559
2 changed files with 162 additions and 156 deletions

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional, Any, List
from sqlmodel import SQLModel, Field, Column, JSON
from sqlmodel import SQLModel, Field
from pydantic import BaseModel, Field as PydanticField
# --- Database Tables (SQLModel) ---
@@ -16,7 +16,7 @@ class User(SQLModel, table=True):
class Log(SQLModel, table=True):
__tablename__ = "logs"
id: Optional[int] = Field(default=None, primary_key=True)
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
decky: str = Field(index=True)
service: str = Field(index=True)
event_type: str = Field(index=True)
@@ -28,7 +28,7 @@ class Log(SQLModel, table=True):
class Bounty(SQLModel, table=True):
__tablename__ = "bounty"
id: Optional[int] = Field(default=None, primary_key=True)
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
decky: str = Field(index=True)
service: str = Field(index=True)
attacker_ip: str = Field(index=True)

View File

@@ -4,9 +4,8 @@ import uuid
from datetime import datetime
from typing import Any, Optional, List
from sqlalchemy import func, select, desc, asc, text, or_, update
from sqlalchemy import func, select, desc, asc, text, or_, update, literal_column
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlmodel import col
from decnet.config import load_state, _ROOT
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
@@ -30,60 +29,82 @@ class SQLiteRepository(BaseRepository):
def _initialize_sync(self) -> None:
"""Initialize the database schema synchronously."""
init_db(self.db_path)
# Ensure default admin exists via sync SQLAlchemy connection
from sqlalchemy import select
from decnet.web.db.sqlite.database import get_sync_engine
engine = get_sync_engine(self.db_path)
with engine.connect() as conn:
# Ensure admin exists via sync SQLAlchemy connection
from sqlalchemy import text
result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER})
result = conn.execute(
text("SELECT uuid FROM users WHERE username = :u"),
{"u": DECNET_ADMIN_USER},
)
if not result.fetchone():
print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'")
conn.execute(
text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
"VALUES (:uuid, :u, :p, :r, :m)"),
text(
"INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
"VALUES (:uuid, :u, :p, :r, :m)"
),
{
"uuid": str(uuid.uuid4()),
"u": DECNET_ADMIN_USER,
"p": get_password_hash(DECNET_ADMIN_PASSWORD),
"r": "admin",
"m": 1
}
"m": 1,
},
)
conn.commit()
else:
print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists")
async def initialize(self) -> None:
"""Async warm-up / verification."""
async with self.session_factory() as session:
await session.exec(text("SELECT 1"))
await session.execute(text("SELECT 1"))
def reinitialize(self) -> None:
self._initialize_sync()
async def reinitialize(self) -> None:
"""Initialize the database schema asynchronously (useful for tests)."""
from sqlmodel import SQLModel
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with self.session_factory() as session:
result = await session.execute(
select(User).where(User.username == DECNET_ADMIN_USER)
)
if not result.scalar_one_or_none():
session.add(User(
uuid=str(uuid.uuid4()),
username=DECNET_ADMIN_USER,
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
role="admin",
must_change_password=True,
))
await session.commit()
# ------------------------------------------------------------------ logs
async def add_log(self, log_data: dict[str, Any]) -> None:
# Convert dict to model
data = log_data.copy()
if "fields" in data and isinstance(data["fields"], dict):
data["fields"] = json.dumps(data["fields"])
if "timestamp" in data and isinstance(data["timestamp"], str):
try:
data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00'))
data["timestamp"] = datetime.fromisoformat(
data["timestamp"].replace("Z", "+00:00")
)
except ValueError:
pass
log = Log(**data)
async with self.session_factory() as session:
session.add(log)
session.add(Log(**data))
await session.commit()
def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]):
import shlex
def _apply_filters(
self,
statement,
search: Optional[str],
start_time: Optional[str],
end_time: Optional[str],
):
import re
import shlex
if start_time:
statement = statement.where(Log.timestamp >= start_time)
@@ -94,7 +115,7 @@ class SQLiteRepository(BaseRepository):
try:
tokens = shlex.split(search)
except ValueError:
tokens = search.split(" ")
tokens = search.split()
core_fields = {
"decky": Log.decky,
@@ -102,7 +123,7 @@ class SQLiteRepository(BaseRepository):
"event": Log.event_type,
"attacker": Log.attacker_ip,
"attacker-ip": Log.attacker_ip,
"attacker_ip": Log.attacker_ip
"attacker_ip": Log.attacker_ip,
}
for token in tokens:
@@ -111,9 +132,10 @@ class SQLiteRepository(BaseRepository):
if key in core_fields:
statement = statement.where(core_fields[key] == val)
else:
key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key)
# SQLite json_extract via text()
statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val)
key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key)
statement = statement.where(
text(f"json_extract(fields, '$.{key_safe}') = :val")
).params(val=val)
else:
lk = f"%{token}%"
statement = statement.where(
@@ -121,25 +143,30 @@ class SQLiteRepository(BaseRepository):
Log.raw_line.like(lk),
Log.decky.like(lk),
Log.service.like(lk),
Log.attacker_ip.like(lk)
Log.attacker_ip.like(lk),
)
)
return statement
async def get_logs(
self,
limit: int = 50,
offset: int = 0,
self,
limit: int = 50,
offset: int = 0,
search: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None
end_time: Optional[str] = None,
) -> List[dict]:
statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit)
statement = (
select(Log)
.order_by(desc(Log.timestamp))
.offset(offset)
.limit(limit)
)
statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session:
results = await session.execute(statement)
return [log.dict() for log in results.scalars().all()]
return [log.model_dump() for log in results.scalars().all()]
async def get_max_log_id(self) -> int:
async with self.session_factory() as session:
@@ -148,25 +175,27 @@ class SQLiteRepository(BaseRepository):
return val if val is not None else 0
async def get_logs_after_id(
self,
last_id: int,
limit: int = 50,
self,
last_id: int,
limit: int = 50,
search: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None
end_time: Optional[str] = None,
) -> List[dict]:
statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
statement = (
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
)
statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session:
results = await session.execute(statement)
return [log.dict() for log in results.scalars().all()]
return [log.model_dump() for log in results.scalars().all()]
async def get_total_logs(
self,
self,
search: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None
end_time: Optional[str] = None,
) -> int:
statement = select(func.count()).select_from(Log)
statement = self._apply_filters(statement, search, start_time, end_time)
@@ -180,130 +209,99 @@ class SQLiteRepository(BaseRepository):
search: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
interval_minutes: int = 15
interval_minutes: int = 15,
) -> List[dict]:
# raw SQL for time bucketing as it is very engine specific
_where_stmt = select(Log)
_where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time)
# Extract WHERE clause from compiled statement
# For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query
# but bind parameters from the filtered statement
# SQLite specific bucket logic
bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}"
# We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible
# Or just build the query fully
where_clause, params = self._build_where_clause_legacy(search, start_time, end_time)
query = f"""
SELECT
datetime({bucket_expr}, 'unixepoch') as bucket_time,
COUNT(*) as count
FROM logs
{where_clause}
GROUP BY bucket_time
ORDER BY bucket_time ASC
"""
async with self.session_factory() as session:
results = await session.execute(text(query), params)
return [{"time": r[0], "count": r[1]} for r in results.all()]
bucket_seconds = interval_minutes * 60
bucket_expr = literal_column(
f"datetime((strftime('%s', timestamp) / {bucket_seconds}) * {bucket_seconds}, 'unixepoch')"
).label("bucket_time")
def _build_where_clause_legacy(self, search, start_time, end_time):
# Re-using the logic from the previous iteration for the raw query part
import shlex
import re
where_clauses = []
params = {}
if start_time:
where_clauses.append("timestamp >= :start_time")
params["start_time"] = start_time
if end_time:
where_clauses.append("timestamp <= :end_time")
params["end_time"] = end_time
if search:
try: tokens = shlex.split(search)
except: tokens = search.split(" ")
core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"}
for i, token in enumerate(tokens):
if ":" in token:
k, v = token.split(":", 1)
if k in core_fields:
where_clauses.append(f"{core_fields[k]} = :val_{i}")
params[f"val_{i}"] = v
else:
ks = re.sub(r'[^a-zA-Z0-9_]', '', k)
where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}")
params[f"val_{i}"] = v
else:
where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})")
params[f"lk_{i}"] = f"%{token}%"
where = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
return where, params
statement = select(bucket_expr, func.count().label("count")).select_from(Log)
statement = self._apply_filters(statement, search, start_time, end_time)
statement = statement.group_by(literal_column("bucket_time")).order_by(
literal_column("bucket_time")
)
async with self.session_factory() as session:
results = await session.execute(statement)
return [{"time": r[0], "count": r[1]} for r in results.all()]
async def get_stats_summary(self) -> dict[str, Any]:
async with self.session_factory() as session:
total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0
unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0
active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0
_state = load_state()
deployed_deckies = len(_state[0].deckies) if _state else 0
total_logs = (
await session.execute(select(func.count()).select_from(Log))
).scalar() or 0
unique_attackers = (
await session.execute(
select(func.count(func.distinct(Log.attacker_ip)))
)
).scalar() or 0
active_deckies = (
await session.execute(
select(func.count(func.distinct(Log.decky)))
)
).scalar() or 0
return {
"total_logs": total_logs,
"unique_attackers": unique_attackers,
"active_deckies": active_deckies,
"deployed_deckies": deployed_deckies
}
_state = await asyncio.to_thread(load_state)
deployed_deckies = len(_state[0].deckies) if _state else 0
return {
"total_logs": total_logs,
"unique_attackers": unique_attackers,
"active_deckies": active_deckies,
"deployed_deckies": deployed_deckies,
}
async def get_deckies(self) -> List[dict]:
_state = load_state()
_state = await asyncio.to_thread(load_state)
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
# ------------------------------------------------------------------ users
async def get_user_by_username(self, username: str) -> Optional[dict]:
async with self.session_factory() as session:
statement = select(User).where(User.username == username)
results = await session.execute(statement)
user = results.scalar_one_or_none()
return user.dict() if user else None
result = await session.execute(
select(User).where(User.username == username)
)
user = result.scalar_one_or_none()
return user.model_dump() if user else None
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
async with self.session_factory() as session:
statement = select(User).where(User.uuid == uuid)
results = await session.execute(statement)
user = results.scalar_one_or_none()
return user.dict() if user else None
result = await session.execute(
select(User).where(User.uuid == uuid)
)
user = result.scalar_one_or_none()
return user.model_dump() if user else None
async def create_user(self, user_data: dict[str, Any]) -> None:
user = User(**user_data)
async with self.session_factory() as session:
session.add(user)
session.add(User(**user_data))
await session.commit()
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
async def update_user_password(
self, uuid: str, password_hash: str, must_change_password: bool = False
) -> None:
async with self.session_factory() as session:
statement = update(User).where(User.uuid == uuid).values(
password_hash=password_hash,
must_change_password=must_change_password
await session.execute(
update(User)
.where(User.uuid == uuid)
.values(
password_hash=password_hash,
must_change_password=must_change_password,
)
)
await session.execute(statement)
await session.commit()
# ---------------------------------------------------------------- bounties
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
data = bounty_data.copy()
if "payload" in data and isinstance(data["payload"], dict):
data["payload"] = json.dumps(data["payload"])
bounty = Bounty(**data)
async with self.session_factory() as session:
session.add(bounty)
session.add(Bounty(**data))
await session.commit()
def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]):
@@ -316,33 +314,41 @@ class SQLiteRepository(BaseRepository):
Bounty.decky.like(lk),
Bounty.service.like(lk),
Bounty.attacker_ip.like(lk),
Bounty.payload.like(lk)
Bounty.payload.like(lk),
)
)
return statement
async def get_bounties(
self,
limit: int = 50,
offset: int = 0,
self,
limit: int = 50,
offset: int = 0,
bounty_type: Optional[str] = None,
search: Optional[str] = None
search: Optional[str] = None,
) -> List[dict]:
statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit)
statement = (
select(Bounty)
.order_by(desc(Bounty.timestamp))
.offset(offset)
.limit(limit)
)
statement = self._apply_bounty_filters(statement, bounty_type, search)
async with self.session_factory() as session:
results = await session.execute(statement)
items = results.scalars().all()
final = []
for item in items:
d = item.dict()
try: d["payload"] = json.loads(d["payload"])
except: pass
for item in results.scalars().all():
d = item.model_dump()
try:
d["payload"] = json.loads(d["payload"])
except (json.JSONDecodeError, TypeError):
pass
final.append(d)
return final
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int:
async def get_total_bounties(
self, bounty_type: Optional[str] = None, search: Optional[str] = None
) -> int:
statement = select(func.count()).select_from(Bounty)
statement = self._apply_bounty_filters(statement, bounty_type, search)