refactor: implement database backend factory for SQLite and MySQL

- Add `get_repository()` factory function to select DB implementation at runtime via DECNET_DB_TYPE env var
- Extract BaseRepository abstract interface from SQLiteRepository
- Update dependencies to use factory-based repository injection
- Add DECNET_DB_TYPE env var support (defaults to sqlite)
- Refactor models and repository base class for cross-dialect compatibility
This commit is contained in:
2026-04-15 12:50:41 -04:00
parent f6cb90ee66
commit 172a002d41
6 changed files with 177 additions and 23 deletions

View File

@@ -59,6 +59,12 @@ DECNET_DEVELOPER: bool = os.environ.get("DECNET_DEVELOPER", "False").lower() ==
# Database Options
DECNET_DB_TYPE: str = os.environ.get("DECNET_DB_TYPE", "sqlite").lower()
DECNET_DB_URL: Optional[str] = os.environ.get("DECNET_DB_URL")
# MySQL component vars (used only when DECNET_DB_URL is not set)
DECNET_DB_HOST: str = os.environ.get("DECNET_DB_HOST", "localhost")
DECNET_DB_PORT: int = _port("DECNET_DB_PORT", 3306) if os.environ.get("DECNET_DB_PORT") else 3306
DECNET_DB_NAME: str = os.environ.get("DECNET_DB_NAME", "decnet")
DECNET_DB_USER: str = os.environ.get("DECNET_DB_USER", "decnet")
DECNET_DB_PASSWORD: Optional[str] = os.environ.get("DECNET_DB_PASSWORD")
# CORS — comma-separated list of allowed origins for the web dashboard API.
# Defaults to the configured web host/port. Override with DECNET_CORS_ORIGINS if needed.

View File

@@ -14,7 +14,7 @@ from decnet.logging import get_logger
from decnet.web.dependencies import repo
from decnet.collector import log_collector_worker
from decnet.web.ingester import log_ingestion_worker
from decnet.web.attacker_worker import attacker_profile_worker
from decnet.profiler import attacker_profile_worker
from decnet.web.router import api_router
log = get_logger("api")

View File

@@ -1,18 +1,29 @@
"""
Repository factory — selects a :class:`BaseRepository` implementation based on
``DECNET_DB_TYPE`` (``sqlite`` or ``mysql``).
"""
from __future__ import annotations
import os
from typing import Any
from decnet.env import os
from decnet.web.db.repository import BaseRepository
def get_repository(**kwargs: Any) -> BaseRepository:
"""Factory function to instantiate the correct repository implementation based on environment."""
"""Instantiate the repository implementation selected by ``DECNET_DB_TYPE``.
Keyword arguments are forwarded to the concrete implementation:
* SQLite accepts ``db_path``.
* MySQL accepts ``url`` and engine tuning knobs (``pool_size``, …).
"""
db_type = os.environ.get("DECNET_DB_TYPE", "sqlite").lower()
if db_type == "sqlite":
from decnet.web.db.sqlite.repository import SQLiteRepository
return SQLiteRepository(**kwargs)
elif db_type == "mysql":
# Placeholder for future implementation
# from decnet.web.db.mysql.repository import MySQLRepository
# return MySQLRepository()
raise NotImplementedError("MySQL support is planned but not yet implemented.")
else:
if db_type == "mysql":
from decnet.web.db.mysql.repository import MySQLRepository
return MySQLRepository(**kwargs)
raise ValueError(f"Unsupported database type: {db_type}")

View File

@@ -1,6 +1,14 @@
from datetime import datetime, timezone
from typing import Literal, Optional, Any, List, Annotated
from sqlalchemy import Column, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlmodel import SQLModel, Field
# Use on columns that accumulate over an attacker's lifetime (commands,
# fingerprints, state blobs). TEXT on MySQL caps at 64 KiB; MEDIUMTEXT
# stretches to 16 MiB. SQLite has no fixed-width text types so Text()
# stays unchanged there.
_BIG_TEXT = Text().with_variant(MEDIUMTEXT(), "mysql")
from pydantic import BaseModel, ConfigDict, Field as PydanticField, BeforeValidator
from decnet.models import IniContent
@@ -30,9 +38,11 @@ class Log(SQLModel, table=True):
service: str = Field(index=True)
event_type: str = Field(index=True)
attacker_ip: str = Field(index=True)
raw_line: str
fields: str
msg: Optional[str] = None
# Long-text columns — use TEXT so MySQL DDL doesn't truncate to VARCHAR(255).
# TEXT is equivalent to plain text in SQLite.
raw_line: str = Field(sa_column=Column("raw_line", Text, nullable=False))
fields: str = Field(sa_column=Column("fields", Text, nullable=False))
msg: Optional[str] = Field(default=None, sa_column=Column("msg", Text, nullable=True))
class Bounty(SQLModel, table=True):
__tablename__ = "bounty"
@@ -42,13 +52,15 @@ class Bounty(SQLModel, table=True):
service: str = Field(index=True)
attacker_ip: str = Field(index=True)
bounty_type: str = Field(index=True)
payload: str
payload: str = Field(sa_column=Column("payload", Text, nullable=False))
class State(SQLModel, table=True):
__tablename__ = "state"
key: str = Field(primary_key=True)
value: str # Stores JSON serialized DecnetConfig or other state blobs
# JSON-serialized DecnetConfig or other state blobs — can be large as
# deckies/services accumulate. MEDIUMTEXT on MySQL (16 MiB ceiling).
value: str = Field(sa_column=Column("value", _BIG_TEXT, nullable=False))
class Attacker(SQLModel, table=True):
@@ -60,14 +72,63 @@ class Attacker(SQLModel, table=True):
event_count: int = Field(default=0)
service_count: int = Field(default=0)
decky_count: int = Field(default=0)
services: str = Field(default="[]") # JSON list[str]
deckies: str = Field(default="[]") # JSON list[str], first-contact ordered
traversal_path: Optional[str] = None # "decky-01 → decky-03 → decky-05"
# JSON blobs — these grow over the attacker's lifetime. Use MEDIUMTEXT on
# MySQL (16 MiB) for the fields that accumulate (fingerprints, commands,
# and the deckies/services lists that are unbounded in principle).
services: str = Field(
default="[]", sa_column=Column("services", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[str]
deckies: str = Field(
default="[]", sa_column=Column("deckies", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[str], first-contact ordered
traversal_path: Optional[str] = Field(
default=None, sa_column=Column("traversal_path", Text, nullable=True)
) # "decky-01 → decky-03 → decky-05"
is_traversal: bool = Field(default=False)
bounty_count: int = Field(default=0)
credential_count: int = Field(default=0)
fingerprints: str = Field(default="[]") # JSON list[dict] — bounty fingerprints
commands: str = Field(default="[]") # JSON list[dict] — commands per service/decky
fingerprints: str = Field(
default="[]", sa_column=Column("fingerprints", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[dict] — bounty fingerprints
commands: str = Field(
default="[]", sa_column=Column("commands", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[dict] — commands per service/decky
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), index=True
)
class AttackerBehavior(SQLModel, table=True):
"""
Timing & behavioral profile for an attacker, joined to Attacker by uuid.
Kept in a separate table so the core Attacker row stays narrow and
behavior data can be updated independently (e.g. as the sniffer observes
more packets) without touching the event-count aggregates.
"""
__tablename__ = "attacker_behavior"
attacker_uuid: str = Field(primary_key=True, foreign_key="attackers.uuid")
# OS / TCP stack fingerprint (rolled up from sniffer events)
os_guess: Optional[str] = None
hop_distance: Optional[int] = None
tcp_fingerprint: str = Field(
default="{}",
sa_column=Column("tcp_fingerprint", Text, nullable=False, default="{}"),
) # JSON: window, wscale, mss, options_sig
retransmit_count: int = Field(default=0)
# Behavioral (derived by the profiler from log-event timing)
behavior_class: Optional[str] = None # beaconing | interactive | scanning | mixed | unknown
beacon_interval_s: Optional[float] = None
beacon_jitter_pct: Optional[float] = None
tool_guess: Optional[str] = None # cobalt_strike | sliver | havoc | mythic
timing_stats: str = Field(
default="{}",
sa_column=Column("timing_stats", Text, nullable=False, default="{}"),
) # JSON: mean/median/stdev/min/max IAT
phase_sequence: str = Field(
default="{}",
sa_column=Column("phase_sequence", Text, nullable=False, default="{}"),
) # JSON: recon_end/exfil_start/latency
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), index=True
)

View File

@@ -60,6 +60,26 @@ class BaseRepository(ABC):
"""Update a user's password and change the must_change_password flag."""
pass
@abstractmethod
async def list_users(self) -> list[dict[str, Any]]:
"""Retrieve all users (caller must strip password_hash before returning to clients)."""
pass
@abstractmethod
async def delete_user(self, uuid: str) -> bool:
"""Delete a user by UUID. Returns True if user was found and deleted."""
pass
@abstractmethod
async def update_user_role(self, uuid: str, role: str) -> None:
"""Update a user's role."""
pass
@abstractmethod
async def purge_logs_and_bounties(self) -> dict[str, int]:
"""Delete all logs, bounties, and attacker profiles. Returns counts of deleted rows."""
pass
@abstractmethod
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
"""Add a new harvested artifact (bounty) to the database."""
@@ -117,8 +137,23 @@ class BaseRepository(ABC):
pass
@abstractmethod
async def upsert_attacker(self, data: dict[str, Any]) -> None:
"""Insert or replace an attacker profile record."""
async def upsert_attacker(self, data: dict[str, Any]) -> str:
"""Insert or replace an attacker profile record. Returns the row's UUID."""
pass
@abstractmethod
async def upsert_attacker_behavior(self, attacker_uuid: str, data: dict[str, Any]) -> None:
"""Insert or replace the behavioral/fingerprint row for an attacker."""
pass
@abstractmethod
async def get_attacker_behavior(self, attacker_uuid: str) -> Optional[dict[str, Any]]:
"""Retrieve the behavioral/fingerprint row for an attacker UUID."""
pass
@abstractmethod
async def get_behaviors_for_ips(self, ips: set[str]) -> dict[str, dict[str, Any]]:
"""Bulk-fetch behavior rows keyed by attacker IP (JOIN to attackers)."""
pass
@abstractmethod

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional
import jwt
from fastapi import HTTPException, status, Request
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from decnet.web.auth import ALGORITHM, SECRET_KEY
@@ -96,3 +96,44 @@ async def get_current_user_unchecked(request: Request) -> str:
Use only for endpoints that must remain reachable with the flag set (e.g. change-password).
"""
return await _decode_token(request)
# ---------------------------------------------------------------------------
# Role-based access control
# ---------------------------------------------------------------------------
def require_role(*allowed_roles: str):
"""Factory that returns a FastAPI dependency enforcing role membership.
The returned dependency chains from ``get_current_user`` (JWT + must_change_password)
then verifies the user's role is in *allowed_roles*. Returns the full user dict so
endpoints can inspect ``user["uuid"]``, ``user["role"]``, etc. without a second lookup.
"""
async def _check(current_user: str = Depends(get_current_user)) -> dict:
user = await repo.get_user_by_uuid(current_user)
if not user or user["role"] not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions",
)
return user
return _check
def require_stream_role(*allowed_roles: str):
"""Like ``require_role`` but for SSE endpoints that accept a query-param token."""
async def _check(request: Request, token: Optional[str] = None) -> dict:
user_uuid = await get_stream_user(request, token)
user = await repo.get_user_by_uuid(user_uuid)
if not user or user["role"] not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions",
)
return user
return _check
require_admin = require_role("admin")
require_viewer = require_role("viewer", "admin")
require_stream_viewer = require_stream_role("viewer", "admin")