merge: testing → main (reconcile 2-week divergence)
This commit is contained in:
0
tests/db/__init__.py
Normal file
0
tests/db/__init__.py
Normal file
0
tests/db/mysql/__init__.py
Normal file
0
tests/db/mysql/__init__.py
Normal file
70
tests/db/mysql/test_mysql_histogram_sql.py
Normal file
70
tests/db/mysql/test_mysql_histogram_sql.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Inspection-level tests for the MySQL-dialect SQL emitted by MySQLRepository.
|
||||
|
||||
We compile the SQLAlchemy statements against the MySQL dialect and assert on
|
||||
the string form — no live MySQL server is required.
|
||||
"""
|
||||
import pytest
|
||||
from sqlalchemy import func, select, literal_column
|
||||
from sqlalchemy.dialects import mysql
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from decnet.web.db.models import Log
|
||||
|
||||
|
||||
def _compile(stmt) -> str:
|
||||
"""Compile a statement to MySQL-dialect SQL with literal values inlined."""
|
||||
return str(stmt.compile(
|
||||
dialect=mysql.dialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
))
|
||||
|
||||
|
||||
def test_mysql_histogram_uses_from_unixtime_bucket():
|
||||
"""The MySQL dialect must bucket with UNIX_TIMESTAMP DIV N * N wrapped in FROM_UNIXTIME."""
|
||||
bucket_seconds = 900 # 15 min
|
||||
bucket_expr = literal_column(
|
||||
f"FROM_UNIXTIME((UNIX_TIMESTAMP(timestamp) DIV {bucket_seconds}) * {bucket_seconds})"
|
||||
).label("bucket_time")
|
||||
stmt: SelectOfScalar = select(bucket_expr, func.count().label("count")).select_from(Log)
|
||||
|
||||
sql = _compile(stmt)
|
||||
assert "FROM_UNIXTIME" in sql
|
||||
assert "UNIX_TIMESTAMP" in sql
|
||||
assert "DIV 900" in sql
|
||||
# Sanity: SQLite-only strftime must NOT appear in the MySQL-dialect output.
|
||||
assert "strftime" not in sql
|
||||
assert "unixepoch" not in sql
|
||||
|
||||
|
||||
def test_mysql_json_unquote_predicate_shape():
|
||||
"""MySQL JSON filter uses JSON_UNQUOTE(JSON_EXTRACT(...))."""
|
||||
from decnet.web.db.mysql.repository import MySQLRepository
|
||||
|
||||
# Build a dummy instance without touching the engine. We only need _json_field_equals,
|
||||
# which is a pure function of the key.
|
||||
repo = MySQLRepository.__new__(MySQLRepository) # bypass __init__ / no DB connection
|
||||
predicate = repo._json_field_equals("username")
|
||||
|
||||
# text() objects carry their literal SQL in .text
|
||||
assert "JSON_UNQUOTE" in predicate.text
|
||||
assert "JSON_EXTRACT(fields, '$.username')" in predicate.text
|
||||
assert ":val" in predicate.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("key", ["user", "port", "sess_id"])
|
||||
def test_mysql_json_predicate_safe_for_reasonable_keys(key):
|
||||
"""Keys matching [A-Za-z0-9_]+ are inserted verbatim; verify no SQL breakage."""
|
||||
from decnet.web.db.mysql.repository import MySQLRepository
|
||||
repo = MySQLRepository.__new__(MySQLRepository)
|
||||
pred = repo._json_field_equals(key)
|
||||
assert f"'$.{key}'" in pred.text
|
||||
|
||||
|
||||
def test_sqlite_histogram_still_uses_strftime():
|
||||
"""Regression guard — SQLite implementation must keep its strftime-based bucket."""
|
||||
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||
import inspect
|
||||
src = inspect.getsource(SQLiteRepository.get_log_histogram)
|
||||
assert "strftime" in src
|
||||
assert "unixepoch" in src
|
||||
234
tests/db/mysql/test_mysql_migration.py
Normal file
234
tests/db/mysql/test_mysql_migration.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Tests for MySQLRepository._migrate_column_types().
|
||||
|
||||
No live MySQL server required — uses an in-memory SQLite engine that exposes
|
||||
the same information_schema-style query surface via a mocked connection, plus
|
||||
an integration-style test using a real async engine over aiosqlite (which
|
||||
ignores the TEXT/MEDIUMTEXT distinction but verifies the ALTER path is called
|
||||
and idempotent).
|
||||
|
||||
The ALTER TABLE branch is tested via unittest.mock: we intercept the
|
||||
information_schema query result and assert the correct MODIFY COLUMN
|
||||
statements are issued.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||
|
||||
from decnet.web.db.mysql.repository import MySQLRepository
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_repo() -> MySQLRepository:
|
||||
"""Construct a MySQLRepository without touching any real DB."""
|
||||
return MySQLRepository.__new__(MySQLRepository)
|
||||
|
||||
|
||||
# ── _migrate_column_types ─────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_column_types_issues_alter_for_text_columns():
|
||||
"""When information_schema reports TEXT columns, ALTER TABLE is called for each."""
|
||||
repo = _make_repo()
|
||||
|
||||
# Rows returned by the information_schema query: two TEXT columns found
|
||||
fake_rows = [
|
||||
("attackers", "commands"),
|
||||
("attackers", "fingerprints"),
|
||||
("state", "value"),
|
||||
]
|
||||
|
||||
exec_results: list[str] = []
|
||||
|
||||
async def fake_execute(stmt):
|
||||
sql = str(stmt)
|
||||
if "information_schema" in sql:
|
||||
result = MagicMock()
|
||||
result.fetchall.return_value = fake_rows
|
||||
return result
|
||||
# Capture ALTER TABLE calls
|
||||
exec_results.append(sql)
|
||||
return MagicMock()
|
||||
|
||||
fake_conn = AsyncMock()
|
||||
fake_conn.execute.side_effect = fake_execute
|
||||
|
||||
fake_ctx = AsyncMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
repo.engine = MagicMock()
|
||||
repo.engine.begin.return_value = fake_ctx
|
||||
|
||||
await repo._migrate_column_types()
|
||||
|
||||
# Three ALTER TABLE statements expected, one per TEXT column returned
|
||||
assert len(exec_results) == 3
|
||||
assert any("`commands` MEDIUMTEXT" in s for s in exec_results)
|
||||
assert any("`fingerprints` MEDIUMTEXT" in s for s in exec_results)
|
||||
assert any("`value` MEDIUMTEXT" in s for s in exec_results)
|
||||
# Verify NOT NULL is preserved
|
||||
assert all("NOT NULL" in s for s in exec_results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_column_types_no_alter_when_already_mediumtext():
|
||||
"""When information_schema returns no TEXT rows, no ALTER is issued."""
|
||||
repo = _make_repo()
|
||||
|
||||
exec_results: list[str] = []
|
||||
|
||||
async def fake_execute(stmt):
|
||||
sql = str(stmt)
|
||||
if "information_schema" in sql:
|
||||
result = MagicMock()
|
||||
result.fetchall.return_value = [] # nothing to migrate
|
||||
return result
|
||||
exec_results.append(sql)
|
||||
return MagicMock()
|
||||
|
||||
fake_conn = AsyncMock()
|
||||
fake_conn.execute.side_effect = fake_execute
|
||||
|
||||
fake_ctx = AsyncMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
repo.engine = MagicMock()
|
||||
repo.engine.begin.return_value = fake_ctx
|
||||
|
||||
await repo._migrate_column_types()
|
||||
|
||||
assert exec_results == [], "No ALTER TABLE should be issued when columns are already MEDIUMTEXT"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_column_types_idempotent_on_repeated_calls():
|
||||
"""Calling _migrate_column_types twice is safe: second call is a no-op."""
|
||||
repo = _make_repo()
|
||||
call_count = 0
|
||||
|
||||
async def fake_execute(stmt):
|
||||
nonlocal call_count
|
||||
sql = str(stmt)
|
||||
if "information_schema" in sql:
|
||||
result = MagicMock()
|
||||
# First call: two TEXT columns; second call: zero (already migrated)
|
||||
call_count += 1
|
||||
result.fetchall.return_value = (
|
||||
[("attackers", "commands")] if call_count == 1 else []
|
||||
)
|
||||
return result
|
||||
return MagicMock()
|
||||
|
||||
def _make_ctx():
|
||||
fake_conn = AsyncMock()
|
||||
fake_conn.execute.side_effect = fake_execute
|
||||
ctx = AsyncMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
return ctx
|
||||
|
||||
repo.engine = MagicMock()
|
||||
repo.engine.begin.side_effect = _make_ctx
|
||||
|
||||
await repo._migrate_column_types()
|
||||
await repo._migrate_column_types() # second call must not raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_column_types_default_clause_per_column():
|
||||
"""Each attacker column gets DEFAULT '[]'; state.value gets no DEFAULT."""
|
||||
repo = _make_repo()
|
||||
|
||||
all_text_rows = [
|
||||
("attackers", "commands"),
|
||||
("attackers", "fingerprints"),
|
||||
("attackers", "services"),
|
||||
("attackers", "deckies"),
|
||||
("state", "value"),
|
||||
]
|
||||
alter_stmts: list[str] = []
|
||||
|
||||
async def fake_execute(stmt):
|
||||
sql = str(stmt)
|
||||
if "information_schema" in sql:
|
||||
result = MagicMock()
|
||||
result.fetchall.return_value = all_text_rows
|
||||
return result
|
||||
alter_stmts.append(sql)
|
||||
return MagicMock()
|
||||
|
||||
fake_conn = AsyncMock()
|
||||
fake_conn.execute.side_effect = fake_execute
|
||||
|
||||
fake_ctx = AsyncMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
repo.engine = MagicMock()
|
||||
repo.engine.begin.return_value = fake_ctx
|
||||
|
||||
await repo._migrate_column_types()
|
||||
|
||||
attacker_alters = [s for s in alter_stmts if "`attackers`" in s]
|
||||
state_alters = [s for s in alter_stmts if "`state`" in s]
|
||||
|
||||
assert len(attacker_alters) == 4
|
||||
assert len(state_alters) == 1
|
||||
|
||||
for stmt in attacker_alters:
|
||||
assert "DEFAULT '[]'" in stmt, f"Missing DEFAULT '[]' in: {stmt}"
|
||||
|
||||
# state.value has no DEFAULT in the schema
|
||||
assert "DEFAULT" not in state_alters[0], \
|
||||
f"Unexpected DEFAULT in state.value alter: {state_alters[0]}"
|
||||
|
||||
|
||||
# ── initialize override ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mysql_initialize_calls_migrate_column_types():
|
||||
"""MySQLRepository.initialize() must invoke every migration helper in
|
||||
the right order: attackers first, then session_profile (DEBT-036),
|
||||
then column types, then seed the admin user."""
|
||||
repo = _make_repo()
|
||||
|
||||
call_order: list[str] = []
|
||||
|
||||
async def fake_migrate_attackers():
|
||||
call_order.append("migrate_attackers")
|
||||
|
||||
async def fake_migrate_session_profile():
|
||||
call_order.append("migrate_session_profile")
|
||||
|
||||
async def fake_migrate_column_types():
|
||||
call_order.append("migrate_column_types")
|
||||
|
||||
async def fake_ensure_admin():
|
||||
call_order.append("ensure_admin")
|
||||
|
||||
repo._migrate_attackers_table = fake_migrate_attackers
|
||||
repo._migrate_session_profile_table = fake_migrate_session_profile
|
||||
repo._migrate_column_types = fake_migrate_column_types
|
||||
repo._ensure_admin_user = fake_ensure_admin
|
||||
|
||||
# Stub engine.begin() so create_all is a no-op
|
||||
fake_conn = AsyncMock()
|
||||
fake_conn.run_sync = AsyncMock()
|
||||
fake_ctx = AsyncMock()
|
||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
repo.engine = MagicMock()
|
||||
repo.engine.begin.return_value = fake_ctx
|
||||
|
||||
await repo.initialize()
|
||||
|
||||
assert call_order == [
|
||||
"migrate_attackers",
|
||||
"migrate_session_profile",
|
||||
"migrate_column_types",
|
||||
"ensure_admin",
|
||||
], f"Unexpected call order: {call_order}"
|
||||
78
tests/db/mysql/test_mysql_url_builder.py
Normal file
78
tests/db/mysql/test_mysql_url_builder.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Unit tests for decnet.web.db.mysql.database.build_mysql_url / resolve_url.
|
||||
|
||||
No MySQL server is required — these are pure URL-construction tests.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from decnet.web.db.mysql.database import build_mysql_url, resolve_url
|
||||
|
||||
|
||||
def test_build_url_defaults(monkeypatch):
|
||||
for v in ("DECNET_DB_HOST", "DECNET_DB_PORT", "DECNET_DB_NAME",
|
||||
"DECNET_DB_USER", "DECNET_DB_PASSWORD", "DECNET_DB_URL"):
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
# PYTEST_* is set by pytest itself, so empty password is allowed here.
|
||||
url = build_mysql_url()
|
||||
assert url == "mysql+asyncmy://decnet:@localhost:3306/decnet"
|
||||
|
||||
|
||||
def test_build_url_from_env(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_DB_HOST", "db.internal")
|
||||
monkeypatch.setenv("DECNET_DB_PORT", "3307")
|
||||
monkeypatch.setenv("DECNET_DB_NAME", "decnet_prod")
|
||||
monkeypatch.setenv("DECNET_DB_USER", "svc_decnet")
|
||||
monkeypatch.setenv("DECNET_DB_PASSWORD", "hunter2")
|
||||
url = build_mysql_url()
|
||||
assert url == "mysql+asyncmy://svc_decnet:hunter2@db.internal:3307/decnet_prod"
|
||||
|
||||
|
||||
def test_build_url_percent_encodes_password(monkeypatch):
|
||||
"""Passwords with @ : / # etc must not break URL parsing."""
|
||||
monkeypatch.setenv("DECNET_DB_PASSWORD", "p@ss:word/!#")
|
||||
url = build_mysql_url(user="u", host="h", port=3306, database="d")
|
||||
# @ → %40, : → %3A, / → %2F, # → %23, ! → %21
|
||||
assert "p%40ss%3Aword%2F%21%23" in url
|
||||
assert url.startswith("mysql+asyncmy://u:")
|
||||
assert url.endswith("@h:3306/d")
|
||||
|
||||
|
||||
def test_build_url_component_args_override_env(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_DB_HOST", "ignored")
|
||||
monkeypatch.setenv("DECNET_DB_PASSWORD", "env-pw")
|
||||
url = build_mysql_url(host="arg.host", user="arg-user", password="arg-pw",
|
||||
port=9999, database="arg-db")
|
||||
assert url == "mysql+asyncmy://arg-user:arg-pw@arg.host:9999/arg-db"
|
||||
|
||||
|
||||
def test_resolve_url_prefers_explicit_arg(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_DB_URL", "mysql+asyncmy://env-url/x")
|
||||
assert resolve_url("mysql+asyncmy://explicit/y") == "mysql+asyncmy://explicit/y"
|
||||
|
||||
|
||||
def test_resolve_url_uses_env_url_before_components(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_DB_URL", "mysql+asyncmy://env-user:env-pw@env-host/env-db")
|
||||
monkeypatch.setenv("DECNET_DB_HOST", "ignored.host")
|
||||
assert resolve_url() == "mysql+asyncmy://env-user:env-pw@env-host/env-db"
|
||||
|
||||
|
||||
def test_resolve_url_falls_back_to_components(monkeypatch):
|
||||
monkeypatch.delenv("DECNET_DB_URL", raising=False)
|
||||
monkeypatch.setenv("DECNET_DB_HOST", "fallback.host")
|
||||
monkeypatch.setenv("DECNET_DB_PASSWORD", "pw")
|
||||
url = resolve_url()
|
||||
assert "fallback.host" in url
|
||||
assert url.startswith("mysql+asyncmy://")
|
||||
|
||||
|
||||
def test_build_url_requires_password_outside_pytest(monkeypatch):
|
||||
"""Without a password and not in a pytest run, construction must fail loudly."""
|
||||
for v in ("DECNET_DB_URL", "DECNET_DB_PASSWORD"):
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
# Strip every PYTEST_* env var so the safety check trips.
|
||||
import os
|
||||
for k in list(os.environ):
|
||||
if k.startswith("PYTEST"):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
with pytest.raises(ValueError, match="DECNET_DB_PASSWORD is not set"):
|
||||
build_mysql_url()
|
||||
202
tests/db/test_base_repo.py
Normal file
202
tests/db/test_base_repo.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Mock test for BaseRepository to ensure coverage of abstract pass lines.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from decnet.web.db.repository import BaseRepository
|
||||
|
||||
class DummyRepo(BaseRepository):
|
||||
async def initialize(self) -> None: await super().initialize()
|
||||
async def add_log(self, data): await super().add_log(data)
|
||||
async def get_logs(self, **kw): await super().get_logs(**kw)
|
||||
async def get_total_logs(self, **kw): await super().get_total_logs(**kw)
|
||||
async def get_stats_summary(self): await super().get_stats_summary()
|
||||
async def get_deckies(self): await super().get_deckies()
|
||||
async def get_user_by_username(self, u): await super().get_user_by_username(u)
|
||||
async def get_user_by_uuid(self, u): await super().get_user_by_uuid(u)
|
||||
async def create_user(self, d): await super().create_user(d)
|
||||
async def update_user_password(self, *a, **kw): await super().update_user_password(*a, **kw)
|
||||
async def add_bounty(self, d): await super().add_bounty(d)
|
||||
async def get_bounties(self, **kw): await super().get_bounties(**kw)
|
||||
async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw)
|
||||
async def upsert_credential(self, d): await super().upsert_credential(d); return 0
|
||||
async def get_credentials(self, **kw): await super().get_credentials(**kw)
|
||||
async def get_total_credentials(self, **kw): await super().get_total_credentials(**kw)
|
||||
async def get_credentials_for_attacker(self, ip): await super().get_credentials_for_attacker(ip)
|
||||
async def get_credential_attempts_for_secret(self, h): await super().get_credential_attempts_for_secret(h)
|
||||
async def upsert_credential_reuse(self, **kw): await super().upsert_credential_reuse(**kw); return None
|
||||
async def list_credential_reuses(self, **kw): await super().list_credential_reuses(**kw); return (0, [])
|
||||
async def get_credential_reuse_by_id(self, i): await super().get_credential_reuse_by_id(i)
|
||||
async def update_credential_attacker_uuid(self, ip, u): await super().update_credential_attacker_uuid(ip, u); return 0
|
||||
async def get_state(self, k): await super().get_state(k)
|
||||
async def set_state(self, k, v): await super().set_state(k, v)
|
||||
async def get_max_log_id(self): await super().get_max_log_id()
|
||||
async def get_logs_after_id(self, last_id, limit=500): await super().get_logs_after_id(last_id, limit)
|
||||
async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip()
|
||||
async def get_bounties_for_ips(self, ips): await super().get_bounties_for_ips(ips)
|
||||
async def upsert_attacker(self, d): await super().upsert_attacker(d); return ""
|
||||
async def upsert_attacker_behavior(self, u, d): await super().upsert_attacker_behavior(u, d)
|
||||
async def get_attacker_behavior(self, u): await super().get_attacker_behavior(u)
|
||||
async def get_behaviors_for_ips(self, ips): await super().get_behaviors_for_ips(ips)
|
||||
async def upsert_session_profile(self, sid, data): await super().upsert_session_profile(sid, data)
|
||||
async def get_session_profile(self, sid): await super().get_session_profile(sid)
|
||||
async def increment_smtp_target(self, u, d): await super().increment_smtp_target(u, d)
|
||||
async def list_smtp_targets(self, u): await super().list_smtp_targets(u)
|
||||
async def get_attacker_stored_mail(self, u): await super().get_attacker_stored_mail(u)
|
||||
async def smtp_target_seen(self, d): await super().smtp_target_seen(d)
|
||||
async def get_attacker_by_uuid(self, u): await super().get_attacker_by_uuid(u)
|
||||
async def get_attackers(self, **kw): await super().get_attackers(**kw)
|
||||
async def get_total_attackers(self, **kw): await super().get_total_attackers(**kw)
|
||||
async def get_attacker_commands(self, **kw): await super().get_attacker_commands(**kw)
|
||||
async def list_users(self): await super().list_users()
|
||||
async def delete_user(self, u): await super().delete_user(u)
|
||||
async def update_user_role(self, u, r): await super().update_user_role(u, r)
|
||||
async def purge_logs_and_bounties(self): await super().purge_logs_and_bounties()
|
||||
async def get_attacker_artifacts(self, uuid): await super().get_attacker_artifacts(uuid)
|
||||
async def get_attacker_transcripts(self, uuid): await super().get_attacker_transcripts(uuid)
|
||||
async def get_session_log(self, sid): await super().get_session_log(sid)
|
||||
# DEBT-041 / 3eb67c9 — attacker_intel re-key
|
||||
async def find_credential_reuse_candidates(self, min_targets=2): await super().find_credential_reuse_candidates(min_targets); return []
|
||||
async def get_attacker_intel_by_uuid(self, u): await super().get_attacker_intel_by_uuid(u)
|
||||
async def get_unenriched_attackers(self, limit=100): await super().get_unenriched_attackers(limit)
|
||||
async def upsert_attacker_intel(self, d): await super().upsert_attacker_intel(d); return ""
|
||||
# Identity resolution (this PR)
|
||||
async def get_identity_by_uuid(self, u): await super().get_identity_by_uuid(u)
|
||||
async def list_identities(self, limit=50, offset=0): await super().list_identities(limit, offset); return []
|
||||
async def count_identities(self): await super().count_identities(); return 0
|
||||
async def list_observations_for_identity(self, u, limit=50, offset=0): await super().list_observations_for_identity(u, limit, offset); return []
|
||||
async def count_observations_for_identity(self, u): await super().count_observations_for_identity(u); return 0
|
||||
async def list_attackers_for_clustering(self, limit=None): await super().list_attackers_for_clustering(limit); return []
|
||||
async def create_attacker_identity(self, row): await super().create_attacker_identity(row); return ""
|
||||
async def set_attacker_identity_id(self, a, i): await super().set_attacker_identity_id(a, i)
|
||||
async def list_all_identities(self): await super().list_all_identities(); return []
|
||||
async def update_identity_merged_into(self, u, w): await super().update_identity_merged_into(u, w)
|
||||
async def update_identity_fingerprints(self, u, *, ja3_hashes=None, hassh_hashes=None, tls_cert_sha256=None):
|
||||
await super().update_identity_fingerprints(u, ja3_hashes=ja3_hashes, hassh_hashes=hassh_hashes, tls_cert_sha256=tls_cert_sha256)
|
||||
# Campaign clustering (this PR)
|
||||
async def get_campaign_by_uuid(self, u): await super().get_campaign_by_uuid(u)
|
||||
async def list_campaigns(self, limit=50, offset=0): await super().list_campaigns(limit, offset); return []
|
||||
async def count_campaigns(self): await super().count_campaigns(); return 0
|
||||
async def list_identities_for_campaign(self, u, limit=50, offset=0): await super().list_identities_for_campaign(u, limit, offset); return []
|
||||
async def count_identities_for_campaign(self, u): await super().count_identities_for_campaign(u); return 0
|
||||
async def list_identities_for_clustering(self, limit=None): await super().list_identities_for_clustering(limit); return []
|
||||
async def create_campaign(self, row): await super().create_campaign(row); return ""
|
||||
async def set_identity_campaign_id(self, i, c): await super().set_identity_campaign_id(i, c)
|
||||
async def list_all_campaigns(self): await super().list_all_campaigns(); return []
|
||||
async def update_campaign_merged_into(self, u, w): await super().update_campaign_merged_into(u, w)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_repo_coverage():
|
||||
dr = DummyRepo()
|
||||
# Call all to hit 'pass' statements
|
||||
await dr.initialize()
|
||||
await dr.add_log({})
|
||||
await dr.get_logs()
|
||||
await dr.get_total_logs()
|
||||
await dr.get_stats_summary()
|
||||
await dr.get_deckies()
|
||||
await dr.get_user_by_username("a")
|
||||
await dr.get_user_by_uuid("a")
|
||||
await dr.create_user({})
|
||||
await dr.update_user_password("a", "b")
|
||||
await dr.add_bounty({})
|
||||
await dr.get_bounties()
|
||||
await dr.get_total_bounties()
|
||||
await dr.upsert_credential({})
|
||||
await dr.get_credentials()
|
||||
await dr.get_total_credentials()
|
||||
await dr.get_credentials_for_attacker("1.2.3.4")
|
||||
await dr.get_credential_attempts_for_secret("abc")
|
||||
await dr.upsert_credential_reuse(
|
||||
secret_sha256="x", secret_kind="plaintext", principal=None,
|
||||
attacker_uuid=None, attacker_ip="1.2.3.4", decky="d", service="ssh",
|
||||
attempt_count=1, ts=None,
|
||||
)
|
||||
await dr.list_credential_reuses()
|
||||
await dr.get_credential_reuse_by_id("a")
|
||||
await dr.update_credential_attacker_uuid("1.2.3.4", "u")
|
||||
await dr.get_state("k")
|
||||
await dr.set_state("k", "v")
|
||||
await dr.get_max_log_id()
|
||||
await dr.get_logs_after_id(0)
|
||||
await dr.get_all_bounties_by_ip()
|
||||
await dr.get_bounties_for_ips({"1.1.1.1"})
|
||||
await dr.upsert_attacker({})
|
||||
await dr.upsert_attacker_behavior("a", {})
|
||||
await dr.get_attacker_behavior("a")
|
||||
await dr.get_behaviors_for_ips({"1.1.1.1"})
|
||||
await dr.upsert_session_profile("sid", {})
|
||||
await dr.get_session_profile("sid")
|
||||
await dr.increment_smtp_target("uuid", "corp.com")
|
||||
await dr.list_smtp_targets("uuid")
|
||||
await dr.get_attacker_stored_mail("uuid")
|
||||
await dr.smtp_target_seen("corp.com")
|
||||
await dr.get_attacker_by_uuid("a")
|
||||
await dr.get_attackers()
|
||||
await dr.get_total_attackers()
|
||||
await dr.get_attacker_commands(uuid="a")
|
||||
await dr.list_users()
|
||||
await dr.delete_user("a")
|
||||
await dr.update_user_role("a", "admin")
|
||||
await dr.purge_logs_and_bounties()
|
||||
await dr.get_attacker_artifacts("a")
|
||||
await dr.get_attacker_transcripts("a")
|
||||
await dr.get_session_log("a")
|
||||
await dr.find_credential_reuse_candidates()
|
||||
await dr.get_attacker_intel_by_uuid("a")
|
||||
await dr.get_unenriched_attackers()
|
||||
await dr.upsert_attacker_intel({"attacker_uuid": "a", "attacker_ip": "1.1.1.1"})
|
||||
await dr.get_identity_by_uuid("a")
|
||||
await dr.list_identities()
|
||||
await dr.count_identities()
|
||||
await dr.list_observations_for_identity("a")
|
||||
await dr.count_observations_for_identity("a")
|
||||
await dr.list_attackers_for_clustering()
|
||||
await dr.create_attacker_identity({"uuid": "i"})
|
||||
await dr.set_attacker_identity_id("a", "i")
|
||||
await dr.list_all_identities()
|
||||
await dr.update_identity_merged_into("a", "b")
|
||||
await dr.update_identity_merged_into("a", None)
|
||||
await dr.update_identity_fingerprints("a", ja3_hashes='["x"]', hassh_hashes=None, tls_cert_sha256='["y"]')
|
||||
await dr.get_campaign_by_uuid("a")
|
||||
await dr.list_campaigns()
|
||||
await dr.count_campaigns()
|
||||
await dr.list_identities_for_campaign("a")
|
||||
await dr.count_identities_for_campaign("a")
|
||||
await dr.list_identities_for_clustering()
|
||||
await dr.create_campaign({"uuid": "c"})
|
||||
await dr.set_identity_campaign_id("i", "c")
|
||||
await dr.set_identity_campaign_id("i", None)
|
||||
await dr.list_all_campaigns()
|
||||
await dr.update_campaign_merged_into("c", "d")
|
||||
await dr.update_campaign_merged_into("c", None)
|
||||
|
||||
# Swarm methods: default NotImplementedError on BaseRepository. Covering
|
||||
# them here keeps the coverage contract honest for the swarm CRUD surface.
|
||||
for coro, args in [
|
||||
(dr.add_swarm_host, ({},)),
|
||||
(dr.get_swarm_host_by_name, ("w",)),
|
||||
(dr.get_swarm_host_by_uuid, ("u",)),
|
||||
(dr.list_swarm_hosts, ()),
|
||||
(dr.update_swarm_host, ("u", {})),
|
||||
(dr.delete_swarm_host, ("u",)),
|
||||
(dr.upsert_decky_shard, ({},)),
|
||||
(dr.list_decky_shards, ()),
|
||||
(dr.delete_decky_shards_for_host, ("u",)),
|
||||
(dr.create_topology, ({},)),
|
||||
(dr.get_topology, ("t",)),
|
||||
(dr.list_topologies, ()),
|
||||
(dr.update_topology_status, ("t", "active")),
|
||||
(dr.delete_topology_cascade, ("t",)),
|
||||
(dr.add_lan, ({},)),
|
||||
(dr.update_lan, ("l", {})),
|
||||
(dr.list_lans_for_topology, ("t",)),
|
||||
(dr.add_topology_decky, ({},)),
|
||||
(dr.update_topology_decky, ("d", {})),
|
||||
(dr.list_topology_deckies, ("t",)),
|
||||
(dr.add_topology_edge, ({},)),
|
||||
(dr.list_topology_edges, ("t",)),
|
||||
(dr.list_topology_status_events, ("t",)),
|
||||
]:
|
||||
with pytest.raises(NotImplementedError):
|
||||
await coro(*args)
|
||||
145
tests/db/test_campaign_repo.py
Normal file
145
tests/db/test_campaign_repo.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for the Campaign clustering repo methods on SQLModelRepository."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.web.db.factory import get_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
r = get_repository(db_path=str(tmp_path / "campaigns.db"))
|
||||
await r.initialize()
|
||||
return r
|
||||
|
||||
|
||||
async def _create_identity(repo, uuid: str, **kwargs) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
return await repo.create_attacker_identity({
|
||||
"uuid": uuid,
|
||||
"first_seen_at": kwargs.get("first_seen_at", now),
|
||||
"last_seen_at": kwargs.get("last_seen_at", now),
|
||||
"ja3_hashes": kwargs.get("ja3_hashes"),
|
||||
"hassh_hashes": kwargs.get("hassh_hashes"),
|
||||
"payload_simhashes": kwargs.get("payload_simhashes"),
|
||||
"c2_endpoints": kwargs.get("c2_endpoints"),
|
||||
})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_campaign(repo):
|
||||
await repo.create_campaign({"uuid": "c1", "confidence": 0.8})
|
||||
row = await repo.get_campaign_by_uuid("c1")
|
||||
assert row is not None
|
||||
assert row["uuid"] == "c1"
|
||||
assert row["confidence"] == 0.8
|
||||
assert row["merged_into_uuid"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_campaign_follows_merge_chain(repo):
|
||||
await repo.create_campaign({"uuid": "c1"})
|
||||
await repo.create_campaign({"uuid": "c2"})
|
||||
await repo.update_campaign_merged_into("c2", "c1")
|
||||
|
||||
# Querying the loser returns the winner.
|
||||
row = await repo.get_campaign_by_uuid("c2")
|
||||
assert row["uuid"] == "c1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_and_count_excludes_merged_out(repo):
|
||||
await repo.create_campaign({"uuid": "c1"})
|
||||
await repo.create_campaign({"uuid": "c2"})
|
||||
await repo.update_campaign_merged_into("c2", "c1")
|
||||
|
||||
listed = await repo.list_campaigns()
|
||||
assert {c["uuid"] for c in listed} == {"c1"}
|
||||
assert await repo.count_campaigns() == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_campaigns_includes_merged_out(repo):
|
||||
await repo.create_campaign({"uuid": "c1"})
|
||||
await repo.create_campaign({"uuid": "c2"})
|
||||
await repo.update_campaign_merged_into("c2", "c1")
|
||||
|
||||
all_campaigns = await repo.list_all_campaigns()
|
||||
assert {c["uuid"] for c in all_campaigns} == {"c1", "c2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unknown_campaign_returns_none(repo):
|
||||
assert await repo.get_campaign_by_uuid("nope") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_campaign_merged_into_can_revoke(repo):
|
||||
await repo.create_campaign({"uuid": "c1"})
|
||||
await repo.create_campaign({"uuid": "c2"})
|
||||
await repo.update_campaign_merged_into("c2", "c1")
|
||||
# Revoke
|
||||
await repo.update_campaign_merged_into("c2", None)
|
||||
|
||||
row = await repo.get_campaign_by_uuid("c2")
|
||||
assert row["uuid"] == "c2"
|
||||
assert row["merged_into_uuid"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_identity_campaign_id_links_and_unlinks(repo):
|
||||
await repo.create_campaign({"uuid": "c1"})
|
||||
await _create_identity(repo, "i1")
|
||||
|
||||
await repo.set_identity_campaign_id("i1", "c1")
|
||||
linked = await repo.list_identities_for_campaign("c1")
|
||||
assert {i["uuid"] for i in linked} == {"i1"}
|
||||
assert await repo.count_identities_for_campaign("c1") == 1
|
||||
|
||||
await repo.set_identity_campaign_id("i1", None)
|
||||
assert await repo.count_identities_for_campaign("c1") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_identities_for_clustering_projects_expected_fields(repo):
|
||||
await _create_identity(
|
||||
repo, "i1",
|
||||
ja3_hashes='["ja3-a"]',
|
||||
hassh_hashes='["hassh-a"]',
|
||||
payload_simhashes='["dead"]',
|
||||
c2_endpoints='["1.2.3.4:443"]',
|
||||
)
|
||||
rows = await repo.list_identities_for_clustering()
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
assert row["uuid"] == "i1"
|
||||
assert row["ja3_hashes"] == '["ja3-a"]'
|
||||
assert row["hassh_hashes"] == '["hassh-a"]'
|
||||
assert row["payload_simhashes"] == '["dead"]'
|
||||
assert row["c2_endpoints"] == '["1.2.3.4:443"]'
|
||||
assert row["campaign_id"] is None
|
||||
assert row["merged_into_uuid"] is None
|
||||
assert row["first_seen_at"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_identities_for_clustering_respects_limit(repo):
|
||||
for n in range(3):
|
||||
await _create_identity(repo, f"i{n}")
|
||||
assert len(await repo.list_identities_for_clustering(limit=2)) == 2
|
||||
assert len(await repo.list_identities_for_clustering()) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_identities_for_campaign_paginates(repo):
|
||||
await repo.create_campaign({"uuid": "c1"})
|
||||
for n in range(3):
|
||||
await _create_identity(repo, f"i{n}")
|
||||
await repo.set_identity_campaign_id(f"i{n}", "c1")
|
||||
|
||||
page = await repo.list_identities_for_campaign("c1", limit=2, offset=0)
|
||||
assert len(page) == 2
|
||||
page2 = await repo.list_identities_for_campaign("c1", limit=2, offset=2)
|
||||
assert len(page2) == 1
|
||||
226
tests/db/test_credential_reuse.py
Normal file
226
tests/db/test_credential_reuse.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""CredentialReuse repo tests — upsert idempotency, list pagination, FK backfill."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.web.db.factory import get_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path: Path):
|
||||
r = get_repository(db_path=str(tmp_path / "reuse.db"))
|
||||
await r.initialize()
|
||||
return r
|
||||
|
||||
|
||||
def _sha256(s: str) -> str:
|
||||
return hashlib.sha256(s.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _seed_credential(repo, **overrides):
|
||||
base = {
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"decky_name": "decky-01",
|
||||
"service": "ssh",
|
||||
"principal": "root",
|
||||
"secret_sha256": _sha256("hunter2"),
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"secret_printable": "hunter2",
|
||||
"fields": {},
|
||||
}
|
||||
base.update(overrides)
|
||||
return await repo.upsert_credential(base)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_inserts_first_observation(repo) -> None:
|
||||
sha = _sha256("hunter2")
|
||||
out = await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="10.0.0.5",
|
||||
decky="decky-01", service="ssh", attempt_count=1,
|
||||
)
|
||||
assert out is not None
|
||||
assert out["inserted"] is True
|
||||
assert out["target_count"] == 1
|
||||
assert out["confidence"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_grows_target_count_across_services(repo) -> None:
|
||||
"""Same secret on two distinct (decky, service) pairs → target_count=2.
|
||||
|
||||
target_count is recomputed from the credentials table, so the test
|
||||
must seed actual Credential rows first.
|
||||
"""
|
||||
sha = _sha256("p4ssw0rd")
|
||||
await _seed_credential(repo, secret_sha256=sha, decky_name="d1", service="ssh")
|
||||
await _seed_credential(repo, secret_sha256=sha, decky_name="d2", service="ftp")
|
||||
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="10.0.0.5",
|
||||
decky="d1", service="ssh", attempt_count=1,
|
||||
)
|
||||
out = await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="10.0.0.5",
|
||||
decky="d2", service="ftp", attempt_count=1,
|
||||
)
|
||||
assert out["inserted"] is False
|
||||
assert out["changed"] is True
|
||||
assert out["target_count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_dedups_same_decky_service(repo) -> None:
|
||||
"""Repeated upserts for the same (decky, service) don't grow target_count."""
|
||||
sha = _sha256("samepw")
|
||||
await _seed_credential(repo, secret_sha256=sha)
|
||||
for _ in range(3):
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="10.0.0.5",
|
||||
decky="decky-01", service="ssh", attempt_count=1,
|
||||
)
|
||||
rows = (await repo.list_credential_reuses(min_target_count=1))[1]
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["target_count"] == 1
|
||||
assert rows[0]["attempt_count"] == 3
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_merges_attacker_lists(repo) -> None:
|
||||
"""Distinct attacker_uuid/ip values accumulate into the JSON lists."""
|
||||
sha = _sha256("shared")
|
||||
await _seed_credential(repo, secret_sha256=sha, attacker_ip="1.1.1.1")
|
||||
await _seed_credential(
|
||||
repo, secret_sha256=sha, attacker_ip="2.2.2.2", decky_name="d2",
|
||||
)
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid="uuid-A", attacker_ip="1.1.1.1",
|
||||
decky="decky-01", service="ssh", attempt_count=1,
|
||||
)
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid="uuid-B", attacker_ip="2.2.2.2",
|
||||
decky="d2", service="ssh", attempt_count=1,
|
||||
)
|
||||
rows = (await repo.list_credential_reuses(min_target_count=1))[1]
|
||||
assert set(rows[0]["attacker_uuids"]) == {"uuid-A", "uuid-B"}
|
||||
assert set(rows[0]["attacker_ips"]) == {"1.1.1.1", "2.2.2.2"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_null_principal_uniqueness(repo) -> None:
|
||||
"""Two upserts with principal=None go to the same row, not two rows."""
|
||||
sha = _sha256("redis-auth")
|
||||
await _seed_credential(repo, secret_sha256=sha, service="redis", principal=None)
|
||||
for _ in range(2):
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal=None,
|
||||
attacker_uuid=None, attacker_ip="1.1.1.1",
|
||||
decky="decky-01", service="redis", attempt_count=1,
|
||||
)
|
||||
rows = (await repo.list_credential_reuses(min_target_count=1))[1]
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["principal"] is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_filters_by_min_target_count(repo) -> None:
|
||||
"""min_target_count=2 hides 1-target findings."""
|
||||
sha = _sha256("only-once")
|
||||
await _seed_credential(repo, secret_sha256=sha)
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="1.1.1.1",
|
||||
decky="decky-01", service="ssh", attempt_count=1,
|
||||
)
|
||||
total, rows = await repo.list_credential_reuses(min_target_count=2)
|
||||
assert total == 0
|
||||
assert rows == []
|
||||
total, _ = await repo.list_credential_reuses(min_target_count=1)
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pagination_orders_by_target_count_desc(repo) -> None:
|
||||
sha_a = _sha256("a")
|
||||
sha_b = _sha256("b")
|
||||
# secret a → 1 target
|
||||
await _seed_credential(repo, secret_sha256=sha_a)
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha_a, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="1.1.1.1",
|
||||
decky="d1", service="ssh", attempt_count=1,
|
||||
)
|
||||
# secret b → 2 targets
|
||||
await _seed_credential(repo, secret_sha256=sha_b, service="ssh")
|
||||
await _seed_credential(repo, secret_sha256=sha_b, service="ftp", decky_name="d2")
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha_b, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="1.1.1.1",
|
||||
decky="decky-01", service="ssh", attempt_count=1,
|
||||
)
|
||||
await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha_b, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="1.1.1.1",
|
||||
decky="d2", service="ftp", attempt_count=1,
|
||||
)
|
||||
total, rows = await repo.list_credential_reuses(min_target_count=1)
|
||||
assert total == 2
|
||||
assert rows[0]["secret_sha256"] == sha_b # higher target_count first
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_by_id_roundtrip(repo) -> None:
|
||||
sha = _sha256("rt")
|
||||
await _seed_credential(repo, secret_sha256=sha)
|
||||
out = await repo.upsert_credential_reuse(
|
||||
secret_sha256=sha, secret_kind="plaintext", principal="root",
|
||||
attacker_uuid=None, attacker_ip="1.1.1.1",
|
||||
decky="decky-01", service="ssh", attempt_count=1,
|
||||
)
|
||||
fetched = await repo.get_credential_reuse_by_id(out["id"])
|
||||
assert fetched is not None
|
||||
assert fetched["id"] == out["id"]
|
||||
assert fetched["secret_sha256"] == sha
|
||||
assert isinstance(fetched["deckies"], list)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_by_id_missing_returns_none(repo) -> None:
|
||||
assert await repo.get_credential_reuse_by_id("nope") is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_credential_attacker_uuid_backfills_only_nulls(repo) -> None:
|
||||
"""The profiler hook must backfill attacker_uuid only on rows where it
|
||||
is currently null — pre-existing UUIDs must not be overwritten."""
|
||||
sha = _sha256("backfill")
|
||||
await _seed_credential(repo, secret_sha256=sha, attacker_ip="9.9.9.9")
|
||||
await _seed_credential(
|
||||
repo, secret_sha256=sha, attacker_ip="9.9.9.9",
|
||||
service="ftp", decky_name="d2",
|
||||
)
|
||||
# Backfill: both null, both should update.
|
||||
n = await repo.update_credential_attacker_uuid("9.9.9.9", "uuid-9")
|
||||
assert n == 2
|
||||
|
||||
# Second call: both already set, nothing should change.
|
||||
n2 = await repo.update_credential_attacker_uuid("9.9.9.9", "uuid-other")
|
||||
assert n2 == 0
|
||||
|
||||
rows = await repo.get_credentials_for_attacker("9.9.9.9")
|
||||
assert all(r["attacker_uuid"] == "uuid-9" for r in rows)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_credential_attacker_uuid_no_match(repo) -> None:
|
||||
n = await repo.update_credential_attacker_uuid("0.0.0.0", "uuid-x")
|
||||
assert n == 0
|
||||
168
tests/db/test_credentials.py
Normal file
168
tests/db/test_credentials.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Credential model + repo tests — upsert, dedup, cross-service reuse."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.web.db.factory import get_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path: Path):
|
||||
r = get_repository(db_path=str(tmp_path / "creds.db"))
|
||||
await r.initialize()
|
||||
return r
|
||||
|
||||
|
||||
def _sha256(s: str) -> str:
|
||||
return hashlib.sha256(s.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_inserts_then_dedups(repo) -> None:
|
||||
"""Same dedup tuple twice → one row, attempt_count=2."""
|
||||
payload = {
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"decky_name": "decky-01",
|
||||
"service": "ssh",
|
||||
"principal": "root",
|
||||
"secret_sha256": _sha256("hunter2"),
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"secret_printable": "hunter2",
|
||||
"fields": {"user": "root"},
|
||||
}
|
||||
rid_a = await repo.upsert_credential(payload)
|
||||
rid_b = await repo.upsert_credential(payload)
|
||||
assert rid_a == rid_b
|
||||
rows = await repo.get_credentials()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["attempt_count"] == 2
|
||||
assert rows[0]["fields"] == {"user": "root"} # preserved
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_different_principal_creates_new_row(repo) -> None:
|
||||
base = {
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"decky_name": "decky-01",
|
||||
"service": "ssh",
|
||||
"secret_sha256": _sha256("hunter2"),
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"secret_printable": "hunter2",
|
||||
"fields": {},
|
||||
}
|
||||
await repo.upsert_credential({**base, "principal": "root"})
|
||||
await repo.upsert_credential({**base, "principal": "admin"})
|
||||
rows = await repo.get_credentials()
|
||||
assert len(rows) == 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_null_principal_dedups_independently(repo) -> None:
|
||||
"""principal=None and principal='root' are different keys."""
|
||||
base = {
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"decky_name": "decky-01",
|
||||
"service": "ssh",
|
||||
"secret_sha256": _sha256("hunter2"),
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"secret_printable": "hunter2",
|
||||
"fields": {},
|
||||
}
|
||||
await repo.upsert_credential({**base, "principal": None})
|
||||
await repo.upsert_credential({**base, "principal": None}) # dedupes
|
||||
await repo.upsert_credential({**base, "principal": "root"})
|
||||
rows = await repo.get_credentials()
|
||||
assert len(rows) == 2
|
||||
null_row = next(r for r in rows if r["principal"] is None)
|
||||
assert null_row["attempt_count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_service_reuse_query(repo) -> None:
|
||||
"""Same secret across SSH + FTP + SMTP → reuse query returns all three."""
|
||||
secret = "hunter2"
|
||||
sha = _sha256(secret)
|
||||
services = [
|
||||
("ssh", "decky-01", "root"),
|
||||
("ftp", "decky-02", "anonymous"),
|
||||
("smtp", "decky-03", "acme.com"),
|
||||
]
|
||||
for svc, decky, principal in services:
|
||||
await repo.upsert_credential({
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"decky_name": decky,
|
||||
"service": svc,
|
||||
"principal": principal,
|
||||
"secret_sha256": sha,
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"secret_printable": secret,
|
||||
"fields": {},
|
||||
})
|
||||
reuse = await repo.get_credential_attempts_for_secret(sha)
|
||||
assert {r["service"] for r in reuse} == {"ssh", "ftp", "smtp"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_credentials_for_attacker(repo) -> None:
|
||||
base = {
|
||||
"decky_name": "decky-01",
|
||||
"service": "ssh",
|
||||
"principal": "root",
|
||||
"secret_sha256": _sha256("hunter2"),
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"secret_printable": "hunter2",
|
||||
"fields": {},
|
||||
}
|
||||
await repo.upsert_credential({**base, "attacker_ip": "10.0.0.5"})
|
||||
await repo.upsert_credential({**base, "attacker_ip": "10.0.0.6"})
|
||||
rows = await repo.get_credentials_for_attacker("10.0.0.5")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["attacker_ip"] == "10.0.0.5"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_secret_kind_dedups_independently(repo) -> None:
|
||||
"""Same sha256, same principal — different secret_kind = different row.
|
||||
|
||||
Two rows with the same content-addressable hash but different kinds
|
||||
represent fundamentally different credentials (e.g. a plaintext
|
||||
password that happens to hash to the same value as a Postgres
|
||||
md5 challenge response is statistically impossible but semantically
|
||||
distinct anyway). Dedup must respect the kind boundary."""
|
||||
base = {
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"decky_name": "decky-01",
|
||||
"service": "ssh",
|
||||
"principal": "root",
|
||||
"secret_sha256": _sha256("hunter2"),
|
||||
"secret_b64": "aHVudGVyMg==",
|
||||
"fields": {},
|
||||
}
|
||||
await repo.upsert_credential({**base, "secret_kind": "plaintext"})
|
||||
await repo.upsert_credential({**base, "secret_kind": "postgres_md5_challenge"})
|
||||
rows = await repo.get_credentials()
|
||||
assert len(rows) == 2
|
||||
kinds = {r["secret_kind"] for r in rows}
|
||||
assert kinds == {"plaintext", "postgres_md5_challenge"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filters(repo) -> None:
|
||||
base_secret = _sha256("a")
|
||||
await repo.upsert_credential({
|
||||
"attacker_ip": "10.0.0.5", "decky_name": "decky-01", "service": "ssh",
|
||||
"principal": "root", "secret_sha256": base_secret,
|
||||
"secret_printable": "a", "fields": {},
|
||||
})
|
||||
await repo.upsert_credential({
|
||||
"attacker_ip": "10.0.0.5", "decky_name": "decky-01", "service": "ftp",
|
||||
"principal": "root", "secret_sha256": base_secret,
|
||||
"secret_printable": "a", "fields": {},
|
||||
})
|
||||
rows = await repo.get_credentials(service="ssh")
|
||||
assert len(rows) == 1 and rows[0]["service"] == "ssh"
|
||||
assert await repo.get_total_credentials(service="ssh") == 1
|
||||
assert await repo.get_total_credentials() == 2
|
||||
44
tests/db/test_factory.py
Normal file
44
tests/db/test_factory.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Unit tests for the repository factory — dispatch on DECNET_DB_TYPE.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from decnet.web.db.factory import get_repository
|
||||
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||
from decnet.web.db.mysql.repository import MySQLRepository
|
||||
|
||||
|
||||
def test_factory_defaults_to_sqlite(monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("DECNET_DB_TYPE", raising=False)
|
||||
repo = get_repository(db_path=str(tmp_path / "t.db"))
|
||||
assert isinstance(repo, SQLiteRepository)
|
||||
|
||||
|
||||
def test_factory_sqlite_explicit(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("DECNET_DB_TYPE", "sqlite")
|
||||
repo = get_repository(db_path=str(tmp_path / "t.db"))
|
||||
assert isinstance(repo, SQLiteRepository)
|
||||
|
||||
|
||||
def test_factory_mysql_branch(monkeypatch):
|
||||
"""MySQL branch must import and instantiate without a live server.
|
||||
|
||||
Engine creation is lazy in SQLAlchemy — no socket is opened until the
|
||||
first query — so the repository constructs cleanly here.
|
||||
"""
|
||||
monkeypatch.setenv("DECNET_DB_TYPE", "mysql")
|
||||
monkeypatch.setenv("DECNET_DB_URL", "mysql+asyncmy://u:p@127.0.0.1:3306/x")
|
||||
repo = get_repository()
|
||||
assert isinstance(repo, MySQLRepository)
|
||||
|
||||
|
||||
def test_factory_is_case_insensitive(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("DECNET_DB_TYPE", "SQLite")
|
||||
repo = get_repository(db_path=str(tmp_path / "t.db"))
|
||||
assert isinstance(repo, SQLiteRepository)
|
||||
|
||||
|
||||
def test_factory_rejects_unknown_type(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_DB_TYPE", "cassandra")
|
||||
with pytest.raises(ValueError, match="Unsupported database type"):
|
||||
get_repository()
|
||||
206
tests/db/test_identity_schema.py
Normal file
206
tests/db/test_identity_schema.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Schema-only tests for the AttackerIdentity table and the
|
||||
attackers.identity_id FK.
|
||||
|
||||
The identities table ships empty in this PR; the clusterer that
|
||||
populates it is a separate downstream effort. These tests verify only
|
||||
that the schema lands correctly:
|
||||
|
||||
* the table exists after metadata.create_all()
|
||||
* attackers.identity_id is nullable and indexed
|
||||
* the FK references attacker_identities.uuid
|
||||
* an attacker row may be inserted with identity_id=NULL
|
||||
* an identity row may be inserted with all clusterer-populated columns NULL
|
||||
|
||||
If any of these regress, downstream API/frontend/clusterer work all
|
||||
stop. See development/IDENTITY_RESOLUTION.md §Schema.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import inspect
|
||||
from sqlmodel import Session
|
||||
|
||||
from decnet.web.db.models import Attacker, AttackerIdentity
|
||||
from decnet.web.db.sqlite.database import get_sync_engine, init_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_path(tmp_path) -> str:
|
||||
p = tmp_path / "schema.db"
|
||||
init_db(str(p))
|
||||
return str(p)
|
||||
|
||||
|
||||
def test_attacker_identities_table_exists(db_path: str) -> None:
|
||||
engine = get_sync_engine(db_path)
|
||||
inspector = inspect(engine)
|
||||
assert "attacker_identities" in inspector.get_table_names()
|
||||
|
||||
|
||||
def test_attackers_identity_id_column_present_and_nullable(db_path: str) -> None:
|
||||
engine = get_sync_engine(db_path)
|
||||
inspector = inspect(engine)
|
||||
columns = {c["name"]: c for c in inspector.get_columns("attackers")}
|
||||
assert "identity_id" in columns, "attackers.identity_id column missing"
|
||||
assert columns["identity_id"]["nullable"] is True, (
|
||||
"attackers.identity_id must be nullable — clusterer hasn't run yet on existing rows"
|
||||
)
|
||||
|
||||
|
||||
def test_attackers_identity_id_is_indexed(db_path: str) -> None:
|
||||
engine = get_sync_engine(db_path)
|
||||
inspector = inspect(engine)
|
||||
indexes = inspector.get_indexes("attackers")
|
||||
indexed_columns = {col for idx in indexes for col in idx["column_names"]}
|
||||
assert "identity_id" in indexed_columns, (
|
||||
"attackers.identity_id needs an index for join performance "
|
||||
"(IdentityDetail aggregates by identity_id; without an index "
|
||||
"every lookup is a full scan)"
|
||||
)
|
||||
|
||||
|
||||
def test_attackers_identity_id_fk_targets_attacker_identities(db_path: str) -> None:
|
||||
engine = get_sync_engine(db_path)
|
||||
inspector = inspect(engine)
|
||||
fks = inspector.get_foreign_keys("attackers")
|
||||
identity_fks = [
|
||||
fk for fk in fks if "identity_id" in fk["constrained_columns"]
|
||||
]
|
||||
assert identity_fks, "no FK on attackers.identity_id"
|
||||
assert identity_fks[0]["referred_table"] == "attacker_identities"
|
||||
assert identity_fks[0]["referred_columns"] == ["uuid"]
|
||||
|
||||
|
||||
def test_identity_schema_version_default_is_1(db_path: str) -> None:
|
||||
"""
|
||||
schema_version is non-negotiable from day one. Federation gossip
|
||||
in V2 will share identity vectors across operators; bumping the
|
||||
feature definitions without a version field silently poisons
|
||||
receivers. Default must be 1 on insert.
|
||||
"""
|
||||
engine = get_sync_engine(db_path)
|
||||
with Session(engine) as session:
|
||||
identity = AttackerIdentity(uuid=str(uuid.uuid4()))
|
||||
session.add(identity)
|
||||
session.commit()
|
||||
session.refresh(identity)
|
||||
assert identity.schema_version == 1
|
||||
|
||||
|
||||
def test_attacker_can_be_inserted_with_null_identity_id(db_path: str) -> None:
|
||||
"""
|
||||
Existing code paths (profiler, correlator) keep upserting attackers
|
||||
without setting identity_id. They MUST work unchanged — the
|
||||
identity_id column is nullable and remains NULL until the clusterer
|
||||
runs.
|
||||
"""
|
||||
engine = get_sync_engine(db_path)
|
||||
with Session(engine) as session:
|
||||
now = datetime.now(timezone.utc)
|
||||
att = Attacker(
|
||||
uuid=str(uuid.uuid4()),
|
||||
ip="203.0.113.4",
|
||||
first_seen=now,
|
||||
last_seen=now,
|
||||
)
|
||||
session.add(att)
|
||||
session.commit()
|
||||
session.refresh(att)
|
||||
assert att.identity_id is None
|
||||
|
||||
|
||||
def test_identity_with_all_clusterer_fields_null(db_path: str) -> None:
|
||||
"""
|
||||
The table ships empty; even when the clusterer eventually inserts
|
||||
rows, it may write a row with most fields null (e.g. before
|
||||
fingerprint summaries have been computed). Every clusterer-populated
|
||||
column must accept NULL.
|
||||
"""
|
||||
engine = get_sync_engine(db_path)
|
||||
with Session(engine) as session:
|
||||
identity = AttackerIdentity(uuid=str(uuid.uuid4()))
|
||||
session.add(identity)
|
||||
session.commit()
|
||||
session.refresh(identity)
|
||||
for field in (
|
||||
"campaign_id",
|
||||
"first_seen_at",
|
||||
"last_seen_at",
|
||||
"confidence",
|
||||
"ja3_hashes",
|
||||
"hassh_hashes",
|
||||
"payload_simhashes",
|
||||
"c2_endpoints",
|
||||
"kd_digraph_simhash",
|
||||
"merged_into_uuid",
|
||||
"notes",
|
||||
):
|
||||
assert getattr(identity, field) is None, (
|
||||
f"AttackerIdentity.{field} must default to None — "
|
||||
f"the table ships empty pre-clusterer"
|
||||
)
|
||||
# observation_count is denormalized; defaults to 0 (not NULL).
|
||||
assert identity.observation_count == 0
|
||||
|
||||
|
||||
def test_attacker_identity_link_round_trip(db_path: str) -> None:
|
||||
"""
|
||||
End-to-end: insert an identity, link an attacker observation to
|
||||
it via identity_id FK, query both sides. Smoke-tests the schema
|
||||
works as designed without invoking the production repo layer.
|
||||
"""
|
||||
engine = get_sync_engine(db_path)
|
||||
with Session(engine) as session:
|
||||
identity = AttackerIdentity(uuid=str(uuid.uuid4()))
|
||||
session.add(identity)
|
||||
session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
att = Attacker(
|
||||
uuid=str(uuid.uuid4()),
|
||||
ip="203.0.113.5",
|
||||
first_seen=now,
|
||||
last_seen=now,
|
||||
identity_id=identity.uuid,
|
||||
)
|
||||
session.add(att)
|
||||
session.commit()
|
||||
session.refresh(att)
|
||||
assert att.identity_id == identity.uuid
|
||||
|
||||
|
||||
def test_identity_id_fk_constraint_blocks_orphans(db_path: str) -> None:
|
||||
"""
|
||||
Inserting an attacker with identity_id pointing at a nonexistent
|
||||
identity must fail. The clusterer should never write an orphan
|
||||
link; the schema enforces that contract.
|
||||
|
||||
SQLite's PRAGMA foreign_keys is off by default at the connection
|
||||
level; we enable it explicitly here so the test reflects the
|
||||
contract production code relies on (via the same PRAGMA on its
|
||||
connections).
|
||||
"""
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
with pytest.raises(sqlite3.IntegrityError):
|
||||
conn.execute(
|
||||
"INSERT INTO attackers (uuid, ip, first_seen, last_seen, "
|
||||
"event_count, service_count, decky_count, services, deckies, "
|
||||
"is_traversal, bounty_count, credential_count, fingerprints, "
|
||||
"commands, updated_at, identity_id) VALUES "
|
||||
"(?, ?, ?, ?, 0, 0, 0, '[]', '[]', 0, 0, 0, '[]', '[]', ?, ?)",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
"203.0.113.6",
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
"ffffffff-ffff-ffff-ffff-ffffffffffff", # nonexistent identity
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
Reference in New Issue
Block a user