refactor: implement database backend factory for SQLite and MySQL
- Add `get_repository()` factory function to select DB implementation at runtime via DECNET_DB_TYPE env var - Extract BaseRepository abstract interface from SQLiteRepository - Update dependencies to use factory-based repository injection - Add DECNET_DB_TYPE env var support (defaults to sqlite) - Refactor models and repository base class for cross-dialect compatibility
This commit is contained in:
@@ -1,18 +1,29 @@
|
||||
"""
|
||||
Repository factory — selects a :class:`BaseRepository` implementation based on
|
||||
``DECNET_DB_TYPE`` (``sqlite`` or ``mysql``).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from decnet.env import os
|
||||
|
||||
from decnet.web.db.repository import BaseRepository
|
||||
|
||||
|
||||
def get_repository(**kwargs: Any) -> BaseRepository:
|
||||
"""Factory function to instantiate the correct repository implementation based on environment."""
|
||||
"""Instantiate the repository implementation selected by ``DECNET_DB_TYPE``.
|
||||
|
||||
Keyword arguments are forwarded to the concrete implementation:
|
||||
|
||||
* SQLite accepts ``db_path``.
|
||||
* MySQL accepts ``url`` and engine tuning knobs (``pool_size``, …).
|
||||
"""
|
||||
db_type = os.environ.get("DECNET_DB_TYPE", "sqlite").lower()
|
||||
|
||||
if db_type == "sqlite":
|
||||
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||
return SQLiteRepository(**kwargs)
|
||||
elif db_type == "mysql":
|
||||
# Placeholder for future implementation
|
||||
# from decnet.web.db.mysql.repository import MySQLRepository
|
||||
# return MySQLRepository()
|
||||
raise NotImplementedError("MySQL support is planned but not yet implemented.")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
if db_type == "mysql":
|
||||
from decnet.web.db.mysql.repository import MySQLRepository
|
||||
return MySQLRepository(**kwargs)
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional, Any, List, Annotated
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
# Use on columns that accumulate over an attacker's lifetime (commands,
|
||||
# fingerprints, state blobs). TEXT on MySQL caps at 64 KiB; MEDIUMTEXT
|
||||
# stretches to 16 MiB. SQLite has no fixed-width text types so Text()
|
||||
# stays unchanged there.
|
||||
_BIG_TEXT = Text().with_variant(MEDIUMTEXT(), "mysql")
|
||||
from pydantic import BaseModel, ConfigDict, Field as PydanticField, BeforeValidator
|
||||
from decnet.models import IniContent
|
||||
|
||||
@@ -30,9 +38,11 @@ class Log(SQLModel, table=True):
|
||||
service: str = Field(index=True)
|
||||
event_type: str = Field(index=True)
|
||||
attacker_ip: str = Field(index=True)
|
||||
raw_line: str
|
||||
fields: str
|
||||
msg: Optional[str] = None
|
||||
# Long-text columns — use TEXT so MySQL DDL doesn't truncate to VARCHAR(255).
|
||||
# TEXT is equivalent to plain text in SQLite.
|
||||
raw_line: str = Field(sa_column=Column("raw_line", Text, nullable=False))
|
||||
fields: str = Field(sa_column=Column("fields", Text, nullable=False))
|
||||
msg: Optional[str] = Field(default=None, sa_column=Column("msg", Text, nullable=True))
|
||||
|
||||
class Bounty(SQLModel, table=True):
|
||||
__tablename__ = "bounty"
|
||||
@@ -42,13 +52,15 @@ class Bounty(SQLModel, table=True):
|
||||
service: str = Field(index=True)
|
||||
attacker_ip: str = Field(index=True)
|
||||
bounty_type: str = Field(index=True)
|
||||
payload: str
|
||||
payload: str = Field(sa_column=Column("payload", Text, nullable=False))
|
||||
|
||||
|
||||
class State(SQLModel, table=True):
|
||||
__tablename__ = "state"
|
||||
key: str = Field(primary_key=True)
|
||||
value: str # Stores JSON serialized DecnetConfig or other state blobs
|
||||
# JSON-serialized DecnetConfig or other state blobs — can be large as
|
||||
# deckies/services accumulate. MEDIUMTEXT on MySQL (16 MiB ceiling).
|
||||
value: str = Field(sa_column=Column("value", _BIG_TEXT, nullable=False))
|
||||
|
||||
|
||||
class Attacker(SQLModel, table=True):
|
||||
@@ -60,14 +72,63 @@ class Attacker(SQLModel, table=True):
|
||||
event_count: int = Field(default=0)
|
||||
service_count: int = Field(default=0)
|
||||
decky_count: int = Field(default=0)
|
||||
services: str = Field(default="[]") # JSON list[str]
|
||||
deckies: str = Field(default="[]") # JSON list[str], first-contact ordered
|
||||
traversal_path: Optional[str] = None # "decky-01 → decky-03 → decky-05"
|
||||
# JSON blobs — these grow over the attacker's lifetime. Use MEDIUMTEXT on
|
||||
# MySQL (16 MiB) for the fields that accumulate (fingerprints, commands,
|
||||
# and the deckies/services lists that are unbounded in principle).
|
||||
services: str = Field(
|
||||
default="[]", sa_column=Column("services", _BIG_TEXT, nullable=False, default="[]")
|
||||
) # JSON list[str]
|
||||
deckies: str = Field(
|
||||
default="[]", sa_column=Column("deckies", _BIG_TEXT, nullable=False, default="[]")
|
||||
) # JSON list[str], first-contact ordered
|
||||
traversal_path: Optional[str] = Field(
|
||||
default=None, sa_column=Column("traversal_path", Text, nullable=True)
|
||||
) # "decky-01 → decky-03 → decky-05"
|
||||
is_traversal: bool = Field(default=False)
|
||||
bounty_count: int = Field(default=0)
|
||||
credential_count: int = Field(default=0)
|
||||
fingerprints: str = Field(default="[]") # JSON list[dict] — bounty fingerprints
|
||||
commands: str = Field(default="[]") # JSON list[dict] — commands per service/decky
|
||||
fingerprints: str = Field(
|
||||
default="[]", sa_column=Column("fingerprints", _BIG_TEXT, nullable=False, default="[]")
|
||||
) # JSON list[dict] — bounty fingerprints
|
||||
commands: str = Field(
|
||||
default="[]", sa_column=Column("commands", _BIG_TEXT, nullable=False, default="[]")
|
||||
) # JSON list[dict] — commands per service/decky
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc), index=True
|
||||
)
|
||||
|
||||
|
||||
class AttackerBehavior(SQLModel, table=True):
|
||||
"""
|
||||
Timing & behavioral profile for an attacker, joined to Attacker by uuid.
|
||||
|
||||
Kept in a separate table so the core Attacker row stays narrow and
|
||||
behavior data can be updated independently (e.g. as the sniffer observes
|
||||
more packets) without touching the event-count aggregates.
|
||||
"""
|
||||
__tablename__ = "attacker_behavior"
|
||||
attacker_uuid: str = Field(primary_key=True, foreign_key="attackers.uuid")
|
||||
# OS / TCP stack fingerprint (rolled up from sniffer events)
|
||||
os_guess: Optional[str] = None
|
||||
hop_distance: Optional[int] = None
|
||||
tcp_fingerprint: str = Field(
|
||||
default="{}",
|
||||
sa_column=Column("tcp_fingerprint", Text, nullable=False, default="{}"),
|
||||
) # JSON: window, wscale, mss, options_sig
|
||||
retransmit_count: int = Field(default=0)
|
||||
# Behavioral (derived by the profiler from log-event timing)
|
||||
behavior_class: Optional[str] = None # beaconing | interactive | scanning | mixed | unknown
|
||||
beacon_interval_s: Optional[float] = None
|
||||
beacon_jitter_pct: Optional[float] = None
|
||||
tool_guess: Optional[str] = None # cobalt_strike | sliver | havoc | mythic
|
||||
timing_stats: str = Field(
|
||||
default="{}",
|
||||
sa_column=Column("timing_stats", Text, nullable=False, default="{}"),
|
||||
) # JSON: mean/median/stdev/min/max IAT
|
||||
phase_sequence: str = Field(
|
||||
default="{}",
|
||||
sa_column=Column("phase_sequence", Text, nullable=False, default="{}"),
|
||||
) # JSON: recon_end/exfil_start/latency
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc), index=True
|
||||
)
|
||||
|
||||
@@ -60,6 +60,26 @@ class BaseRepository(ABC):
|
||||
"""Update a user's password and change the must_change_password flag."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_users(self) -> list[dict[str, Any]]:
|
||||
"""Retrieve all users (caller must strip password_hash before returning to clients)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_user(self, uuid: str) -> bool:
|
||||
"""Delete a user by UUID. Returns True if user was found and deleted."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_user_role(self, uuid: str, role: str) -> None:
|
||||
"""Update a user's role."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
||||
"""Delete all logs, bounties, and attacker profiles. Returns counts of deleted rows."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def add_bounty(self, bounty_data: dict[str, Any]) -> None:
|
||||
"""Add a new harvested artifact (bounty) to the database."""
|
||||
@@ -117,8 +137,23 @@ class BaseRepository(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_attacker(self, data: dict[str, Any]) -> None:
|
||||
"""Insert or replace an attacker profile record."""
|
||||
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
||||
"""Insert or replace an attacker profile record. Returns the row's UUID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_attacker_behavior(self, attacker_uuid: str, data: dict[str, Any]) -> None:
|
||||
"""Insert or replace the behavioral/fingerprint row for an attacker."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_attacker_behavior(self, attacker_uuid: str) -> Optional[dict[str, Any]]:
|
||||
"""Retrieve the behavioral/fingerprint row for an attacker UUID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_behaviors_for_ips(self, ips: set[str]) -> dict[str, dict[str, Any]]:
|
||||
"""Bulk-fetch behavior rows keyed by attacker IP (JOIN to attackers)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
Reference in New Issue
Block a user