feat: add MySQL backend support for DECNET database
- Implement MySQLRepository extending BaseRepository - Add SQLAlchemy/SQLModel ORM abstraction layer (sqlmodel_repo.py) - Support connection pooling and tuning via DECNET_DB_URL env var - Cross-compatible with SQLite backend via factory pattern - Prepared for production deployment with MySQL SIEM/ELK integration
This commit is contained in:
0
decnet/web/db/mysql/__init__.py
Normal file
0
decnet/web/db/mysql/__init__.py
Normal file
BIN
decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc
Normal file
BIN
decnet/web/db/mysql/__pycache__/__init__.cpython-314.pyc
Normal file
Binary file not shown.
BIN
decnet/web/db/mysql/__pycache__/database.cpython-314.pyc
Normal file
BIN
decnet/web/db/mysql/__pycache__/database.cpython-314.pyc
Normal file
Binary file not shown.
BIN
decnet/web/db/mysql/__pycache__/repository.cpython-314.pyc
Normal file
BIN
decnet/web/db/mysql/__pycache__/repository.cpython-314.pyc
Normal file
Binary file not shown.
98
decnet/web/db/mysql/database.py
Normal file
98
decnet/web/db/mysql/database.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
MySQL async engine factory.
|
||||||
|
|
||||||
|
Builds a SQLAlchemy AsyncEngine against MySQL using the ``aiomysql`` driver.
|
||||||
|
|
||||||
|
Connection info is resolved (in order of precedence):
|
||||||
|
|
||||||
|
1. An explicit ``url`` argument passed to :func:`get_async_engine`
|
||||||
|
2. ``DECNET_DB_URL`` — full SQLAlchemy URL
|
||||||
|
3. Component env vars:
|
||||||
|
``DECNET_DB_HOST`` (default ``localhost``)
|
||||||
|
``DECNET_DB_PORT`` (default ``3306``)
|
||||||
|
``DECNET_DB_NAME`` (default ``decnet``)
|
||||||
|
``DECNET_DB_USER`` (default ``decnet``)
|
||||||
|
``DECNET_DB_PASSWORD`` (default empty — raises unless pytest is running)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_POOL_SIZE = 10
|
||||||
|
DEFAULT_MAX_OVERFLOW = 20
|
||||||
|
DEFAULT_POOL_RECYCLE = 3600 # seconds — avoid MySQL ``wait_timeout`` disconnects
|
||||||
|
DEFAULT_POOL_PRE_PING = True
|
||||||
|
|
||||||
|
|
||||||
|
def build_mysql_url(
|
||||||
|
host: Optional[str] = None,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
database: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Compose an async SQLAlchemy URL for MySQL using the aiomysql driver.
|
||||||
|
|
||||||
|
Component args override env vars. Password is percent-encoded so special
|
||||||
|
characters (``@``, ``:``, ``/``…) don't break URL parsing.
|
||||||
|
"""
|
||||||
|
host = host or os.environ.get("DECNET_DB_HOST", "localhost")
|
||||||
|
port = port or int(os.environ.get("DECNET_DB_PORT", "3306"))
|
||||||
|
database = database or os.environ.get("DECNET_DB_NAME", "decnet")
|
||||||
|
user = user or os.environ.get("DECNET_DB_USER", "decnet")
|
||||||
|
|
||||||
|
if password is None:
|
||||||
|
password = os.environ.get("DECNET_DB_PASSWORD", "")
|
||||||
|
|
||||||
|
# Allow empty passwords during tests (pytest sets PYTEST_* env vars).
|
||||||
|
# Outside tests, an empty MySQL password is almost never intentional.
|
||||||
|
if not password and not any(k.startswith("PYTEST") for k in os.environ):
|
||||||
|
raise ValueError(
|
||||||
|
"DECNET_DB_PASSWORD is not set. Either export it, set DECNET_DB_URL, "
|
||||||
|
"or run under pytest for an empty-password default."
|
||||||
|
)
|
||||||
|
|
||||||
|
pw_enc = quote_plus(password)
|
||||||
|
user_enc = quote_plus(user)
|
||||||
|
return f"mysql+aiomysql://{user_enc}:{pw_enc}@{host}:{port}/{database}"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_url(url: Optional[str] = None) -> str:
|
||||||
|
"""Pick a connection URL: explicit arg → DECNET_DB_URL env → built from components."""
|
||||||
|
if url:
|
||||||
|
return url
|
||||||
|
env_url = os.environ.get("DECNET_DB_URL")
|
||||||
|
if env_url:
|
||||||
|
return env_url
|
||||||
|
return build_mysql_url()
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_engine(
|
||||||
|
url: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
pool_size: int = DEFAULT_POOL_SIZE,
|
||||||
|
max_overflow: int = DEFAULT_MAX_OVERFLOW,
|
||||||
|
pool_recycle: int = DEFAULT_POOL_RECYCLE,
|
||||||
|
pool_pre_ping: bool = DEFAULT_POOL_PRE_PING,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> AsyncEngine:
|
||||||
|
"""Create an AsyncEngine for MySQL.
|
||||||
|
|
||||||
|
Defaults tuned for a dashboard workload: a modest pool, hourly recycle
|
||||||
|
to sidestep MySQL's idle-connection reaper, and pre-ping to fail fast
|
||||||
|
if a pooled connection has been killed server-side.
|
||||||
|
"""
|
||||||
|
dsn = resolve_url(url)
|
||||||
|
return create_async_engine(
|
||||||
|
dsn,
|
||||||
|
echo=echo,
|
||||||
|
pool_size=pool_size,
|
||||||
|
max_overflow=max_overflow,
|
||||||
|
pool_recycle=pool_recycle,
|
||||||
|
pool_pre_ping=pool_pre_ping,
|
||||||
|
)
|
||||||
87
decnet/web/db/mysql/repository.py
Normal file
87
decnet/web/db/mysql/repository.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
MySQL implementation of :class:`BaseRepository`.
|
||||||
|
|
||||||
|
Inherits the portable SQLModel query code from :class:`SQLModelRepository`
|
||||||
|
and only overrides the two places where MySQL's SQL dialect differs from
|
||||||
|
SQLite's:
|
||||||
|
|
||||||
|
* :meth:`_migrate_attackers_table` — uses ``information_schema`` (MySQL
|
||||||
|
has no ``PRAGMA``).
|
||||||
|
* :meth:`get_log_histogram` — uses ``FROM_UNIXTIME`` /
|
||||||
|
``UNIX_TIMESTAMP`` + integer division for bucketing.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import func, select, text, literal_column
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
|
|
||||||
|
from decnet.web.db.models import Log
|
||||||
|
from decnet.web.db.mysql.database import get_async_engine
|
||||||
|
from decnet.web.db.sqlmodel_repo import SQLModelRepository
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLRepository(SQLModelRepository):
|
||||||
|
"""MySQL backend — uses ``aiomysql``."""
|
||||||
|
|
||||||
|
def __init__(self, url: Optional[str] = None, **engine_kwargs) -> None:
|
||||||
|
self.engine = get_async_engine(url=url, **engine_kwargs)
|
||||||
|
self.session_factory = async_sessionmaker(
|
||||||
|
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _migrate_attackers_table(self) -> None:
|
||||||
|
"""Drop the legacy (pre-UUID) ``attackers`` table if it exists without a ``uuid`` column.
|
||||||
|
|
||||||
|
MySQL exposes column metadata via ``information_schema.COLUMNS``.
|
||||||
|
``DATABASE()`` scopes the lookup to the currently connected schema.
|
||||||
|
"""
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
rows = (await conn.execute(text(
|
||||||
|
"SELECT COLUMN_NAME FROM information_schema.COLUMNS "
|
||||||
|
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'attackers'"
|
||||||
|
))).fetchall()
|
||||||
|
if rows and not any(r[0] == "uuid" for r in rows):
|
||||||
|
await conn.execute(text("DROP TABLE attackers"))
|
||||||
|
|
||||||
|
def _json_field_equals(self, key: str):
|
||||||
|
# MySQL 5.7+ exposes JSON_EXTRACT; quoted string result returned for
|
||||||
|
# TEXT-stored JSON, same behavior we rely on in SQLite.
|
||||||
|
return text(f"JSON_UNQUOTE(JSON_EXTRACT(fields, '$.{key}')) = :val")
|
||||||
|
|
||||||
|
async def get_log_histogram(
|
||||||
|
self,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None,
|
||||||
|
interval_minutes: int = 15,
|
||||||
|
) -> List[dict]:
|
||||||
|
bucket_seconds = max(interval_minutes, 1) * 60
|
||||||
|
# Truncate each timestamp to the start of its bucket:
|
||||||
|
# FROM_UNIXTIME( (UNIX_TIMESTAMP(timestamp) DIV N) * N )
|
||||||
|
# DIV is MySQL's integer division operator.
|
||||||
|
bucket_expr = literal_column(
|
||||||
|
f"FROM_UNIXTIME((UNIX_TIMESTAMP(timestamp) DIV {bucket_seconds}) * {bucket_seconds})"
|
||||||
|
).label("bucket_time")
|
||||||
|
|
||||||
|
statement: SelectOfScalar = select(bucket_expr, func.count().label("count")).select_from(Log)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
statement = statement.group_by(literal_column("bucket_time")).order_by(
|
||||||
|
literal_column("bucket_time")
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
# Normalize to ISO string for API parity with the SQLite backend
|
||||||
|
# (SQLite's datetime() returns a string already; FROM_UNIXTIME
|
||||||
|
# returns a datetime).
|
||||||
|
out: List[dict] = []
|
||||||
|
for r in results.all():
|
||||||
|
ts = r[0]
|
||||||
|
out.append({
|
||||||
|
"time": ts.isoformat(sep=" ") if hasattr(ts, "isoformat") else ts,
|
||||||
|
"count": r[1],
|
||||||
|
})
|
||||||
|
return out
|
||||||
637
decnet/web/db/sqlmodel_repo.py
Normal file
637
decnet/web/db/sqlmodel_repo.py
Normal file
@@ -0,0 +1,637 @@
|
|||||||
|
"""
|
||||||
|
Shared SQLModel-based repository implementation.
|
||||||
|
|
||||||
|
Contains all dialect-portable query code used by the SQLite and MySQL
|
||||||
|
backends. Dialect-specific behavior lives in subclasses:
|
||||||
|
|
||||||
|
* engine/session construction (``__init__``)
|
||||||
|
* ``_migrate_attackers_table`` (legacy schema check; DDL introspection
|
||||||
|
is not portable)
|
||||||
|
* ``get_log_histogram`` (date-bucket expression differs per dialect)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import func, select, desc, asc, text, or_, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
from sqlmodel.sql.expression import SelectOfScalar
|
||||||
|
|
||||||
|
from decnet.config import load_state
|
||||||
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||||
|
from decnet.web.auth import get_password_hash
|
||||||
|
from decnet.web.db.repository import BaseRepository
|
||||||
|
from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior
|
||||||
|
|
||||||
|
|
||||||
|
class SQLModelRepository(BaseRepository):
|
||||||
|
"""Concrete SQLModel/SQLAlchemy-async repository.
|
||||||
|
|
||||||
|
Subclasses provide ``self.engine`` (AsyncEngine) and ``self.session_factory``
|
||||||
|
in ``__init__``, and override the few dialect-specific helpers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
engine: AsyncEngine
|
||||||
|
session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------ lifecycle
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Create tables if absent and seed the admin user."""
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
await self._migrate_attackers_table()
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
await self._ensure_admin_user()
|
||||||
|
|
||||||
|
async def reinitialize(self) -> None:
|
||||||
|
"""Re-create schema (for tests / reset flows). Does NOT drop existing tables."""
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
await self._ensure_admin_user()
|
||||||
|
|
||||||
|
async def _ensure_admin_user(self) -> None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.username == DECNET_ADMIN_USER)
|
||||||
|
)
|
||||||
|
if not result.scalar_one_or_none():
|
||||||
|
session.add(User(
|
||||||
|
uuid=str(uuid.uuid4()),
|
||||||
|
username=DECNET_ADMIN_USER,
|
||||||
|
password_hash=get_password_hash(DECNET_ADMIN_PASSWORD),
|
||||||
|
role="admin",
|
||||||
|
must_change_password=True,
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def _migrate_attackers_table(self) -> None:
|
||||||
|
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- logs
|
||||||
|
|
||||||
|
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||||||
|
data = log_data.copy()
|
||||||
|
if "fields" in data and isinstance(data["fields"], dict):
|
||||||
|
data["fields"] = json.dumps(data["fields"])
|
||||||
|
if "timestamp" in data and isinstance(data["timestamp"], str):
|
||||||
|
try:
|
||||||
|
data["timestamp"] = datetime.fromisoformat(
|
||||||
|
data["timestamp"].replace("Z", "+00:00")
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(Log(**data))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
self,
|
||||||
|
statement: SelectOfScalar,
|
||||||
|
search: Optional[str],
|
||||||
|
start_time: Optional[str],
|
||||||
|
end_time: Optional[str],
|
||||||
|
) -> SelectOfScalar:
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
if start_time:
|
||||||
|
statement = statement.where(Log.timestamp >= start_time)
|
||||||
|
if end_time:
|
||||||
|
statement = statement.where(Log.timestamp <= end_time)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
try:
|
||||||
|
tokens = shlex.split(search)
|
||||||
|
except ValueError:
|
||||||
|
tokens = search.split()
|
||||||
|
|
||||||
|
core_fields = {
|
||||||
|
"decky": Log.decky,
|
||||||
|
"service": Log.service,
|
||||||
|
"event": Log.event_type,
|
||||||
|
"attacker": Log.attacker_ip,
|
||||||
|
"attacker-ip": Log.attacker_ip,
|
||||||
|
"attacker_ip": Log.attacker_ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if ":" in token:
|
||||||
|
key, val = token.split(":", 1)
|
||||||
|
if key in core_fields:
|
||||||
|
statement = statement.where(core_fields[key] == val)
|
||||||
|
else:
|
||||||
|
key_safe = re.sub(r"[^a-zA-Z0-9_]", "", key)
|
||||||
|
if key_safe:
|
||||||
|
statement = statement.where(
|
||||||
|
self._json_field_equals(key_safe)
|
||||||
|
).params(val=val)
|
||||||
|
else:
|
||||||
|
lk = f"%{token}%"
|
||||||
|
statement = statement.where(
|
||||||
|
or_(
|
||||||
|
Log.raw_line.like(lk),
|
||||||
|
Log.decky.like(lk),
|
||||||
|
Log.service.like(lk),
|
||||||
|
Log.attacker_ip.like(lk),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return statement
|
||||||
|
|
||||||
|
def _json_field_equals(self, key: str):
|
||||||
|
"""Return a text() predicate that matches rows where fields->key == :val.
|
||||||
|
|
||||||
|
Both SQLite and MySQL expose a ``JSON_EXTRACT`` function; MySQL also
|
||||||
|
exposes the same function under ``json_extract`` (case-insensitive).
|
||||||
|
The ``:val`` parameter is bound separately and must be supplied with
|
||||||
|
``.params(val=...)`` by the caller, which keeps us safe from injection.
|
||||||
|
"""
|
||||||
|
return text(f"JSON_EXTRACT(fields, '$.{key}') = :val")
|
||||||
|
|
||||||
|
async def get_logs(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
statement = (
|
||||||
|
select(Log)
|
||||||
|
.order_by(desc(Log.timestamp))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||||||
|
|
||||||
|
async def get_max_log_id(self) -> int:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(select(func.max(Log.id)))
|
||||||
|
val = result.scalar()
|
||||||
|
return val if val is not None else 0
|
||||||
|
|
||||||
|
async def get_logs_after_id(
|
||||||
|
self,
|
||||||
|
last_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
statement = (
|
||||||
|
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
||||||
|
)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||||||
|
|
||||||
|
async def get_total_logs(
|
||||||
|
self,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
statement = select(func.count()).select_from(Log)
|
||||||
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_log_histogram(
|
||||||
|
self,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
start_time: Optional[str] = None,
|
||||||
|
end_time: Optional[str] = None,
|
||||||
|
interval_minutes: int = 15,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""Dialect-specific — override per backend."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_stats_summary(self) -> dict[str, Any]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
total_logs = (
|
||||||
|
await session.execute(select(func.count()).select_from(Log))
|
||||||
|
).scalar() or 0
|
||||||
|
unique_attackers = (
|
||||||
|
await session.execute(
|
||||||
|
select(func.count(func.distinct(Log.attacker_ip)))
|
||||||
|
)
|
||||||
|
).scalar() or 0
|
||||||
|
|
||||||
|
_state = await asyncio.to_thread(load_state)
|
||||||
|
deployed_deckies = len(_state[0].deckies) if _state else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_logs": total_logs,
|
||||||
|
"unique_attackers": unique_attackers,
|
||||||
|
"active_deckies": deployed_deckies,
|
||||||
|
"deployed_deckies": deployed_deckies,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_deckies(self) -> List[dict]:
|
||||||
|
_state = await asyncio.to_thread(load_state)
|
||||||
|
return [_d.model_dump() for _d in _state[0].deckies] if _state else []
|
||||||
|
|
||||||
|
# --------------------------------------------------------------- users
|
||||||
|
|
||||||
|
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.username == username)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
return user.model_dump() if user else None
|
||||||
|
|
||||||
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).where(User.uuid == uuid)
|
||||||
|
)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
return user.model_dump() if user else None
|
||||||
|
|
||||||
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(User(**user_data))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update_user_password(
|
||||||
|
self, uuid: str, password_hash: str, must_change_password: bool = False
|
||||||
|
) -> None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
await session.execute(
|
||||||
|
update(User)
|
||||||
|
.where(User.uuid == uuid)
|
||||||
|
.values(
|
||||||
|
password_hash=password_hash,
|
||||||
|
must_change_password=must_change_password,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def list_users(self) -> list[dict]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(select(User))
|
||||||
|
return [u.model_dump() for u in result.scalars().all()]
|
||||||
|
|
||||||
|
async def delete_user(self, uuid: str) -> bool:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(select(User).where(User.uuid == uuid))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
await session.delete(user)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def update_user_role(self, uuid: str, role: str) -> None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
await session.execute(
|
||||||
|
update(User).where(User.uuid == uuid).values(role=role)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
|
||||||
|
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
|
||||||
|
# attacker_behavior has FK → attackers.uuid; delete children first.
|
||||||
|
await session.execute(text("DELETE FROM attacker_behavior"))
|
||||||
|
attackers_deleted = (await session.execute(text("DELETE FROM attackers"))).rowcount
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"logs": logs_deleted,
|
||||||
|
"bounties": bounties_deleted,
|
||||||
|
"attackers": attackers_deleted,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------ bounties
|
||||||
|
|
||||||
|
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||||||
|
data = bounty_data.copy()
|
||||||
|
if "payload" in data and isinstance(data["payload"], dict):
|
||||||
|
data["payload"] = json.dumps(data["payload"])
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(Bounty(**data))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
def _apply_bounty_filters(
|
||||||
|
self,
|
||||||
|
statement: SelectOfScalar,
|
||||||
|
bounty_type: Optional[str],
|
||||||
|
search: Optional[str],
|
||||||
|
) -> SelectOfScalar:
|
||||||
|
if bounty_type:
|
||||||
|
statement = statement.where(Bounty.bounty_type == bounty_type)
|
||||||
|
if search:
|
||||||
|
lk = f"%{search}%"
|
||||||
|
statement = statement.where(
|
||||||
|
or_(
|
||||||
|
Bounty.decky.like(lk),
|
||||||
|
Bounty.service.like(lk),
|
||||||
|
Bounty.attacker_ip.like(lk),
|
||||||
|
Bounty.payload.like(lk),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return statement
|
||||||
|
|
||||||
|
async def get_bounties(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
bounty_type: Optional[str] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
statement = (
|
||||||
|
select(Bounty)
|
||||||
|
.order_by(desc(Bounty.timestamp))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
results = await session.execute(statement)
|
||||||
|
final = []
|
||||||
|
for item in results.scalars().all():
|
||||||
|
d = item.model_dump(mode="json")
|
||||||
|
try:
|
||||||
|
d["payload"] = json.loads(d["payload"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
final.append(d)
|
||||||
|
return final
|
||||||
|
|
||||||
|
async def get_total_bounties(
|
||||||
|
self, bounty_type: Optional[str] = None, search: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
statement = select(func.count()).select_from(Bounty)
|
||||||
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
statement = select(State).where(State.key == key)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
state = result.scalar_one_or_none()
|
||||||
|
if state:
|
||||||
|
return json.loads(state.value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
statement = select(State).where(State.key == key)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
state = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
value_json = json.dumps(value)
|
||||||
|
if state:
|
||||||
|
state.value = value_json
|
||||||
|
session.add(state)
|
||||||
|
else:
|
||||||
|
session.add(State(key=key, value=value_json))
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# ----------------------------------------------------------- attackers
|
||||||
|
|
||||||
|
async def get_all_logs_raw(self) -> List[dict[str, Any]]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(
|
||||||
|
Log.id,
|
||||||
|
Log.raw_line,
|
||||||
|
Log.attacker_ip,
|
||||||
|
Log.service,
|
||||||
|
Log.event_type,
|
||||||
|
Log.decky,
|
||||||
|
Log.timestamp,
|
||||||
|
Log.fields,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": r.id,
|
||||||
|
"raw_line": r.raw_line,
|
||||||
|
"attacker_ip": r.attacker_ip,
|
||||||
|
"service": r.service,
|
||||||
|
"event_type": r.event_type,
|
||||||
|
"decky": r.decky,
|
||||||
|
"timestamp": r.timestamp,
|
||||||
|
"fields": r.fields,
|
||||||
|
}
|
||||||
|
for r in result.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
||||||
|
from collections import defaultdict
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Bounty).order_by(asc(Bounty.timestamp))
|
||||||
|
)
|
||||||
|
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
||||||
|
for item in result.scalars().all():
|
||||||
|
d = item.model_dump(mode="json")
|
||||||
|
try:
|
||||||
|
d["payload"] = json.loads(d["payload"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
grouped[item.attacker_ip].append(d)
|
||||||
|
return dict(grouped)
|
||||||
|
|
||||||
|
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
||||||
|
from collections import defaultdict
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
||||||
|
)
|
||||||
|
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
||||||
|
for item in result.scalars().all():
|
||||||
|
d = item.model_dump(mode="json")
|
||||||
|
try:
|
||||||
|
d["payload"] = json.loads(d["payload"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
grouped[item.attacker_ip].append(d)
|
||||||
|
return dict(grouped)
|
||||||
|
|
||||||
|
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Attacker).where(Attacker.ip == data["ip"])
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing:
|
||||||
|
for k, v in data.items():
|
||||||
|
setattr(existing, k, v)
|
||||||
|
session.add(existing)
|
||||||
|
row_uuid = existing.uuid
|
||||||
|
else:
|
||||||
|
row_uuid = str(uuid.uuid4())
|
||||||
|
data = {**data, "uuid": row_uuid}
|
||||||
|
session.add(Attacker(**data))
|
||||||
|
await session.commit()
|
||||||
|
return row_uuid
|
||||||
|
|
||||||
|
async def upsert_attacker_behavior(
|
||||||
|
self,
|
||||||
|
attacker_uuid: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(AttackerBehavior).where(
|
||||||
|
AttackerBehavior.attacker_uuid == attacker_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
payload = {**data, "updated_at": datetime.now(timezone.utc)}
|
||||||
|
if existing:
|
||||||
|
for k, v in payload.items():
|
||||||
|
setattr(existing, k, v)
|
||||||
|
session.add(existing)
|
||||||
|
else:
|
||||||
|
session.add(AttackerBehavior(attacker_uuid=attacker_uuid, **payload))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_attacker_behavior(
|
||||||
|
self,
|
||||||
|
attacker_uuid: str,
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(AttackerBehavior).where(
|
||||||
|
AttackerBehavior.attacker_uuid == attacker_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return self._deserialize_behavior(row.model_dump(mode="json"))
|
||||||
|
|
||||||
|
async def get_behaviors_for_ips(
|
||||||
|
self,
|
||||||
|
ips: set[str],
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
if not ips:
|
||||||
|
return {}
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Attacker.ip, AttackerBehavior)
|
||||||
|
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
||||||
|
.where(Attacker.ip.in_(ips))
|
||||||
|
)
|
||||||
|
out: dict[str, dict[str, Any]] = {}
|
||||||
|
for ip, row in result.all():
|
||||||
|
out[ip] = self._deserialize_behavior(row.model_dump(mode="json"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deserialize_behavior(d: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
for key in ("tcp_fingerprint", "timing_stats", "phase_sequence"):
|
||||||
|
if isinstance(d.get(key), str):
|
||||||
|
try:
|
||||||
|
d[key] = json.loads(d[key])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
return d
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
for key in ("services", "deckies", "fingerprints", "commands"):
|
||||||
|
if isinstance(d.get(key), str):
|
||||||
|
try:
|
||||||
|
d[key] = json.loads(d[key])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
return d
|
||||||
|
|
||||||
|
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Attacker).where(Attacker.uuid == uuid)
|
||||||
|
)
|
||||||
|
attacker = result.scalar_one_or_none()
|
||||||
|
if not attacker:
|
||||||
|
return None
|
||||||
|
return self._deserialize_attacker(attacker.model_dump(mode="json"))
|
||||||
|
|
||||||
|
async def get_attackers(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
sort_by: str = "recent",
|
||||||
|
service: Optional[str] = None,
|
||||||
|
) -> List[dict[str, Any]]:
|
||||||
|
order = {
|
||||||
|
"active": desc(Attacker.event_count),
|
||||||
|
"traversals": desc(Attacker.is_traversal),
|
||||||
|
}.get(sort_by, desc(Attacker.last_seen))
|
||||||
|
|
||||||
|
statement = select(Attacker).order_by(order).offset(offset).limit(limit)
|
||||||
|
if search:
|
||||||
|
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
||||||
|
if service:
|
||||||
|
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return [
|
||||||
|
self._deserialize_attacker(a.model_dump(mode="json"))
|
||||||
|
for a in result.scalars().all()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_total_attackers(
|
||||||
|
self, search: Optional[str] = None, service: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
statement = select(func.count()).select_from(Attacker)
|
||||||
|
if search:
|
||||||
|
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
||||||
|
if service:
|
||||||
|
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||||
|
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_attacker_commands(
|
||||||
|
self,
|
||||||
|
uuid: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
service: Optional[str] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Attacker.commands).where(Attacker.uuid == uuid)
|
||||||
|
)
|
||||||
|
raw = result.scalar_one_or_none()
|
||||||
|
if raw is None:
|
||||||
|
return {"total": 0, "data": []}
|
||||||
|
|
||||||
|
commands: list = json.loads(raw) if isinstance(raw, str) else raw
|
||||||
|
if service:
|
||||||
|
commands = [c for c in commands if c.get("service") == service]
|
||||||
|
|
||||||
|
total = len(commands)
|
||||||
|
page = commands[offset: offset + limit]
|
||||||
|
return {"total": total, "data": page}
|
||||||
Reference in New Issue
Block a user