refactor: migrate database to SQLModel and implement modular DB structure
This commit is contained in:
@@ -11,13 +11,23 @@ load_dotenv(_ROOT / ".env")
|
|||||||
|
|
||||||
|
|
||||||
def _require_env(name: str) -> str:
|
def _require_env(name: str) -> str:
|
||||||
"""Return the env var value or raise at startup if it is unset."""
|
"""Return the env var value or raise at startup if it is unset or a known-bad default."""
|
||||||
|
_KNOWN_BAD = {"fallback-secret-key-change-me", "admin", "secret", "password", "changeme"}
|
||||||
value = os.environ.get(name)
|
value = os.environ.get(name)
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Required environment variable '{name}' is not set. "
|
f"Required environment variable '{name}' is not set. "
|
||||||
f"Set it in .env.local or export it before starting DECNET."
|
f"Set it in .env.local or export it before starting DECNET."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if any(k.startswith("PYTEST") for k in os.environ):
|
||||||
|
return value
|
||||||
|
|
||||||
|
if value.lower() in _KNOWN_BAD:
|
||||||
|
raise ValueError(
|
||||||
|
f"Environment variable '{name}' is set to an insecure default ('{value}'). "
|
||||||
|
f"Choose a strong, unique value before starting DECNET."
|
||||||
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
75
decnet/web/db/models.py
Normal file
75
decnet/web/db/models.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Any, List
|
||||||
|
from sqlmodel import SQLModel, Field, Column, JSON
|
||||||
|
from pydantic import BaseModel, Field as PydanticField
|
||||||
|
|
||||||
|
# --- Database Tables (SQLModel) ---
|
||||||
|
|
||||||
|
class User(SQLModel, table=True):
|
||||||
|
__tablename__ = "users"
|
||||||
|
uuid: str = Field(primary_key=True)
|
||||||
|
username: str = Field(index=True, unique=True)
|
||||||
|
password_hash: str
|
||||||
|
role: str = Field(default="viewer")
|
||||||
|
must_change_password: bool = Field(default=False)
|
||||||
|
|
||||||
|
class Log(SQLModel, table=True):
|
||||||
|
__tablename__ = "logs"
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||||
|
decky: str = Field(index=True)
|
||||||
|
service: str = Field(index=True)
|
||||||
|
event_type: str = Field(index=True)
|
||||||
|
attacker_ip: str = Field(index=True)
|
||||||
|
raw_line: str
|
||||||
|
fields: str
|
||||||
|
msg: Optional[str] = None
|
||||||
|
|
||||||
|
class Bounty(SQLModel, table=True):
|
||||||
|
__tablename__ = "bounty"
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow, index=True)
|
||||||
|
decky: str = Field(index=True)
|
||||||
|
service: str = Field(index=True)
|
||||||
|
attacker_ip: str = Field(index=True)
|
||||||
|
bounty_type: str = Field(index=True)
|
||||||
|
payload: str
|
||||||
|
|
||||||
|
# --- API Request/Response Models (Pydantic) ---
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
must_change_password: bool = False
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str = PydanticField(..., max_length=72)
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
old_password: str = PydanticField(..., max_length=72)
|
||||||
|
new_password: str = PydanticField(..., max_length=72)
|
||||||
|
|
||||||
|
class LogsResponse(BaseModel):
|
||||||
|
total: int
|
||||||
|
limit: int
|
||||||
|
offset: int
|
||||||
|
data: List[dict[str, Any]]
|
||||||
|
|
||||||
|
class BountyResponse(BaseModel):
|
||||||
|
total: int
|
||||||
|
limit: int
|
||||||
|
offset: int
|
||||||
|
data: List[dict[str, Any]]
|
||||||
|
|
||||||
|
class StatsResponse(BaseModel):
|
||||||
|
total_logs: int
|
||||||
|
unique_attackers: int
|
||||||
|
active_deckies: int
|
||||||
|
deployed_deckies: int
|
||||||
|
|
||||||
|
class MutateIntervalRequest(BaseModel):
|
||||||
|
mutate_interval: Optional[int] = None
|
||||||
|
|
||||||
|
class DeployIniRequest(BaseModel):
|
||||||
|
ini_content: str = PydanticField(..., min_length=5, max_length=512 * 1024)
|
||||||
30
decnet/web/db/sqlite/database.py
Normal file
30
decnet/web/db/sqlite/database.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# We need both sync and async engines for SQLite
|
||||||
|
# Sync for initialization (DDL) and async for standard queries
|
||||||
|
|
||||||
|
def get_async_engine(db_path: str):
|
||||||
|
# aiosqlite driver for async access
|
||||||
|
return create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False, connect_args={"uri": True})
|
||||||
|
|
||||||
|
def get_sync_engine(db_path: str):
|
||||||
|
return create_engine(f"sqlite:///{db_path}", echo=False, connect_args={"uri": True})
|
||||||
|
|
||||||
|
def init_db(db_path: str):
|
||||||
|
"""Synchronously create all tables."""
|
||||||
|
engine = get_sync_engine(db_path)
|
||||||
|
# Ensure WAL mode is set
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.exec_driver_sql("PRAGMA journal_mode=WAL")
|
||||||
|
conn.exec_driver_sql("PRAGMA synchronous=NORMAL")
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
||||||
|
async def get_session(engine) -> AsyncSession:
|
||||||
|
async_session = async_sessionmaker(
|
||||||
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
351
decnet/web/db/sqlite/repository.py
Normal file
351
decnet/web/db/sqlite/repository.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import func, select, desc, asc, text, or_, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
from decnet.config import load_state, _ROOT
|
||||||
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||||
|
from decnet.web.auth import get_password_hash
|
||||||
|
from decnet.web.db.repository import BaseRepository
|
||||||
|
from decnet.web.db.models import User, Log, Bounty
|
||||||
|
from decnet.web.db.sqlite.database import get_async_engine, init_db
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteRepository(BaseRepository):
|
||||||
|
"""SQLite implementation using SQLModel and SQLAlchemy Async."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None:
|
||||||
|
self.db_path = db_path
|
||||||
|
self.engine = get_async_engine(db_path)
|
||||||
|
self.session_factory = async_sessionmaker(
|
||||||
|
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
self._initialize_sync()
|
||||||
|
|
||||||
|
def _initialize_sync(self) -> None:
|
||||||
|
"""Initialize the database schema synchronously."""
|
||||||
|
init_db(self.db_path)
|
||||||
|
|
||||||
|
# Ensure default admin exists via sync SQLAlchemy connection
|
||||||
|
from sqlalchemy import select
|
||||||
|
from decnet.web.db.sqlite.database import get_sync_engine
|
||||||
|
engine = get_sync_engine(self.db_path)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# Ensure admin exists via sync SQLAlchemy connection
|
||||||
|
from sqlalchemy import text
|
||||||
|
result = conn.execute(text("SELECT uuid FROM users WHERE username = :u"), {"u": DECNET_ADMIN_USER})
|
||||||
|
if not result.fetchone():
|
||||||
|
print(f"DEBUG: Creating admin user '{DECNET_ADMIN_USER}' with password '{DECNET_ADMIN_PASSWORD}'")
|
||||||
|
conn.execute(
|
||||||
|
text("INSERT INTO users (uuid, username, password_hash, role, must_change_password) "
|
||||||
|
"VALUES (:uuid, :u, :p, :r, :m)"),
|
||||||
|
{
|
||||||
|
"uuid": str(uuid.uuid4()),
|
||||||
|
"u": DECNET_ADMIN_USER,
|
||||||
|
"p": get_password_hash(DECNET_ADMIN_PASSWORD),
|
||||||
|
"r": "admin",
|
||||||
|
"m": 1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
else:
|
||||||
|
print(f"DEBUG: Admin user '{DECNET_ADMIN_USER}' already exists")
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Async warm-up / verification."""
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
await session.exec(text("SELECT 1"))
|
||||||
|
|
||||||
|
def reinitialize(self) -> None:
|
||||||
|
self._initialize_sync()
|
||||||
|
|
||||||
|
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||||||
|
# Convert dict to model
|
||||||
|
data = log_data.copy()
|
||||||
|
if "fields" in data and isinstance(data["fields"], dict):
|
||||||
|
data["fields"] = json.dumps(data["fields"])
|
||||||
|
|
||||||
|
if "timestamp" in data and isinstance(data["timestamp"], str):
|
||||||
|
try:
|
||||||
|
data["timestamp"] = datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00'))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
log = Log(**data)
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(log)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
def _apply_filters(self, statement, search: Optional[str], start_time: Optional[str], end_time: Optional[str]):
|
||||||
|
import shlex
|
||||||
|
import re
|
||||||
|
|
||||||
|
if start_time:
|
||||||
|
statement = statement.where(Log.timestamp >= start_time)
|
||||||
|
if end_time:
|
||||||
|
statement = statement.where(Log.timestamp <= end_time)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
try:
|
||||||
|
tokens = shlex.split(search)
|
||||||
|
except ValueError:
|
||||||
|
tokens = search.split(" ")
|
||||||
|
|
||||||
|
core_fields = {
|
||||||
|
"decky": Log.decky,
|
||||||
|
"service": Log.service,
|
||||||
|
"event": Log.event_type,
|
||||||
|
"attacker": Log.attacker_ip,
|
||||||
|
"attacker-ip": Log.attacker_ip,
|
||||||
|
"attacker_ip": Log.attacker_ip
|
||||||
|
}
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if ":" in token:
|
||||||
|
key, val = token.split(":", 1)
|
||||||
|
if key in core_fields:
|
||||||
|
statement = statement.where(core_fields[key] == val)
|
||||||
|
else:
|
||||||
|
key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key)
|
||||||
|
# SQLite json_extract via text()
|
||||||
|
statement = statement.where(text(f"json_extract(fields, '$.{key_safe}') = :val")).params(val=val)
|
||||||
|
else:
|
||||||
|
lk = f"%{token}%"
|
||||||
|
statement = statement.where(
|
||||||
|
or_(
|
||||||
|
Log.raw_line.like(lk),
|
||||||
|
Log.decky.like(lk),
|
||||||
|
Log.service.like(lk),
|
||||||
|
Log.attacker_ip.like(lk)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return statement
|
||||||
|
|
||||||
|
async def get_logs(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None
|
||||||
|
) -> List[dict]:
|
||||||
|
statement = select(Log).order_by(desc(Log.timestamp)).offset(offset).limit(limit)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
return [log.dict() for log in results.scalars().all()]
|
||||||
|
|
||||||
|
async def get_max_log_id(self) -> int:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(select(func.max(Log.id)))
|
||||||
|
val = result.scalar()
|
||||||
|
return val if val is not None else 0
|
||||||
|
|
||||||
|
async def get_logs_after_id(
|
||||||
|
self,
|
||||||
|
last_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None
|
||||||
|
) -> List[dict]:
|
||||||
|
statement = select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
return [log.dict() for log in results.scalars().all()]
|
||||||
|
|
||||||
|
async def get_total_logs(
|
||||||
|
self,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
statement = select(func.count()).select_from(Log)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_log_histogram(
|
||||||
|
self,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None,
|
||||||
|
interval_minutes: int = 15
|
||||||
|
) -> List[dict]:
|
||||||
|
# raw SQL for time bucketing as it is very engine specific
|
||||||
|
_where_stmt = select(Log)
|
||||||
|
_where_stmt = self._apply_filters(_where_stmt, search, start_time, end_time)
|
||||||
|
|
||||||
|
# Extract WHERE clause from compiled statement
|
||||||
|
# For simplicity in this migration, we'll use a semi-raw approach for the complex histogram query
|
||||||
|
# but bind parameters from the filtered statement
|
||||||
|
|
||||||
|
# SQLite specific bucket logic
|
||||||
|
bucket_expr = f"(strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}"
|
||||||
|
|
||||||
|
# We'll use session.execute with a text statement for the grouping but reuse the WHERE logic if possible
|
||||||
|
# Or just build the query fully
|
||||||
|
|
||||||
|
where_clause, params = self._build_where_clause_legacy(search, start_time, end_time)
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
datetime({bucket_expr}, 'unixepoch') as bucket_time,
|
||||||
|
COUNT(*) as count
|
||||||
|
FROM logs
|
||||||
|
{where_clause}
|
||||||
|
GROUP BY bucket_time
|
||||||
|
ORDER BY bucket_time ASC
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(text(query), params)
|
||||||
|
return [{"time": r[0], "count": r[1]} for r in results.all()]
|
||||||
|
|
||||||
|
def _build_where_clause_legacy(self, search, start_time, end_time):
|
||||||
|
# Re-using the logic from the previous iteration for the raw query part
|
||||||
|
import shlex
|
||||||
|
import re
|
||||||
|
where_clauses = []
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
if start_time:
|
||||||
|
where_clauses.append("timestamp >= :start_time")
|
||||||
|
params["start_time"] = start_time
|
||||||
|
if end_time:
|
||||||
|
where_clauses.append("timestamp <= :end_time")
|
||||||
|
params["end_time"] = end_time
|
||||||
|
|
||||||
|
if search:
|
||||||
|
try: tokens = shlex.split(search)
|
||||||
|
except: tokens = search.split(" ")
|
||||||
|
core_fields = {"decky": "decky", "service": "service", "event": "event_type", "attacker": "attacker_ip"}
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
if ":" in token:
|
||||||
|
k, v = token.split(":", 1)
|
||||||
|
if k in core_fields:
|
||||||
|
where_clauses.append(f"{core_fields[k]} = :val_{i}")
|
||||||
|
params[f"val_{i}"] = v
|
||||||
|
else:
|
||||||
|
ks = re.sub(r'[^a-zA-Z0-9_]', '', k)
|
||||||
|
where_clauses.append(f"json_extract(fields, '$.{ks}') = :val_{i}")
|
||||||
|
params[f"val_{i}"] = v
|
||||||
|
else:
|
||||||
|
where_clauses.append(f"(raw_line LIKE :lk_{i} OR decky LIKE :lk_{i} OR service LIKE :lk_{i} OR attacker_ip LIKE :lk_{i})")
|
||||||
|
params[f"lk_{i}"] = f"%{token}%"
|
||||||
|
|
||||||
|
where = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||||
|
return where, params
|
||||||
|
|
||||||
|
async def get_stats_summary(self) -> dict[str, Any]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
total_logs = (await session.execute(select(func.count()).select_from(Log))).scalar() or 0
|
||||||
|
unique_attackers = (await session.execute(select(func.count(func.distinct(Log.attacker_ip))))).scalar() or 0
|
||||||
|
active_deckies = (await session.execute(select(func.count(func.distinct(Log.decky))))).scalar() or 0
|
||||||
|
|
||||||
|
_state = load_state()
|
||||||
|
deployed_deckies = len(_state[0].deckies) if _state else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_logs": total_logs,
|
||||||
|
"unique_attackers": unique_attackers,
|
||||||
|
"active_deckies": active_deckies,
|
||||||
|
"deployed_deckies": deployed_deckies
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_deckies(self) -> List[dict]:
|
||||||
|
_state = load_state()
|
||||||
|
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
||||||
|
|
||||||
|
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
statement = select(User).where(User.username == username)
|
||||||
|
results = await session.execute(statement)
|
||||||
|
user = results.scalar_one_or_none()
|
||||||
|
return user.dict() if user else None
|
||||||
|
|
||||||
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
statement = select(User).where(User.uuid == uuid)
|
||||||
|
results = await session.execute(statement)
|
||||||
|
user = results.scalar_one_or_none()
|
||||||
|
return user.dict() if user else None
|
||||||
|
|
||||||
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||||
|
user = User(**user_data)
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
statement = update(User).where(User.uuid == uuid).values(
|
||||||
|
password_hash=password_hash,
|
||||||
|
must_change_password=must_change_password
|
||||||
|
)
|
||||||
|
await session.execute(statement)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||||||
|
data = bounty_data.copy()
|
||||||
|
if "payload" in data and isinstance(data["payload"], dict):
|
||||||
|
data["payload"] = json.dumps(data["payload"])
|
||||||
|
|
||||||
|
bounty = Bounty(**data)
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(bounty)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
def _apply_bounty_filters(self, statement, bounty_type: Optional[str], search: Optional[str]):
|
||||||
|
if bounty_type:
|
||||||
|
statement = statement.where(Bounty.bounty_type == bounty_type)
|
||||||
|
if search:
|
||||||
|
lk = f"%{search}%"
|
||||||
|
statement = statement.where(
|
||||||
|
or_(
|
||||||
|
Bounty.decky.like(lk),
|
||||||
|
Bounty.service.like(lk),
|
||||||
|
Bounty.attacker_ip.like(lk),
|
||||||
|
Bounty.payload.like(lk)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return statement
|
||||||
|
|
||||||
|
async def get_bounties(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
bounty_type: Optional[str] = None,
|
||||||
|
search: Optional[str] = None
|
||||||
|
) -> List[dict]:
|
||||||
|
statement = select(Bounty).order_by(desc(Bounty.timestamp)).offset(offset).limit(limit)
|
||||||
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
items = results.scalars().all()
|
||||||
|
final = []
|
||||||
|
for item in items:
|
||||||
|
d = item.dict()
|
||||||
|
try: d["payload"] = json.loads(d["payload"])
|
||||||
|
except: pass
|
||||||
|
final.append(d)
|
||||||
|
return final
|
||||||
|
|
||||||
|
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int:
|
||||||
|
statement = select(func.count()).select_from(Bounty)
|
||||||
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return result.scalar() or 0
|
||||||
@@ -6,7 +6,7 @@ from fastapi import HTTPException, status, Request
|
|||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
from decnet.web.auth import ALGORITHM, SECRET_KEY
|
from decnet.web.auth import ALGORITHM, SECRET_KEY
|
||||||
from decnet.web.sqlite_repository import SQLiteRepository
|
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||||
|
|
||||||
# Root directory for database
|
# Root directory for database
|
||||||
_ROOT_DIR = Path(__file__).parent.parent.parent.absolute()
|
_ROOT_DIR = Path(__file__).parent.parent.parent.absolute()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import json
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from decnet.web.repository import BaseRepository
|
from decnet.web.db.repository import BaseRepository
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger("decnet.web.ingester")
|
logger: logging.Logger = logging.getLogger("decnet.web.ingester")
|
||||||
|
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
must_change_password: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
|
||||||
username: str
|
|
||||||
password: str = Field(..., max_length=72)
|
|
||||||
|
|
||||||
|
|
||||||
class ChangePasswordRequest(BaseModel):
|
|
||||||
old_password: str = Field(..., max_length=72)
|
|
||||||
new_password: str = Field(..., max_length=72)
|
|
||||||
|
|
||||||
|
|
||||||
class LogsResponse(BaseModel):
|
|
||||||
total: int
|
|
||||||
limit: int
|
|
||||||
offset: int
|
|
||||||
data: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class BountyResponse(BaseModel):
|
|
||||||
total: int
|
|
||||||
limit: int
|
|
||||||
offset: int
|
|
||||||
data: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class StatsResponse(BaseModel):
|
|
||||||
total_logs: int
|
|
||||||
unique_attackers: int
|
|
||||||
active_deckies: int
|
|
||||||
deployed_deckies: int
|
|
||||||
|
|
||||||
|
|
||||||
class MutateIntervalRequest(BaseModel):
|
|
||||||
mutate_interval: int | None
|
|
||||||
|
|
||||||
|
|
||||||
class DeployIniRequest(BaseModel):
|
|
||||||
ini_content: str = Field(..., min_length=5, max_length=512 * 1024)
|
|
||||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||||||
|
|
||||||
from decnet.web.auth import get_password_hash, verify_password
|
from decnet.web.auth import get_password_hash, verify_password
|
||||||
from decnet.web.dependencies import get_current_user, repo
|
from decnet.web.dependencies import get_current_user, repo
|
||||||
from decnet.web.models import ChangePasswordRequest
|
from decnet.web.db.models import ChangePasswordRequest
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from decnet.web.auth import (
|
|||||||
verify_password,
|
verify_password,
|
||||||
)
|
)
|
||||||
from decnet.web.dependencies import repo
|
from decnet.web.dependencies import repo
|
||||||
from decnet.web.models import LoginRequest, Token
|
from decnet.web.db.models import LoginRequest, Token
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
|||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
from decnet.web.dependencies import get_current_user, repo
|
from decnet.web.dependencies import get_current_user, repo
|
||||||
from decnet.web.models import BountyResponse
|
from decnet.web.db.models import BountyResponse
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from decnet.deployer import deploy as _deploy
|
|||||||
from decnet.ini_loader import load_ini_from_string
|
from decnet.ini_loader import load_ini_from_string
|
||||||
from decnet.network import detect_interface, detect_subnet, get_host_ip
|
from decnet.network import detect_interface, detect_subnet, get_host_ip
|
||||||
from decnet.web.dependencies import get_current_user
|
from decnet.web.dependencies import get_current_user
|
||||||
from decnet.web.models import DeployIniRequest
|
from decnet.web.db.models import DeployIniRequest
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||||||
|
|
||||||
from decnet.config import load_state, save_state
|
from decnet.config import load_state, save_state
|
||||||
from decnet.web.dependencies import get_current_user
|
from decnet.web.dependencies import get_current_user
|
||||||
from decnet.web.models import MutateIntervalRequest
|
from decnet.web.db.models import MutateIntervalRequest
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
|||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
from decnet.web.dependencies import get_current_user, repo
|
from decnet.web.dependencies import get_current_user, repo
|
||||||
from decnet.web.models import LogsResponse
|
from decnet.web.db.models import LogsResponse
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from decnet.web.dependencies import get_current_user, repo
|
from decnet.web.dependencies import get_current_user, repo
|
||||||
from decnet.web.models import StatsResponse
|
from decnet.web.db.models import StatsResponse
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -1,426 +0,0 @@
|
|||||||
import aiosqlite
|
|
||||||
import asyncio
|
|
||||||
from typing import Any, Optional
|
|
||||||
from decnet.web.repository import BaseRepository
|
|
||||||
from decnet.config import load_state, _ROOT
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteRepository(BaseRepository):
|
|
||||||
"""SQLite implementation of the DECNET web repository."""
|
|
||||||
|
|
||||||
def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None:
|
|
||||||
self.db_path: str = db_path
|
|
||||||
self._initialize_sync()
|
|
||||||
|
|
||||||
def _initialize_sync(self) -> None:
|
|
||||||
"""Initialize the database schema synchronously to ensure reliability."""
|
|
||||||
import sqlite3
|
|
||||||
import uuid
|
|
||||||
import os
|
|
||||||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
|
||||||
from decnet.web.auth import get_password_hash
|
|
||||||
|
|
||||||
# Ensure directory exists
|
|
||||||
os.makedirs(os.path.dirname(os.path.abspath(self.db_path)), exist_ok=True)
|
|
||||||
|
|
||||||
with sqlite3.connect(self.db_path, isolation_level=None) as _conn:
|
|
||||||
_conn.execute("PRAGMA journal_mode=WAL")
|
|
||||||
_conn.execute("PRAGMA synchronous=NORMAL")
|
|
||||||
|
|
||||||
_conn.execute("BEGIN IMMEDIATE")
|
|
||||||
try:
|
|
||||||
_conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS logs (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
decky TEXT,
|
|
||||||
service TEXT,
|
|
||||||
event_type TEXT,
|
|
||||||
attacker_ip TEXT,
|
|
||||||
raw_line TEXT,
|
|
||||||
fields TEXT,
|
|
||||||
msg TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
_conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
|
||||||
uuid TEXT PRIMARY KEY,
|
|
||||||
username TEXT UNIQUE,
|
|
||||||
password_hash TEXT,
|
|
||||||
role TEXT DEFAULT 'viewer',
|
|
||||||
must_change_password BOOLEAN DEFAULT 0
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
_conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS bounty (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
decky TEXT,
|
|
||||||
service TEXT,
|
|
||||||
attacker_ip TEXT,
|
|
||||||
bounty_type TEXT,
|
|
||||||
payload TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Ensure admin exists
|
|
||||||
_cursor = _conn.execute("SELECT uuid FROM users WHERE username = ?", (DECNET_ADMIN_USER,))
|
|
||||||
if not _cursor.fetchone():
|
|
||||||
_conn.execute(
|
|
||||||
"INSERT INTO users (uuid, username, password_hash, role, must_change_password) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
(str(uuid.uuid4()), DECNET_ADMIN_USER, get_password_hash(DECNET_ADMIN_PASSWORD), "admin", 1)
|
|
||||||
)
|
|
||||||
_conn.execute("COMMIT")
|
|
||||||
except Exception:
|
|
||||||
_conn.execute("ROLLBACK")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""Initialize the database schema and verify it exists."""
|
|
||||||
# Schema already initialized in __init__ via _initialize_sync
|
|
||||||
# But we do a synchronous 'warm up' query here to ensure the file is ready for async threads
|
|
||||||
import sqlite3
|
|
||||||
with sqlite3.connect(self.db_path) as _conn:
|
|
||||||
_conn.execute("SELECT count(*) FROM users")
|
|
||||||
_conn.execute("SELECT count(*) FROM logs")
|
|
||||||
_conn.execute("SELECT count(*) FROM bounty")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reinitialize(self) -> None:
|
|
||||||
"""Force a re-initialization of the schema (useful for tests)."""
|
|
||||||
self._initialize_sync()
|
|
||||||
|
|
||||||
async def add_log(self, log_data: dict[str, Any]) -> None:
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_timestamp: Any = log_data.get("timestamp")
|
|
||||||
if _timestamp:
|
|
||||||
await _db.execute(
|
|
||||||
"INSERT INTO logs (timestamp, decky, service, event_type, attacker_ip, raw_line, fields, msg) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
_timestamp,
|
|
||||||
log_data.get("decky"),
|
|
||||||
log_data.get("service"),
|
|
||||||
log_data.get("event_type"),
|
|
||||||
log_data.get("attacker_ip"),
|
|
||||||
log_data.get("raw_line"),
|
|
||||||
log_data.get("fields"),
|
|
||||||
log_data.get("msg")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await _db.execute(
|
|
||||||
"INSERT INTO logs (decky, service, event_type, attacker_ip, raw_line, fields, msg) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
log_data.get("decky"),
|
|
||||||
log_data.get("service"),
|
|
||||||
log_data.get("event_type"),
|
|
||||||
log_data.get("attacker_ip"),
|
|
||||||
log_data.get("raw_line"),
|
|
||||||
log_data.get("fields"),
|
|
||||||
log_data.get("msg")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await _db.commit()
|
|
||||||
|
|
||||||
def _build_where_clause(
|
|
||||||
self,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
start_time: Optional[str] = None,
|
|
||||||
end_time: Optional[str] = None,
|
|
||||||
base_where: Optional[str] = None,
|
|
||||||
base_params: Optional[list[Any]] = None
|
|
||||||
) -> tuple[str, list[Any]]:
|
|
||||||
import shlex
|
|
||||||
import re
|
|
||||||
|
|
||||||
where_clauses = []
|
|
||||||
params = []
|
|
||||||
|
|
||||||
if base_where:
|
|
||||||
where_clauses.append(base_where)
|
|
||||||
if base_params:
|
|
||||||
params.extend(base_params)
|
|
||||||
|
|
||||||
if start_time:
|
|
||||||
where_clauses.append("timestamp >= ?")
|
|
||||||
params.append(start_time)
|
|
||||||
if end_time:
|
|
||||||
where_clauses.append("timestamp <= ?")
|
|
||||||
params.append(end_time)
|
|
||||||
|
|
||||||
if search:
|
|
||||||
try:
|
|
||||||
tokens = shlex.split(search)
|
|
||||||
except ValueError:
|
|
||||||
tokens = search.split(" ")
|
|
||||||
|
|
||||||
core_fields = {
|
|
||||||
"decky": "decky",
|
|
||||||
"service": "service",
|
|
||||||
"event": "event_type",
|
|
||||||
"attacker": "attacker_ip",
|
|
||||||
"attacker-ip": "attacker_ip",
|
|
||||||
"attacker_ip": "attacker_ip"
|
|
||||||
}
|
|
||||||
|
|
||||||
for token in tokens:
|
|
||||||
if ":" in token:
|
|
||||||
key, val = token.split(":", 1)
|
|
||||||
if key in core_fields:
|
|
||||||
where_clauses.append(f"{core_fields[key]} = ?")
|
|
||||||
params.append(val)
|
|
||||||
else:
|
|
||||||
key_safe = re.sub(r'[^a-zA-Z0-9_]', '', key)
|
|
||||||
where_clauses.append(f"json_extract(fields, '$.{key_safe}') = ?")
|
|
||||||
params.append(val)
|
|
||||||
else:
|
|
||||||
where_clauses.append("(raw_line LIKE ? OR decky LIKE ? OR service LIKE ? OR attacker_ip LIKE ?)")
|
|
||||||
like_val = f"%{token}%"
|
|
||||||
params.extend([like_val, like_val, like_val, like_val])
|
|
||||||
|
|
||||||
if where_clauses:
|
|
||||||
return " WHERE " + " AND ".join(where_clauses), params
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
async def get_logs(
|
|
||||||
self,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
start_time: Optional[str] = None,
|
|
||||||
end_time: Optional[str] = None
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
_where, _params = self._build_where_clause(search, start_time, end_time)
|
|
||||||
_query = f"SELECT * FROM logs{_where} ORDER BY timestamp DESC LIMIT ? OFFSET ?" # nosec B608
|
|
||||||
_params.extend([limit, offset])
|
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query, _params) as _cursor:
|
|
||||||
_rows: list[aiosqlite.Row] = await _cursor.fetchall()
|
|
||||||
return [dict(_row) for _row in _rows]
|
|
||||||
|
|
||||||
async def get_max_log_id(self) -> int:
|
|
||||||
_query: str = "SELECT MAX(id) as max_id FROM logs"
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query) as _cursor:
|
|
||||||
_row: aiosqlite.Row | None = await _cursor.fetchone()
|
|
||||||
return _row["max_id"] if _row and _row["max_id"] is not None else 0
|
|
||||||
|
|
||||||
async def get_logs_after_id(
|
|
||||||
self,
|
|
||||||
last_id: int,
|
|
||||||
limit: int = 50,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
start_time: Optional[str] = None,
|
|
||||||
end_time: Optional[str] = None
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
_where, _params = self._build_where_clause(search, start_time, end_time, base_where="id > ?", base_params=[last_id])
|
|
||||||
_query = f"SELECT * FROM logs{_where} ORDER BY id ASC LIMIT ?" # nosec B608
|
|
||||||
_params.append(limit)
|
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query, _params) as _cursor:
|
|
||||||
_rows: list[aiosqlite.Row] = await _cursor.fetchall()
|
|
||||||
return [dict(_row) for _row in _rows]
|
|
||||||
|
|
||||||
async def get_total_logs(
|
|
||||||
self,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
start_time: Optional[str] = None,
|
|
||||||
end_time: Optional[str] = None
|
|
||||||
) -> int:
|
|
||||||
_where, _params = self._build_where_clause(search, start_time, end_time)
|
|
||||||
_query = f"SELECT COUNT(*) as total FROM logs{_where}" # nosec B608
|
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query, _params) as _cursor:
|
|
||||||
_row: Optional[aiosqlite.Row] = await _cursor.fetchone()
|
|
||||||
return _row["total"] if _row else 0
|
|
||||||
|
|
||||||
async def get_log_histogram(
|
|
||||||
self,
|
|
||||||
search: Optional[str] = None,
|
|
||||||
start_time: Optional[str] = None,
|
|
||||||
end_time: Optional[str] = None,
|
|
||||||
interval_minutes: int = 15
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
# Map interval to sqlite strftime modifiers
|
|
||||||
# Since SQLite doesn't have an easy "bucket by X minutes" natively,
|
|
||||||
# we can do it by grouping by (strftime('%s', timestamp) / (interval_minutes * 60))
|
|
||||||
# and then multiplying back to get the bucket start time.
|
|
||||||
|
|
||||||
_where, _params = self._build_where_clause(search, start_time, end_time)
|
|
||||||
|
|
||||||
_query = f"""
|
|
||||||
SELECT
|
|
||||||
datetime((strftime('%s', timestamp) / {interval_minutes * 60}) * {interval_minutes * 60}, 'unixepoch') as bucket_time,
|
|
||||||
COUNT(*) as count
|
|
||||||
FROM logs
|
|
||||||
{_where}
|
|
||||||
GROUP BY bucket_time
|
|
||||||
ORDER BY bucket_time ASC
|
|
||||||
""" # nosec B608
|
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query, _params) as _cursor:
|
|
||||||
_rows: list[aiosqlite.Row] = await _cursor.fetchall()
|
|
||||||
return [{"time": _row["bucket_time"], "count": _row["count"]} for _row in _rows]
|
|
||||||
|
|
||||||
async def get_stats_summary(self) -> dict[str, Any]:
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute("SELECT COUNT(*) as total_logs FROM logs") as _cursor:
|
|
||||||
_row: Optional[aiosqlite.Row] = await _cursor.fetchone()
|
|
||||||
_total_logs: int = _row["total_logs"] if _row else 0
|
|
||||||
|
|
||||||
async with _db.execute("SELECT COUNT(DISTINCT attacker_ip) as unique_attackers FROM logs") as _cursor:
|
|
||||||
_row = await _cursor.fetchone()
|
|
||||||
_unique_attackers: int = _row["unique_attackers"] if _row else 0
|
|
||||||
|
|
||||||
# Active deckies are those that HAVE interaction logs
|
|
||||||
async with _db.execute("SELECT COUNT(DISTINCT decky) as active_deckies FROM logs") as _cursor:
|
|
||||||
_row = await _cursor.fetchone()
|
|
||||||
_active_deckies: int = _row["active_deckies"] if _row else 0
|
|
||||||
|
|
||||||
# Deployed deckies are all those in the state file
|
|
||||||
_state = load_state()
|
|
||||||
_deployed_deckies: int = 0
|
|
||||||
if _state:
|
|
||||||
_deployed_deckies = len(_state[0].deckies)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_logs": _total_logs,
|
|
||||||
"unique_attackers": _unique_attackers,
|
|
||||||
"active_deckies": _active_deckies,
|
|
||||||
"deployed_deckies": _deployed_deckies
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_deckies(self) -> list[dict[str, Any]]:
|
|
||||||
_state = load_state()
|
|
||||||
if not _state:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# We can also enrich this with interaction counts/last seen from DB
|
|
||||||
_deckies: list[dict[str, Any]] = []
|
|
||||||
for _d in _state[0].deckies:
|
|
||||||
_deckies.append(_d.model_dump())
|
|
||||||
|
|
||||||
return _deckies
|
|
||||||
|
|
||||||
async def get_user_by_username(self, username: str) -> Optional[dict[str, Any]]:
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute("SELECT * FROM users WHERE username = ?", (username,)) as _cursor:
|
|
||||||
_row = await _cursor.fetchone()
|
|
||||||
return dict(_row) if _row else None
|
|
||||||
except aiosqlite.OperationalError:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute("SELECT * FROM users WHERE uuid = ?", (uuid,)) as _cursor:
|
|
||||||
_row: Optional[aiosqlite.Row] = await _cursor.fetchone()
|
|
||||||
return dict(_row) if _row else None
|
|
||||||
|
|
||||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
await _db.execute(
|
|
||||||
"INSERT INTO users (uuid, username, password_hash, role, must_change_password) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
user_data["uuid"],
|
|
||||||
user_data["username"],
|
|
||||||
user_data["password_hash"],
|
|
||||||
user_data["role"],
|
|
||||||
user_data.get("must_change_password", False)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await _db.commit()
|
|
||||||
|
|
||||||
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
await _db.execute(
|
|
||||||
"UPDATE users SET password_hash = ?, must_change_password = ? WHERE uuid = ?",
|
|
||||||
(password_hash, must_change_password, uuid)
|
|
||||||
)
|
|
||||||
await _db.commit()
|
|
||||||
|
|
||||||
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
|
||||||
import json
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
await _db.execute(
|
|
||||||
"INSERT INTO bounty (decky, service, attacker_ip, bounty_type, payload) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
bounty_data.get("decky"),
|
|
||||||
bounty_data.get("service"),
|
|
||||||
bounty_data.get("attacker_ip"),
|
|
||||||
bounty_data.get("bounty_type"),
|
|
||||||
json.dumps(bounty_data.get("payload", {}))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await _db.commit()
|
|
||||||
|
|
||||||
def _build_bounty_where(
|
|
||||||
self,
|
|
||||||
bounty_type: Optional[str] = None,
|
|
||||||
search: Optional[str] = None
|
|
||||||
) -> tuple[str, list[Any]]:
|
|
||||||
_where_clauses = []
|
|
||||||
_params = []
|
|
||||||
|
|
||||||
if bounty_type:
|
|
||||||
_where_clauses.append("bounty_type = ?")
|
|
||||||
_params.append(bounty_type)
|
|
||||||
|
|
||||||
if search:
|
|
||||||
_where_clauses.append("(decky LIKE ? OR service LIKE ? OR attacker_ip LIKE ? OR payload LIKE ?)")
|
|
||||||
_like_val = f"%{search}%"
|
|
||||||
_params.extend([_like_val, _like_val, _like_val, _like_val])
|
|
||||||
|
|
||||||
if _where_clauses:
|
|
||||||
return " WHERE " + " AND ".join(_where_clauses), _params
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
async def get_bounties(
|
|
||||||
self,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0,
|
|
||||||
bounty_type: Optional[str] = None,
|
|
||||||
search: Optional[str] = None
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
import json
|
|
||||||
_where, _params = self._build_bounty_where(bounty_type, search)
|
|
||||||
_query = f"SELECT * FROM bounty{_where} ORDER BY timestamp DESC LIMIT ? OFFSET ?" # nosec B608
|
|
||||||
_params.extend([limit, offset])
|
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query, _params) as _cursor:
|
|
||||||
_rows: list[aiosqlite.Row] = await _cursor.fetchall()
|
|
||||||
_results = []
|
|
||||||
for _row in _rows:
|
|
||||||
_d = dict(_row)
|
|
||||||
try:
|
|
||||||
_d["payload"] = json.loads(_d["payload"])
|
|
||||||
except Exception: # nosec B110
|
|
||||||
pass
|
|
||||||
_results.append(_d)
|
|
||||||
return _results
|
|
||||||
|
|
||||||
async def get_total_bounties(self, bounty_type: Optional[str] = None, search: Optional[str] = None) -> int:
|
|
||||||
_where, _params = self._build_bounty_where(bounty_type, search)
|
|
||||||
_query = f"SELECT COUNT(*) as total FROM bounty{_where}" # nosec B608
|
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as _db:
|
|
||||||
_db.row_factory = aiosqlite.Row
|
|
||||||
async with _db.execute(_query, _params) as _cursor:
|
|
||||||
_row: Optional[aiosqlite.Row] = await _cursor.fetchone()
|
|
||||||
return _row["total"] if _row else 0
|
|
||||||
Reference in New Issue
Block a user