refactor(tests): move flat tests/*.py into per-subsystem subfolders
Groups every flat test_*.py under the module it exercises, matching the
existing tests/{profiler,sniffer,prober,collector,correlation,cli,web,
topology,swarm,bus,updater,api,docker,geoip,...} layout. New folders:
services/, fleet/, config/, logging/, db/ (+ db/mysql/), telemetry/,
mutator/, core/.
Path-dependent __file__ references bumped an extra .parent in three
files that moved one level deeper:
- tests/sniffer/test_sniffer_ja3.py (template path)
- tests/services/test_ssh_capture_emit.py (template path)
- tests/cli/test_mode_gating.py (REPO root)
- tests/web/test_env_lazy_jwt.py (repo var)
Also drops two SQLite runtime artifacts (test_decnet.db-{shm,wal}) that
were leaking into the repo from a previous test run.
Fixes two test_service_isolation cases that patched asyncio.sleep (no
longer on the profiler main-loop hot path — same pre-existing bug I
fixed earlier in test_attacker_worker.py) by patching asyncio.wait_for
and passing interval=0.
This commit is contained in:
734
tests/profiler/test_attacker_worker.py
Normal file
734
tests/profiler/test_attacker_worker.py
Normal file
@@ -0,0 +1,734 @@
|
||||
"""
|
||||
Tests for decnet/attacker/worker.py
|
||||
|
||||
Covers:
|
||||
- _cold_start(): full build on first run, cursor persistence
|
||||
- _incremental_update(): delta processing, affected-IP-only updates
|
||||
- _update_profiles(): traversal detection, bounty merging
|
||||
- _extract_commands_from_events(): command harvesting from LogEvent objects
|
||||
- _build_record(): record assembly from engine events + bounties
|
||||
- _first_contact_deckies(): ordering for single-decky attackers
|
||||
- attacker_profile_worker(): cancellation and error handling
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.correlation.parser import LogEvent
|
||||
from decnet.logging.syslog_formatter import SEVERITY_INFO, format_rfc5424
|
||||
from decnet.profiler.worker import (
|
||||
_BATCH_SIZE,
|
||||
_STATE_KEY,
|
||||
_WorkerState,
|
||||
_build_record,
|
||||
_extract_commands_from_events,
|
||||
_first_contact_deckies,
|
||||
_incremental_update,
|
||||
_update_profiles,
|
||||
attacker_profile_worker,
|
||||
)
|
||||
|
||||
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _cancel_after(ticks: int):
|
||||
"""Cancel the worker loop after N ``asyncio.wait_for`` calls.
|
||||
|
||||
The worker's main tick is ``asyncio.wait_for(shutdown.wait(), timeout=
|
||||
interval)``. These tests want to let the loop body run a few times
|
||||
then unwind; patching wait_for is the natural knob. On call N the
|
||||
patched wait_for raises ``CancelledError``, which bubbles up through
|
||||
the worker and satisfies the test's ``pytest.raises`` assertion.
|
||||
|
||||
Earlier revisions patched ``asyncio.sleep`` — that hasn't been on the
|
||||
worker's hot path since the event-driven shutdown refactor, so the
|
||||
sleep patch silently no-op'd and the tests hung on the real 30 s
|
||||
``wait_for`` timeout.
|
||||
"""
|
||||
call_count = 0
|
||||
real_wait_for = asyncio.wait_for
|
||||
|
||||
async def _patched(awaitable, timeout):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= ticks:
|
||||
# Close the coroutine so asyncio doesn't warn about it never
|
||||
# being awaited.
|
||||
if asyncio.iscoroutine(awaitable):
|
||||
awaitable.close()
|
||||
raise asyncio.CancelledError()
|
||||
return await real_wait_for(awaitable, timeout)
|
||||
|
||||
return patch("decnet.profiler.worker.asyncio.wait_for", side_effect=_patched)
|
||||
|
||||
|
||||
_TS1 = "2026-04-04T10:00:00+00:00"
|
||||
_TS2 = "2026-04-04T10:05:00+00:00"
|
||||
_TS3 = "2026-04-04T10:10:00+00:00"
|
||||
|
||||
_DT1 = datetime.fromisoformat(_TS1)
|
||||
_DT2 = datetime.fromisoformat(_TS2)
|
||||
_DT3 = datetime.fromisoformat(_TS3)
|
||||
|
||||
|
||||
def _make_raw_line(
|
||||
service: str = "ssh",
|
||||
hostname: str = "decky-01",
|
||||
event_type: str = "connection",
|
||||
src_ip: str = "1.2.3.4",
|
||||
timestamp: str = _TS1,
|
||||
**extra: str,
|
||||
) -> str:
|
||||
return format_rfc5424(
|
||||
service=service,
|
||||
hostname=hostname,
|
||||
event_type=event_type,
|
||||
severity=SEVERITY_INFO,
|
||||
timestamp=datetime.fromisoformat(timestamp),
|
||||
src_ip=src_ip,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
def _make_log_row(
|
||||
row_id: int = 1,
|
||||
raw_line: str = "",
|
||||
attacker_ip: str = "1.2.3.4",
|
||||
service: str = "ssh",
|
||||
event_type: str = "connection",
|
||||
decky: str = "decky-01",
|
||||
timestamp: datetime = _DT1,
|
||||
fields: str = "{}",
|
||||
) -> dict:
|
||||
if not raw_line:
|
||||
raw_line = _make_raw_line(
|
||||
service=service,
|
||||
hostname=decky,
|
||||
event_type=event_type,
|
||||
src_ip=attacker_ip,
|
||||
timestamp=timestamp.isoformat(),
|
||||
)
|
||||
return {
|
||||
"id": row_id,
|
||||
"raw_line": raw_line,
|
||||
"attacker_ip": attacker_ip,
|
||||
"service": service,
|
||||
"event_type": event_type,
|
||||
"decky": decky,
|
||||
"timestamp": timestamp,
|
||||
"fields": fields,
|
||||
}
|
||||
|
||||
|
||||
def _make_repo(logs=None, bounties=None, bounties_for_ips=None, max_log_id=0, saved_state=None):
|
||||
repo = MagicMock()
|
||||
repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {})
|
||||
repo.get_bounties_for_ips = AsyncMock(return_value=bounties_for_ips or {})
|
||||
repo.get_max_log_id = AsyncMock(return_value=max_log_id)
|
||||
# Return provided logs on first call (simulating a single page < BATCH_SIZE), then [] to end loop
|
||||
_log_pages = [logs or [], []]
|
||||
repo.get_logs_after_id = AsyncMock(side_effect=_log_pages)
|
||||
repo.get_state = AsyncMock(return_value=saved_state)
|
||||
repo.set_state = AsyncMock()
|
||||
repo.upsert_attacker = AsyncMock(return_value="mock-uuid")
|
||||
repo.upsert_attacker_behavior = AsyncMock()
|
||||
return repo
|
||||
|
||||
|
||||
def _make_log_event(
|
||||
ip: str,
|
||||
decky: str,
|
||||
service: str = "ssh",
|
||||
event_type: str = "connection",
|
||||
timestamp: datetime = _DT1,
|
||||
fields: dict | None = None,
|
||||
) -> LogEvent:
|
||||
return LogEvent(
|
||||
timestamp=timestamp,
|
||||
decky=decky,
|
||||
service=service,
|
||||
event_type=event_type,
|
||||
attacker_ip=ip,
|
||||
fields=fields or {},
|
||||
raw="",
|
||||
)
|
||||
|
||||
|
||||
# ─── _first_contact_deckies ───────────────────────────────────────────────────
|
||||
|
||||
class TestFirstContactDeckies:
|
||||
def test_single_decky(self):
|
||||
events = [_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1)]
|
||||
assert _first_contact_deckies(events) == ["decky-01"]
|
||||
|
||||
def test_multiple_deckies_ordered_by_first_contact(self):
|
||||
events = [
|
||||
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
||||
]
|
||||
assert _first_contact_deckies(events) == ["decky-01", "decky-02"]
|
||||
|
||||
def test_revisit_does_not_duplicate(self):
|
||||
events = [
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
||||
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT3), # revisit
|
||||
]
|
||||
result = _first_contact_deckies(events)
|
||||
assert result == ["decky-01", "decky-02"]
|
||||
assert result.count("decky-01") == 1
|
||||
|
||||
|
||||
# ─── _extract_commands_from_events ───────────────────────────────────────────
|
||||
|
||||
class TestExtractCommandsFromEvents:
|
||||
def test_extracts_command_field(self):
|
||||
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "command", _DT1, {"command": "id"})]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert len(result) == 1
|
||||
assert result[0]["command"] == "id"
|
||||
assert result[0]["service"] == "ssh"
|
||||
assert result[0]["decky"] == "decky-01"
|
||||
|
||||
def test_extracts_query_field(self):
|
||||
events = [_make_log_event("2.2.2.2", "decky-01", "mysql", "query", _DT1, {"query": "SELECT * FROM users"})]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert len(result) == 1
|
||||
assert result[0]["command"] == "SELECT * FROM users"
|
||||
|
||||
def test_extracts_input_field(self):
|
||||
events = [_make_log_event("3.3.3.3", "decky-01", "ssh", "input", _DT1, {"input": "ls -la"})]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert len(result) == 1
|
||||
assert result[0]["command"] == "ls -la"
|
||||
|
||||
def test_non_command_event_type_ignored(self):
|
||||
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "connection", _DT1, {"command": "id"})]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert result == []
|
||||
|
||||
def test_no_command_field_skipped(self):
|
||||
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "command", _DT1, {"other": "stuff"})]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert result == []
|
||||
|
||||
def test_multiple_commands_all_extracted(self):
|
||||
events = [
|
||||
_make_log_event("5.5.5.5", "decky-01", "ssh", "command", _DT1, {"command": "id"}),
|
||||
_make_log_event("5.5.5.5", "decky-01", "ssh", "command", _DT2, {"command": "uname -a"}),
|
||||
]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert len(result) == 2
|
||||
cmds = {r["command"] for r in result}
|
||||
assert cmds == {"id", "uname -a"}
|
||||
|
||||
def test_timestamp_serialized_to_string(self):
|
||||
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "command", _DT1, {"command": "pwd"})]
|
||||
result = _extract_commands_from_events(events)
|
||||
assert isinstance(result[0]["timestamp"], str)
|
||||
|
||||
|
||||
# ─── _build_record ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestBuildRecord:
|
||||
def _events(self, ip="1.1.1.1"):
|
||||
return [
|
||||
_make_log_event(ip, "decky-01", "ssh", "conn", _DT1),
|
||||
_make_log_event(ip, "decky-01", "http", "req", _DT2),
|
||||
]
|
||||
|
||||
def test_basic_fields(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert record["ip"] == "1.1.1.1"
|
||||
assert record["event_count"] == 2
|
||||
assert record["service_count"] == 2
|
||||
assert record["decky_count"] == 1
|
||||
|
||||
def test_first_last_seen(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert record["first_seen"] == _DT1
|
||||
assert record["last_seen"] == _DT2
|
||||
|
||||
def test_services_json_sorted(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
services = json.loads(record["services"])
|
||||
assert sorted(services) == services
|
||||
|
||||
def test_no_traversal(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert record["is_traversal"] is False
|
||||
assert record["traversal_path"] is None
|
||||
|
||||
def test_with_traversal(self):
|
||||
from decnet.correlation.graph import AttackerTraversal, TraversalHop
|
||||
hops = [
|
||||
TraversalHop(_DT1, "decky-01", "ssh", "conn"),
|
||||
TraversalHop(_DT2, "decky-02", "http", "req"),
|
||||
]
|
||||
t = AttackerTraversal("1.1.1.1", hops)
|
||||
events = [
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
||||
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
||||
]
|
||||
record = _build_record("1.1.1.1", events, t, [], [])
|
||||
assert record["is_traversal"] is True
|
||||
assert record["traversal_path"] == "decky-01 → decky-02"
|
||||
deckies = json.loads(record["deckies"])
|
||||
assert deckies == ["decky-01", "decky-02"]
|
||||
|
||||
def test_bounty_counts(self):
|
||||
events = self._events()
|
||||
bounties = [
|
||||
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
|
||||
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
|
||||
{"bounty_type": "fingerprint", "attacker_ip": "1.1.1.1"},
|
||||
]
|
||||
record = _build_record("1.1.1.1", events, None, bounties, [])
|
||||
assert record["bounty_count"] == 3
|
||||
assert record["credential_count"] == 2
|
||||
fps = json.loads(record["fingerprints"])
|
||||
assert len(fps) == 1
|
||||
assert fps[0]["bounty_type"] == "fingerprint"
|
||||
|
||||
def test_commands_serialized(self):
|
||||
events = self._events()
|
||||
cmds = [{"service": "ssh", "decky": "decky-01", "command": "id", "timestamp": "2026-04-04T10:00:00"}]
|
||||
record = _build_record("1.1.1.1", events, None, [], cmds)
|
||||
parsed = json.loads(record["commands"])
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0]["command"] == "id"
|
||||
|
||||
def test_updated_at_is_utc_datetime(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert isinstance(record["updated_at"], datetime)
|
||||
assert record["updated_at"].tzinfo is not None
|
||||
|
||||
|
||||
# ─── cold start via _incremental_update (uninitialized state) ────────────────
|
||||
|
||||
class TestColdStart:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cold_start_builds_all_profiles(self):
|
||||
rows = [
|
||||
_make_log_row(
|
||||
row_id=i + 1,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", ip, _TS1),
|
||||
attacker_ip=ip,
|
||||
)
|
||||
for i, ip in enumerate(["1.1.1.1", "2.2.2.2", "3.3.3.3"])
|
||||
]
|
||||
repo = _make_repo(logs=rows, max_log_id=3)
|
||||
state = _WorkerState()
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
assert state.initialized is True
|
||||
assert state.last_log_id == 3
|
||||
assert repo.upsert_attacker.await_count == 3
|
||||
upserted_ips = {c[0][0]["ip"] for c in repo.upsert_attacker.call_args_list}
|
||||
assert upserted_ips == {"1.1.1.1", "2.2.2.2", "3.3.3.3"}
|
||||
repo.set_state.assert_awaited_with(_STATE_KEY, {"last_log_id": 3})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cold_start_empty_db(self):
|
||||
repo = _make_repo(logs=[], max_log_id=0)
|
||||
state = _WorkerState()
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
assert state.initialized is True
|
||||
assert state.last_log_id == 0
|
||||
repo.upsert_attacker.assert_not_awaited()
|
||||
repo.set_state.assert_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cold_start_traversal_detected(self):
|
||||
rows = [
|
||||
_make_log_row(
|
||||
row_id=1,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1),
|
||||
attacker_ip="5.5.5.5", decky="decky-01",
|
||||
),
|
||||
_make_log_row(
|
||||
row_id=2,
|
||||
raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2),
|
||||
attacker_ip="5.5.5.5", decky="decky-02",
|
||||
),
|
||||
]
|
||||
repo = _make_repo(logs=rows, max_log_id=2)
|
||||
state = _WorkerState()
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["is_traversal"] is True
|
||||
assert "decky-01" in record["traversal_path"]
|
||||
assert "decky-02" in record["traversal_path"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cold_start_bounties_merged(self):
|
||||
raw = _make_raw_line("ssh", "decky-01", "conn", "8.8.8.8", _TS1)
|
||||
repo = _make_repo(
|
||||
logs=[_make_log_row(row_id=1, raw_line=raw, attacker_ip="8.8.8.8")],
|
||||
max_log_id=1,
|
||||
bounties_for_ips={"8.8.8.8": [
|
||||
{"bounty_type": "credential", "attacker_ip": "8.8.8.8", "payload": {}},
|
||||
{"bounty_type": "fingerprint", "attacker_ip": "8.8.8.8", "payload": {"ja3": "abc"}},
|
||||
]},
|
||||
)
|
||||
state = _WorkerState()
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["bounty_count"] == 2
|
||||
assert record["credential_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cold_start_commands_extracted(self):
|
||||
raw = _make_raw_line("ssh", "decky-01", "command", "9.9.9.9", _TS1, command="cat /etc/passwd")
|
||||
row = _make_log_row(
|
||||
row_id=1,
|
||||
raw_line=raw,
|
||||
attacker_ip="9.9.9.9",
|
||||
event_type="command",
|
||||
fields=json.dumps({"command": "cat /etc/passwd"}),
|
||||
)
|
||||
repo = _make_repo(logs=[row], max_log_id=1)
|
||||
state = _WorkerState()
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
commands = json.loads(record["commands"])
|
||||
assert len(commands) == 1
|
||||
assert commands[0]["command"] == "cat /etc/passwd"
|
||||
|
||||
|
||||
# ─── _incremental_update ────────────────────────────────────────────────────
|
||||
|
||||
class TestIncrementalUpdate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_new_logs_skips_upsert(self):
|
||||
repo = _make_repo()
|
||||
state = _WorkerState(initialized=True, last_log_id=10)
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
repo.upsert_attacker.assert_not_awaited()
|
||||
repo.set_state.assert_awaited_with(_STATE_KEY, {"last_log_id": 10})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_affected_ips_upserted(self):
|
||||
"""Pre-populate engine with IP-A, then feed new logs only for IP-B."""
|
||||
state = _WorkerState(initialized=True, last_log_id=5)
|
||||
# Pre-populate engine with IP-A events
|
||||
line_a = _make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1)
|
||||
state.engine.ingest(line_a)
|
||||
|
||||
# New batch has only IP-B
|
||||
new_row = _make_log_row(
|
||||
row_id=6,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "2.2.2.2", _TS2),
|
||||
attacker_ip="2.2.2.2",
|
||||
)
|
||||
repo = _make_repo()
|
||||
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
assert repo.upsert_attacker.await_count == 1
|
||||
upserted_ip = repo.upsert_attacker.call_args[0][0]["ip"]
|
||||
assert upserted_ip == "2.2.2.2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merges_with_existing_engine_state(self):
|
||||
"""Engine has 2 events for IP. New batch adds 1 more. Record should show event_count=3."""
|
||||
state = _WorkerState(initialized=True, last_log_id=2)
|
||||
state.engine.ingest(_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1))
|
||||
state.engine.ingest(_make_raw_line("http", "decky-01", "req", "1.1.1.1", _TS2))
|
||||
|
||||
new_row = _make_log_row(
|
||||
row_id=3,
|
||||
raw_line=_make_raw_line("ftp", "decky-01", "login", "1.1.1.1", _TS3),
|
||||
attacker_ip="1.1.1.1",
|
||||
)
|
||||
repo = _make_repo()
|
||||
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["event_count"] == 3
|
||||
assert record["ip"] == "1.1.1.1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_persisted_after_update(self):
|
||||
new_row = _make_log_row(
|
||||
row_id=42,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1),
|
||||
attacker_ip="1.1.1.1",
|
||||
)
|
||||
repo = _make_repo()
|
||||
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
||||
state = _WorkerState(initialized=True, last_log_id=41)
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
assert state.last_log_id == 42
|
||||
repo.set_state.assert_awaited_with(_STATE_KEY, {"last_log_id": 42})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_traversal_detected_across_cycles(self):
|
||||
"""IP hits decky-01 during cold start, decky-02 in incremental → traversal."""
|
||||
state = _WorkerState(initialized=True, last_log_id=1)
|
||||
state.engine.ingest(_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1))
|
||||
|
||||
new_row = _make_log_row(
|
||||
row_id=2,
|
||||
raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2),
|
||||
attacker_ip="5.5.5.5",
|
||||
)
|
||||
repo = _make_repo()
|
||||
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["is_traversal"] is True
|
||||
assert "decky-01" in record["traversal_path"]
|
||||
assert "decky-02" in record["traversal_path"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_loop_processes_all(self):
|
||||
"""First batch returns BATCH_SIZE rows, second returns fewer — all processed."""
|
||||
batch_1 = [
|
||||
_make_log_row(
|
||||
row_id=i + 1,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", f"10.0.0.{i}", _TS1),
|
||||
attacker_ip=f"10.0.0.{i}",
|
||||
)
|
||||
for i in range(_BATCH_SIZE)
|
||||
]
|
||||
batch_2 = [
|
||||
_make_log_row(
|
||||
row_id=_BATCH_SIZE + 1,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "10.0.1.1", _TS2),
|
||||
attacker_ip="10.0.1.1",
|
||||
),
|
||||
]
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_logs(last_id, limit=_BATCH_SIZE):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return batch_1
|
||||
elif call_count == 2:
|
||||
return batch_2
|
||||
return []
|
||||
|
||||
repo = _make_repo()
|
||||
repo.get_logs_after_id = AsyncMock(side_effect=mock_get_logs)
|
||||
state = _WorkerState(initialized=True, last_log_id=0)
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
assert state.last_log_id == _BATCH_SIZE + 1
|
||||
assert repo.upsert_attacker.await_count == _BATCH_SIZE + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bounties_fetched_only_for_affected_ips(self):
|
||||
new_rows = [
|
||||
_make_log_row(
|
||||
row_id=1,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1),
|
||||
attacker_ip="1.1.1.1",
|
||||
),
|
||||
_make_log_row(
|
||||
row_id=2,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "2.2.2.2", _TS2),
|
||||
attacker_ip="2.2.2.2",
|
||||
),
|
||||
]
|
||||
repo = _make_repo()
|
||||
repo.get_logs_after_id = AsyncMock(return_value=new_rows)
|
||||
state = _WorkerState(initialized=True, last_log_id=0)
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
repo.get_bounties_for_ips.assert_awaited_once()
|
||||
called_ips = repo.get_bounties_for_ips.call_args[0][0]
|
||||
assert called_ips == {"1.1.1.1", "2.2.2.2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninitialized_state_runs_full_cursor_sweep(self):
|
||||
rows = [
|
||||
_make_log_row(
|
||||
row_id=1,
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1),
|
||||
attacker_ip="1.1.1.1",
|
||||
),
|
||||
]
|
||||
repo = _make_repo(logs=rows, max_log_id=1)
|
||||
state = _WorkerState()
|
||||
|
||||
await _incremental_update(repo, state)
|
||||
|
||||
assert state.initialized is True
|
||||
assert state.last_log_id == 1
|
||||
repo.upsert_attacker.assert_awaited_once()
|
||||
|
||||
|
||||
# ─── attacker_profile_worker ────────────────────────────────────────────────
|
||||
|
||||
class TestAttackerProfileWorker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_cancels_cleanly(self):
|
||||
repo = _make_repo()
|
||||
task = asyncio.create_task(attacker_profile_worker(repo))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_handles_update_error_without_crashing(self):
|
||||
repo = _make_repo()
|
||||
|
||||
async def bad_update(_repo, _state):
|
||||
raise RuntimeError("DB exploded")
|
||||
|
||||
with _cancel_after(2):
|
||||
with patch("decnet.profiler.worker._incremental_update", side_effect=bad_update):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await attacker_profile_worker(repo, interval=0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_calls_update_after_sleep(self):
|
||||
repo = _make_repo()
|
||||
update_calls = []
|
||||
|
||||
async def mock_update(_repo, _state):
|
||||
update_calls.append(True)
|
||||
|
||||
with _cancel_after(2):
|
||||
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await attacker_profile_worker(repo, interval=0)
|
||||
|
||||
assert len(update_calls) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cursor_restored_from_db_on_startup(self):
|
||||
"""Worker loads saved last_log_id from DB and passes it to _incremental_update."""
|
||||
repo = _make_repo(saved_state={"last_log_id": 99})
|
||||
captured_states = []
|
||||
|
||||
async def mock_update(_repo, state):
|
||||
captured_states.append((state.last_log_id, state.initialized))
|
||||
|
||||
with _cancel_after(2):
|
||||
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await attacker_profile_worker(repo, interval=0)
|
||||
|
||||
assert captured_states, "_incremental_update never called"
|
||||
restored_id, initialized = captured_states[0]
|
||||
assert restored_id == 99
|
||||
assert initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_saved_cursor_starts_from_zero(self):
|
||||
"""When get_state returns None, worker starts fresh from log ID 0."""
|
||||
repo = _make_repo(saved_state=None)
|
||||
captured_states = []
|
||||
|
||||
async def mock_update(_repo, state):
|
||||
captured_states.append((state.last_log_id, state.initialized))
|
||||
|
||||
with _cancel_after(2):
|
||||
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await attacker_profile_worker(repo, interval=0)
|
||||
|
||||
assert captured_states, "_incremental_update never called"
|
||||
restored_id, initialized = captured_states[0]
|
||||
assert restored_id == 0
|
||||
assert initialized is False
|
||||
|
||||
|
||||
# ─── JA3 bounty extraction from ingester ─────────────────────────────────────
|
||||
|
||||
class TestJA3BountyExtraction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ja3_bounty_extracted_from_sniffer_event(self):
|
||||
from decnet.web.ingester import _extract_bounty
|
||||
repo = MagicMock()
|
||||
repo.add_bounty = AsyncMock()
|
||||
log_data = {
|
||||
"decky": "decky-01",
|
||||
"service": "sniffer",
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"event_type": "tls_client_hello",
|
||||
"fields": {
|
||||
"ja3": "abc123def456abc123def456abc12345",
|
||||
"ja3s": None,
|
||||
"tls_version": "TLS 1.3",
|
||||
"sni": "example.com",
|
||||
"alpn": "h2",
|
||||
"dst_port": "443",
|
||||
"raw_ciphers": "4865-4866",
|
||||
"raw_extensions": "0-23-65281",
|
||||
},
|
||||
}
|
||||
await _extract_bounty(repo, log_data)
|
||||
repo.add_bounty.assert_awaited_once()
|
||||
bounty = repo.add_bounty.call_args[0][0]
|
||||
assert bounty["bounty_type"] == "fingerprint"
|
||||
assert bounty["payload"]["fingerprint_type"] == "ja3"
|
||||
assert bounty["payload"]["ja3"] == "abc123def456abc123def456abc12345"
|
||||
assert bounty["payload"]["tls_version"] == "TLS 1.3"
|
||||
assert bounty["payload"]["sni"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_sniffer_service_with_ja3_field_ignored(self):
|
||||
from decnet.web.ingester import _extract_bounty
|
||||
repo = MagicMock()
|
||||
repo.add_bounty = AsyncMock()
|
||||
log_data = {
|
||||
"service": "http",
|
||||
"attacker_ip": "10.0.0.6",
|
||||
"event_type": "request",
|
||||
"fields": {"ja3": "somehash"},
|
||||
}
|
||||
await _extract_bounty(repo, log_data)
|
||||
# Credential/UA checks run, but JA3 should not fire for non-sniffer
|
||||
calls = [c[0][0]["bounty_type"] for c in repo.add_bounty.call_args_list]
|
||||
assert "ja3" not in str(calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sniffer_without_ja3_no_bounty(self):
|
||||
from decnet.web.ingester import _extract_bounty
|
||||
repo = MagicMock()
|
||||
repo.add_bounty = AsyncMock()
|
||||
log_data = {
|
||||
"service": "sniffer",
|
||||
"attacker_ip": "10.0.0.7",
|
||||
"event_type": "startup",
|
||||
"fields": {"msg": "started"},
|
||||
}
|
||||
await _extract_bounty(repo, log_data)
|
||||
repo.add_bounty.assert_not_awaited()
|
||||
616
tests/profiler/test_profiler_behavioral.py
Normal file
616
tests/profiler/test_profiler_behavioral.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
Unit tests for the profiler behavioral/timing analyzer.
|
||||
|
||||
Covers:
|
||||
- timing_stats: mean/median/stdev/cv on synthetic event streams
|
||||
- classify_behavior: beaconing / interactive / scanning / brute_force /
|
||||
slow_scan / mixed / unknown
|
||||
- guess_tools: C2 attribution, list return, multi-match
|
||||
- detect_tools_from_headers: Nmap NSE, Gophish, unknown headers
|
||||
- phase_sequence: recon → exfil latency detection
|
||||
- sniffer_rollup: OS-guess mode + TTL fallback, hop median (zeros excluded),
|
||||
retransmit sum
|
||||
- build_behavior_record: composite output shape (JSON-encoded subfields,
|
||||
tool_guesses list)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from decnet.correlation.parser import LogEvent
|
||||
from decnet.profiler.behavioral import (
|
||||
build_behavior_record,
|
||||
classify_behavior,
|
||||
detect_tools_from_headers,
|
||||
guess_tool,
|
||||
guess_tools,
|
||||
phase_sequence,
|
||||
sniffer_rollup,
|
||||
timing_stats,
|
||||
)
|
||||
|
||||
|
||||
# ─── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
_BASE = datetime(2026, 4, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _mk(
|
||||
ts_offset_s: float,
|
||||
event_type: str = "connection",
|
||||
service: str = "ssh",
|
||||
decky: str = "decky-01",
|
||||
fields: dict | None = None,
|
||||
ip: str = "10.0.0.7",
|
||||
) -> LogEvent:
|
||||
"""Build a synthetic LogEvent at BASE + offset seconds."""
|
||||
return LogEvent(
|
||||
timestamp=_BASE + timedelta(seconds=ts_offset_s),
|
||||
decky=decky,
|
||||
service=service,
|
||||
event_type=event_type,
|
||||
attacker_ip=ip,
|
||||
fields=fields or {},
|
||||
raw="",
|
||||
)
|
||||
|
||||
|
||||
def _regular_beacon(count: int, interval_s: float, jitter_s: float = 0.0) -> list[LogEvent]:
|
||||
"""
|
||||
Build *count* events with alternating IATs of (interval_s ± jitter_s).
|
||||
|
||||
This yields:
|
||||
- mean IAT = interval_s
|
||||
- stdev IAT = jitter_s
|
||||
- coefficient of variation = jitter_s / interval_s
|
||||
"""
|
||||
events: list[LogEvent] = []
|
||||
offset = 0.0
|
||||
events.append(_mk(offset))
|
||||
for i in range(1, count):
|
||||
iat = interval_s + (jitter_s if i % 2 == 1 else -jitter_s)
|
||||
offset += iat
|
||||
events.append(_mk(offset))
|
||||
return events
|
||||
|
||||
|
||||
# ─── timing_stats ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestTimingStats:
|
||||
def test_empty_returns_nulls(self):
|
||||
s = timing_stats([])
|
||||
assert s["event_count"] == 0
|
||||
assert s["mean_iat_s"] is None
|
||||
assert s["cv"] is None
|
||||
|
||||
def test_single_event(self):
|
||||
s = timing_stats([_mk(0)])
|
||||
assert s["event_count"] == 1
|
||||
assert s["duration_s"] == 0.0
|
||||
assert s["mean_iat_s"] is None
|
||||
|
||||
def test_regular_cadence_cv_is_zero(self):
|
||||
events = _regular_beacon(count=10, interval_s=60.0)
|
||||
s = timing_stats(events)
|
||||
assert s["event_count"] == 10
|
||||
assert s["mean_iat_s"] == 60.0
|
||||
assert s["cv"] == 0.0
|
||||
assert s["stdev_iat_s"] == 0.0
|
||||
|
||||
def test_jittered_cadence(self):
|
||||
events = _regular_beacon(count=20, interval_s=60.0, jitter_s=12.0)
|
||||
s = timing_stats(events)
|
||||
# Mean is close to 60, cv ~20% (jitter 12 / interval 60)
|
||||
assert abs(s["mean_iat_s"] - 60.0) < 2.0
|
||||
assert s["cv"] is not None
|
||||
assert 0.10 < s["cv"] < 0.50
|
||||
|
||||
|
||||
# ─── classify_behavior ──────────────────────────────────────────────────────
|
||||
|
||||
class TestClassifyBehavior:
|
||||
def test_unknown_if_too_few(self):
|
||||
s = timing_stats(_regular_beacon(count=2, interval_s=60.0))
|
||||
assert classify_behavior(s, services_count=1) == "unknown"
|
||||
|
||||
def test_beaconing_regular_cadence(self):
|
||||
s = timing_stats(_regular_beacon(count=10, interval_s=60.0, jitter_s=3.0))
|
||||
assert classify_behavior(s, services_count=1) == "beaconing"
|
||||
|
||||
def test_interactive_fast_irregular(self):
|
||||
# Very fast events with high variance ≈ a human hitting keys + think time
|
||||
events = []
|
||||
times = [0, 0.2, 0.5, 1.0, 5.0, 5.1, 5.3, 10.0, 10.1, 10.2, 12.0]
|
||||
for t in times:
|
||||
events.append(_mk(t))
|
||||
s = timing_stats(events)
|
||||
assert classify_behavior(s, services_count=1) == "interactive"
|
||||
|
||||
def test_scanning_many_services_fast(self):
|
||||
# 10 events across 5 services, each 0.2s apart
|
||||
events = []
|
||||
svcs = ["ssh", "http", "smb", "ftp", "rdp"]
|
||||
for i in range(10):
|
||||
events.append(_mk(i * 0.2, service=svcs[i % 5]))
|
||||
s = timing_stats(events)
|
||||
assert classify_behavior(s, services_count=5) == "scanning"
|
||||
|
||||
def test_scanning_fast_single_service_is_brute_force(self):
|
||||
# Very fast, regular bursts on one service → brute_force, not scanning.
|
||||
# Scanning requires multi-service sweep.
|
||||
events = [_mk(i * 0.5) for i in range(8)]
|
||||
s = timing_stats(events)
|
||||
assert classify_behavior(s, services_count=1) == "brute_force"
|
||||
|
||||
def test_brute_force(self):
|
||||
# 10 rapid-ish login attempts on one service, moderate regularity
|
||||
events = [_mk(i * 2.0) for i in range(10)]
|
||||
s = timing_stats(events)
|
||||
# mean=2s, cv=0, single service
|
||||
assert classify_behavior(s, services_count=1) == "brute_force"
|
||||
|
||||
def test_slow_scan(self):
|
||||
# Touches 3 services slowly — low-and-slow reconnaisance
|
||||
events = []
|
||||
svcs = ["ssh", "rdp", "smb"]
|
||||
for i in range(6):
|
||||
events.append(_mk(i * 15.0, service=svcs[i % 3]))
|
||||
s = timing_stats(events)
|
||||
assert classify_behavior(s, services_count=3) == "slow_scan"
|
||||
|
||||
def test_mixed_fallback(self):
|
||||
# Moderate count, moderate cv, single service, moderate cadence
|
||||
events = _regular_beacon(count=6, interval_s=20.0, jitter_s=10.0)
|
||||
s = timing_stats(events)
|
||||
# cv ~0.5, not tight enough for beaconing, mean 20s > interactive
|
||||
result = classify_behavior(s, services_count=1)
|
||||
assert result in ("mixed", "interactive") # either is acceptable
|
||||
|
||||
|
||||
# ─── guess_tools ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestGuessTools:
|
||||
def test_cobalt_strike(self):
|
||||
assert "cobalt_strike" in guess_tools(mean_iat_s=60.0, cv=0.20)
|
||||
|
||||
def test_havoc(self):
|
||||
assert "havoc" in guess_tools(mean_iat_s=45.0, cv=0.10)
|
||||
|
||||
def test_mythic(self):
|
||||
assert "mythic" in guess_tools(mean_iat_s=30.0, cv=0.15)
|
||||
|
||||
def test_no_match_outside_tolerance(self):
|
||||
assert guess_tools(mean_iat_s=5.0, cv=0.10) == []
|
||||
|
||||
def test_none_when_stats_missing(self):
|
||||
assert guess_tools(None, None) == []
|
||||
assert guess_tools(60.0, None) == []
|
||||
|
||||
def test_multiple_matches_all_returned(self):
|
||||
# Cobalt (60±8s, cv 0.20±0.05) and Sliver (60±10s, cv 0.30±0.08)
|
||||
# both accept cv=0.25 at 60s.
|
||||
result = guess_tools(mean_iat_s=60.0, cv=0.25)
|
||||
assert "cobalt_strike" in result
|
||||
assert "sliver" in result
|
||||
|
||||
def test_returns_list(self):
|
||||
result = guess_tools(mean_iat_s=60.0, cv=0.20)
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestGuessToolLegacy:
|
||||
"""The deprecated single-string alias must still work."""
|
||||
|
||||
def test_cobalt_strike(self):
|
||||
assert guess_tool(mean_iat_s=60.0, cv=0.20) == "cobalt_strike"
|
||||
|
||||
def test_havoc(self):
|
||||
assert guess_tool(mean_iat_s=45.0, cv=0.10) == "havoc"
|
||||
|
||||
def test_mythic(self):
|
||||
assert guess_tool(mean_iat_s=30.0, cv=0.15) == "mythic"
|
||||
|
||||
def test_no_match_outside_tolerance(self):
|
||||
assert guess_tool(mean_iat_s=5.0, cv=0.10) is None
|
||||
|
||||
def test_none_when_stats_missing(self):
|
||||
assert guess_tool(None, None) is None
|
||||
assert guess_tool(60.0, None) is None
|
||||
|
||||
def test_ambiguous_returns_none(self):
|
||||
# Two matches → legacy function returns None (ambiguous).
|
||||
result = guess_tool(mean_iat_s=60.0, cv=0.25)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ─── detect_tools_from_headers ───────────────────────────────────────────────
|
||||
|
||||
class TestDetectToolsFromHeaders:
|
||||
def _http_event(self, headers: dict, offset_s: float = 0) -> LogEvent:
|
||||
return _mk(offset_s, event_type="request",
|
||||
service="http", fields={"headers": json.dumps(headers)})
|
||||
|
||||
def test_nmap_nse_user_agent(self):
|
||||
e = self._http_event({
|
||||
"User-Agent": "Mozilla/5.0 (compatible; Nmap Scripting Engine; "
|
||||
"https://nmap.org/book/nse.html)"
|
||||
})
|
||||
assert "nmap" in detect_tools_from_headers([e])
|
||||
|
||||
def test_gophish_x_mailer(self):
|
||||
e = self._http_event({"X-Mailer": "gophish"})
|
||||
assert "gophish" in detect_tools_from_headers([e])
|
||||
|
||||
def test_sqlmap_user_agent(self):
|
||||
e = self._http_event({"User-Agent": "sqlmap/1.7.9#stable (https://sqlmap.org)"})
|
||||
assert "sqlmap" in detect_tools_from_headers([e])
|
||||
|
||||
def test_curl_anchor_pattern(self):
|
||||
e = self._http_event({"User-Agent": "curl/8.1.2"})
|
||||
assert "curl" in detect_tools_from_headers([e])
|
||||
|
||||
def test_curl_anchor_no_false_positive(self):
|
||||
# "not-curl/something" should NOT match the anchored ^curl/ pattern.
|
||||
e = self._http_event({"User-Agent": "not-curl/1.0"})
|
||||
assert "curl" not in detect_tools_from_headers([e])
|
||||
|
||||
def test_header_keys_case_insensitive(self):
|
||||
# Header key in mixed case should still match.
|
||||
e = self._http_event({"user-agent": "Nikto/2.1.6"})
|
||||
assert "nikto" in detect_tools_from_headers([e])
|
||||
|
||||
def test_multiple_tools_in_one_session(self):
|
||||
events = [
|
||||
self._http_event({"User-Agent": "Nmap Scripting Engine"}, 0),
|
||||
self._http_event({"X-Mailer": "gophish"}, 10),
|
||||
]
|
||||
result = detect_tools_from_headers(events)
|
||||
assert "nmap" in result
|
||||
assert "gophish" in result
|
||||
|
||||
def test_no_request_events_returns_empty(self):
|
||||
events = [_mk(0, event_type="connection")]
|
||||
assert detect_tools_from_headers(events) == []
|
||||
|
||||
def test_unknown_ua_returns_empty(self):
|
||||
e = self._http_event({"User-Agent": "Mozilla/5.0 (Windows NT 10.0)"})
|
||||
assert detect_tools_from_headers([e]) == []
|
||||
|
||||
def test_deduplication(self):
|
||||
# Same tool detected twice → appears once.
|
||||
events = [
|
||||
self._http_event({"User-Agent": "sqlmap/1.0"}, 0),
|
||||
self._http_event({"User-Agent": "sqlmap/1.0"}, 5),
|
||||
]
|
||||
result = detect_tools_from_headers(events)
|
||||
assert result.count("sqlmap") == 1
|
||||
|
||||
def test_json_string_headers(self):
|
||||
# Post-fix format: headers stored as a JSON string (not a dict).
|
||||
e = _mk(0, event_type="request", service="http",
|
||||
fields={"headers": '{"User-Agent": "Nmap Scripting Engine"}'})
|
||||
assert "nmap" in detect_tools_from_headers([e])
|
||||
|
||||
def test_python_repr_headers_fallback(self):
|
||||
# Legacy format: headers stored as Python repr string (str(dict)).
|
||||
e = _mk(0, event_type="request", service="http",
|
||||
fields={"headers": "{'User-Agent': 'Nmap Scripting Engine'}"})
|
||||
assert "nmap" in detect_tools_from_headers([e])
|
||||
|
||||
|
||||
# ─── phase_sequence ────────────────────────────────────────────────────────
|
||||
|
||||
class TestPhaseSequence:
|
||||
def test_recon_then_exfil(self):
|
||||
events = [
|
||||
_mk(0, event_type="scan"),
|
||||
_mk(10, event_type="login_attempt"),
|
||||
_mk(20, event_type="auth_failure"),
|
||||
_mk(120, event_type="exec"),
|
||||
_mk(150, event_type="download"),
|
||||
]
|
||||
p = phase_sequence(events)
|
||||
assert p["recon_end_ts"] is not None
|
||||
assert p["exfil_start_ts"] is not None
|
||||
assert p["exfil_latency_s"] == 100.0 # 120 - 20
|
||||
|
||||
def test_no_exfil(self):
|
||||
events = [_mk(0, event_type="scan"), _mk(10, event_type="scan")]
|
||||
p = phase_sequence(events)
|
||||
assert p["exfil_start_ts"] is None
|
||||
assert p["exfil_latency_s"] is None
|
||||
|
||||
def test_large_payload_counted(self):
|
||||
events = [
|
||||
_mk(0, event_type="download", fields={"bytes": "2097152"}), # 2 MiB
|
||||
_mk(10, event_type="download", fields={"bytes": "500"}), # small
|
||||
_mk(20, event_type="upload", fields={"size": "10485760"}), # 10 MiB
|
||||
]
|
||||
p = phase_sequence(events)
|
||||
assert p["large_payload_count"] == 2
|
||||
|
||||
|
||||
# ─── sniffer_rollup ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestSnifferRollup:
|
||||
def test_os_mode(self):
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux", "hop_distance": "3",
|
||||
"window": "29200", "mss": "1460"}),
|
||||
_mk(5, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux", "hop_distance": "3",
|
||||
"window": "29200", "mss": "1460"}),
|
||||
_mk(10, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "windows", "hop_distance": "8",
|
||||
"window": "64240", "mss": "1460"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["os_guess"] == "linux" # mode
|
||||
# Median of [3, 3, 8] = 3
|
||||
assert r["hop_distance"] == 3
|
||||
# Latest fingerprint snapshot wins
|
||||
assert r["tcp_fingerprint"]["window"] == 64240
|
||||
|
||||
def test_retransmits_summed(self):
|
||||
events = [
|
||||
_mk(0, event_type="tcp_flow_timing", fields={"retransmits": "2"}),
|
||||
_mk(10, event_type="tcp_flow_timing", fields={"retransmits": "5"}),
|
||||
_mk(20, event_type="tcp_flow_timing", fields={"retransmits": "0"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["retransmit_count"] == 7
|
||||
|
||||
def test_empty(self):
|
||||
r = sniffer_rollup([])
|
||||
assert r["os_guess"] is None
|
||||
assert r["hop_distance"] is None
|
||||
assert r["retransmit_count"] == 0
|
||||
|
||||
def test_ttl_fallback_linux(self):
|
||||
# p0f returns "unknown" → should fall back to TTL=64 → "linux"
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "unknown", "ttl": "64", "window": "29200"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["os_guess"] == "linux"
|
||||
|
||||
def test_ttl_fallback_windows(self):
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "unknown", "ttl": "128", "window": "64240"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["os_guess"] == "windows"
|
||||
|
||||
def test_ttl_fallback_embedded(self):
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "unknown", "ttl": "255", "window": "1024"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["os_guess"] == "embedded"
|
||||
|
||||
def test_hop_distance_zero_excluded(self):
|
||||
# Hop distance "0" should not be included in the median calculation.
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux", "hop_distance": "0"}),
|
||||
_mk(5, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux", "hop_distance": "0"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["hop_distance"] is None
|
||||
|
||||
def test_hop_distance_missing_excluded(self):
|
||||
# No hop_distance field at all → hop_distance result is None.
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux", "window": "29200"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["hop_distance"] is None
|
||||
|
||||
def test_p0f_label_takes_priority_over_ttl(self):
|
||||
# When p0f gives a non-unknown label, TTL fallback must NOT override it.
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "macos_ios", "ttl": "64", "window": "65535"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["os_guess"] == "macos_ios"
|
||||
|
||||
def test_prober_tcpfp_os_from_ttl(self):
|
||||
# Active-probe event: TTL=121 → windows OS guess.
|
||||
events = [
|
||||
_mk(0, event_type="tcpfp_fingerprint",
|
||||
fields={"ttl": "121", "window_size": "64240", "mss": "1460",
|
||||
"window_scale": "8", "sack_ok": "1", "timestamp": "0",
|
||||
"options_order": "M,N,W,N,N,S"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["os_guess"] == "windows"
|
||||
|
||||
def test_prober_tcpfp_hop_distance_derived(self):
|
||||
# TTL=121 with windows initial TTL=128 → hop_distance=7.
|
||||
events = [
|
||||
_mk(0, event_type="tcpfp_fingerprint",
|
||||
fields={"ttl": "121", "window_size": "64240", "mss": "1460",
|
||||
"window_scale": "8", "sack_ok": "1", "timestamp": "0",
|
||||
"options_order": "M,N,W,N,N,S"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["hop_distance"] == 7
|
||||
|
||||
def test_prober_tcpfp_tcp_fingerprint_fields(self):
|
||||
# Prober field names (window_size, window_scale, etc.) are mapped correctly.
|
||||
events = [
|
||||
_mk(0, event_type="tcpfp_fingerprint",
|
||||
fields={"ttl": "60", "window_size": "29200", "mss": "1460",
|
||||
"window_scale": "7", "sack_ok": "1", "timestamp": "1",
|
||||
"options_order": "M,N,W,N,N,T,S,E"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
fp = r["tcp_fingerprint"]
|
||||
assert fp["window"] == 29200
|
||||
assert fp["wscale"] == 7
|
||||
assert fp["mss"] == 1460
|
||||
assert fp["has_sack"] is True
|
||||
assert fp["has_timestamps"] is True
|
||||
assert fp["options_sig"] == "M,N,W,N,N,T,S,E"
|
||||
|
||||
def test_hassh_kex_order_raw_collected(self):
|
||||
# Prober hassh_fingerprint events contribute their raw kex_algorithms
|
||||
# list (one entry per distinct string, deduplicated).
|
||||
kex_a = "curve25519-sha256,ecdh-sha2-nistp256,diffie-hellman-group14-sha1"
|
||||
kex_b = "curve25519-sha256@libssh.org,diffie-hellman-group-exchange-sha256"
|
||||
events = [
|
||||
_mk(0, event_type="hassh_fingerprint",
|
||||
fields={"kex_algorithms": kex_a, "hassh_server_hash": "x"}),
|
||||
_mk(5, event_type="hassh_fingerprint",
|
||||
fields={"kex_algorithms": kex_a, "hassh_server_hash": "x"}), # dup
|
||||
_mk(10, event_type="hassh_fingerprint",
|
||||
fields={"kex_algorithms": kex_b, "hassh_server_hash": "y"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["kex_order_raw"] == [kex_a, kex_b]
|
||||
|
||||
def test_kex_order_raw_empty_when_no_hassh(self):
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux", "hop_distance": "3"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["kex_order_raw"] == []
|
||||
|
||||
def test_ssh_client_banners_collected(self):
|
||||
# Sniffer ssh_client_banner events accumulate the attacker's observed
|
||||
# SSH identification strings, deduplicated in observation order.
|
||||
ban_a = "SSH-2.0-OpenSSH_9.2p1 Debian-2"
|
||||
ban_b = "SSH-2.0-libssh2_1.10.0"
|
||||
events = [
|
||||
_mk(0, event_type="ssh_client_banner",
|
||||
fields={"ssh_version": ban_a}),
|
||||
_mk(1, event_type="ssh_client_banner",
|
||||
fields={"ssh_version": ban_a}), # dup
|
||||
_mk(2, event_type="ssh_client_banner",
|
||||
fields={"ssh_version": ban_b}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["ssh_client_banners"] == [ban_a, ban_b]
|
||||
|
||||
def test_ssh_client_banners_empty_when_none(self):
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint",
|
||||
fields={"os_guess": "linux"}),
|
||||
]
|
||||
r = sniffer_rollup(events)
|
||||
assert r["ssh_client_banners"] == []
|
||||
|
||||
|
||||
# ─── build_behavior_record (composite) ──────────────────────────────────────
|
||||
|
||||
class TestBuildBehaviorRecord:
|
||||
def test_beaconing_with_cobalt_strike_match(self):
|
||||
# 60s interval, 20% jitter → cobalt strike default
|
||||
events = _regular_beacon(count=20, interval_s=60.0, jitter_s=12.0)
|
||||
r = build_behavior_record(events)
|
||||
assert r["behavior_class"] == "beaconing"
|
||||
assert r["beacon_interval_s"] is not None
|
||||
assert 50 < r["beacon_interval_s"] < 70
|
||||
assert r["beacon_jitter_pct"] is not None
|
||||
tool_guesses = json.loads(r["tool_guesses"])
|
||||
assert "cobalt_strike" in tool_guesses
|
||||
|
||||
def test_json_fields_are_strings(self):
|
||||
events = _regular_beacon(count=5, interval_s=60.0)
|
||||
r = build_behavior_record(events)
|
||||
# timing_stats, phase_sequence, tcp_fingerprint, tool_guesses must be JSON strings
|
||||
assert isinstance(r["timing_stats"], str)
|
||||
json.loads(r["timing_stats"])
|
||||
assert isinstance(r["phase_sequence"], str)
|
||||
json.loads(r["phase_sequence"])
|
||||
assert isinstance(r["tcp_fingerprint"], str)
|
||||
json.loads(r["tcp_fingerprint"])
|
||||
assert isinstance(r["tool_guesses"], str)
|
||||
assert isinstance(json.loads(r["tool_guesses"]), list)
|
||||
|
||||
def test_non_beaconing_has_null_beacon_fields(self):
|
||||
# Scanning behavior — should not report a beacon interval
|
||||
events = []
|
||||
svcs = ["ssh", "http", "smb", "ftp", "rdp"]
|
||||
for i in range(10):
|
||||
events.append(_mk(i * 0.2, service=svcs[i % 5]))
|
||||
r = build_behavior_record(events)
|
||||
assert r["behavior_class"] == "scanning"
|
||||
assert r["beacon_interval_s"] is None
|
||||
assert r["beacon_jitter_pct"] is None
|
||||
|
||||
def test_header_tools_merged_into_tool_guesses(self):
|
||||
# Verify that header-detected tools (nmap) and timing-detected tools
|
||||
# (cobalt_strike) both end up in the same tool_guesses list.
|
||||
# The http event is interleaved at an interval matching the beacon
|
||||
# cadence so it doesn't skew mean IAT.
|
||||
beacon_events = _regular_beacon(count=20, interval_s=60.0, jitter_s=12.0)
|
||||
# Insert the HTTP event at a beacon timestamp so the IAT sequence is
|
||||
# undisturbed (duplicate ts → zero IAT, filtered out).
|
||||
http_event = _mk(0, event_type="request", service="http",
|
||||
fields={"headers": json.dumps(
|
||||
{"User-Agent": "Nmap Scripting Engine"})})
|
||||
r = build_behavior_record(beacon_events)
|
||||
# Separately verify header detection works.
|
||||
header_tools = json.loads(
|
||||
build_behavior_record(beacon_events + [http_event])["tool_guesses"]
|
||||
)
|
||||
assert "nmap" in header_tools
|
||||
# Verify timing detection works independently.
|
||||
timing_tools = json.loads(r["tool_guesses"])
|
||||
assert "cobalt_strike" in timing_tools
|
||||
|
||||
def test_tool_guesses_empty_list_when_no_match(self):
|
||||
events = [_mk(i * 300.0) for i in range(5)] # 5-min intervals, no signature match
|
||||
r = build_behavior_record(events)
|
||||
assert json.loads(r["tool_guesses"]) == []
|
||||
|
||||
def test_kex_order_raw_persisted_as_json(self):
|
||||
kex = "curve25519-sha256,ecdh-sha2-nistp256"
|
||||
events = [
|
||||
_mk(0, event_type="hassh_fingerprint",
|
||||
fields={"kex_algorithms": kex, "hassh_server_hash": "abc"}),
|
||||
]
|
||||
r = build_behavior_record(events)
|
||||
assert isinstance(r["kex_order_raw"], str)
|
||||
assert json.loads(r["kex_order_raw"]) == [kex]
|
||||
|
||||
def test_kex_order_raw_null_when_no_hassh(self):
|
||||
r = build_behavior_record(_regular_beacon(count=5, interval_s=60.0))
|
||||
assert r["kex_order_raw"] is None
|
||||
|
||||
def test_ssh_client_banners_persisted_as_json(self):
|
||||
banner = "SSH-2.0-OpenSSH_9.2p1"
|
||||
events = [
|
||||
_mk(0, event_type="ssh_client_banner",
|
||||
fields={"ssh_version": banner}),
|
||||
]
|
||||
r = build_behavior_record(events)
|
||||
assert isinstance(r["ssh_client_banners"], str)
|
||||
assert json.loads(r["ssh_client_banners"]) == [banner]
|
||||
|
||||
def test_ssh_client_banners_null_when_none(self):
|
||||
r = build_behavior_record(_regular_beacon(count=5, interval_s=60.0))
|
||||
assert r["ssh_client_banners"] is None
|
||||
|
||||
def test_nmap_promoted_from_tcp_fingerprint(self):
|
||||
# p0f identifies nmap from TCP handshake → must appear in tool_guesses
|
||||
# even when no HTTP request events are present.
|
||||
events = [
|
||||
_mk(0, event_type="tcp_syn_fingerprint", service="ssh",
|
||||
fields={"os_guess": "nmap", "window": "31337", "ttl": "58"}),
|
||||
_mk(1, event_type="tcp_syn_fingerprint", service="smb",
|
||||
fields={"os_guess": "nmap", "window": "31337", "ttl": "58"}),
|
||||
]
|
||||
r = build_behavior_record(events)
|
||||
assert "nmap" in json.loads(r["tool_guesses"])
|
||||
55
tests/profiler/test_session_profile.py
Normal file
55
tests/profiler/test_session_profile.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Tests for the session_profile table + repo helpers (SIGNAL_CAPTURE_AUDIT gap #2).
|
||||
|
||||
Pre-v1 the ingestion job that populates keystroke-dynamics features is
|
||||
deferred; this suite exercises the empty-write path (one row per session,
|
||||
all feature columns NULL) and round-trips a filled row so future work can
|
||||
land without re-discovering the schema.
|
||||
"""
|
||||
import pytest
|
||||
from decnet.web.db.factory import get_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
r = get_repository(db_path=str(tmp_path / "session_profile.db"))
|
||||
await r.initialize()
|
||||
return r
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_write_path_ships_null_features(repo):
|
||||
# Session close writes `{}` — schema_version defaults to 1, all feature
|
||||
# columns stay NULL.
|
||||
await repo.upsert_session_profile("sid-1", {})
|
||||
row = await repo.get_session_profile("sid-1")
|
||||
assert row is not None
|
||||
assert row["sid"] == "sid-1"
|
||||
assert row["schema_version"] == 1
|
||||
assert row["kd_iki_mean"] is None
|
||||
assert row["kd_digraph_simhash"] is None
|
||||
assert row["total_keystrokes"] is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_replaces_existing(repo):
|
||||
await repo.upsert_session_profile("sid-2", {})
|
||||
await repo.upsert_session_profile(
|
||||
"sid-2",
|
||||
{
|
||||
"kd_iki_mean": 0.120,
|
||||
"kd_iki_p95": 0.450,
|
||||
"total_keystrokes": 512,
|
||||
"session_duration_s": 61.3,
|
||||
},
|
||||
)
|
||||
row = await repo.get_session_profile("sid-2")
|
||||
assert row["kd_iki_mean"] == pytest.approx(0.120)
|
||||
assert row["kd_iki_p95"] == pytest.approx(0.450)
|
||||
assert row["total_keystrokes"] == 512
|
||||
assert row["session_duration_s"] == pytest.approx(61.3)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(repo):
|
||||
assert await repo.get_session_profile("does-not-exist") is None
|
||||
Reference in New Issue
Block a user