merge: testing → main (reconcile 2-week divergence)

This commit is contained in:
2026-04-28 18:36:00 -04:00
parent 499836c9e4
commit 862e4dbb31
1235 changed files with 160255 additions and 7996 deletions

0
tests/db/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,70 @@
"""
Inspection-level tests for the MySQL-dialect SQL emitted by MySQLRepository.
We compile the SQLAlchemy statements against the MySQL dialect and assert on
the string form — no live MySQL server is required.
"""
import pytest
from sqlalchemy import func, select, literal_column
from sqlalchemy.dialects import mysql
from sqlmodel.sql.expression import SelectOfScalar
from decnet.web.db.models import Log
def _compile(stmt) -> str:
"""Compile a statement to MySQL-dialect SQL with literal values inlined."""
return str(stmt.compile(
dialect=mysql.dialect(),
compile_kwargs={"literal_binds": True},
))
def test_mysql_histogram_uses_from_unixtime_bucket():
"""The MySQL dialect must bucket with UNIX_TIMESTAMP DIV N * N wrapped in FROM_UNIXTIME."""
bucket_seconds = 900 # 15 min
bucket_expr = literal_column(
f"FROM_UNIXTIME((UNIX_TIMESTAMP(timestamp) DIV {bucket_seconds}) * {bucket_seconds})"
).label("bucket_time")
stmt: SelectOfScalar = select(bucket_expr, func.count().label("count")).select_from(Log)
sql = _compile(stmt)
assert "FROM_UNIXTIME" in sql
assert "UNIX_TIMESTAMP" in sql
assert "DIV 900" in sql
# Sanity: SQLite-only strftime must NOT appear in the MySQL-dialect output.
assert "strftime" not in sql
assert "unixepoch" not in sql
def test_mysql_json_unquote_predicate_shape():
"""MySQL JSON filter uses JSON_UNQUOTE(JSON_EXTRACT(...))."""
from decnet.web.db.mysql.repository import MySQLRepository
# Build a dummy instance without touching the engine. We only need _json_field_equals,
# which is a pure function of the key.
repo = MySQLRepository.__new__(MySQLRepository) # bypass __init__ / no DB connection
predicate = repo._json_field_equals("username")
# text() objects carry their literal SQL in .text
assert "JSON_UNQUOTE" in predicate.text
assert "JSON_EXTRACT(fields, '$.username')" in predicate.text
assert ":val" in predicate.text
@pytest.mark.parametrize("key", ["user", "port", "sess_id"])
def test_mysql_json_predicate_safe_for_reasonable_keys(key):
"""Keys matching [A-Za-z0-9_]+ are inserted verbatim; verify no SQL breakage."""
from decnet.web.db.mysql.repository import MySQLRepository
repo = MySQLRepository.__new__(MySQLRepository)
pred = repo._json_field_equals(key)
assert f"'$.{key}'" in pred.text
def test_sqlite_histogram_still_uses_strftime():
"""Regression guard — SQLite implementation must keep its strftime-based bucket."""
from decnet.web.db.sqlite.repository import SQLiteRepository
import inspect
src = inspect.getsource(SQLiteRepository.get_log_histogram)
assert "strftime" in src
assert "unixepoch" in src

View File

@@ -0,0 +1,234 @@
"""
Tests for MySQLRepository._migrate_column_types().
No live MySQL server required — uses an in-memory SQLite engine that exposes
the same information_schema-style query surface via a mocked connection, plus
an integration-style test using a real async engine over aiosqlite (which
ignores the TEXT/MEDIUMTEXT distinction but verifies the ALTER path is called
and idempotent).
The ALTER TABLE branch is tested via unittest.mock: we intercept the
information_schema query result and assert the correct MODIFY COLUMN
statements are issued.
"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, call
from decnet.web.db.mysql.repository import MySQLRepository
# ── helpers ──────────────────────────────────────────────────────────────────
def _make_repo() -> MySQLRepository:
"""Construct a MySQLRepository without touching any real DB."""
return MySQLRepository.__new__(MySQLRepository)
# ── _migrate_column_types ─────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_migrate_column_types_issues_alter_for_text_columns():
"""When information_schema reports TEXT columns, ALTER TABLE is called for each."""
repo = _make_repo()
# Rows returned by the information_schema query: two TEXT columns found
fake_rows = [
("attackers", "commands"),
("attackers", "fingerprints"),
("state", "value"),
]
exec_results: list[str] = []
async def fake_execute(stmt):
sql = str(stmt)
if "information_schema" in sql:
result = MagicMock()
result.fetchall.return_value = fake_rows
return result
# Capture ALTER TABLE calls
exec_results.append(sql)
return MagicMock()
fake_conn = AsyncMock()
fake_conn.execute.side_effect = fake_execute
fake_ctx = AsyncMock()
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
fake_ctx.__aexit__ = AsyncMock(return_value=False)
repo.engine = MagicMock()
repo.engine.begin.return_value = fake_ctx
await repo._migrate_column_types()
# Three ALTER TABLE statements expected, one per TEXT column returned
assert len(exec_results) == 3
assert any("`commands` MEDIUMTEXT" in s for s in exec_results)
assert any("`fingerprints` MEDIUMTEXT" in s for s in exec_results)
assert any("`value` MEDIUMTEXT" in s for s in exec_results)
# Verify NOT NULL is preserved
assert all("NOT NULL" in s for s in exec_results)
@pytest.mark.asyncio
async def test_migrate_column_types_no_alter_when_already_mediumtext():
"""When information_schema returns no TEXT rows, no ALTER is issued."""
repo = _make_repo()
exec_results: list[str] = []
async def fake_execute(stmt):
sql = str(stmt)
if "information_schema" in sql:
result = MagicMock()
result.fetchall.return_value = [] # nothing to migrate
return result
exec_results.append(sql)
return MagicMock()
fake_conn = AsyncMock()
fake_conn.execute.side_effect = fake_execute
fake_ctx = AsyncMock()
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
fake_ctx.__aexit__ = AsyncMock(return_value=False)
repo.engine = MagicMock()
repo.engine.begin.return_value = fake_ctx
await repo._migrate_column_types()
assert exec_results == [], "No ALTER TABLE should be issued when columns are already MEDIUMTEXT"
@pytest.mark.asyncio
async def test_migrate_column_types_idempotent_on_repeated_calls():
"""Calling _migrate_column_types twice is safe: second call is a no-op."""
repo = _make_repo()
call_count = 0
async def fake_execute(stmt):
nonlocal call_count
sql = str(stmt)
if "information_schema" in sql:
result = MagicMock()
# First call: two TEXT columns; second call: zero (already migrated)
call_count += 1
result.fetchall.return_value = (
[("attackers", "commands")] if call_count == 1 else []
)
return result
return MagicMock()
def _make_ctx():
fake_conn = AsyncMock()
fake_conn.execute.side_effect = fake_execute
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=fake_conn)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx
repo.engine = MagicMock()
repo.engine.begin.side_effect = _make_ctx
await repo._migrate_column_types()
await repo._migrate_column_types() # second call must not raise
@pytest.mark.asyncio
async def test_migrate_column_types_default_clause_per_column():
"""Each attacker column gets DEFAULT '[]'; state.value gets no DEFAULT."""
repo = _make_repo()
all_text_rows = [
("attackers", "commands"),
("attackers", "fingerprints"),
("attackers", "services"),
("attackers", "deckies"),
("state", "value"),
]
alter_stmts: list[str] = []
async def fake_execute(stmt):
sql = str(stmt)
if "information_schema" in sql:
result = MagicMock()
result.fetchall.return_value = all_text_rows
return result
alter_stmts.append(sql)
return MagicMock()
fake_conn = AsyncMock()
fake_conn.execute.side_effect = fake_execute
fake_ctx = AsyncMock()
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
fake_ctx.__aexit__ = AsyncMock(return_value=False)
repo.engine = MagicMock()
repo.engine.begin.return_value = fake_ctx
await repo._migrate_column_types()
attacker_alters = [s for s in alter_stmts if "`attackers`" in s]
state_alters = [s for s in alter_stmts if "`state`" in s]
assert len(attacker_alters) == 4
assert len(state_alters) == 1
for stmt in attacker_alters:
assert "DEFAULT '[]'" in stmt, f"Missing DEFAULT '[]' in: {stmt}"
# state.value has no DEFAULT in the schema
assert "DEFAULT" not in state_alters[0], \
f"Unexpected DEFAULT in state.value alter: {state_alters[0]}"
# ── initialize override ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_mysql_initialize_calls_migrate_column_types():
"""MySQLRepository.initialize() must invoke every migration helper in
the right order: attackers first, then session_profile (DEBT-036),
then column types, then seed the admin user."""
repo = _make_repo()
call_order: list[str] = []
async def fake_migrate_attackers():
call_order.append("migrate_attackers")
async def fake_migrate_session_profile():
call_order.append("migrate_session_profile")
async def fake_migrate_column_types():
call_order.append("migrate_column_types")
async def fake_ensure_admin():
call_order.append("ensure_admin")
repo._migrate_attackers_table = fake_migrate_attackers
repo._migrate_session_profile_table = fake_migrate_session_profile
repo._migrate_column_types = fake_migrate_column_types
repo._ensure_admin_user = fake_ensure_admin
# Stub engine.begin() so create_all is a no-op
fake_conn = AsyncMock()
fake_conn.run_sync = AsyncMock()
fake_ctx = AsyncMock()
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
fake_ctx.__aexit__ = AsyncMock(return_value=False)
repo.engine = MagicMock()
repo.engine.begin.return_value = fake_ctx
await repo.initialize()
assert call_order == [
"migrate_attackers",
"migrate_session_profile",
"migrate_column_types",
"ensure_admin",
], f"Unexpected call order: {call_order}"

View File

@@ -0,0 +1,78 @@
"""
Unit tests for decnet.web.db.mysql.database.build_mysql_url / resolve_url.
No MySQL server is required — these are pure URL-construction tests.
"""
import pytest
from decnet.web.db.mysql.database import build_mysql_url, resolve_url
def test_build_url_defaults(monkeypatch):
for v in ("DECNET_DB_HOST", "DECNET_DB_PORT", "DECNET_DB_NAME",
"DECNET_DB_USER", "DECNET_DB_PASSWORD", "DECNET_DB_URL"):
monkeypatch.delenv(v, raising=False)
# PYTEST_* is set by pytest itself, so empty password is allowed here.
url = build_mysql_url()
assert url == "mysql+asyncmy://decnet:@localhost:3306/decnet"
def test_build_url_from_env(monkeypatch):
monkeypatch.setenv("DECNET_DB_HOST", "db.internal")
monkeypatch.setenv("DECNET_DB_PORT", "3307")
monkeypatch.setenv("DECNET_DB_NAME", "decnet_prod")
monkeypatch.setenv("DECNET_DB_USER", "svc_decnet")
monkeypatch.setenv("DECNET_DB_PASSWORD", "hunter2")
url = build_mysql_url()
assert url == "mysql+asyncmy://svc_decnet:hunter2@db.internal:3307/decnet_prod"
def test_build_url_percent_encodes_password(monkeypatch):
"""Passwords with @ : / # etc must not break URL parsing."""
monkeypatch.setenv("DECNET_DB_PASSWORD", "p@ss:word/!#")
url = build_mysql_url(user="u", host="h", port=3306, database="d")
# @ → %40, : → %3A, / → %2F, # → %23, ! → %21
assert "p%40ss%3Aword%2F%21%23" in url
assert url.startswith("mysql+asyncmy://u:")
assert url.endswith("@h:3306/d")
def test_build_url_component_args_override_env(monkeypatch):
monkeypatch.setenv("DECNET_DB_HOST", "ignored")
monkeypatch.setenv("DECNET_DB_PASSWORD", "env-pw")
url = build_mysql_url(host="arg.host", user="arg-user", password="arg-pw",
port=9999, database="arg-db")
assert url == "mysql+asyncmy://arg-user:arg-pw@arg.host:9999/arg-db"
def test_resolve_url_prefers_explicit_arg(monkeypatch):
monkeypatch.setenv("DECNET_DB_URL", "mysql+asyncmy://env-url/x")
assert resolve_url("mysql+asyncmy://explicit/y") == "mysql+asyncmy://explicit/y"
def test_resolve_url_uses_env_url_before_components(monkeypatch):
monkeypatch.setenv("DECNET_DB_URL", "mysql+asyncmy://env-user:env-pw@env-host/env-db")
monkeypatch.setenv("DECNET_DB_HOST", "ignored.host")
assert resolve_url() == "mysql+asyncmy://env-user:env-pw@env-host/env-db"
def test_resolve_url_falls_back_to_components(monkeypatch):
monkeypatch.delenv("DECNET_DB_URL", raising=False)
monkeypatch.setenv("DECNET_DB_HOST", "fallback.host")
monkeypatch.setenv("DECNET_DB_PASSWORD", "pw")
url = resolve_url()
assert "fallback.host" in url
assert url.startswith("mysql+asyncmy://")
def test_build_url_requires_password_outside_pytest(monkeypatch):
"""Without a password and not in a pytest run, construction must fail loudly."""
for v in ("DECNET_DB_URL", "DECNET_DB_PASSWORD"):
monkeypatch.delenv(v, raising=False)
# Strip every PYTEST_* env var so the safety check trips.
import os
for k in list(os.environ):
if k.startswith("PYTEST"):
monkeypatch.delenv(k, raising=False)
with pytest.raises(ValueError, match="DECNET_DB_PASSWORD is not set"):
build_mysql_url()

202
tests/db/test_base_repo.py Normal file
View File

@@ -0,0 +1,202 @@
"""
Mock test for BaseRepository to ensure coverage of abstract pass lines.
"""
import pytest
from decnet.web.db.repository import BaseRepository
class DummyRepo(BaseRepository):
async def initialize(self) -> None: await super().initialize()
async def add_log(self, data): await super().add_log(data)
async def get_logs(self, **kw): await super().get_logs(**kw)
async def get_total_logs(self, **kw): await super().get_total_logs(**kw)
async def get_stats_summary(self): await super().get_stats_summary()
async def get_deckies(self): await super().get_deckies()
async def get_user_by_username(self, u): await super().get_user_by_username(u)
async def get_user_by_uuid(self, u): await super().get_user_by_uuid(u)
async def create_user(self, d): await super().create_user(d)
async def update_user_password(self, *a, **kw): await super().update_user_password(*a, **kw)
async def add_bounty(self, d): await super().add_bounty(d)
async def get_bounties(self, **kw): await super().get_bounties(**kw)
async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw)
async def upsert_credential(self, d): await super().upsert_credential(d); return 0
async def get_credentials(self, **kw): await super().get_credentials(**kw)
async def get_total_credentials(self, **kw): await super().get_total_credentials(**kw)
async def get_credentials_for_attacker(self, ip): await super().get_credentials_for_attacker(ip)
async def get_credential_attempts_for_secret(self, h): await super().get_credential_attempts_for_secret(h)
async def upsert_credential_reuse(self, **kw): await super().upsert_credential_reuse(**kw); return None
async def list_credential_reuses(self, **kw): await super().list_credential_reuses(**kw); return (0, [])
async def get_credential_reuse_by_id(self, i): await super().get_credential_reuse_by_id(i)
async def update_credential_attacker_uuid(self, ip, u): await super().update_credential_attacker_uuid(ip, u); return 0
async def get_state(self, k): await super().get_state(k)
async def set_state(self, k, v): await super().set_state(k, v)
async def get_max_log_id(self): await super().get_max_log_id()
async def get_logs_after_id(self, last_id, limit=500): await super().get_logs_after_id(last_id, limit)
async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip()
async def get_bounties_for_ips(self, ips): await super().get_bounties_for_ips(ips)
async def upsert_attacker(self, d): await super().upsert_attacker(d); return ""
async def upsert_attacker_behavior(self, u, d): await super().upsert_attacker_behavior(u, d)
async def get_attacker_behavior(self, u): await super().get_attacker_behavior(u)
async def get_behaviors_for_ips(self, ips): await super().get_behaviors_for_ips(ips)
async def upsert_session_profile(self, sid, data): await super().upsert_session_profile(sid, data)
async def get_session_profile(self, sid): await super().get_session_profile(sid)
async def increment_smtp_target(self, u, d): await super().increment_smtp_target(u, d)
async def list_smtp_targets(self, u): await super().list_smtp_targets(u)
async def get_attacker_stored_mail(self, u): await super().get_attacker_stored_mail(u)
async def smtp_target_seen(self, d): await super().smtp_target_seen(d)
async def get_attacker_by_uuid(self, u): await super().get_attacker_by_uuid(u)
async def get_attackers(self, **kw): await super().get_attackers(**kw)
async def get_total_attackers(self, **kw): await super().get_total_attackers(**kw)
async def get_attacker_commands(self, **kw): await super().get_attacker_commands(**kw)
async def list_users(self): await super().list_users()
async def delete_user(self, u): await super().delete_user(u)
async def update_user_role(self, u, r): await super().update_user_role(u, r)
async def purge_logs_and_bounties(self): await super().purge_logs_and_bounties()
async def get_attacker_artifacts(self, uuid): await super().get_attacker_artifacts(uuid)
async def get_attacker_transcripts(self, uuid): await super().get_attacker_transcripts(uuid)
async def get_session_log(self, sid): await super().get_session_log(sid)
# DEBT-041 / 3eb67c9 — attacker_intel re-key
async def find_credential_reuse_candidates(self, min_targets=2): await super().find_credential_reuse_candidates(min_targets); return []
async def get_attacker_intel_by_uuid(self, u): await super().get_attacker_intel_by_uuid(u)
async def get_unenriched_attackers(self, limit=100): await super().get_unenriched_attackers(limit)
async def upsert_attacker_intel(self, d): await super().upsert_attacker_intel(d); return ""
# Identity resolution (this PR)
async def get_identity_by_uuid(self, u): await super().get_identity_by_uuid(u)
async def list_identities(self, limit=50, offset=0): await super().list_identities(limit, offset); return []
async def count_identities(self): await super().count_identities(); return 0
async def list_observations_for_identity(self, u, limit=50, offset=0): await super().list_observations_for_identity(u, limit, offset); return []
async def count_observations_for_identity(self, u): await super().count_observations_for_identity(u); return 0
async def list_attackers_for_clustering(self, limit=None): await super().list_attackers_for_clustering(limit); return []
async def create_attacker_identity(self, row): await super().create_attacker_identity(row); return ""
async def set_attacker_identity_id(self, a, i): await super().set_attacker_identity_id(a, i)
async def list_all_identities(self): await super().list_all_identities(); return []
async def update_identity_merged_into(self, u, w): await super().update_identity_merged_into(u, w)
async def update_identity_fingerprints(self, u, *, ja3_hashes=None, hassh_hashes=None, tls_cert_sha256=None):
await super().update_identity_fingerprints(u, ja3_hashes=ja3_hashes, hassh_hashes=hassh_hashes, tls_cert_sha256=tls_cert_sha256)
# Campaign clustering (this PR)
async def get_campaign_by_uuid(self, u): await super().get_campaign_by_uuid(u)
async def list_campaigns(self, limit=50, offset=0): await super().list_campaigns(limit, offset); return []
async def count_campaigns(self): await super().count_campaigns(); return 0
async def list_identities_for_campaign(self, u, limit=50, offset=0): await super().list_identities_for_campaign(u, limit, offset); return []
async def count_identities_for_campaign(self, u): await super().count_identities_for_campaign(u); return 0
async def list_identities_for_clustering(self, limit=None): await super().list_identities_for_clustering(limit); return []
async def create_campaign(self, row): await super().create_campaign(row); return ""
async def set_identity_campaign_id(self, i, c): await super().set_identity_campaign_id(i, c)
async def list_all_campaigns(self): await super().list_all_campaigns(); return []
async def update_campaign_merged_into(self, u, w): await super().update_campaign_merged_into(u, w)
@pytest.mark.asyncio
async def test_base_repo_coverage():
dr = DummyRepo()
# Call all to hit 'pass' statements
await dr.initialize()
await dr.add_log({})
await dr.get_logs()
await dr.get_total_logs()
await dr.get_stats_summary()
await dr.get_deckies()
await dr.get_user_by_username("a")
await dr.get_user_by_uuid("a")
await dr.create_user({})
await dr.update_user_password("a", "b")
await dr.add_bounty({})
await dr.get_bounties()
await dr.get_total_bounties()
await dr.upsert_credential({})
await dr.get_credentials()
await dr.get_total_credentials()
await dr.get_credentials_for_attacker("1.2.3.4")
await dr.get_credential_attempts_for_secret("abc")
await dr.upsert_credential_reuse(
secret_sha256="x", secret_kind="plaintext", principal=None,
attacker_uuid=None, attacker_ip="1.2.3.4", decky="d", service="ssh",
attempt_count=1, ts=None,
)
await dr.list_credential_reuses()
await dr.get_credential_reuse_by_id("a")
await dr.update_credential_attacker_uuid("1.2.3.4", "u")
await dr.get_state("k")
await dr.set_state("k", "v")
await dr.get_max_log_id()
await dr.get_logs_after_id(0)
await dr.get_all_bounties_by_ip()
await dr.get_bounties_for_ips({"1.1.1.1"})
await dr.upsert_attacker({})
await dr.upsert_attacker_behavior("a", {})
await dr.get_attacker_behavior("a")
await dr.get_behaviors_for_ips({"1.1.1.1"})
await dr.upsert_session_profile("sid", {})
await dr.get_session_profile("sid")
await dr.increment_smtp_target("uuid", "corp.com")
await dr.list_smtp_targets("uuid")
await dr.get_attacker_stored_mail("uuid")
await dr.smtp_target_seen("corp.com")
await dr.get_attacker_by_uuid("a")
await dr.get_attackers()
await dr.get_total_attackers()
await dr.get_attacker_commands(uuid="a")
await dr.list_users()
await dr.delete_user("a")
await dr.update_user_role("a", "admin")
await dr.purge_logs_and_bounties()
await dr.get_attacker_artifacts("a")
await dr.get_attacker_transcripts("a")
await dr.get_session_log("a")
await dr.find_credential_reuse_candidates()
await dr.get_attacker_intel_by_uuid("a")
await dr.get_unenriched_attackers()
await dr.upsert_attacker_intel({"attacker_uuid": "a", "attacker_ip": "1.1.1.1"})
await dr.get_identity_by_uuid("a")
await dr.list_identities()
await dr.count_identities()
await dr.list_observations_for_identity("a")
await dr.count_observations_for_identity("a")
await dr.list_attackers_for_clustering()
await dr.create_attacker_identity({"uuid": "i"})
await dr.set_attacker_identity_id("a", "i")
await dr.list_all_identities()
await dr.update_identity_merged_into("a", "b")
await dr.update_identity_merged_into("a", None)
await dr.update_identity_fingerprints("a", ja3_hashes='["x"]', hassh_hashes=None, tls_cert_sha256='["y"]')
await dr.get_campaign_by_uuid("a")
await dr.list_campaigns()
await dr.count_campaigns()
await dr.list_identities_for_campaign("a")
await dr.count_identities_for_campaign("a")
await dr.list_identities_for_clustering()
await dr.create_campaign({"uuid": "c"})
await dr.set_identity_campaign_id("i", "c")
await dr.set_identity_campaign_id("i", None)
await dr.list_all_campaigns()
await dr.update_campaign_merged_into("c", "d")
await dr.update_campaign_merged_into("c", None)
# Swarm methods: default NotImplementedError on BaseRepository. Covering
# them here keeps the coverage contract honest for the swarm CRUD surface.
for coro, args in [
(dr.add_swarm_host, ({},)),
(dr.get_swarm_host_by_name, ("w",)),
(dr.get_swarm_host_by_uuid, ("u",)),
(dr.list_swarm_hosts, ()),
(dr.update_swarm_host, ("u", {})),
(dr.delete_swarm_host, ("u",)),
(dr.upsert_decky_shard, ({},)),
(dr.list_decky_shards, ()),
(dr.delete_decky_shards_for_host, ("u",)),
(dr.create_topology, ({},)),
(dr.get_topology, ("t",)),
(dr.list_topologies, ()),
(dr.update_topology_status, ("t", "active")),
(dr.delete_topology_cascade, ("t",)),
(dr.add_lan, ({},)),
(dr.update_lan, ("l", {})),
(dr.list_lans_for_topology, ("t",)),
(dr.add_topology_decky, ({},)),
(dr.update_topology_decky, ("d", {})),
(dr.list_topology_deckies, ("t",)),
(dr.add_topology_edge, ({},)),
(dr.list_topology_edges, ("t",)),
(dr.list_topology_status_events, ("t",)),
]:
with pytest.raises(NotImplementedError):
await coro(*args)

View File

@@ -0,0 +1,145 @@
"""Tests for the Campaign clustering repo methods on SQLModelRepository."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from decnet.web.db.factory import get_repository
@pytest.fixture
async def repo(tmp_path):
r = get_repository(db_path=str(tmp_path / "campaigns.db"))
await r.initialize()
return r
async def _create_identity(repo, uuid: str, **kwargs) -> str:
now = datetime.now(timezone.utc)
return await repo.create_attacker_identity({
"uuid": uuid,
"first_seen_at": kwargs.get("first_seen_at", now),
"last_seen_at": kwargs.get("last_seen_at", now),
"ja3_hashes": kwargs.get("ja3_hashes"),
"hassh_hashes": kwargs.get("hassh_hashes"),
"payload_simhashes": kwargs.get("payload_simhashes"),
"c2_endpoints": kwargs.get("c2_endpoints"),
})
@pytest.mark.asyncio
async def test_create_and_get_campaign(repo):
await repo.create_campaign({"uuid": "c1", "confidence": 0.8})
row = await repo.get_campaign_by_uuid("c1")
assert row is not None
assert row["uuid"] == "c1"
assert row["confidence"] == 0.8
assert row["merged_into_uuid"] is None
@pytest.mark.asyncio
async def test_get_campaign_follows_merge_chain(repo):
await repo.create_campaign({"uuid": "c1"})
await repo.create_campaign({"uuid": "c2"})
await repo.update_campaign_merged_into("c2", "c1")
# Querying the loser returns the winner.
row = await repo.get_campaign_by_uuid("c2")
assert row["uuid"] == "c1"
@pytest.mark.asyncio
async def test_list_and_count_excludes_merged_out(repo):
await repo.create_campaign({"uuid": "c1"})
await repo.create_campaign({"uuid": "c2"})
await repo.update_campaign_merged_into("c2", "c1")
listed = await repo.list_campaigns()
assert {c["uuid"] for c in listed} == {"c1"}
assert await repo.count_campaigns() == 1
@pytest.mark.asyncio
async def test_list_all_campaigns_includes_merged_out(repo):
await repo.create_campaign({"uuid": "c1"})
await repo.create_campaign({"uuid": "c2"})
await repo.update_campaign_merged_into("c2", "c1")
all_campaigns = await repo.list_all_campaigns()
assert {c["uuid"] for c in all_campaigns} == {"c1", "c2"}
@pytest.mark.asyncio
async def test_get_unknown_campaign_returns_none(repo):
assert await repo.get_campaign_by_uuid("nope") is None
@pytest.mark.asyncio
async def test_update_campaign_merged_into_can_revoke(repo):
await repo.create_campaign({"uuid": "c1"})
await repo.create_campaign({"uuid": "c2"})
await repo.update_campaign_merged_into("c2", "c1")
# Revoke
await repo.update_campaign_merged_into("c2", None)
row = await repo.get_campaign_by_uuid("c2")
assert row["uuid"] == "c2"
assert row["merged_into_uuid"] is None
@pytest.mark.asyncio
async def test_set_identity_campaign_id_links_and_unlinks(repo):
await repo.create_campaign({"uuid": "c1"})
await _create_identity(repo, "i1")
await repo.set_identity_campaign_id("i1", "c1")
linked = await repo.list_identities_for_campaign("c1")
assert {i["uuid"] for i in linked} == {"i1"}
assert await repo.count_identities_for_campaign("c1") == 1
await repo.set_identity_campaign_id("i1", None)
assert await repo.count_identities_for_campaign("c1") == 0
@pytest.mark.asyncio
async def test_list_identities_for_clustering_projects_expected_fields(repo):
await _create_identity(
repo, "i1",
ja3_hashes='["ja3-a"]',
hassh_hashes='["hassh-a"]',
payload_simhashes='["dead"]',
c2_endpoints='["1.2.3.4:443"]',
)
rows = await repo.list_identities_for_clustering()
assert len(rows) == 1
row = rows[0]
assert row["uuid"] == "i1"
assert row["ja3_hashes"] == '["ja3-a"]'
assert row["hassh_hashes"] == '["hassh-a"]'
assert row["payload_simhashes"] == '["dead"]'
assert row["c2_endpoints"] == '["1.2.3.4:443"]'
assert row["campaign_id"] is None
assert row["merged_into_uuid"] is None
assert row["first_seen_at"] is not None
@pytest.mark.asyncio
async def test_list_identities_for_clustering_respects_limit(repo):
for n in range(3):
await _create_identity(repo, f"i{n}")
assert len(await repo.list_identities_for_clustering(limit=2)) == 2
assert len(await repo.list_identities_for_clustering()) == 3
@pytest.mark.asyncio
async def test_list_identities_for_campaign_paginates(repo):
await repo.create_campaign({"uuid": "c1"})
for n in range(3):
await _create_identity(repo, f"i{n}")
await repo.set_identity_campaign_id(f"i{n}", "c1")
page = await repo.list_identities_for_campaign("c1", limit=2, offset=0)
assert len(page) == 2
page2 = await repo.list_identities_for_campaign("c1", limit=2, offset=2)
assert len(page2) == 1

View File

@@ -0,0 +1,226 @@
"""CredentialReuse repo tests — upsert idempotency, list pagination, FK backfill."""
from __future__ import annotations
import hashlib
from pathlib import Path
import pytest
from decnet.web.db.factory import get_repository
@pytest.fixture
async def repo(tmp_path: Path):
r = get_repository(db_path=str(tmp_path / "reuse.db"))
await r.initialize()
return r
def _sha256(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
async def _seed_credential(repo, **overrides):
base = {
"attacker_ip": "10.0.0.5",
"decky_name": "decky-01",
"service": "ssh",
"principal": "root",
"secret_sha256": _sha256("hunter2"),
"secret_b64": "aHVudGVyMg==",
"secret_printable": "hunter2",
"fields": {},
}
base.update(overrides)
return await repo.upsert_credential(base)
@pytest.mark.anyio
async def test_upsert_inserts_first_observation(repo) -> None:
sha = _sha256("hunter2")
out = await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="10.0.0.5",
decky="decky-01", service="ssh", attempt_count=1,
)
assert out is not None
assert out["inserted"] is True
assert out["target_count"] == 1
assert out["confidence"] == 1.0
@pytest.mark.anyio
async def test_upsert_grows_target_count_across_services(repo) -> None:
"""Same secret on two distinct (decky, service) pairs → target_count=2.
target_count is recomputed from the credentials table, so the test
must seed actual Credential rows first.
"""
sha = _sha256("p4ssw0rd")
await _seed_credential(repo, secret_sha256=sha, decky_name="d1", service="ssh")
await _seed_credential(repo, secret_sha256=sha, decky_name="d2", service="ftp")
await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="10.0.0.5",
decky="d1", service="ssh", attempt_count=1,
)
out = await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="10.0.0.5",
decky="d2", service="ftp", attempt_count=1,
)
assert out["inserted"] is False
assert out["changed"] is True
assert out["target_count"] == 2
@pytest.mark.anyio
async def test_upsert_dedups_same_decky_service(repo) -> None:
"""Repeated upserts for the same (decky, service) don't grow target_count."""
sha = _sha256("samepw")
await _seed_credential(repo, secret_sha256=sha)
for _ in range(3):
await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="10.0.0.5",
decky="decky-01", service="ssh", attempt_count=1,
)
rows = (await repo.list_credential_reuses(min_target_count=1))[1]
assert len(rows) == 1
assert rows[0]["target_count"] == 1
assert rows[0]["attempt_count"] == 3
@pytest.mark.anyio
async def test_upsert_merges_attacker_lists(repo) -> None:
"""Distinct attacker_uuid/ip values accumulate into the JSON lists."""
sha = _sha256("shared")
await _seed_credential(repo, secret_sha256=sha, attacker_ip="1.1.1.1")
await _seed_credential(
repo, secret_sha256=sha, attacker_ip="2.2.2.2", decky_name="d2",
)
await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid="uuid-A", attacker_ip="1.1.1.1",
decky="decky-01", service="ssh", attempt_count=1,
)
await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid="uuid-B", attacker_ip="2.2.2.2",
decky="d2", service="ssh", attempt_count=1,
)
rows = (await repo.list_credential_reuses(min_target_count=1))[1]
assert set(rows[0]["attacker_uuids"]) == {"uuid-A", "uuid-B"}
assert set(rows[0]["attacker_ips"]) == {"1.1.1.1", "2.2.2.2"}
@pytest.mark.anyio
async def test_null_principal_uniqueness(repo) -> None:
"""Two upserts with principal=None go to the same row, not two rows."""
sha = _sha256("redis-auth")
await _seed_credential(repo, secret_sha256=sha, service="redis", principal=None)
for _ in range(2):
await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal=None,
attacker_uuid=None, attacker_ip="1.1.1.1",
decky="decky-01", service="redis", attempt_count=1,
)
rows = (await repo.list_credential_reuses(min_target_count=1))[1]
assert len(rows) == 1
assert rows[0]["principal"] is None
@pytest.mark.anyio
async def test_list_filters_by_min_target_count(repo) -> None:
"""min_target_count=2 hides 1-target findings."""
sha = _sha256("only-once")
await _seed_credential(repo, secret_sha256=sha)
await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="1.1.1.1",
decky="decky-01", service="ssh", attempt_count=1,
)
total, rows = await repo.list_credential_reuses(min_target_count=2)
assert total == 0
assert rows == []
total, _ = await repo.list_credential_reuses(min_target_count=1)
assert total == 1
@pytest.mark.anyio
async def test_list_pagination_orders_by_target_count_desc(repo) -> None:
sha_a = _sha256("a")
sha_b = _sha256("b")
# secret a → 1 target
await _seed_credential(repo, secret_sha256=sha_a)
await repo.upsert_credential_reuse(
secret_sha256=sha_a, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="1.1.1.1",
decky="d1", service="ssh", attempt_count=1,
)
# secret b → 2 targets
await _seed_credential(repo, secret_sha256=sha_b, service="ssh")
await _seed_credential(repo, secret_sha256=sha_b, service="ftp", decky_name="d2")
await repo.upsert_credential_reuse(
secret_sha256=sha_b, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="1.1.1.1",
decky="decky-01", service="ssh", attempt_count=1,
)
await repo.upsert_credential_reuse(
secret_sha256=sha_b, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="1.1.1.1",
decky="d2", service="ftp", attempt_count=1,
)
total, rows = await repo.list_credential_reuses(min_target_count=1)
assert total == 2
assert rows[0]["secret_sha256"] == sha_b # higher target_count first
@pytest.mark.anyio
async def test_get_by_id_roundtrip(repo) -> None:
sha = _sha256("rt")
await _seed_credential(repo, secret_sha256=sha)
out = await repo.upsert_credential_reuse(
secret_sha256=sha, secret_kind="plaintext", principal="root",
attacker_uuid=None, attacker_ip="1.1.1.1",
decky="decky-01", service="ssh", attempt_count=1,
)
fetched = await repo.get_credential_reuse_by_id(out["id"])
assert fetched is not None
assert fetched["id"] == out["id"]
assert fetched["secret_sha256"] == sha
assert isinstance(fetched["deckies"], list)
@pytest.mark.anyio
async def test_get_by_id_missing_returns_none(repo) -> None:
assert await repo.get_credential_reuse_by_id("nope") is None
@pytest.mark.anyio
async def test_update_credential_attacker_uuid_backfills_only_nulls(repo) -> None:
"""The profiler hook must backfill attacker_uuid only on rows where it
is currently null — pre-existing UUIDs must not be overwritten."""
sha = _sha256("backfill")
await _seed_credential(repo, secret_sha256=sha, attacker_ip="9.9.9.9")
await _seed_credential(
repo, secret_sha256=sha, attacker_ip="9.9.9.9",
service="ftp", decky_name="d2",
)
# Backfill: both null, both should update.
n = await repo.update_credential_attacker_uuid("9.9.9.9", "uuid-9")
assert n == 2
# Second call: both already set, nothing should change.
n2 = await repo.update_credential_attacker_uuid("9.9.9.9", "uuid-other")
assert n2 == 0
rows = await repo.get_credentials_for_attacker("9.9.9.9")
assert all(r["attacker_uuid"] == "uuid-9" for r in rows)
@pytest.mark.anyio
async def test_update_credential_attacker_uuid_no_match(repo) -> None:
n = await repo.update_credential_attacker_uuid("0.0.0.0", "uuid-x")
assert n == 0

View File

@@ -0,0 +1,168 @@
"""Credential model + repo tests — upsert, dedup, cross-service reuse."""
from __future__ import annotations
import hashlib
from pathlib import Path
import pytest
from decnet.web.db.factory import get_repository
@pytest.fixture
async def repo(tmp_path: Path):
r = get_repository(db_path=str(tmp_path / "creds.db"))
await r.initialize()
return r
def _sha256(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
@pytest.mark.anyio
async def test_upsert_inserts_then_dedups(repo) -> None:
"""Same dedup tuple twice → one row, attempt_count=2."""
payload = {
"attacker_ip": "10.0.0.5",
"decky_name": "decky-01",
"service": "ssh",
"principal": "root",
"secret_sha256": _sha256("hunter2"),
"secret_b64": "aHVudGVyMg==",
"secret_printable": "hunter2",
"fields": {"user": "root"},
}
rid_a = await repo.upsert_credential(payload)
rid_b = await repo.upsert_credential(payload)
assert rid_a == rid_b
rows = await repo.get_credentials()
assert len(rows) == 1
assert rows[0]["attempt_count"] == 2
assert rows[0]["fields"] == {"user": "root"} # preserved
@pytest.mark.anyio
async def test_different_principal_creates_new_row(repo) -> None:
base = {
"attacker_ip": "10.0.0.5",
"decky_name": "decky-01",
"service": "ssh",
"secret_sha256": _sha256("hunter2"),
"secret_b64": "aHVudGVyMg==",
"secret_printable": "hunter2",
"fields": {},
}
await repo.upsert_credential({**base, "principal": "root"})
await repo.upsert_credential({**base, "principal": "admin"})
rows = await repo.get_credentials()
assert len(rows) == 2
@pytest.mark.anyio
async def test_null_principal_dedups_independently(repo) -> None:
"""principal=None and principal='root' are different keys."""
base = {
"attacker_ip": "10.0.0.5",
"decky_name": "decky-01",
"service": "ssh",
"secret_sha256": _sha256("hunter2"),
"secret_b64": "aHVudGVyMg==",
"secret_printable": "hunter2",
"fields": {},
}
await repo.upsert_credential({**base, "principal": None})
await repo.upsert_credential({**base, "principal": None}) # dedupes
await repo.upsert_credential({**base, "principal": "root"})
rows = await repo.get_credentials()
assert len(rows) == 2
null_row = next(r for r in rows if r["principal"] is None)
assert null_row["attempt_count"] == 2
@pytest.mark.anyio
async def test_cross_service_reuse_query(repo) -> None:
"""Same secret across SSH + FTP + SMTP → reuse query returns all three."""
secret = "hunter2"
sha = _sha256(secret)
services = [
("ssh", "decky-01", "root"),
("ftp", "decky-02", "anonymous"),
("smtp", "decky-03", "acme.com"),
]
for svc, decky, principal in services:
await repo.upsert_credential({
"attacker_ip": "10.0.0.5",
"decky_name": decky,
"service": svc,
"principal": principal,
"secret_sha256": sha,
"secret_b64": "aHVudGVyMg==",
"secret_printable": secret,
"fields": {},
})
reuse = await repo.get_credential_attempts_for_secret(sha)
assert {r["service"] for r in reuse} == {"ssh", "ftp", "smtp"}
@pytest.mark.anyio
async def test_get_credentials_for_attacker(repo) -> None:
base = {
"decky_name": "decky-01",
"service": "ssh",
"principal": "root",
"secret_sha256": _sha256("hunter2"),
"secret_b64": "aHVudGVyMg==",
"secret_printable": "hunter2",
"fields": {},
}
await repo.upsert_credential({**base, "attacker_ip": "10.0.0.5"})
await repo.upsert_credential({**base, "attacker_ip": "10.0.0.6"})
rows = await repo.get_credentials_for_attacker("10.0.0.5")
assert len(rows) == 1
assert rows[0]["attacker_ip"] == "10.0.0.5"
@pytest.mark.anyio
async def test_secret_kind_dedups_independently(repo) -> None:
"""Same sha256, same principal — different secret_kind = different row.
Two rows with the same content-addressable hash but different kinds
represent fundamentally different credentials (e.g. a plaintext
password that happens to hash to the same value as a Postgres
md5 challenge response is statistically impossible but semantically
distinct anyway). Dedup must respect the kind boundary."""
base = {
"attacker_ip": "10.0.0.5",
"decky_name": "decky-01",
"service": "ssh",
"principal": "root",
"secret_sha256": _sha256("hunter2"),
"secret_b64": "aHVudGVyMg==",
"fields": {},
}
await repo.upsert_credential({**base, "secret_kind": "plaintext"})
await repo.upsert_credential({**base, "secret_kind": "postgres_md5_challenge"})
rows = await repo.get_credentials()
assert len(rows) == 2
kinds = {r["secret_kind"] for r in rows}
assert kinds == {"plaintext", "postgres_md5_challenge"}
@pytest.mark.anyio
async def test_filters(repo) -> None:
base_secret = _sha256("a")
await repo.upsert_credential({
"attacker_ip": "10.0.0.5", "decky_name": "decky-01", "service": "ssh",
"principal": "root", "secret_sha256": base_secret,
"secret_printable": "a", "fields": {},
})
await repo.upsert_credential({
"attacker_ip": "10.0.0.5", "decky_name": "decky-01", "service": "ftp",
"principal": "root", "secret_sha256": base_secret,
"secret_printable": "a", "fields": {},
})
rows = await repo.get_credentials(service="ssh")
assert len(rows) == 1 and rows[0]["service"] == "ssh"
assert await repo.get_total_credentials(service="ssh") == 1
assert await repo.get_total_credentials() == 2

44
tests/db/test_factory.py Normal file
View File

@@ -0,0 +1,44 @@
"""
Unit tests for the repository factory — dispatch on DECNET_DB_TYPE.
"""
import pytest
from decnet.web.db.factory import get_repository
from decnet.web.db.sqlite.repository import SQLiteRepository
from decnet.web.db.mysql.repository import MySQLRepository
def test_factory_defaults_to_sqlite(monkeypatch, tmp_path):
monkeypatch.delenv("DECNET_DB_TYPE", raising=False)
repo = get_repository(db_path=str(tmp_path / "t.db"))
assert isinstance(repo, SQLiteRepository)
def test_factory_sqlite_explicit(monkeypatch, tmp_path):
monkeypatch.setenv("DECNET_DB_TYPE", "sqlite")
repo = get_repository(db_path=str(tmp_path / "t.db"))
assert isinstance(repo, SQLiteRepository)
def test_factory_mysql_branch(monkeypatch):
"""MySQL branch must import and instantiate without a live server.
Engine creation is lazy in SQLAlchemy — no socket is opened until the
first query — so the repository constructs cleanly here.
"""
monkeypatch.setenv("DECNET_DB_TYPE", "mysql")
monkeypatch.setenv("DECNET_DB_URL", "mysql+asyncmy://u:p@127.0.0.1:3306/x")
repo = get_repository()
assert isinstance(repo, MySQLRepository)
def test_factory_is_case_insensitive(monkeypatch, tmp_path):
monkeypatch.setenv("DECNET_DB_TYPE", "SQLite")
repo = get_repository(db_path=str(tmp_path / "t.db"))
assert isinstance(repo, SQLiteRepository)
def test_factory_rejects_unknown_type(monkeypatch):
monkeypatch.setenv("DECNET_DB_TYPE", "cassandra")
with pytest.raises(ValueError, match="Unsupported database type"):
get_repository()

View File

@@ -0,0 +1,206 @@
"""
Schema-only tests for the AttackerIdentity table and the
attackers.identity_id FK.
The identities table ships empty in this PR; the clusterer that
populates it is a separate downstream effort. These tests verify only
that the schema lands correctly:
* the table exists after metadata.create_all()
* attackers.identity_id is nullable and indexed
* the FK references attacker_identities.uuid
* an attacker row may be inserted with identity_id=NULL
* an identity row may be inserted with all clusterer-populated columns NULL
If any of these regress, downstream API/frontend/clusterer work all
stop. See development/IDENTITY_RESOLUTION.md §Schema.
"""
from __future__ import annotations
import sqlite3
import uuid
from datetime import datetime, timezone
import pytest
from sqlalchemy import inspect
from sqlmodel import Session
from decnet.web.db.models import Attacker, AttackerIdentity
from decnet.web.db.sqlite.database import get_sync_engine, init_db
@pytest.fixture
def db_path(tmp_path) -> str:
p = tmp_path / "schema.db"
init_db(str(p))
return str(p)
def test_attacker_identities_table_exists(db_path: str) -> None:
engine = get_sync_engine(db_path)
inspector = inspect(engine)
assert "attacker_identities" in inspector.get_table_names()
def test_attackers_identity_id_column_present_and_nullable(db_path: str) -> None:
engine = get_sync_engine(db_path)
inspector = inspect(engine)
columns = {c["name"]: c for c in inspector.get_columns("attackers")}
assert "identity_id" in columns, "attackers.identity_id column missing"
assert columns["identity_id"]["nullable"] is True, (
"attackers.identity_id must be nullable — clusterer hasn't run yet on existing rows"
)
def test_attackers_identity_id_is_indexed(db_path: str) -> None:
engine = get_sync_engine(db_path)
inspector = inspect(engine)
indexes = inspector.get_indexes("attackers")
indexed_columns = {col for idx in indexes for col in idx["column_names"]}
assert "identity_id" in indexed_columns, (
"attackers.identity_id needs an index for join performance "
"(IdentityDetail aggregates by identity_id; without an index "
"every lookup is a full scan)"
)
def test_attackers_identity_id_fk_targets_attacker_identities(db_path: str) -> None:
engine = get_sync_engine(db_path)
inspector = inspect(engine)
fks = inspector.get_foreign_keys("attackers")
identity_fks = [
fk for fk in fks if "identity_id" in fk["constrained_columns"]
]
assert identity_fks, "no FK on attackers.identity_id"
assert identity_fks[0]["referred_table"] == "attacker_identities"
assert identity_fks[0]["referred_columns"] == ["uuid"]
def test_identity_schema_version_default_is_1(db_path: str) -> None:
"""
schema_version is non-negotiable from day one. Federation gossip
in V2 will share identity vectors across operators; bumping the
feature definitions without a version field silently poisons
receivers. Default must be 1 on insert.
"""
engine = get_sync_engine(db_path)
with Session(engine) as session:
identity = AttackerIdentity(uuid=str(uuid.uuid4()))
session.add(identity)
session.commit()
session.refresh(identity)
assert identity.schema_version == 1
def test_attacker_can_be_inserted_with_null_identity_id(db_path: str) -> None:
"""
Existing code paths (profiler, correlator) keep upserting attackers
without setting identity_id. They MUST work unchanged — the
identity_id column is nullable and remains NULL until the clusterer
runs.
"""
engine = get_sync_engine(db_path)
with Session(engine) as session:
now = datetime.now(timezone.utc)
att = Attacker(
uuid=str(uuid.uuid4()),
ip="203.0.113.4",
first_seen=now,
last_seen=now,
)
session.add(att)
session.commit()
session.refresh(att)
assert att.identity_id is None
def test_identity_with_all_clusterer_fields_null(db_path: str) -> None:
"""
The table ships empty; even when the clusterer eventually inserts
rows, it may write a row with most fields null (e.g. before
fingerprint summaries have been computed). Every clusterer-populated
column must accept NULL.
"""
engine = get_sync_engine(db_path)
with Session(engine) as session:
identity = AttackerIdentity(uuid=str(uuid.uuid4()))
session.add(identity)
session.commit()
session.refresh(identity)
for field in (
"campaign_id",
"first_seen_at",
"last_seen_at",
"confidence",
"ja3_hashes",
"hassh_hashes",
"payload_simhashes",
"c2_endpoints",
"kd_digraph_simhash",
"merged_into_uuid",
"notes",
):
assert getattr(identity, field) is None, (
f"AttackerIdentity.{field} must default to None — "
f"the table ships empty pre-clusterer"
)
# observation_count is denormalized; defaults to 0 (not NULL).
assert identity.observation_count == 0
def test_attacker_identity_link_round_trip(db_path: str) -> None:
"""
End-to-end: insert an identity, link an attacker observation to
it via identity_id FK, query both sides. Smoke-tests the schema
works as designed without invoking the production repo layer.
"""
engine = get_sync_engine(db_path)
with Session(engine) as session:
identity = AttackerIdentity(uuid=str(uuid.uuid4()))
session.add(identity)
session.commit()
now = datetime.now(timezone.utc)
att = Attacker(
uuid=str(uuid.uuid4()),
ip="203.0.113.5",
first_seen=now,
last_seen=now,
identity_id=identity.uuid,
)
session.add(att)
session.commit()
session.refresh(att)
assert att.identity_id == identity.uuid
def test_identity_id_fk_constraint_blocks_orphans(db_path: str) -> None:
"""
Inserting an attacker with identity_id pointing at a nonexistent
identity must fail. The clusterer should never write an orphan
link; the schema enforces that contract.
SQLite's PRAGMA foreign_keys is off by default at the connection
level; we enable it explicitly here so the test reflects the
contract production code relies on (via the same PRAGMA on its
connections).
"""
with sqlite3.connect(db_path) as conn:
conn.execute("PRAGMA foreign_keys = ON")
with pytest.raises(sqlite3.IntegrityError):
conn.execute(
"INSERT INTO attackers (uuid, ip, first_seen, last_seen, "
"event_count, service_count, decky_count, services, deckies, "
"is_traversal, bounty_count, credential_count, fingerprints, "
"commands, updated_at, identity_id) VALUES "
"(?, ?, ?, ?, 0, 0, 0, '[]', '[]', 0, 0, 0, '[]', '[]', ?, ?)",
(
str(uuid.uuid4()),
"203.0.113.6",
datetime.now(timezone.utc).isoformat(),
datetime.now(timezone.utc).isoformat(),
datetime.now(timezone.utc).isoformat(),
"ffffffff-ffff-ffff-ffff-ffffffffffff", # nonexistent identity
),
)
conn.commit()