refactor: migrate database to SQLModel and implement modular DB structure
This commit is contained in:
75
decnet/web/db/models.py
Normal file
75
decnet/web/db/models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any, List
|
||||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
# --- Database Tables (SQLModel) ---
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "users"
|
||||
uuid: str = Field(primary_key=True)
|
||||
username: str = Field(index=True, unique=True)
|
||||
password_hash: str
|
||||
role: str = Field(default="viewer")
|
||||
must_change_password: bool = Field(default=False)
|
||||
|
||||
class Log(SQLModel, table=True):
|
||||
__tablename__ = "logs"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
decky: str = Field(index=True)
|
||||
service: str = Field(index=True)
|
||||
event_type: str = Field(index=True)
|
||||
attacker_ip: str = Field(index=True)
|
||||
raw_line: str
|
||||
fields: str
|
||||
msg: Optional[str] = None
|
||||
|
||||
class Bounty(SQLModel, table=True):
|
||||
__tablename__ = "bounty"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||
decky: str = Field(index=True)
|
||||
service: str = Field(index=True)
|
||||
attacker_ip: str = Field(index=True)
|
||||
bounty_type: str = Field(index=True)
|
||||
payload: str
|
||||
|
||||
# --- API Request/Response Models (Pydantic) ---
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
must_change_password: bool = False
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str = PydanticField(..., max_length=72)
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
old_password: str = PydanticField(..., max_length=72)
|
||||
new_password: str = PydanticField(..., max_length=72)
|
||||
|
||||
class LogsResponse(BaseModel):
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
data: List[dict[str, Any]]
|
||||
|
||||
class BountyResponse(BaseModel):
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
data: List[dict[str, Any]]
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
total_logs: int
|
||||
unique_attackers: int
|
||||
active_deckies: int
|
||||
deployed_deckies: int
|
||||
|
||||
class MutateIntervalRequest(BaseModel):
|
||||
mutate_interval: Optional[int] = None
|
||||
|
||||
class DeployIniRequest(BaseModel):
|
||||
ini_content: str = PydanticField(..., min_length=5, max_length=512 * 1024)
|
||||
82
decnet/web/db/repository.py
Normal file
82
decnet/web/db/repository.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseRepository(ABC):
|
||||
"""Abstract base class for DECNET web dashboard data storage."""
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the database schema."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||||
"""Add a new log entry to the database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_logs(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
search: Optional[str] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Retrieve paginated log entries."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_total_logs(self, search: Optional[str] = None) -> int:
|
||||
"""Retrieve the total count of logs, optionally filtered by search."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats_summary(self) -> dict[str, Any]:
|
||||
"""Retrieve high-level dashboard metrics."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_deckies(self) -> list[dict[str, Any]]:
|
||||
"""Retrieve the list of currently deployed deckies."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_username(self, username: str) -> Optional[dict[str, Any]]:
|
||||
"""Retrieve a user by their username."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||
"""Retrieve a user by their UUID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||
"""Create a new dashboard user."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
||||
"""Update a user's password and change the must_change_password flag."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||||
"""Add a new harvested artifact (bounty) to the database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_bounties(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
bounty_type: Optional[str] = None,
|
||||
search: Optional[str] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Retrieve paginated bounty entries."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int:
|
||||
"""Retrieve the total count of bounties, optionally filtered."""
|
||||
pass
|
||||
30
decnet/web/db/sqlite/database.py
Normal file
30
decnet/web/db/sqlite/database.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import create_engine
|
||||
from sqlmodel import SQLModel
|
||||
from pathlib import Path
|
||||
|
||||
# We need both sync and async engines for SQLite
|
||||
# Sync for initialization (DDL) and async for standard queries
|
||||
|
||||
def get_async_engine(db_path: str):
|
||||
# aiosqlite driver for async access
|
||||
return create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False, connect_args={"uri": True})
|
||||
|
||||
def get_sync_engine(db_path: str):
|
||||
return create_engine(f"sqlite:///{db_path}", echo=False, connect_args={"uri": True})
|
||||
|
||||
def init_db(db_path: str):
|
||||
"""Synchronously create all tables."""
|
||||
engine = get_sync_engine(db_path)
|
||||
# Ensure WAL mode is set
|
||||
with engine.connect() as conn:
|
||||
conn.exec_driver_sql("PRAGMA journal_mode=WAL")
|
||||
conn.exec_driver_sql("PRAGMA synchronous=NORMAL")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
async def get_session(engine) -> AsyncSession:
|
||||
async_session = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
351
decnet/web/db/sqlite/repository.py
Normal file
351
decnet/web/db/sqlite/repository.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, List
|
||||
|
||||
from sqlalchemy import func, select, desc, asc, text, or_, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.config import load_state, _ROOT
|
||||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||
from decnet.web.auth import get_password_hash
|
||||
from decnet.web.db.repository import BaseRepository
|
||||
from decnet.web.db.models import User, Log, Bounty
|
||||
from decnet.web.db.sqlite.database import get_async_engine, init_db
|
||||
|
||||
|
||||
class SQLiteRepository(BaseRepository):
|
||||
"""SQLite implementation using SQLModel and SQLAlchemy Async."""
|
||||
|
||||
def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None:
|
||||
self.db_path = db_path
|
||||
self.engine = get_async_engine(db_path)
|
||||
self.session_factory = async_sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
self._initialize_sync()
|
||||
|
||||
def _initialize_sync(self) -> None:
|
||||
"""Initialize the database schema synchronously."""
|
||||
init_db(self.db_path)
|
||||
|
||||
# Ensure default admin exists via sync SQLAlchemy connection
|
||||
from sqlalchemy import select
|
||||
from decnet.web.db.sqlite.database import get_sync_engine
|
||||
engine = get_sync_engine(self.db_path)
|
||||
with engine.connect() as conn:
|
||||
# Ensure admin exists via sync SQLAlchemy connection
|
||||
from sqlalchemy import text
|
||||
result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER})
|
||||
if not result.fetchone():
|
||||
print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'")
|
||||
conn.execute(
|
||||
text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
|
||||
"VALUES (:uuid, :u, :p, :r, :m)"),
|
||||
{
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"u": DECNET_ADMIN_USER,
|
||||
"p": get_password_hash(DECNET_ADMIN_PASSWORD),
|
||||
"r": "admin",
|
||||
"m": 1
|
||||
}
|
||||
)
|
||||
conn.commit()
|
||||
else:
|
||||
print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Async warm-up / verification."""
|
||||
async with self.session_factory() as session:
|
||||
await session.exec(text("SELECT 1"))
|
||||
|
||||
def reinitialize(self) -> None:
|
||||
self._initialize_sync()
|
||||
|
||||
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||||
# Convert dict to model
|
||||
data = log_data.copy()
|
||||
if "fields" in data and isinstance(data["fields"], dict):
|
||||
data["fields"] = json.dumps(data["fields"])
|
||||
|
||||
if "timestamp" in data and isinstance(data["timestamp"], str):
|
||||
try:
|
||||
data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00'))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
log = Log(**data)
|
||||
async with self.session_factory() as session:
|
||||
session.add(log)
|
||||
await session.commit()
|
||||
|
||||
def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]):
|
||||
import shlex
|
||||
import re
|
||||
|
||||
if start_time:
|
||||
statement = statement.where(Log.timestamp >= start_time)
|
||||
if end_time:
|
||||
statement = statement.where(Log.timestamp <= end_time)
|
||||
|
||||
if search:
|
||||
try:
|
||||
tokens = shlex.split(search)
|
||||
except ValueError:
|
||||
tokens = search.split(" ")
|
||||
|
||||
core_fields = {
|
||||
"decky": Log.decky,
|
||||
"service": Log.service,
|
||||
"event": Log.event_type,
|
||||
"attacker": Log.attacker_ip,
|
||||
"attacker-ip": Log.attacker_ip,
|
||||
"attacker_ip": Log.attacker_ip
|
||||
}
|
||||
|
||||
for token in tokens:
|
||||
if ":" in token:
|
||||
key, val = token.split(":", 1)
|
||||
if key in core_fields:
|
||||
statement = statement.where(core_fields[key] == val)
|
||||
else:
|
||||
key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key)
|
||||
# SQLite json_extract via text()
|
||||
statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val)
|
||||
else:
|
||||
lk = f"%{token}%"
|
||||
statement = statement.where(
|
||||
or_(
|
||||
Log.raw_line.like(lk),
|
||||
Log.decky.like(lk),
|
||||
Log.service.like(lk),
|
||||
Log.attacker_ip.like(lk)
|
||||
)
|
||||
)
|
||||
return statement
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
search: Optional[str] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None
|
||||
) -> List[dict]:
|
||||
statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
results = await session.execute(statement)
|
||||
return [log.dict() for log in results.scalars().all()]
|
||||
|
||||
async def get_max_log_id(self) -> int:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(select(func.max(Log.id)))
|
||||
val = result.scalar()
|
||||
return val if val is not None else 0
|
||||
|
||||
async def get_logs_after_id(
|
||||
self,
|
||||
last_id: int,
|
||||
limit: int = 50,
|
||||
search: Optional[str] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None
|
||||
) -> List[dict]:
|
||||
statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
results = await session.execute(statement)
|
||||
return [log.dict() for log in results.scalars().all()]
|
||||
|
||||
async def get_total_logs(
|
||||
self,
|
||||
search: Optional[str] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None
|
||||
) -> int:
|
||||
statement = select(func.count()).select_from(Log)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(statement)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_log_histogram(
|
||||
self,
|
||||
search: Optional[str] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
interval_minutes: int = 15
|
||||
) -> List[dict]:
|
||||
# raw SQL for time bucketing as it is very engine specific
|
||||
_where_stmt = select(Log)
|
||||
_where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time)
|
||||
|
||||
# Extract WHERE clause from compiled statement
|
||||
# For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query
|
||||
# but bind parameters from the filtered statement
|
||||
|
||||
# SQLite specific bucket logic
|
||||
bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}"
|
||||
|
||||
# We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible
|
||||
# Or just build the query fully
|
||||
|
||||
where_clause, params = self._build_where_clause_legacy(search, start_time, end_time)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
datetime({bucket_expr}, 'unixepoch') as bucket_time,
|
||||
COUNT(*) as count
|
||||
FROM logs
|
||||
{where_clause}
|
||||
GROUP BY bucket_time
|
||||
ORDER BY bucket_time ASC
|
||||
"""
|
||||
|
||||
async with self.session_factory() as session:
|
||||
results = await session.execute(text(query), params)
|
||||
return [{"time": r[0], "count": r[1]} for r in results.all()]
|
||||
|
||||
def _build_where_clause_legacy(self, search, start_time, end_time):
|
||||
# Re-using the logic from the previous iteration for the raw query part
|
||||
import shlex
|
||||
import re
|
||||
where_clauses = []
|
||||
params = {}
|
||||
|
||||
if start_time:
|
||||
where_clauses.append("timestamp >= :start_time")
|
||||
params["start_time"] = start_time
|
||||
if end_time:
|
||||
where_clauses.append("timestamp <= :end_time")
|
||||
params["end_time"] = end_time
|
||||
|
||||
if search:
|
||||
try: tokens = shlex.split(search)
|
||||
except: tokens = search.split(" ")
|
||||
core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"}
|
||||
for i, token in enumerate(tokens):
|
||||
if ":" in token:
|
||||
k, v = token.split(":", 1)
|
||||
if k in core_fields:
|
||||
where_clauses.append(f"{core_fields[k]} = :val_{i}")
|
||||
params[f"val_{i}"] = v
|
||||
else:
|
||||
ks = re.sub(r'[^a-zA-Z0-9_]', '', k)
|
||||
where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}")
|
||||
params[f"val_{i}"] = v
|
||||
else:
|
||||
where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})")
|
||||
params[f"lk_{i}"] = f"%{token}%"
|
||||
|
||||
where = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||
return where, params
|
||||
|
||||
async def get_stats_summary(self) -> dict[str, Any]:
|
||||
async with self.session_factory() as session:
|
||||
total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0
|
||||
unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0
|
||||
active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0
|
||||
|
||||
_state = load_state()
|
||||
deployed_deckies = len(_state[0].deckies) if _state else 0
|
||||
|
||||
return {
|
||||
"total_logs": total_logs,
|
||||
"unique_attackers": unique_attackers,
|
||||
"active_deckies": active_deckies,
|
||||
"deployed_deckies": deployed_deckies
|
||||
}
|
||||
|
||||
async def get_deckies(self) -> List[dict]:
|
||||
_state = load_state()
|
||||
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
||||
async with self.session_factory() as session:
|
||||
statement = select(User).where(User.username == username)
|
||||
results = await session.execute(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
return user.dict() if user else None
|
||||
|
||||
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
||||
async with self.session_factory() as session:
|
||||
statement = select(User).where(User.uuid == uuid)
|
||||
results = await session.execute(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
return user.dict() if user else None
|
||||
|
||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||
user = User(**user_data)
|
||||
async with self.session_factory() as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
||||
async with self.session_factory() as session:
|
||||
statement = update(User).where(User.uuid == uuid).values(
|
||||
password_hash=password_hash,
|
||||
must_change_password=must_change_password
|
||||
)
|
||||
await session.execute(statement)
|
||||
await session.commit()
|
||||
|
||||
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||||
data = bounty_data.copy()
|
||||
if "payload" in data and isinstance(data["payload"], dict):
|
||||
data["payload"] = json.dumps(data["payload"])
|
||||
|
||||
bounty = Bounty(**data)
|
||||
async with self.session_factory() as session:
|
||||
session.add(bounty)
|
||||
await session.commit()
|
||||
|
||||
def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]):
|
||||
if bounty_type:
|
||||
statement = statement.where(Bounty.bounty_type == bounty_type)
|
||||
if search:
|
||||
lk = f"%{search}%"
|
||||
statement = statement.where(
|
||||
or_(
|
||||
Bounty.decky.like(lk),
|
||||
Bounty.service.like(lk),
|
||||
Bounty.attacker_ip.like(lk),
|
||||
Bounty.payload.like(lk)
|
||||
)
|
||||
)
|
||||
return statement
|
||||
|
||||
async def get_bounties(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
bounty_type: Optional[str] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[dict]:
|
||||
statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit)
|
||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
results = await session.execute(statement)
|
||||
items = results.scalars().all()
|
||||
final = []
|
||||
for item in items:
|
||||
d = item.dict()
|
||||
try: d["payload"] = json.loads(d["payload"])
|
||||
except: pass
|
||||
final.append(d)
|
||||
return final
|
||||
|
||||
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int:
|
||||
statement = select(func.count()).select_from(Bounty)
|
||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(statement)
|
||||
return result.scalar() or 0
|
||||
Reference in New Issue
Block a user