feat: Phase 1 — JA3/JA3S sniffer, Attacker model, profile worker
Add passive TLS fingerprinting via a sniffer container on the MACVLAN interface, plus the Attacker table and periodic rebuild worker that correlates per-IP profiles from Log + Bounty + CorrelationEngine. - templates/sniffer/: Scapy sniffer with pure-Python TLS parser; emits tls_client_hello / tls_session RFC 5424 lines with ja3, ja3s, sni, alpn, raw_ciphers, raw_extensions; GREASE filtered per RFC 8701 - decnet/services/sniffer.py: service plugin (no ports, NET_RAW/NET_ADMIN) - decnet/web/db/models.py: Attacker SQLModel table + AttackersResponse - decnet/web/db/repository.py: 5 new abstract methods - decnet/web/db/sqlite/repository.py: implement all 5 (upsert, pagination, sort by recent/active/traversals, bounty grouping) - decnet/web/attacker_worker.py: 30s periodic rebuild via CorrelationEngine; extracts commands from log fields, merges fingerprint bounties - decnet/web/api.py: wire attacker_profile_worker into lifespan - decnet/web/ingester.py: extract JA3 bounty (fingerprint_type=ja3) - development/DEVELOPMENT.md: full attacker intelligence collection roadmap - pyproject.toml: scapy>=2.6.1 added to dev deps - tests: test_sniffer_ja3.py (40+ vectors), test_attacker_worker.py, test_base_repo.py / test_web_api.py updated for new surface
This commit is contained in:
515
tests/test_attacker_worker.py
Normal file
515
tests/test_attacker_worker.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
Tests for decnet/web/attacker_worker.py
|
||||
|
||||
Covers:
|
||||
- _rebuild(): CorrelationEngine integration, traversal detection, upsert calls
|
||||
- _extract_commands(): command harvesting from raw log rows
|
||||
- _build_record(): record assembly from engine events + bounties
|
||||
- _first_contact_deckies(): ordering for single-decky attackers
|
||||
- attacker_profile_worker(): cancellation and error handling
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.logging.syslog_formatter import SEVERITY_INFO, format_rfc5424
|
||||
from decnet.web.attacker_worker import (
|
||||
_build_record,
|
||||
_extract_commands,
|
||||
_first_contact_deckies,
|
||||
_rebuild,
|
||||
attacker_profile_worker,
|
||||
)
|
||||
from decnet.correlation.parser import LogEvent
|
||||
|
||||
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
_TS1 = "2026-04-04T10:00:00+00:00"
|
||||
_TS2 = "2026-04-04T10:05:00+00:00"
|
||||
_TS3 = "2026-04-04T10:10:00+00:00"
|
||||
|
||||
_DT1 = datetime.fromisoformat(_TS1)
|
||||
_DT2 = datetime.fromisoformat(_TS2)
|
||||
_DT3 = datetime.fromisoformat(_TS3)
|
||||
|
||||
|
||||
def _make_raw_line(
|
||||
service: str = "ssh",
|
||||
hostname: str = "decky-01",
|
||||
event_type: str = "connection",
|
||||
src_ip: str = "1.2.3.4",
|
||||
timestamp: str = _TS1,
|
||||
**extra: str,
|
||||
) -> str:
|
||||
return format_rfc5424(
|
||||
service=service,
|
||||
hostname=hostname,
|
||||
event_type=event_type,
|
||||
severity=SEVERITY_INFO,
|
||||
timestamp=datetime.fromisoformat(timestamp),
|
||||
src_ip=src_ip,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
def _make_log_row(
|
||||
raw_line: str = "",
|
||||
attacker_ip: str = "1.2.3.4",
|
||||
service: str = "ssh",
|
||||
event_type: str = "connection",
|
||||
decky: str = "decky-01",
|
||||
timestamp: datetime = _DT1,
|
||||
fields: str = "{}",
|
||||
) -> dict:
|
||||
if not raw_line:
|
||||
raw_line = _make_raw_line(
|
||||
service=service,
|
||||
hostname=decky,
|
||||
event_type=event_type,
|
||||
src_ip=attacker_ip,
|
||||
timestamp=timestamp.isoformat(),
|
||||
)
|
||||
return {
|
||||
"id": 1,
|
||||
"raw_line": raw_line,
|
||||
"attacker_ip": attacker_ip,
|
||||
"service": service,
|
||||
"event_type": event_type,
|
||||
"decky": decky,
|
||||
"timestamp": timestamp,
|
||||
"fields": fields,
|
||||
}
|
||||
|
||||
|
||||
def _make_repo(logs=None, bounties=None):
|
||||
repo = MagicMock()
|
||||
repo.get_all_logs_raw = AsyncMock(return_value=logs or [])
|
||||
repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {})
|
||||
repo.upsert_attacker = AsyncMock()
|
||||
return repo
|
||||
|
||||
|
||||
def _make_log_event(
|
||||
ip: str,
|
||||
decky: str,
|
||||
service: str = "ssh",
|
||||
event_type: str = "connection",
|
||||
timestamp: datetime = _DT1,
|
||||
) -> LogEvent:
|
||||
return LogEvent(
|
||||
timestamp=timestamp,
|
||||
decky=decky,
|
||||
service=service,
|
||||
event_type=event_type,
|
||||
attacker_ip=ip,
|
||||
fields={},
|
||||
raw="",
|
||||
)
|
||||
|
||||
|
||||
# ─── _first_contact_deckies ───────────────────────────────────────────────────
|
||||
|
||||
class TestFirstContactDeckies:
|
||||
def test_single_decky(self):
|
||||
events = [_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1)]
|
||||
assert _first_contact_deckies(events) == ["decky-01"]
|
||||
|
||||
def test_multiple_deckies_ordered_by_first_contact(self):
|
||||
events = [
|
||||
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
||||
]
|
||||
assert _first_contact_deckies(events) == ["decky-01", "decky-02"]
|
||||
|
||||
def test_revisit_does_not_duplicate(self):
|
||||
events = [
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
||||
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT3), # revisit
|
||||
]
|
||||
result = _first_contact_deckies(events)
|
||||
assert result == ["decky-01", "decky-02"]
|
||||
assert result.count("decky-01") == 1
|
||||
|
||||
|
||||
# ─── _extract_commands ────────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractCommands:
|
||||
def _row(self, ip, event_type, fields):
|
||||
return _make_log_row(
|
||||
attacker_ip=ip,
|
||||
event_type=event_type,
|
||||
service="ssh",
|
||||
decky="decky-01",
|
||||
fields=json.dumps(fields),
|
||||
)
|
||||
|
||||
def test_extracts_command_field(self):
|
||||
rows = [self._row("1.1.1.1", "command", {"command": "id"})]
|
||||
result = _extract_commands(rows, "1.1.1.1")
|
||||
assert len(result) == 1
|
||||
assert result[0]["command"] == "id"
|
||||
assert result[0]["service"] == "ssh"
|
||||
assert result[0]["decky"] == "decky-01"
|
||||
|
||||
def test_extracts_query_field(self):
|
||||
rows = [self._row("2.2.2.2", "query", {"query": "SELECT * FROM users"})]
|
||||
result = _extract_commands(rows, "2.2.2.2")
|
||||
assert len(result) == 1
|
||||
assert result[0]["command"] == "SELECT * FROM users"
|
||||
|
||||
def test_extracts_input_field(self):
|
||||
rows = [self._row("3.3.3.3", "input", {"input": "ls -la"})]
|
||||
result = _extract_commands(rows, "3.3.3.3")
|
||||
assert len(result) == 1
|
||||
assert result[0]["command"] == "ls -la"
|
||||
|
||||
def test_non_command_event_type_ignored(self):
|
||||
rows = [self._row("1.1.1.1", "connection", {"command": "id"})]
|
||||
result = _extract_commands(rows, "1.1.1.1")
|
||||
assert result == []
|
||||
|
||||
def test_wrong_ip_ignored(self):
|
||||
rows = [self._row("9.9.9.9", "command", {"command": "whoami"})]
|
||||
result = _extract_commands(rows, "1.1.1.1")
|
||||
assert result == []
|
||||
|
||||
def test_no_command_field_skipped(self):
|
||||
rows = [self._row("1.1.1.1", "command", {"other": "stuff"})]
|
||||
result = _extract_commands(rows, "1.1.1.1")
|
||||
assert result == []
|
||||
|
||||
def test_invalid_json_fields_skipped(self):
|
||||
row = _make_log_row(
|
||||
attacker_ip="1.1.1.1",
|
||||
event_type="command",
|
||||
fields="not valid json",
|
||||
)
|
||||
result = _extract_commands([row], "1.1.1.1")
|
||||
assert result == []
|
||||
|
||||
def test_multiple_commands_all_extracted(self):
|
||||
rows = [
|
||||
self._row("5.5.5.5", "command", {"command": "id"}),
|
||||
self._row("5.5.5.5", "command", {"command": "uname -a"}),
|
||||
]
|
||||
result = _extract_commands(rows, "5.5.5.5")
|
||||
assert len(result) == 2
|
||||
cmds = {r["command"] for r in result}
|
||||
assert cmds == {"id", "uname -a"}
|
||||
|
||||
def test_timestamp_serialized_to_string(self):
|
||||
rows = [self._row("1.1.1.1", "command", {"command": "pwd"})]
|
||||
result = _extract_commands(rows, "1.1.1.1")
|
||||
assert isinstance(result[0]["timestamp"], str)
|
||||
|
||||
|
||||
# ─── _build_record ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestBuildRecord:
|
||||
def _events(self, ip="1.1.1.1"):
|
||||
return [
|
||||
_make_log_event(ip, "decky-01", "ssh", "conn", _DT1),
|
||||
_make_log_event(ip, "decky-01", "http", "req", _DT2),
|
||||
]
|
||||
|
||||
def test_basic_fields(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert record["ip"] == "1.1.1.1"
|
||||
assert record["event_count"] == 2
|
||||
assert record["service_count"] == 2
|
||||
assert record["decky_count"] == 1
|
||||
|
||||
def test_first_last_seen(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert record["first_seen"] == _DT1
|
||||
assert record["last_seen"] == _DT2
|
||||
|
||||
def test_services_json_sorted(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
services = json.loads(record["services"])
|
||||
assert sorted(services) == services
|
||||
|
||||
def test_no_traversal(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert record["is_traversal"] is False
|
||||
assert record["traversal_path"] is None
|
||||
|
||||
def test_with_traversal(self):
|
||||
from decnet.correlation.graph import AttackerTraversal, TraversalHop
|
||||
hops = [
|
||||
TraversalHop(_DT1, "decky-01", "ssh", "conn"),
|
||||
TraversalHop(_DT2, "decky-02", "http", "req"),
|
||||
]
|
||||
t = AttackerTraversal("1.1.1.1", hops)
|
||||
events = [
|
||||
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
|
||||
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
|
||||
]
|
||||
record = _build_record("1.1.1.1", events, t, [], [])
|
||||
assert record["is_traversal"] is True
|
||||
assert record["traversal_path"] == "decky-01 → decky-02"
|
||||
deckies = json.loads(record["deckies"])
|
||||
assert deckies == ["decky-01", "decky-02"]
|
||||
|
||||
def test_bounty_counts(self):
|
||||
events = self._events()
|
||||
bounties = [
|
||||
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
|
||||
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
|
||||
{"bounty_type": "fingerprint", "attacker_ip": "1.1.1.1"},
|
||||
]
|
||||
record = _build_record("1.1.1.1", events, None, bounties, [])
|
||||
assert record["bounty_count"] == 3
|
||||
assert record["credential_count"] == 2
|
||||
fps = json.loads(record["fingerprints"])
|
||||
assert len(fps) == 1
|
||||
assert fps[0]["bounty_type"] == "fingerprint"
|
||||
|
||||
def test_commands_serialized(self):
|
||||
events = self._events()
|
||||
cmds = [{"service": "ssh", "decky": "decky-01", "command": "id", "timestamp": "2026-04-04T10:00:00"}]
|
||||
record = _build_record("1.1.1.1", events, None, [], cmds)
|
||||
parsed = json.loads(record["commands"])
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0]["command"] == "id"
|
||||
|
||||
def test_updated_at_is_utc_datetime(self):
|
||||
events = self._events()
|
||||
record = _build_record("1.1.1.1", events, None, [], [])
|
||||
assert isinstance(record["updated_at"], datetime)
|
||||
assert record["updated_at"].tzinfo is not None
|
||||
|
||||
|
||||
# ─── _rebuild ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRebuild:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_logs_no_upsert(self):
|
||||
repo = _make_repo(logs=[])
|
||||
await _rebuild(repo)
|
||||
repo.upsert_attacker.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_attacker_upserted(self):
|
||||
raw = _make_raw_line("ssh", "decky-01", "connection", "10.0.0.1", _TS1)
|
||||
row = _make_log_row(raw_line=raw, attacker_ip="10.0.0.1")
|
||||
repo = _make_repo(logs=[row])
|
||||
await _rebuild(repo)
|
||||
repo.upsert_attacker.assert_awaited_once()
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["ip"] == "10.0.0.1"
|
||||
assert record["event_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_attackers_all_upserted(self):
|
||||
rows = [
|
||||
_make_log_row(
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", ip, _TS1),
|
||||
attacker_ip=ip,
|
||||
)
|
||||
for ip in ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
|
||||
]
|
||||
repo = _make_repo(logs=rows)
|
||||
await _rebuild(repo)
|
||||
assert repo.upsert_attacker.await_count == 3
|
||||
upserted_ips = {c[0][0]["ip"] for c in repo.upsert_attacker.call_args_list}
|
||||
assert upserted_ips == {"1.1.1.1", "2.2.2.2", "3.3.3.3"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_traversal_detected_across_two_deckies(self):
|
||||
rows = [
|
||||
_make_log_row(
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1),
|
||||
attacker_ip="5.5.5.5", decky="decky-01",
|
||||
),
|
||||
_make_log_row(
|
||||
raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2),
|
||||
attacker_ip="5.5.5.5", decky="decky-02",
|
||||
),
|
||||
]
|
||||
repo = _make_repo(logs=rows)
|
||||
await _rebuild(repo)
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["is_traversal"] is True
|
||||
assert "decky-01" in record["traversal_path"]
|
||||
assert "decky-02" in record["traversal_path"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_decky_not_traversal(self):
|
||||
rows = [
|
||||
_make_log_row(
|
||||
raw_line=_make_raw_line("ssh", "decky-01", "conn", "7.7.7.7", _TS1),
|
||||
attacker_ip="7.7.7.7",
|
||||
),
|
||||
_make_log_row(
|
||||
raw_line=_make_raw_line("http", "decky-01", "req", "7.7.7.7", _TS2),
|
||||
attacker_ip="7.7.7.7",
|
||||
),
|
||||
]
|
||||
repo = _make_repo(logs=rows)
|
||||
await _rebuild(repo)
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["is_traversal"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bounties_merged_into_record(self):
|
||||
raw = _make_raw_line("ssh", "decky-01", "conn", "8.8.8.8", _TS1)
|
||||
repo = _make_repo(
|
||||
logs=[_make_log_row(raw_line=raw, attacker_ip="8.8.8.8")],
|
||||
bounties={"8.8.8.8": [
|
||||
{"bounty_type": "credential", "attacker_ip": "8.8.8.8", "payload": {}},
|
||||
{"bounty_type": "fingerprint", "attacker_ip": "8.8.8.8", "payload": {"ja3": "abc"}},
|
||||
]},
|
||||
)
|
||||
await _rebuild(repo)
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
assert record["bounty_count"] == 2
|
||||
assert record["credential_count"] == 1
|
||||
fps = json.loads(record["fingerprints"])
|
||||
assert len(fps) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands_extracted_during_rebuild(self):
|
||||
raw = _make_raw_line("ssh", "decky-01", "command", "9.9.9.9", _TS1)
|
||||
row = _make_log_row(
|
||||
raw_line=raw,
|
||||
attacker_ip="9.9.9.9",
|
||||
event_type="command",
|
||||
fields=json.dumps({"command": "cat /etc/passwd"}),
|
||||
)
|
||||
repo = _make_repo(logs=[row])
|
||||
await _rebuild(repo)
|
||||
record = repo.upsert_attacker.call_args[0][0]
|
||||
commands = json.loads(record["commands"])
|
||||
assert len(commands) == 1
|
||||
assert commands[0]["command"] == "cat /etc/passwd"
|
||||
|
||||
|
||||
# ─── attacker_profile_worker ──────────────────────────────────────────────────
|
||||
|
||||
class TestAttackerProfileWorker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_cancels_cleanly(self):
|
||||
repo = _make_repo()
|
||||
task = asyncio.create_task(attacker_profile_worker(repo))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_handles_rebuild_error_without_crashing(self):
|
||||
repo = _make_repo()
|
||||
_call_count = 0
|
||||
|
||||
async def fake_sleep(secs):
|
||||
nonlocal _call_count
|
||||
_call_count += 1
|
||||
if _call_count >= 2:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
async def bad_rebuild(_repo):
|
||||
raise RuntimeError("DB exploded")
|
||||
|
||||
with patch("decnet.web.attacker_worker.asyncio.sleep", side_effect=fake_sleep):
|
||||
with patch("decnet.web.attacker_worker._rebuild", side_effect=bad_rebuild):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await attacker_profile_worker(repo)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_calls_rebuild_after_sleep(self):
|
||||
repo = _make_repo()
|
||||
_call_count = 0
|
||||
|
||||
async def fake_sleep(secs):
|
||||
nonlocal _call_count
|
||||
_call_count += 1
|
||||
if _call_count >= 2:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
rebuild_calls = []
|
||||
|
||||
async def mock_rebuild(_repo):
|
||||
rebuild_calls.append(True)
|
||||
|
||||
with patch("decnet.web.attacker_worker.asyncio.sleep", side_effect=fake_sleep):
|
||||
with patch("decnet.web.attacker_worker._rebuild", side_effect=mock_rebuild):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await attacker_profile_worker(repo)
|
||||
|
||||
assert len(rebuild_calls) >= 1
|
||||
|
||||
|
||||
# ─── JA3 bounty extraction from ingester ─────────────────────────────────────
|
||||
|
||||
class TestJA3BountyExtraction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ja3_bounty_extracted_from_sniffer_event(self):
|
||||
from decnet.web.ingester import _extract_bounty
|
||||
repo = MagicMock()
|
||||
repo.add_bounty = AsyncMock()
|
||||
log_data = {
|
||||
"decky": "decky-01",
|
||||
"service": "sniffer",
|
||||
"attacker_ip": "10.0.0.5",
|
||||
"event_type": "tls_client_hello",
|
||||
"fields": {
|
||||
"ja3": "abc123def456abc123def456abc12345",
|
||||
"ja3s": None,
|
||||
"tls_version": "TLS 1.3",
|
||||
"sni": "example.com",
|
||||
"alpn": "h2",
|
||||
"dst_port": "443",
|
||||
"raw_ciphers": "4865-4866",
|
||||
"raw_extensions": "0-23-65281",
|
||||
},
|
||||
}
|
||||
await _extract_bounty(repo, log_data)
|
||||
repo.add_bounty.assert_awaited_once()
|
||||
bounty = repo.add_bounty.call_args[0][0]
|
||||
assert bounty["bounty_type"] == "fingerprint"
|
||||
assert bounty["payload"]["fingerprint_type"] == "ja3"
|
||||
assert bounty["payload"]["ja3"] == "abc123def456abc123def456abc12345"
|
||||
assert bounty["payload"]["tls_version"] == "TLS 1.3"
|
||||
assert bounty["payload"]["sni"] == "example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_sniffer_service_with_ja3_field_ignored(self):
|
||||
from decnet.web.ingester import _extract_bounty
|
||||
repo = MagicMock()
|
||||
repo.add_bounty = AsyncMock()
|
||||
log_data = {
|
||||
"service": "http",
|
||||
"attacker_ip": "10.0.0.6",
|
||||
"event_type": "request",
|
||||
"fields": {"ja3": "somehash"},
|
||||
}
|
||||
await _extract_bounty(repo, log_data)
|
||||
# Credential/UA checks run, but JA3 should not fire for non-sniffer
|
||||
calls = [c[0][0]["bounty_type"] for c in repo.add_bounty.call_args_list]
|
||||
assert "ja3" not in str(calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sniffer_without_ja3_no_bounty(self):
|
||||
from decnet.web.ingester import _extract_bounty
|
||||
repo = MagicMock()
|
||||
repo.add_bounty = AsyncMock()
|
||||
log_data = {
|
||||
"service": "sniffer",
|
||||
"attacker_ip": "10.0.0.7",
|
||||
"event_type": "startup",
|
||||
"fields": {"msg": "started"},
|
||||
}
|
||||
await _extract_bounty(repo, log_data)
|
||||
repo.add_bounty.assert_not_awaited()
|
||||
@@ -21,6 +21,11 @@ class DummyRepo(BaseRepository):
|
||||
async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw)
|
||||
async def get_state(self, k): await super().get_state(k)
|
||||
async def set_state(self, k, v): await super().set_state(k, v)
|
||||
async def get_all_logs_raw(self): await super().get_all_logs_raw()
|
||||
async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip()
|
||||
async def upsert_attacker(self, d): await super().upsert_attacker(d)
|
||||
async def get_attackers(self, **kw): await super().get_attackers(**kw)
|
||||
async def get_total_attackers(self, **kw): await super().get_total_attackers(**kw)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_repo_coverage():
|
||||
@@ -41,3 +46,8 @@ async def test_base_repo_coverage():
|
||||
await dr.get_total_bounties()
|
||||
await dr.get_state("k")
|
||||
await dr.set_state("k", "v")
|
||||
await dr.get_all_logs_raw()
|
||||
await dr.get_all_bounties_by_ip()
|
||||
await dr.upsert_attacker({})
|
||||
await dr.get_attackers()
|
||||
await dr.get_total_attackers()
|
||||
|
||||
437
tests/test_sniffer_ja3.py
Normal file
437
tests/test_sniffer_ja3.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
Unit tests for the JA3/JA3S parsing logic in templates/sniffer/server.py.
|
||||
|
||||
Imports the parser functions directly via sys.path manipulation, with
|
||||
decnet_logging mocked out (it's a container-side stub at template build time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# ─── Import sniffer module with mocked decnet_logging ─────────────────────────
|
||||
|
||||
_SNIFFER_DIR = str(Path(__file__).parent.parent / "templates" / "sniffer")
|
||||
|
||||
def _load_sniffer():
|
||||
"""Load templates/sniffer/server.py with decnet_logging stubbed out."""
|
||||
# Stub the decnet_logging module that server.py imports
|
||||
_stub = types.ModuleType("decnet_logging")
|
||||
_stub.SEVERITY_INFO = 6
|
||||
_stub.SEVERITY_WARNING = 4
|
||||
_stub.syslog_line = MagicMock(return_value="<134>1 fake")
|
||||
_stub.write_syslog_file = MagicMock()
|
||||
sys.modules.setdefault("decnet_logging", _stub)
|
||||
|
||||
if _SNIFFER_DIR not in sys.path:
|
||||
sys.path.insert(0, _SNIFFER_DIR)
|
||||
|
||||
import importlib
|
||||
if "server" in sys.modules:
|
||||
return sys.modules["server"]
|
||||
import server as _srv
|
||||
return _srv
|
||||
|
||||
_srv = _load_sniffer()
|
||||
|
||||
_parse_client_hello = _srv._parse_client_hello
|
||||
_parse_server_hello = _srv._parse_server_hello
|
||||
_ja3 = _srv._ja3
|
||||
_ja3s = _srv._ja3s
|
||||
_is_grease = _srv._is_grease
|
||||
_filter_grease = _srv._filter_grease
|
||||
_tls_version_str = _srv._tls_version_str
|
||||
|
||||
|
||||
# ─── TLS byte builder helpers ─────────────────────────────────────────────────
|
||||
|
||||
def _build_client_hello(
|
||||
version: int = 0x0303,
|
||||
cipher_suites: list[int] | None = None,
|
||||
extensions_bytes: bytes = b"",
|
||||
) -> bytes:
|
||||
"""Build a minimal valid TLS ClientHello byte sequence."""
|
||||
if cipher_suites is None:
|
||||
cipher_suites = [0x002F, 0x0035] # AES-128-SHA, AES-256-SHA
|
||||
|
||||
random_bytes = b"\xAB" * 32
|
||||
session_id = b"\x00" # no session id
|
||||
cs_bytes = b"".join(struct.pack("!H", c) for c in cipher_suites)
|
||||
cs_len = struct.pack("!H", len(cs_bytes))
|
||||
compression = b"\x01\x00" # 1 method: null
|
||||
|
||||
if extensions_bytes:
|
||||
ext_block = struct.pack("!H", len(extensions_bytes)) + extensions_bytes
|
||||
else:
|
||||
ext_block = b"\x00\x00"
|
||||
|
||||
body = (
|
||||
struct.pack("!H", version)
|
||||
+ random_bytes
|
||||
+ session_id
|
||||
+ cs_len
|
||||
+ cs_bytes
|
||||
+ compression
|
||||
+ ext_block
|
||||
)
|
||||
|
||||
hs_header = b"\x01" + struct.pack("!I", len(body))[1:] # type + 3-byte len
|
||||
record_payload = hs_header + body
|
||||
record = b"\x16\x03\x01" + struct.pack("!H", len(record_payload)) + record_payload
|
||||
return record
|
||||
|
||||
|
||||
def _build_extension(ext_type: int, data: bytes) -> bytes:
|
||||
return struct.pack("!HH", ext_type, len(data)) + data
|
||||
|
||||
|
||||
def _build_sni_extension(hostname: str) -> bytes:
|
||||
name_bytes = hostname.encode()
|
||||
# server_name: type(1) + len(2) + name
|
||||
entry = b"\x00" + struct.pack("!H", len(name_bytes)) + name_bytes
|
||||
# server_name_list: len(2) + entries
|
||||
lst = struct.pack("!H", len(entry)) + entry
|
||||
return _build_extension(0x0000, lst)
|
||||
|
||||
|
||||
def _build_supported_groups_extension(groups: list[int]) -> bytes:
|
||||
grp_bytes = b"".join(struct.pack("!H", g) for g in groups)
|
||||
data = struct.pack("!H", len(grp_bytes)) + grp_bytes
|
||||
return _build_extension(0x000A, data)
|
||||
|
||||
|
||||
def _build_ec_point_formats_extension(formats: list[int]) -> bytes:
|
||||
pf = bytes(formats)
|
||||
data = bytes([len(pf)]) + pf
|
||||
return _build_extension(0x000B, data)
|
||||
|
||||
|
||||
def _build_alpn_extension(protocols: list[str]) -> bytes:
|
||||
proto_bytes = b""
|
||||
for p in protocols:
|
||||
pb = p.encode()
|
||||
proto_bytes += bytes([len(pb)]) + pb
|
||||
data = struct.pack("!H", len(proto_bytes)) + proto_bytes
|
||||
return _build_extension(0x0010, data)
|
||||
|
||||
|
||||
def _build_server_hello(
|
||||
version: int = 0x0303,
|
||||
cipher_suite: int = 0x002F,
|
||||
extensions_bytes: bytes = b"",
|
||||
) -> bytes:
|
||||
random_bytes = b"\xCD" * 32
|
||||
session_id = b"\x00"
|
||||
compression = b"\x00"
|
||||
|
||||
if extensions_bytes:
|
||||
ext_block = struct.pack("!H", len(extensions_bytes)) + extensions_bytes
|
||||
else:
|
||||
ext_block = b"\x00\x00"
|
||||
|
||||
body = (
|
||||
struct.pack("!H", version)
|
||||
+ random_bytes
|
||||
+ session_id
|
||||
+ struct.pack("!H", cipher_suite)
|
||||
+ compression
|
||||
+ ext_block
|
||||
)
|
||||
|
||||
hs_header = b"\x02" + struct.pack("!I", len(body))[1:]
|
||||
record_payload = hs_header + body
|
||||
return b"\x16\x03\x01" + struct.pack("!H", len(record_payload)) + record_payload
|
||||
|
||||
|
||||
# ─── GREASE tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestGrease:
|
||||
def test_known_grease_values_detected(self):
|
||||
for v in [0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A,
|
||||
0x6A6A, 0x7A7A, 0x8A8A, 0x9A9A, 0xAAAA, 0xBABA,
|
||||
0xCACA, 0xDADA, 0xEAEA, 0xFAFA]:
|
||||
assert _is_grease(v), f"0x{v:04x} should be GREASE"
|
||||
|
||||
def test_non_grease_values_not_detected(self):
|
||||
for v in [0x002F, 0x0035, 0x1301, 0x000A, 0xFFFF]:
|
||||
assert not _is_grease(v), f"0x{v:04x} should not be GREASE"
|
||||
|
||||
def test_filter_grease_removes_grease(self):
|
||||
values = [0x0A0A, 0x002F, 0x1A1A, 0x0035]
|
||||
result = _filter_grease(values)
|
||||
assert result == [0x002F, 0x0035]
|
||||
|
||||
def test_filter_grease_preserves_all_non_grease(self):
|
||||
values = [0x002F, 0x0035, 0x1301]
|
||||
assert _filter_grease(values) == values
|
||||
|
||||
|
||||
# ─── ClientHello parsing tests ────────────────────────────────────────────────
|
||||
|
||||
class TestParseClientHello:
|
||||
def test_minimal_client_hello_parsed(self):
|
||||
data = _build_client_hello()
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert result["tls_version"] == 0x0303
|
||||
assert result["cipher_suites"] == [0x002F, 0x0035]
|
||||
assert result["extensions"] == []
|
||||
assert result["supported_groups"] == []
|
||||
assert result["ec_point_formats"] == []
|
||||
assert result["sni"] == ""
|
||||
assert result["alpn"] == []
|
||||
|
||||
def test_wrong_record_type_returns_none(self):
|
||||
data = _build_client_hello()
|
||||
bad = b"\x14" + data[1:] # change record type to ChangeCipherSpec
|
||||
assert _parse_client_hello(bad) is None
|
||||
|
||||
def test_wrong_handshake_type_returns_none(self):
|
||||
data = _build_client_hello()
|
||||
# Byte at offset 5 is the handshake type
|
||||
bad = data[:5] + b"\x02" + data[6:] # ServerHello type
|
||||
assert _parse_client_hello(bad) is None
|
||||
|
||||
def test_too_short_returns_none(self):
|
||||
assert _parse_client_hello(b"\x16\x03\x01") is None
|
||||
assert _parse_client_hello(b"") is None
|
||||
|
||||
def test_non_tls_returns_none(self):
|
||||
assert _parse_client_hello(b"GET / HTTP/1.1\r\n") is None
|
||||
|
||||
def test_grease_cipher_suites_filtered(self):
|
||||
data = _build_client_hello(cipher_suites=[0x0A0A, 0x002F, 0x1A1A, 0x0035])
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert 0x0A0A not in result["cipher_suites"]
|
||||
assert 0x1A1A not in result["cipher_suites"]
|
||||
assert result["cipher_suites"] == [0x002F, 0x0035]
|
||||
|
||||
def test_sni_extension_extracted(self):
|
||||
ext = _build_sni_extension("example.com")
|
||||
data = _build_client_hello(extensions_bytes=ext)
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert result["sni"] == "example.com"
|
||||
|
||||
def test_supported_groups_extracted(self):
|
||||
ext = _build_supported_groups_extension([0x001D, 0x0017, 0x0018])
|
||||
data = _build_client_hello(extensions_bytes=ext)
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert result["supported_groups"] == [0x001D, 0x0017, 0x0018]
|
||||
|
||||
def test_grease_in_supported_groups_filtered(self):
|
||||
ext = _build_supported_groups_extension([0x0A0A, 0x001D])
|
||||
data = _build_client_hello(extensions_bytes=ext)
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert 0x0A0A not in result["supported_groups"]
|
||||
assert 0x001D in result["supported_groups"]
|
||||
|
||||
def test_ec_point_formats_extracted(self):
|
||||
ext = _build_ec_point_formats_extension([0x00, 0x01])
|
||||
data = _build_client_hello(extensions_bytes=ext)
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert result["ec_point_formats"] == [0x00, 0x01]
|
||||
|
||||
def test_alpn_extension_extracted(self):
|
||||
ext = _build_alpn_extension(["h2", "http/1.1"])
|
||||
data = _build_client_hello(extensions_bytes=ext)
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert result["alpn"] == ["h2", "http/1.1"]
|
||||
|
||||
def test_multiple_extensions_extracted(self):
|
||||
sni = _build_sni_extension("target.local")
|
||||
grps = _build_supported_groups_extension([0x001D])
|
||||
combined = sni + grps
|
||||
data = _build_client_hello(extensions_bytes=combined)
|
||||
result = _parse_client_hello(data)
|
||||
assert result is not None
|
||||
assert result["sni"] == "target.local"
|
||||
assert 0x001D in result["supported_groups"]
|
||||
# Extension type IDs recorded (SNI=0, supported_groups=10)
|
||||
assert 0x0000 in result["extensions"]
|
||||
assert 0x000A in result["extensions"]
|
||||
|
||||
|
||||
# ─── ServerHello parsing tests ────────────────────────────────────────────────
|
||||
|
||||
class TestParseServerHello:
|
||||
def test_minimal_server_hello_parsed(self):
|
||||
data = _build_server_hello()
|
||||
result = _parse_server_hello(data)
|
||||
assert result is not None
|
||||
assert result["tls_version"] == 0x0303
|
||||
assert result["cipher_suite"] == 0x002F
|
||||
assert result["extensions"] == []
|
||||
|
||||
def test_wrong_record_type_returns_none(self):
|
||||
data = _build_server_hello()
|
||||
bad = b"\x15" + data[1:]
|
||||
assert _parse_server_hello(bad) is None
|
||||
|
||||
def test_wrong_handshake_type_returns_none(self):
|
||||
data = _build_server_hello()
|
||||
bad = data[:5] + b"\x01" + data[6:] # ClientHello type
|
||||
assert _parse_server_hello(bad) is None
|
||||
|
||||
def test_too_short_returns_none(self):
|
||||
assert _parse_server_hello(b"") is None
|
||||
|
||||
def test_server_hello_extension_types_recorded(self):
|
||||
# Build a ServerHello with a generic extension (type=0xFF01)
|
||||
ext_data = _build_extension(0xFF01, b"\x00")
|
||||
data = _build_server_hello(extensions_bytes=ext_data)
|
||||
result = _parse_server_hello(data)
|
||||
assert result is not None
|
||||
assert 0xFF01 in result["extensions"]
|
||||
|
||||
def test_grease_extension_in_server_hello_filtered(self):
|
||||
ext_data = _build_extension(0x0A0A, b"\x00")
|
||||
data = _build_server_hello(extensions_bytes=ext_data)
|
||||
result = _parse_server_hello(data)
|
||||
assert result is not None
|
||||
assert 0x0A0A not in result["extensions"]
|
||||
|
||||
|
||||
# ─── JA3 hash tests ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestJA3:
|
||||
def test_ja3_returns_32_char_hex(self):
|
||||
data = _build_client_hello()
|
||||
ch = _parse_client_hello(data)
|
||||
_, ja3_hash = _ja3(ch)
|
||||
assert len(ja3_hash) == 32
|
||||
assert all(c in "0123456789abcdef" for c in ja3_hash)
|
||||
|
||||
def test_ja3_known_hash(self):
|
||||
# Minimal ClientHello: TLS 1.2, ciphers [47, 53], no extensions
|
||||
ch = {
|
||||
"tls_version": 0x0303, # 771
|
||||
"cipher_suites": [0x002F, 0x0035], # 47, 53
|
||||
"extensions": [],
|
||||
"supported_groups": [],
|
||||
"ec_point_formats": [],
|
||||
"sni": "",
|
||||
"alpn": [],
|
||||
}
|
||||
ja3_str, ja3_hash = _ja3(ch)
|
||||
assert ja3_str == "771,47-53,,,"
|
||||
expected = hashlib.md5(b"771,47-53,,,").hexdigest()
|
||||
assert ja3_hash == expected
|
||||
|
||||
def test_ja3_same_input_same_hash(self):
|
||||
data = _build_client_hello()
|
||||
ch = _parse_client_hello(data)
|
||||
_, h1 = _ja3(ch)
|
||||
_, h2 = _ja3(ch)
|
||||
assert h1 == h2
|
||||
|
||||
def test_ja3_different_ciphers_different_hash(self):
|
||||
ch1 = {"tls_version": 0x0303, "cipher_suites": [47], "extensions": [],
|
||||
"supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []}
|
||||
ch2 = {"tls_version": 0x0303, "cipher_suites": [53], "extensions": [],
|
||||
"supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []}
|
||||
_, h1 = _ja3(ch1)
|
||||
_, h2 = _ja3(ch2)
|
||||
assert h1 != h2
|
||||
|
||||
def test_ja3_empty_lists_produce_valid_string(self):
|
||||
ch = {"tls_version": 0x0303, "cipher_suites": [], "extensions": [],
|
||||
"supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []}
|
||||
ja3_str, ja3_hash = _ja3(ch)
|
||||
assert ja3_str == "771,,,,"
|
||||
assert len(ja3_hash) == 32
|
||||
|
||||
|
||||
# ─── JA3S hash tests ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestJA3S:
|
||||
def test_ja3s_returns_32_char_hex(self):
|
||||
data = _build_server_hello()
|
||||
sh = _parse_server_hello(data)
|
||||
_, ja3s_hash = _ja3s(sh)
|
||||
assert len(ja3s_hash) == 32
|
||||
assert all(c in "0123456789abcdef" for c in ja3s_hash)
|
||||
|
||||
def test_ja3s_known_hash(self):
|
||||
sh = {"tls_version": 0x0303, "cipher_suite": 0x002F, "extensions": []}
|
||||
ja3s_str, ja3s_hash = _ja3s(sh)
|
||||
assert ja3s_str == "771,47,"
|
||||
expected = hashlib.md5(b"771,47,").hexdigest()
|
||||
assert ja3s_hash == expected
|
||||
|
||||
def test_ja3s_different_cipher_different_hash(self):
|
||||
sh1 = {"tls_version": 0x0303, "cipher_suite": 0x002F, "extensions": []}
|
||||
sh2 = {"tls_version": 0x0303, "cipher_suite": 0x0035, "extensions": []}
|
||||
_, h1 = _ja3s(sh1)
|
||||
_, h2 = _ja3s(sh2)
|
||||
assert h1 != h2
|
||||
|
||||
|
||||
# ─── TLS version string tests ─────────────────────────────────────────────────
|
||||
|
||||
class TestTLSVersionStr:
|
||||
def test_tls12(self):
|
||||
assert _tls_version_str(0x0303) == "TLS 1.2"
|
||||
|
||||
def test_tls13(self):
|
||||
assert _tls_version_str(0x0304) == "TLS 1.3"
|
||||
|
||||
def test_tls11(self):
|
||||
assert _tls_version_str(0x0302) == "TLS 1.1"
|
||||
|
||||
def test_tls10(self):
|
||||
assert _tls_version_str(0x0301) == "TLS 1.0"
|
||||
|
||||
def test_unknown_version(self):
|
||||
result = _tls_version_str(0xABCD)
|
||||
assert "0xabcd" in result.lower()
|
||||
|
||||
|
||||
# ─── Full round-trip: parse bytes → JA3/JA3S ──────────────────────────────────
|
||||
|
||||
class TestRoundTrip:
|
||||
def test_client_hello_bytes_to_ja3(self):
|
||||
ciphers = [0x1301, 0x1302, 0x002F]
|
||||
sni_ext = _build_sni_extension("attacker.c2.com")
|
||||
data = _build_client_hello(cipher_suites=ciphers, extensions_bytes=sni_ext)
|
||||
ch = _parse_client_hello(data)
|
||||
assert ch is not None
|
||||
ja3_str, ja3_hash = _ja3(ch)
|
||||
assert "4865-4866-47" in ja3_str # ciphers: 0x1301=4865, 0x1302=4866, 0x002F=47
|
||||
assert len(ja3_hash) == 32
|
||||
assert ch["sni"] == "attacker.c2.com"
|
||||
|
||||
def test_server_hello_bytes_to_ja3s(self):
|
||||
data = _build_server_hello(cipher_suite=0x1301)
|
||||
sh = _parse_server_hello(data)
|
||||
assert sh is not None
|
||||
ja3s_str, ja3s_hash = _ja3s(sh)
|
||||
assert "4865" in ja3s_str # 0x1301 = 4865
|
||||
assert len(ja3s_hash) == 32
|
||||
|
||||
def test_grease_client_hello_filtered_before_hash(self):
|
||||
"""GREASE ciphers must be stripped before JA3 is computed."""
|
||||
ciphers_with_grease = [0x0A0A, 0x002F, 0xFAFA, 0x0035]
|
||||
data = _build_client_hello(cipher_suites=ciphers_with_grease)
|
||||
ch = _parse_client_hello(data)
|
||||
_, ja3_hash = _ja3(ch)
|
||||
|
||||
# Reference: build without GREASE
|
||||
ciphers_clean = [0x002F, 0x0035]
|
||||
data_clean = _build_client_hello(cipher_suites=ciphers_clean)
|
||||
ch_clean = _parse_client_hello(data_clean)
|
||||
_, ja3_hash_clean = _ja3(ch_clean)
|
||||
|
||||
assert ja3_hash == ja3_hash_clean
|
||||
@@ -128,8 +128,9 @@ class TestLifespan:
|
||||
with patch("decnet.web.api.repo", mock_repo):
|
||||
with patch("decnet.web.api.log_ingestion_worker", return_value=asyncio.sleep(0)):
|
||||
with patch("decnet.web.api.log_collector_worker", return_value=asyncio.sleep(0)):
|
||||
async with lifespan(mock_app):
|
||||
mock_repo.initialize.assert_awaited_once()
|
||||
with patch("decnet.web.api.attacker_profile_worker", return_value=asyncio.sleep(0)):
|
||||
async with lifespan(mock_app):
|
||||
mock_repo.initialize.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_db_retry(self):
|
||||
@@ -150,5 +151,6 @@ class TestLifespan:
|
||||
with patch("decnet.web.api.asyncio.sleep", new_callable=AsyncMock):
|
||||
with patch("decnet.web.api.log_ingestion_worker", return_value=asyncio.sleep(0)):
|
||||
with patch("decnet.web.api.log_collector_worker", return_value=asyncio.sleep(0)):
|
||||
async with lifespan(mock_app):
|
||||
assert _call_count == 3
|
||||
with patch("decnet.web.api.attacker_profile_worker", return_value=asyncio.sleep(0)):
|
||||
async with lifespan(mock_app):
|
||||
assert _call_count == 3
|
||||
|
||||
Reference in New Issue
Block a user