refactor: implement database backend factory for SQLite and MySQL

- Add `get_repository()` factory function to select DB implementation at runtime via DECNET_DB_TYPE env var
- Extract BaseRepository abstract interface from SQLiteRepository
- Update dependencies to use factory-based repository injection
- Add DECNET_DB_TYPE env var support (defaults to sqlite)
- Refactor models and repository base class for cross-dialect compatibility
This commit is contained in:
2026-04-15 12:50:41 -04:00
parent f6cb90ee66
commit 172a002d41
6 changed files with 177 additions and 23 deletions

View File

@@ -59,6 +59,12 @@ DECNET_DEVELOPER: bool = os.environ.get("DECNET_DEVELOPER", "False").lower() ==
# Database Options # Database Options
DECNET_DB_TYPE: str = os.environ.get("DECNET_DB_TYPE", "sqlite").lower() DECNET_DB_TYPE: str = os.environ.get("DECNET_DB_TYPE", "sqlite").lower()
DECNET_DB_URL: Optional[str] = os.environ.get("DECNET_DB_URL") DECNET_DB_URL: Optional[str] = os.environ.get("DECNET_DB_URL")
# MySQL component vars (used only when DECNET_DB_URL is not set)
DECNET_DB_HOST: str = os.environ.get("DECNET_DB_HOST", "localhost")
DECNET_DB_PORT: int = _port("DECNET_DB_PORT", 3306) if os.environ.get("DECNET_DB_PORT") else 3306
DECNET_DB_NAME: str = os.environ.get("DECNET_DB_NAME", "decnet")
DECNET_DB_USER: str = os.environ.get("DECNET_DB_USER", "decnet")
DECNET_DB_PASSWORD: Optional[str] = os.environ.get("DECNET_DB_PASSWORD")
# CORS — comma-separated list of allowed origins for the web dashboard API. # CORS — comma-separated list of allowed origins for the web dashboard API.
# Defaults to the configured web host/port. Override with DECNET_CORS_ORIGINS if needed. # Defaults to the configured web host/port. Override with DECNET_CORS_ORIGINS if needed.

View File

@@ -14,7 +14,7 @@ from decnet.logging import get_logger
from decnet.web.dependencies import repo from decnet.web.dependencies import repo
from decnet.collector import log_collector_worker from decnet.collector import log_collector_worker
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
from decnet.web.attacker_worker import attacker_profile_worker from decnet.profiler import attacker_profile_worker
from decnet.web.router import api_router from decnet.web.router import api_router
log = get_logger("api") log = get_logger("api")

View File

@@ -1,18 +1,29 @@
"""
Repository factory — selects a :class:`BaseRepository` implementation based on
``DECNET_DB_TYPE`` (``sqlite`` or ``mysql``).
"""
from __future__ import annotations
import os
from typing import Any from typing import Any
from decnet.env import os
from decnet.web.db.repository import BaseRepository from decnet.web.db.repository import BaseRepository
def get_repository(**kwargs: Any) -> BaseRepository: def get_repository(**kwargs: Any) -> BaseRepository:
"""Factory function to instantiate the correct repository implementation based on environment.""" """Instantiate the repository implementation selected by ``DECNET_DB_TYPE``.
Keyword arguments are forwarded to the concrete implementation:
* SQLite accepts ``db_path``.
* MySQL accepts ``url`` and engine tuning knobs (``pool_size``, …).
"""
db_type = os.environ.get("DECNET_DB_TYPE", "sqlite").lower() db_type = os.environ.get("DECNET_DB_TYPE", "sqlite").lower()
if db_type == "sqlite": if db_type == "sqlite":
from decnet.web.db.sqlite.repository import SQLiteRepository from decnet.web.db.sqlite.repository import SQLiteRepository
return SQLiteRepository(**kwargs) return SQLiteRepository(**kwargs)
elif db_type == "mysql": if db_type == "mysql":
# Placeholder for future implementation from decnet.web.db.mysql.repository import MySQLRepository
# from decnet.web.db.mysql.repository import MySQLRepository return MySQLRepository(**kwargs)
# return MySQLRepository()
raise NotImplementedError("MySQL support is planned but not yet implemented.")
else:
raise ValueError(f"Unsupported database type: {db_type}") raise ValueError(f"Unsupported database type: {db_type}")

View File

@@ -1,6 +1,14 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal, Optional, Any, List, Annotated from typing import Literal, Optional, Any, List, Annotated
from sqlalchemy import Column, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
# Use on columns that accumulate over an attacker's lifetime (commands,
# fingerprints, state blobs). TEXT on MySQL caps at 64 KiB; MEDIUMTEXT
# stretches to 16 MiB. SQLite has no fixed-width text types so Text()
# stays unchanged there.
_BIG_TEXT = Text().with_variant(MEDIUMTEXT(), "mysql")
from pydantic import BaseModel, ConfigDict, Field as PydanticField, BeforeValidator from pydantic import BaseModel, ConfigDict, Field as PydanticField, BeforeValidator
from decnet.models import IniContent from decnet.models import IniContent
@@ -30,9 +38,11 @@ class Log(SQLModel, table=True):
service: str = Field(index=True) service: str = Field(index=True)
event_type: str = Field(index=True) event_type: str = Field(index=True)
attacker_ip: str = Field(index=True) attacker_ip: str = Field(index=True)
raw_line: str # Long-text columns — use TEXT so MySQL DDL doesn't truncate to VARCHAR(255).
fields: str # TEXT is equivalent to plain text in SQLite.
msg: Optional[str] = None raw_line: str = Field(sa_column=Column("raw_line", Text, nullable=False))
fields: str = Field(sa_column=Column("fields", Text, nullable=False))
msg: Optional[str] = Field(default=None, sa_column=Column("msg", Text, nullable=True))
class Bounty(SQLModel, table=True): class Bounty(SQLModel, table=True):
__tablename__ = "bounty" __tablename__ = "bounty"
@@ -42,13 +52,15 @@ class Bounty(SQLModel, table=True):
service: str = Field(index=True) service: str = Field(index=True)
attacker_ip: str = Field(index=True) attacker_ip: str = Field(index=True)
bounty_type: str = Field(index=True) bounty_type: str = Field(index=True)
payload: str payload: str = Field(sa_column=Column("payload", Text, nullable=False))
class State(SQLModel, table=True): class State(SQLModel, table=True):
__tablename__ = "state" __tablename__ = "state"
key: str = Field(primary_key=True) key: str = Field(primary_key=True)
value: str # Stores JSON serialized DecnetConfig or other state blobs # JSON-serialized DecnetConfig or other state blobs — can be large as
# deckies/services accumulate. MEDIUMTEXT on MySQL (16 MiB ceiling).
value: str = Field(sa_column=Column("value", _BIG_TEXT, nullable=False))
class Attacker(SQLModel, table=True): class Attacker(SQLModel, table=True):
@@ -60,14 +72,63 @@ class Attacker(SQLModel, table=True):
event_count: int = Field(default=0) event_count: int = Field(default=0)
service_count: int = Field(default=0) service_count: int = Field(default=0)
decky_count: int = Field(default=0) decky_count: int = Field(default=0)
services: str = Field(default="[]") # JSON list[str] # JSON blobs — these grow over the attacker's lifetime. Use MEDIUMTEXT on
deckies: str = Field(default="[]") # JSON list[str], first-contact ordered # MySQL (16 MiB) for the fields that accumulate (fingerprints, commands,
traversal_path: Optional[str] = None # "decky-01 → decky-03 → decky-05" # and the deckies/services lists that are unbounded in principle).
services: str = Field(
default="[]", sa_column=Column("services", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[str]
deckies: str = Field(
default="[]", sa_column=Column("deckies", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[str], first-contact ordered
traversal_path: Optional[str] = Field(
default=None, sa_column=Column("traversal_path", Text, nullable=True)
) # "decky-01 → decky-03 → decky-05"
is_traversal: bool = Field(default=False) is_traversal: bool = Field(default=False)
bounty_count: int = Field(default=0) bounty_count: int = Field(default=0)
credential_count: int = Field(default=0) credential_count: int = Field(default=0)
fingerprints: str = Field(default="[]") # JSON list[dict] — bounty fingerprints fingerprints: str = Field(
commands: str = Field(default="[]") # JSON list[dict] — commands per service/decky default="[]", sa_column=Column("fingerprints", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[dict] — bounty fingerprints
commands: str = Field(
default="[]", sa_column=Column("commands", _BIG_TEXT, nullable=False, default="[]")
) # JSON list[dict] — commands per service/decky
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), index=True
)
class AttackerBehavior(SQLModel, table=True):
"""
Timing & behavioral profile for an attacker, joined to Attacker by uuid.
Kept in a separate table so the core Attacker row stays narrow and
behavior data can be updated independently (e.g. as the sniffer observes
more packets) without touching the event-count aggregates.
"""
__tablename__ = "attacker_behavior"
attacker_uuid: str = Field(primary_key=True, foreign_key="attackers.uuid")
# OS / TCP stack fingerprint (rolled up from sniffer events)
os_guess: Optional[str] = None
hop_distance: Optional[int] = None
tcp_fingerprint: str = Field(
default="{}",
sa_column=Column("tcp_fingerprint", Text, nullable=False, default="{}"),
) # JSON: window, wscale, mss, options_sig
retransmit_count: int = Field(default=0)
# Behavioral (derived by the profiler from log-event timing)
behavior_class: Optional[str] = None # beaconing | interactive | scanning | mixed | unknown
beacon_interval_s: Optional[float] = None
beacon_jitter_pct: Optional[float] = None
tool_guess: Optional[str] = None # cobalt_strike | sliver | havoc | mythic
timing_stats: str = Field(
default="{}",
sa_column=Column("timing_stats", Text, nullable=False, default="{}"),
) # JSON: mean/median/stdev/min/max IAT
phase_sequence: str = Field(
default="{}",
sa_column=Column("phase_sequence", Text, nullable=False, default="{}"),
) # JSON: recon_end/exfil_start/latency
updated_at: datetime = Field( updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), index=True default_factory=lambda: datetime.now(timezone.utc), index=True
) )

View File

@@ -60,6 +60,26 @@ class BaseRepository(ABC):
"""Update a user's password and change the must_change_password flag.""" """Update a user's password and change the must_change_password flag."""
pass pass
@abstractmethod
async def list_users(self) -> list[dict[str, Any]]:
"""Retrieve all users (caller must strip password_hash before returning to clients)."""
pass
@abstractmethod
async def delete_user(self, uuid: str) -> bool:
"""Delete a user by UUID. Returns True if user was found and deleted."""
pass
@abstractmethod
async def update_user_role(self, uuid: str, role: str) -> None:
"""Update a user's role."""
pass
@abstractmethod
async def purge_logs_and_bounties(self) -> dict[str, int]:
"""Delete all logs, bounties, and attacker profiles. Returns counts of deleted rows."""
pass
@abstractmethod @abstractmethod
async def add_bounty(self, bounty_data: dict[str, Any]) -> None: async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
"""Add a new harvested artifact (bounty) to the database.""" """Add a new harvested artifact (bounty) to the database."""
@@ -117,8 +137,23 @@ class BaseRepository(ABC):
pass pass
@abstractmethod @abstractmethod
async def upsert_attacker(self, data: dict[str, Any]) -> None: async def upsert_attacker(self, data: dict[str, Any]) -> str:
"""Insert or replace an attacker profile record.""" """Insert or replace an attacker profile record. Returns the row's UUID."""
pass
@abstractmethod
async def upsert_attacker_behavior(self, attacker_uuid: str, data: dict[str, Any]) -> None:
"""Insert or replace the behavioral/fingerprint row for an attacker."""
pass
@abstractmethod
async def get_attacker_behavior(self, attacker_uuid: str) -> Optional[dict[str, Any]]:
"""Retrieve the behavioral/fingerprint row for an attacker UUID."""
pass
@abstractmethod
async def get_behaviors_for_ips(self, ips: set[str]) -> dict[str, dict[str, Any]]:
"""Bulk-fetch behavior rows keyed by attacker IP (JOIN to attackers)."""
pass pass
@abstractmethod @abstractmethod

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
import jwt import jwt
from fastapi import HTTPException, status, Request from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from decnet.web.auth import ALGORITHM, SECRET_KEY from decnet.web.auth import ALGORITHM, SECRET_KEY
@@ -96,3 +96,44 @@ async def get_current_user_unchecked(request: Request) -> str:
Use only for endpoints that must remain reachable with the flag set (e.g. change-password). Use only for endpoints that must remain reachable with the flag set (e.g. change-password).
""" """
return await _decode_token(request) return await _decode_token(request)
# ---------------------------------------------------------------------------
# Role-based access control
# ---------------------------------------------------------------------------
def require_role(*allowed_roles: str):
"""Factory that returns a FastAPI dependency enforcing role membership.
The returned dependency chains from ``get_current_user`` (JWT + must_change_password)
then verifies the user's role is in *allowed_roles*. Returns the full user dict so
endpoints can inspect ``user["uuid"]``, ``user["role"]``, etc. without a second lookup.
"""
async def _check(current_user: str = Depends(get_current_user)) -> dict:
user = await repo.get_user_by_uuid(current_user)
if not user or user["role"] not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions",
)
return user
return _check
def require_stream_role(*allowed_roles: str):
"""Like ``require_role`` but for SSE endpoints that accept a query-param token."""
async def _check(request: Request, token: Optional[str] = None) -> dict:
user_uuid = await get_stream_user(request, token)
user = await repo.get_user_by_uuid(user_uuid)
if not user or user["role"] not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions",
)
return user
return _check
require_admin = require_role("admin")
require_viewer = require_role("viewer", "admin")
require_stream_viewer = require_stream_role("viewer", "admin")