From 172a002d412ab707ffe5d26b6c03d927137a46d1 Mon Sep 17 00:00:00 2001 From: anti Date: Wed, 15 Apr 2026 12:50:41 -0400 Subject: [PATCH] refactor: implement database backend factory for SQLite and MySQL - Add `get_repository()` factory function to select DB implementation at runtime via DECNET_DB_TYPE env var - Extract BaseRepository abstract interface from SQLiteRepository - Update dependencies to use factory-based repository injection - Add DECNET_DB_TYPE env var support (defaults to sqlite) - Refactor models and repository base class for cross-dialect compatibility --- decnet/env.py | 6 +++ decnet/web/api.py | 2 +- decnet/web/db/factory.py | 29 ++++++++----- decnet/web/db/models.py | 81 ++++++++++++++++++++++++++++++++----- decnet/web/db/repository.py | 39 +++++++++++++++++- decnet/web/dependencies.py | 43 +++++++++++++++++++- 6 files changed, 177 insertions(+), 23 deletions(-) diff --git a/decnet/env.py b/decnet/env.py index eb57d3d..8afa5c2 100644 --- a/decnet/env.py +++ b/decnet/env.py @@ -59,6 +59,12 @@ DECNET_DEVELOPER: bool = os.environ.get("DECNET_DEVELOPER", "False").lower() == # Database Options DECNET_DB_TYPE: str = os.environ.get("DECNET_DB_TYPE", "sqlite").lower() DECNET_DB_URL: Optional[str] = os.environ.get("DECNET_DB_URL") +# MySQL component vars (used only when DECNET_DB_URL is not set) +DECNET_DB_HOST: str = os.environ.get("DECNET_DB_HOST", "localhost") +DECNET_DB_PORT: int = _port("DECNET_DB_PORT", 3306) if os.environ.get("DECNET_DB_PORT") else 3306 +DECNET_DB_NAME: str = os.environ.get("DECNET_DB_NAME", "decnet") +DECNET_DB_USER: str = os.environ.get("DECNET_DB_USER", "decnet") +DECNET_DB_PASSWORD: Optional[str] = os.environ.get("DECNET_DB_PASSWORD") # CORS — comma-separated list of allowed origins for the web dashboard API. # Defaults to the configured web host/port. Override with DECNET_CORS_ORIGINS if needed. diff --git a/decnet/web/api.py b/decnet/web/api.py index bbac49b..8d044f5 100644 --- a/decnet/web/api.py +++ b/decnet/web/api.py @@ -14,7 +14,7 @@ from decnet.logging import get_logger from decnet.web.dependencies import repo from decnet.collector import log_collector_worker from decnet.web.ingester import log_ingestion_worker -from decnet.web.attacker_worker import attacker_profile_worker +from decnet.profiler import attacker_profile_worker from decnet.web.router import api_router log = get_logger("api") diff --git a/decnet/web/db/factory.py b/decnet/web/db/factory.py index b98884e..2030be1 100644 --- a/decnet/web/db/factory.py +++ b/decnet/web/db/factory.py @@ -1,18 +1,29 @@ +""" +Repository factory — selects a :class:`BaseRepository` implementation based on +``DECNET_DB_TYPE`` (``sqlite`` or ``mysql``). +""" +from __future__ import annotations + +import os from typing import Any -from decnet.env import os + from decnet.web.db.repository import BaseRepository + def get_repository(**kwargs: Any) -> BaseRepository: - """Factory function to instantiate the correct repository implementation based on environment.""" + """Instantiate the repository implementation selected by ``DECNET_DB_TYPE``. + + Keyword arguments are forwarded to the concrete implementation: + + * SQLite accepts ``db_path``. + * MySQL accepts ``url`` and engine tuning knobs (``pool_size``, …). + """ db_type = os.environ.get("DECNET_DB_TYPE", "sqlite").lower() if db_type == "sqlite": from decnet.web.db.sqlite.repository import SQLiteRepository return SQLiteRepository(**kwargs) - elif db_type == "mysql": - # Placeholder for future implementation - # from decnet.web.db.mysql.repository import MySQLRepository - # return MySQLRepository() - raise NotImplementedError("MySQL support is planned but not yet implemented.") - else: - raise ValueError(f"Unsupported database type: {db_type}") + if db_type == "mysql": + from decnet.web.db.mysql.repository import MySQLRepository + return MySQLRepository(**kwargs) + raise ValueError(f"Unsupported database type: {db_type}") diff --git a/decnet/web/db/models.py b/decnet/web/db/models.py index a8ac6d7..8104801 100644 --- a/decnet/web/db/models.py +++ b/decnet/web/db/models.py @@ -1,6 +1,14 @@ from datetime import datetime, timezone from typing import Literal, Optional, Any, List, Annotated +from sqlalchemy import Column, Text +from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlmodel import SQLModel, Field + +# Use on columns that accumulate over an attacker's lifetime (commands, +# fingerprints, state blobs). TEXT on MySQL caps at 64 KiB; MEDIUMTEXT +# stretches to 16 MiB. SQLite has no fixed-width text types so Text() +# stays unchanged there. +_BIG_TEXT = Text().with_variant(MEDIUMTEXT(), "mysql") from pydantic import BaseModel, ConfigDict, Field as PydanticField, BeforeValidator from decnet.models import IniContent @@ -30,9 +38,11 @@ class Log(SQLModel, table=True): service: str = Field(index=True) event_type: str = Field(index=True) attacker_ip: str = Field(index=True) - raw_line: str - fields: str - msg: Optional[str] = None + # Long-text columns — use TEXT so MySQL DDL doesn't truncate to VARCHAR(255). + # TEXT is equivalent to plain text in SQLite. + raw_line: str = Field(sa_column=Column("raw_line", Text, nullable=False)) + fields: str = Field(sa_column=Column("fields", Text, nullable=False)) + msg: Optional[str] = Field(default=None, sa_column=Column("msg", Text, nullable=True)) class Bounty(SQLModel, table=True): __tablename__ = "bounty" @@ -42,13 +52,15 @@ class Bounty(SQLModel, table=True): service: str = Field(index=True) attacker_ip: str = Field(index=True) bounty_type: str = Field(index=True) - payload: str + payload: str = Field(sa_column=Column("payload", Text, nullable=False)) class State(SQLModel, table=True): __tablename__ = "state" key: str = Field(primary_key=True) - value: str # Stores JSON serialized DecnetConfig or other state blobs + # JSON-serialized DecnetConfig or other state blobs — can be large as + # deckies/services accumulate. MEDIUMTEXT on MySQL (16 MiB ceiling). + value: str = Field(sa_column=Column("value", _BIG_TEXT, nullable=False)) class Attacker(SQLModel, table=True): @@ -60,14 +72,63 @@ class Attacker(SQLModel, table=True): event_count: int = Field(default=0) service_count: int = Field(default=0) decky_count: int = Field(default=0) - services: str = Field(default="[]") # JSON list[str] - deckies: str = Field(default="[]") # JSON list[str], first-contact ordered - traversal_path: Optional[str] = None # "decky-01 → decky-03 → decky-05" + # JSON blobs — these grow over the attacker's lifetime. Use MEDIUMTEXT on + # MySQL (16 MiB) for the fields that accumulate (fingerprints, commands, + # and the deckies/services lists that are unbounded in principle). + services: str = Field( + default="[]", sa_column=Column("services", _BIG_TEXT, nullable=False, default="[]") + ) # JSON list[str] + deckies: str = Field( + default="[]", sa_column=Column("deckies", _BIG_TEXT, nullable=False, default="[]") + ) # JSON list[str], first-contact ordered + traversal_path: Optional[str] = Field( + default=None, sa_column=Column("traversal_path", Text, nullable=True) + ) # "decky-01 → decky-03 → decky-05" is_traversal: bool = Field(default=False) bounty_count: int = Field(default=0) credential_count: int = Field(default=0) - fingerprints: str = Field(default="[]") # JSON list[dict] — bounty fingerprints - commands: str = Field(default="[]") # JSON list[dict] — commands per service/decky + fingerprints: str = Field( + default="[]", sa_column=Column("fingerprints", _BIG_TEXT, nullable=False, default="[]") + ) # JSON list[dict] — bounty fingerprints + commands: str = Field( + default="[]", sa_column=Column("commands", _BIG_TEXT, nullable=False, default="[]") + ) # JSON list[dict] — commands per service/decky + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), index=True + ) + + +class AttackerBehavior(SQLModel, table=True): + """ + Timing & behavioral profile for an attacker, joined to Attacker by uuid. + + Kept in a separate table so the core Attacker row stays narrow and + behavior data can be updated independently (e.g. as the sniffer observes + more packets) without touching the event-count aggregates. + """ + __tablename__ = "attacker_behavior" + attacker_uuid: str = Field(primary_key=True, foreign_key="attackers.uuid") + # OS / TCP stack fingerprint (rolled up from sniffer events) + os_guess: Optional[str] = None + hop_distance: Optional[int] = None + tcp_fingerprint: str = Field( + default="{}", + sa_column=Column("tcp_fingerprint", Text, nullable=False, default="{}"), + ) # JSON: window, wscale, mss, options_sig + retransmit_count: int = Field(default=0) + # Behavioral (derived by the profiler from log-event timing) + behavior_class: Optional[str] = None # beaconing | interactive | scanning | mixed | unknown + beacon_interval_s: Optional[float] = None + beacon_jitter_pct: Optional[float] = None + tool_guess: Optional[str] = None # cobalt_strike | sliver | havoc | mythic + timing_stats: str = Field( + default="{}", + sa_column=Column("timing_stats", Text, nullable=False, default="{}"), + ) # JSON: mean/median/stdev/min/max IAT + phase_sequence: str = Field( + default="{}", + sa_column=Column("phase_sequence", Text, nullable=False, default="{}"), + ) # JSON: recon_end/exfil_start/latency updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), index=True ) diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index b5ac989..97ba167 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -60,6 +60,26 @@ class BaseRepository(ABC): """Update a user's password and change the must_change_password flag.""" pass + @abstractmethod + async def list_users(self) -> list[dict[str, Any]]: + """Retrieve all users (caller must strip password_hash before returning to clients).""" + pass + + @abstractmethod + async def delete_user(self, uuid: str) -> bool: + """Delete a user by UUID. Returns True if user was found and deleted.""" + pass + + @abstractmethod + async def update_user_role(self, uuid: str, role: str) -> None: + """Update a user's role.""" + pass + + @abstractmethod + async def purge_logs_and_bounties(self) -> dict[str, int]: + """Delete all logs, bounties, and attacker profiles. Returns counts of deleted rows.""" + pass + @abstractmethod async def add_bounty(self, bounty_data: dict[str, Any]) -> None: """Add a new harvested artifact (bounty) to the database.""" @@ -117,8 +137,23 @@ class BaseRepository(ABC): pass @abstractmethod - async def upsert_attacker(self, data: dict[str, Any]) -> None: - """Insert or replace an attacker profile record.""" + async def upsert_attacker(self, data: dict[str, Any]) -> str: + """Insert or replace an attacker profile record. Returns the row's UUID.""" + pass + + @abstractmethod + async def upsert_attacker_behavior(self, attacker_uuid: str, data: dict[str, Any]) -> None: + """Insert or replace the behavioral/fingerprint row for an attacker.""" + pass + + @abstractmethod + async def get_attacker_behavior(self, attacker_uuid: str) -> Optional[dict[str, Any]]: + """Retrieve the behavioral/fingerprint row for an attacker UUID.""" + pass + + @abstractmethod + async def get_behaviors_for_ips(self, ips: set[str]) -> dict[str, dict[str, Any]]: + """Bulk-fetch behavior rows keyed by attacker IP (JOIN to attackers).""" pass @abstractmethod diff --git a/decnet/web/dependencies.py b/decnet/web/dependencies.py index 99a6d39..2ecfa0d 100644 --- a/decnet/web/dependencies.py +++ b/decnet/web/dependencies.py @@ -1,7 +1,7 @@ from typing import Any, Optional import jwt -from fastapi import HTTPException, status, Request +from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from decnet.web.auth import ALGORITHM, SECRET_KEY @@ -96,3 +96,44 @@ async def get_current_user_unchecked(request: Request) -> str: Use only for endpoints that must remain reachable with the flag set (e.g. change-password). """ return await _decode_token(request) + + +# --------------------------------------------------------------------------- +# Role-based access control +# --------------------------------------------------------------------------- + +def require_role(*allowed_roles: str): + """Factory that returns a FastAPI dependency enforcing role membership. + + The returned dependency chains from ``get_current_user`` (JWT + must_change_password) + then verifies the user's role is in *allowed_roles*. Returns the full user dict so + endpoints can inspect ``user["uuid"]``, ``user["role"]``, etc. without a second lookup. + """ + async def _check(current_user: str = Depends(get_current_user)) -> dict: + user = await repo.get_user_by_uuid(current_user) + if not user or user["role"] not in allowed_roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions", + ) + return user + return _check + + +def require_stream_role(*allowed_roles: str): + """Like ``require_role`` but for SSE endpoints that accept a query-param token.""" + async def _check(request: Request, token: Optional[str] = None) -> dict: + user_uuid = await get_stream_user(request, token) + user = await repo.get_user_by_uuid(user_uuid) + if not user or user["role"] not in allowed_roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions", + ) + return user + return _check + + +require_admin = require_role("admin") +require_viewer = require_role("viewer", "admin") +require_stream_viewer = require_stream_role("viewer", "admin")