- Ingester now loads byte-offset from DB on startup (key: ingest_worker_position) and saves it after each batch — prevents full re-read on every API restart - On file truncation/rotation the saved offset is reset to 0 - Profiler worker now loads last_log_id from DB on startup — every restart becomes an incremental update instead of a full cold rebuild - Updated all affected tests to mock get_state/set_state; added new tests covering position restore, set_state call, truncation reset, and cursor restore/cold-start paths
734 lines
27 KiB
Python
734 lines
27 KiB
Python
"""
|
|
Tests for decnet/attacker/worker.py
|
|
|
|
Covers:
|
|
- _cold_start(): full build on first run, cursor persistence
|
|
- _incremental_update(): delta processing, affected-IP-only updates
|
|
- _update_profiles(): traversal detection, bounty merging
|
|
- _extract_commands_from_events(): command harvesting from LogEvent objects
|
|
- _build_record(): record assembly from engine events + bounties
|
|
- _first_contact_deckies(): ordering for single-decky attackers
|
|
- attacker_profile_worker(): cancellation and error handling
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from decnet.correlation.parser import LogEvent
|
|
from decnet.logging.syslog_formatter import SEVERITY_INFO, format_rfc5424
|
|
from decnet.profiler.worker import (
|
|
_BATCH_SIZE,
|
|
_STATE_KEY,
|
|
_WorkerState,
|
|
_build_record,
|
|
_extract_commands_from_events,
|
|
_first_contact_deckies,
|
|
_incremental_update,
|
|
_update_profiles,
|
|
attacker_profile_worker,
|
|
)
|
|
|
|
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
_TS1 = "2026-04-04T10:00:00+00:00"
|
|
_TS2 = "2026-04-04T10:05:00+00:00"
|
|
_TS3 = "2026-04-04T10:10:00+00:00"
|
|
|
|
_DT1 = datetime.fromisoformat(_TS1)
|
|
_DT2 = datetime.fromisoformat(_TS2)
|
|
_DT3 = datetime.fromisoformat(_TS3)
|
|
|
|
|
|
def _make_raw_line(
|
|
service: str = "ssh",
|
|
hostname: str = "decky-01",
|
|
event_type: str = "connection",
|
|
src_ip: str = "1.2.3.4",
|
|
timestamp: str = _TS1,
|
|
**extra: str,
|
|
) -> str:
|
|
return format_rfc5424(
|
|
service=service,
|
|
hostname=hostname,
|
|
event_type=event_type,
|
|
severity=SEVERITY_INFO,
|
|
timestamp=datetime.fromisoformat(timestamp),
|
|
src_ip=src_ip,
|
|
**extra,
|
|
)
|
|
|
|
|
|
def _make_log_row(
|
|
row_id: int = 1,
|
|
raw_line: str = "",
|
|
attacker_ip: str = "1.2.3.4",
|
|
service: str = "ssh",
|
|
event_type: str = "connection",
|
|
decky: str = "decky-01",
|
|
timestamp: datetime = _DT1,
|
|
fields: str = "{}",
|
|
) -> dict:
|
|
if not raw_line:
|
|
raw_line = _make_raw_line(
|
|
service=service,
|
|
hostname=decky,
|
|
event_type=event_type,
|
|
src_ip=attacker_ip,
|
|
timestamp=timestamp.isoformat(),
|
|
)
|
|
return {
|
|
"id": row_id,
|
|
"raw_line": raw_line,
|
|
"attacker_ip": attacker_ip,
|
|
"service": service,
|
|
"event_type": event_type,
|
|
"decky": decky,
|
|
"timestamp": timestamp,
|
|
"fields": fields,
|
|
}
|
|
|
|
|
|
def _make_repo(logs=None, bounties=None, bounties_for_ips=None, max_log_id=0, saved_state=None):
|
|
repo = MagicMock()
|
|
repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {})
|
|
repo.get_bounties_for_ips = AsyncMock(return_value=bounties_for_ips or {})
|
|
repo.get_max_log_id = AsyncMock(return_value=max_log_id)
|
|
# Return provided logs on first call (simulating a single page < BATCH_SIZE), then [] to end loop
|
|
_log_pages = [logs or [], []]
|
|
repo.get_logs_after_id = AsyncMock(side_effect=_log_pages)
|
|
repo.get_state = AsyncMock(return_value=saved_state)
|
|
repo.set_state = AsyncMock()
|
|
repo.upsert_attacker = AsyncMock(return_value="mock-uuid")
|
|
repo.upsert_attacker_behavior = AsyncMock()
|
|
return repo
|
|
|
|
|
|
def _make_log_event(
|
|
ip: str,
|
|
decky: str,
|
|
service: str = "ssh",
|
|
event_type: str = "connection",
|
|
timestamp: datetime = _DT1,
|
|
fields: dict | None = None,
|
|
) -> LogEvent:
|
|
return LogEvent(
|
|
timestamp=timestamp,
|
|
decky=decky,
|
|
service=service,
|
|
event_type=event_type,
|
|
attacker_ip=ip,
|
|
fields=fields or {},
|
|
raw="",
|
|
)
|
|
|
|
|
|
# ─── _first_contact_deckies ───────────────────────────────────────────────────
|
|
|
|
class TestFirstContactDeckies:
|
|
def test_single_decky(self):
|
|
events = [_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1)]
|
|
assert _first_contact_deckies(events) == ["decky-01"]
|
|
|
|
def test_multiple_deckies_ordered_by_first_contact(self):
|
|
events = [
|
|
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
|
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
|
]
|
|
assert _first_contact_deckies(events) == ["decky-01", "decky-02"]
|
|
|
|
def test_revisit_does_not_duplicate(self):
|
|
events = [
|
|
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
|
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
|
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT3), # revisit
|
|
]
|
|
result = _first_contact_deckies(events)
|
|
assert result == ["decky-01", "decky-02"]
|
|
assert result.count("decky-01") == 1
|
|
|
|
|
|
# ─── _extract_commands_from_events ───────────────────────────────────────────
|
|
|
|
class TestExtractCommandsFromEvents:
|
|
def test_extracts_command_field(self):
|
|
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "command", _DT1, {"command": "id"})]
|
|
result = _extract_commands_from_events(events)
|
|
assert len(result) == 1
|
|
assert result[0]["command"] == "id"
|
|
assert result[0]["service"] == "ssh"
|
|
assert result[0]["decky"] == "decky-01"
|
|
|
|
def test_extracts_query_field(self):
|
|
events = [_make_log_event("2.2.2.2", "decky-01", "mysql", "query", _DT1, {"query": "SELECT * FROM users"})]
|
|
result = _extract_commands_from_events(events)
|
|
assert len(result) == 1
|
|
assert result[0]["command"] == "SELECT * FROM users"
|
|
|
|
def test_extracts_input_field(self):
|
|
events = [_make_log_event("3.3.3.3", "decky-01", "ssh", "input", _DT1, {"input": "ls -la"})]
|
|
result = _extract_commands_from_events(events)
|
|
assert len(result) == 1
|
|
assert result[0]["command"] == "ls -la"
|
|
|
|
def test_non_command_event_type_ignored(self):
|
|
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "connection", _DT1, {"command": "id"})]
|
|
result = _extract_commands_from_events(events)
|
|
assert result == []
|
|
|
|
def test_no_command_field_skipped(self):
|
|
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "command", _DT1, {"other": "stuff"})]
|
|
result = _extract_commands_from_events(events)
|
|
assert result == []
|
|
|
|
def test_multiple_commands_all_extracted(self):
|
|
events = [
|
|
_make_log_event("5.5.5.5", "decky-01", "ssh", "command", _DT1, {"command": "id"}),
|
|
_make_log_event("5.5.5.5", "decky-01", "ssh", "command", _DT2, {"command": "uname -a"}),
|
|
]
|
|
result = _extract_commands_from_events(events)
|
|
assert len(result) == 2
|
|
cmds = {r["command"] for r in result}
|
|
assert cmds == {"id", "uname -a"}
|
|
|
|
def test_timestamp_serialized_to_string(self):
|
|
events = [_make_log_event("1.1.1.1", "decky-01", "ssh", "command", _DT1, {"command": "pwd"})]
|
|
result = _extract_commands_from_events(events)
|
|
assert isinstance(result[0]["timestamp"], str)
|
|
|
|
|
|
# ─── _build_record ────────────────────────────────────────────────────────────
|
|
|
|
class TestBuildRecord:
|
|
def _events(self, ip="1.1.1.1"):
|
|
return [
|
|
_make_log_event(ip, "decky-01", "ssh", "conn", _DT1),
|
|
_make_log_event(ip, "decky-01", "http", "req", _DT2),
|
|
]
|
|
|
|
def test_basic_fields(self):
|
|
events = self._events()
|
|
record = _build_record("1.1.1.1", events, None, [], [])
|
|
assert record["ip"] == "1.1.1.1"
|
|
assert record["event_count"] == 2
|
|
assert record["service_count"] == 2
|
|
assert record["decky_count"] == 1
|
|
|
|
def test_first_last_seen(self):
|
|
events = self._events()
|
|
record = _build_record("1.1.1.1", events, None, [], [])
|
|
assert record["first_seen"] == _DT1
|
|
assert record["last_seen"] == _DT2
|
|
|
|
def test_services_json_sorted(self):
|
|
events = self._events()
|
|
record = _build_record("1.1.1.1", events, None, [], [])
|
|
services = json.loads(record["services"])
|
|
assert sorted(services) == services
|
|
|
|
def test_no_traversal(self):
|
|
events = self._events()
|
|
record = _build_record("1.1.1.1", events, None, [], [])
|
|
assert record["is_traversal"] is False
|
|
assert record["traversal_path"] is None
|
|
|
|
def test_with_traversal(self):
|
|
from decnet.correlation.graph import AttackerTraversal, TraversalHop
|
|
hops = [
|
|
TraversalHop(_DT1, "decky-01", "ssh", "conn"),
|
|
TraversalHop(_DT2, "decky-02", "http", "req"),
|
|
]
|
|
t = AttackerTraversal("1.1.1.1", hops)
|
|
events = [
|
|
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
|
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
|
]
|
|
record = _build_record("1.1.1.1", events, t, [], [])
|
|
assert record["is_traversal"] is True
|
|
assert record["traversal_path"] == "decky-01 → decky-02"
|
|
deckies = json.loads(record["deckies"])
|
|
assert deckies == ["decky-01", "decky-02"]
|
|
|
|
def test_bounty_counts(self):
|
|
events = self._events()
|
|
bounties = [
|
|
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
|
|
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
|
|
{"bounty_type": "fingerprint", "attacker_ip": "1.1.1.1"},
|
|
]
|
|
record = _build_record("1.1.1.1", events, None, bounties, [])
|
|
assert record["bounty_count"] == 3
|
|
assert record["credential_count"] == 2
|
|
fps = json.loads(record["fingerprints"])
|
|
assert len(fps) == 1
|
|
assert fps[0]["bounty_type"] == "fingerprint"
|
|
|
|
def test_commands_serialized(self):
|
|
events = self._events()
|
|
cmds = [{"service": "ssh", "decky": "decky-01", "command": "id", "timestamp": "2026-04-04T10:00:00"}]
|
|
record = _build_record("1.1.1.1", events, None, [], cmds)
|
|
parsed = json.loads(record["commands"])
|
|
assert len(parsed) == 1
|
|
assert parsed[0]["command"] == "id"
|
|
|
|
def test_updated_at_is_utc_datetime(self):
|
|
events = self._events()
|
|
record = _build_record("1.1.1.1", events, None, [], [])
|
|
assert isinstance(record["updated_at"], datetime)
|
|
assert record["updated_at"].tzinfo is not None
|
|
|
|
|
|
# ─── cold start via _incremental_update (uninitialized state) ────────────────
|
|
|
|
class TestColdStart:
|
|
@pytest.mark.asyncio
|
|
async def test_cold_start_builds_all_profiles(self):
|
|
rows = [
|
|
_make_log_row(
|
|
row_id=i + 1,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", ip, _TS1),
|
|
attacker_ip=ip,
|
|
)
|
|
for i, ip in enumerate(["1.1.1.1", "2.2.2.2", "3.3.3.3"])
|
|
]
|
|
repo = _make_repo(logs=rows, max_log_id=3)
|
|
state = _WorkerState()
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
assert state.initialized is True
|
|
assert state.last_log_id == 3
|
|
assert repo.upsert_attacker.await_count == 3
|
|
upserted_ips = {c[0][0]["ip"] for c in repo.upsert_attacker.call_args_list}
|
|
assert upserted_ips == {"1.1.1.1", "2.2.2.2", "3.3.3.3"}
|
|
repo.set_state.assert_awaited_with(_STATE_KEY, {"last_log_id": 3})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cold_start_empty_db(self):
|
|
repo = _make_repo(logs=[], max_log_id=0)
|
|
state = _WorkerState()
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
assert state.initialized is True
|
|
assert state.last_log_id == 0
|
|
repo.upsert_attacker.assert_not_awaited()
|
|
repo.set_state.assert_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cold_start_traversal_detected(self):
|
|
rows = [
|
|
_make_log_row(
|
|
row_id=1,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1),
|
|
attacker_ip="5.5.5.5", decky="decky-01",
|
|
),
|
|
_make_log_row(
|
|
row_id=2,
|
|
raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2),
|
|
attacker_ip="5.5.5.5", decky="decky-02",
|
|
),
|
|
]
|
|
repo = _make_repo(logs=rows, max_log_id=2)
|
|
state = _WorkerState()
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
record = repo.upsert_attacker.call_args[0][0]
|
|
assert record["is_traversal"] is True
|
|
assert "decky-01" in record["traversal_path"]
|
|
assert "decky-02" in record["traversal_path"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cold_start_bounties_merged(self):
|
|
raw = _make_raw_line("ssh", "decky-01", "conn", "8.8.8.8", _TS1)
|
|
repo = _make_repo(
|
|
logs=[_make_log_row(row_id=1, raw_line=raw, attacker_ip="8.8.8.8")],
|
|
max_log_id=1,
|
|
bounties_for_ips={"8.8.8.8": [
|
|
{"bounty_type": "credential", "attacker_ip": "8.8.8.8", "payload": {}},
|
|
{"bounty_type": "fingerprint", "attacker_ip": "8.8.8.8", "payload": {"ja3": "abc"}},
|
|
]},
|
|
)
|
|
state = _WorkerState()
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
record = repo.upsert_attacker.call_args[0][0]
|
|
assert record["bounty_count"] == 2
|
|
assert record["credential_count"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cold_start_commands_extracted(self):
|
|
raw = _make_raw_line("ssh", "decky-01", "command", "9.9.9.9", _TS1, command="cat /etc/passwd")
|
|
row = _make_log_row(
|
|
row_id=1,
|
|
raw_line=raw,
|
|
attacker_ip="9.9.9.9",
|
|
event_type="command",
|
|
fields=json.dumps({"command": "cat /etc/passwd"}),
|
|
)
|
|
repo = _make_repo(logs=[row], max_log_id=1)
|
|
state = _WorkerState()
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
record = repo.upsert_attacker.call_args[0][0]
|
|
commands = json.loads(record["commands"])
|
|
assert len(commands) == 1
|
|
assert commands[0]["command"] == "cat /etc/passwd"
|
|
|
|
|
|
# ─── _incremental_update ────────────────────────────────────────────────────
|
|
|
|
class TestIncrementalUpdate:
|
|
@pytest.mark.asyncio
|
|
async def test_no_new_logs_skips_upsert(self):
|
|
repo = _make_repo()
|
|
state = _WorkerState(initialized=True, last_log_id=10)
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
repo.upsert_attacker.assert_not_awaited()
|
|
repo.set_state.assert_awaited_with(_STATE_KEY, {"last_log_id": 10})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_only_affected_ips_upserted(self):
|
|
"""Pre-populate engine with IP-A, then feed new logs only for IP-B."""
|
|
state = _WorkerState(initialized=True, last_log_id=5)
|
|
# Pre-populate engine with IP-A events
|
|
line_a = _make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1)
|
|
state.engine.ingest(line_a)
|
|
|
|
# New batch has only IP-B
|
|
new_row = _make_log_row(
|
|
row_id=6,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "2.2.2.2", _TS2),
|
|
attacker_ip="2.2.2.2",
|
|
)
|
|
repo = _make_repo()
|
|
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
assert repo.upsert_attacker.await_count == 1
|
|
upserted_ip = repo.upsert_attacker.call_args[0][0]["ip"]
|
|
assert upserted_ip == "2.2.2.2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_merges_with_existing_engine_state(self):
|
|
"""Engine has 2 events for IP. New batch adds 1 more. Record should show event_count=3."""
|
|
state = _WorkerState(initialized=True, last_log_id=2)
|
|
state.engine.ingest(_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1))
|
|
state.engine.ingest(_make_raw_line("http", "decky-01", "req", "1.1.1.1", _TS2))
|
|
|
|
new_row = _make_log_row(
|
|
row_id=3,
|
|
raw_line=_make_raw_line("ftp", "decky-01", "login", "1.1.1.1", _TS3),
|
|
attacker_ip="1.1.1.1",
|
|
)
|
|
repo = _make_repo()
|
|
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
record = repo.upsert_attacker.call_args[0][0]
|
|
assert record["event_count"] == 3
|
|
assert record["ip"] == "1.1.1.1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cursor_persisted_after_update(self):
|
|
new_row = _make_log_row(
|
|
row_id=42,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1),
|
|
attacker_ip="1.1.1.1",
|
|
)
|
|
repo = _make_repo()
|
|
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
|
state = _WorkerState(initialized=True, last_log_id=41)
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
assert state.last_log_id == 42
|
|
repo.set_state.assert_awaited_with(_STATE_KEY, {"last_log_id": 42})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_traversal_detected_across_cycles(self):
|
|
"""IP hits decky-01 during cold start, decky-02 in incremental → traversal."""
|
|
state = _WorkerState(initialized=True, last_log_id=1)
|
|
state.engine.ingest(_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1))
|
|
|
|
new_row = _make_log_row(
|
|
row_id=2,
|
|
raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2),
|
|
attacker_ip="5.5.5.5",
|
|
)
|
|
repo = _make_repo()
|
|
repo.get_logs_after_id = AsyncMock(return_value=[new_row])
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
record = repo.upsert_attacker.call_args[0][0]
|
|
assert record["is_traversal"] is True
|
|
assert "decky-01" in record["traversal_path"]
|
|
assert "decky-02" in record["traversal_path"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_loop_processes_all(self):
|
|
"""First batch returns BATCH_SIZE rows, second returns fewer — all processed."""
|
|
batch_1 = [
|
|
_make_log_row(
|
|
row_id=i + 1,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", f"10.0.0.{i}", _TS1),
|
|
attacker_ip=f"10.0.0.{i}",
|
|
)
|
|
for i in range(_BATCH_SIZE)
|
|
]
|
|
batch_2 = [
|
|
_make_log_row(
|
|
row_id=_BATCH_SIZE + 1,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "10.0.1.1", _TS2),
|
|
attacker_ip="10.0.1.1",
|
|
),
|
|
]
|
|
|
|
call_count = 0
|
|
|
|
async def mock_get_logs(last_id, limit=_BATCH_SIZE):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return batch_1
|
|
elif call_count == 2:
|
|
return batch_2
|
|
return []
|
|
|
|
repo = _make_repo()
|
|
repo.get_logs_after_id = AsyncMock(side_effect=mock_get_logs)
|
|
state = _WorkerState(initialized=True, last_log_id=0)
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
assert state.last_log_id == _BATCH_SIZE + 1
|
|
assert repo.upsert_attacker.await_count == _BATCH_SIZE + 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bounties_fetched_only_for_affected_ips(self):
|
|
new_rows = [
|
|
_make_log_row(
|
|
row_id=1,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1),
|
|
attacker_ip="1.1.1.1",
|
|
),
|
|
_make_log_row(
|
|
row_id=2,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "2.2.2.2", _TS2),
|
|
attacker_ip="2.2.2.2",
|
|
),
|
|
]
|
|
repo = _make_repo()
|
|
repo.get_logs_after_id = AsyncMock(return_value=new_rows)
|
|
state = _WorkerState(initialized=True, last_log_id=0)
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
repo.get_bounties_for_ips.assert_awaited_once()
|
|
called_ips = repo.get_bounties_for_ips.call_args[0][0]
|
|
assert called_ips == {"1.1.1.1", "2.2.2.2"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_uninitialized_state_runs_full_cursor_sweep(self):
|
|
rows = [
|
|
_make_log_row(
|
|
row_id=1,
|
|
raw_line=_make_raw_line("ssh", "decky-01", "conn", "1.1.1.1", _TS1),
|
|
attacker_ip="1.1.1.1",
|
|
),
|
|
]
|
|
repo = _make_repo(logs=rows, max_log_id=1)
|
|
state = _WorkerState()
|
|
|
|
await _incremental_update(repo, state)
|
|
|
|
assert state.initialized is True
|
|
assert state.last_log_id == 1
|
|
repo.upsert_attacker.assert_awaited_once()
|
|
|
|
|
|
# ─── attacker_profile_worker ────────────────────────────────────────────────
|
|
|
|
class TestAttackerProfileWorker:
|
|
@pytest.mark.asyncio
|
|
async def test_worker_cancels_cleanly(self):
|
|
repo = _make_repo()
|
|
task = asyncio.create_task(attacker_profile_worker(repo))
|
|
await asyncio.sleep(0)
|
|
task.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await task
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_worker_handles_update_error_without_crashing(self):
|
|
repo = _make_repo()
|
|
_call_count = 0
|
|
|
|
async def fake_sleep(secs):
|
|
nonlocal _call_count
|
|
_call_count += 1
|
|
if _call_count >= 2:
|
|
raise asyncio.CancelledError()
|
|
|
|
async def bad_update(_repo, _state):
|
|
raise RuntimeError("DB exploded")
|
|
|
|
with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep):
|
|
with patch("decnet.profiler.worker._incremental_update", side_effect=bad_update):
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await attacker_profile_worker(repo)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_worker_calls_update_after_sleep(self):
|
|
repo = _make_repo()
|
|
_call_count = 0
|
|
|
|
async def fake_sleep(secs):
|
|
nonlocal _call_count
|
|
_call_count += 1
|
|
if _call_count >= 2:
|
|
raise asyncio.CancelledError()
|
|
|
|
update_calls = []
|
|
|
|
async def mock_update(_repo, _state):
|
|
update_calls.append(True)
|
|
|
|
with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep):
|
|
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await attacker_profile_worker(repo)
|
|
|
|
assert len(update_calls) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cursor_restored_from_db_on_startup(self):
|
|
"""Worker loads saved last_log_id from DB and passes it to _incremental_update."""
|
|
repo = _make_repo(saved_state={"last_log_id": 99})
|
|
_call_count = 0
|
|
|
|
async def fake_sleep(secs):
|
|
nonlocal _call_count
|
|
_call_count += 1
|
|
if _call_count >= 2:
|
|
raise asyncio.CancelledError()
|
|
|
|
captured_states = []
|
|
|
|
async def mock_update(_repo, state):
|
|
captured_states.append((state.last_log_id, state.initialized))
|
|
|
|
with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep):
|
|
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await attacker_profile_worker(repo)
|
|
|
|
assert captured_states, "_incremental_update never called"
|
|
restored_id, initialized = captured_states[0]
|
|
assert restored_id == 99
|
|
assert initialized is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_saved_cursor_starts_from_zero(self):
|
|
"""When get_state returns None, worker starts fresh from log ID 0."""
|
|
repo = _make_repo(saved_state=None)
|
|
_call_count = 0
|
|
|
|
async def fake_sleep(secs):
|
|
nonlocal _call_count
|
|
_call_count += 1
|
|
if _call_count >= 2:
|
|
raise asyncio.CancelledError()
|
|
|
|
captured_states = []
|
|
|
|
async def mock_update(_repo, state):
|
|
captured_states.append((state.last_log_id, state.initialized))
|
|
|
|
with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep):
|
|
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await attacker_profile_worker(repo)
|
|
|
|
assert captured_states, "_incremental_update never called"
|
|
restored_id, initialized = captured_states[0]
|
|
assert restored_id == 0
|
|
assert initialized is False
|
|
|
|
|
|
# ─── JA3 bounty extraction from ingester ─────────────────────────────────────
|
|
|
|
class TestJA3BountyExtraction:
|
|
@pytest.mark.asyncio
|
|
async def test_ja3_bounty_extracted_from_sniffer_event(self):
|
|
from decnet.web.ingester import _extract_bounty
|
|
repo = MagicMock()
|
|
repo.add_bounty = AsyncMock()
|
|
log_data = {
|
|
"decky": "decky-01",
|
|
"service": "sniffer",
|
|
"attacker_ip": "10.0.0.5",
|
|
"event_type": "tls_client_hello",
|
|
"fields": {
|
|
"ja3": "abc123def456abc123def456abc12345",
|
|
"ja3s": None,
|
|
"tls_version": "TLS 1.3",
|
|
"sni": "example.com",
|
|
"alpn": "h2",
|
|
"dst_port": "443",
|
|
"raw_ciphers": "4865-4866",
|
|
"raw_extensions": "0-23-65281",
|
|
},
|
|
}
|
|
await _extract_bounty(repo, log_data)
|
|
repo.add_bounty.assert_awaited_once()
|
|
bounty = repo.add_bounty.call_args[0][0]
|
|
assert bounty["bounty_type"] == "fingerprint"
|
|
assert bounty["payload"]["fingerprint_type"] == "ja3"
|
|
assert bounty["payload"]["ja3"] == "abc123def456abc123def456abc12345"
|
|
assert bounty["payload"]["tls_version"] == "TLS 1.3"
|
|
assert bounty["payload"]["sni"] == "example.com"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_sniffer_service_with_ja3_field_ignored(self):
|
|
from decnet.web.ingester import _extract_bounty
|
|
repo = MagicMock()
|
|
repo.add_bounty = AsyncMock()
|
|
log_data = {
|
|
"service": "http",
|
|
"attacker_ip": "10.0.0.6",
|
|
"event_type": "request",
|
|
"fields": {"ja3": "somehash"},
|
|
}
|
|
await _extract_bounty(repo, log_data)
|
|
# Credential/UA checks run, but JA3 should not fire for non-sniffer
|
|
calls = [c[0][0]["bounty_type"] for c in repo.add_bounty.call_args_list]
|
|
assert "ja3" not in str(calls)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sniffer_without_ja3_no_bounty(self):
|
|
from decnet.web.ingester import _extract_bounty
|
|
repo = MagicMock()
|
|
repo.add_bounty = AsyncMock()
|
|
log_data = {
|
|
"service": "sniffer",
|
|
"attacker_ip": "10.0.0.7",
|
|
"event_type": "startup",
|
|
"fields": {"msg": "started"},
|
|
}
|
|
await _extract_bounty(repo, log_data)
|
|
repo.add_bounty.assert_not_awaited()
|