feat: Phase 1 — JA3/JA3S sniffer, Attacker model, profile worker

Add passive TLS fingerprinting via a sniffer container on the MACVLAN
interface, plus the Attacker table and periodic rebuild worker that
correlates per-IP profiles from Log + Bounty + CorrelationEngine.

- templates/sniffer/: Scapy sniffer with pure-Python TLS parser;
  emits tls_client_hello / tls_session RFC 5424 lines with ja3, ja3s,
  sni, alpn, raw_ciphers, raw_extensions; GREASE filtered per RFC 8701
- decnet/services/sniffer.py: service plugin (no ports, NET_RAW/NET_ADMIN)
- decnet/web/db/models.py: Attacker SQLModel table + AttackersResponse
- decnet/web/db/repository.py: 5 new abstract methods
- decnet/web/db/sqlite/repository.py: implement all 5 (upsert, pagination,
  sort by recent/active/traversals, bounty grouping)
- decnet/web/attacker_worker.py: 30s periodic rebuild via CorrelationEngine;
  extracts commands from log fields, merges fingerprint bounties
- decnet/web/api.py: wire attacker_profile_worker into lifespan
- decnet/web/ingester.py: extract JA3 bounty (fingerprint_type=ja3)
- development/DEVELOPMENT.md: full attacker intelligence collection roadmap
- pyproject.toml: scapy>=2.6.1 added to dev deps
- tests: test_sniffer_ja3.py (40+ vectors), test_attacker_worker.py,
  test_base_repo.py / test_web_api.py updated for new surface
This commit is contained in:
2026-04-13 20:22:08 -04:00
parent c9be447a38
commit 3dc5b509f6
16 changed files with 1818 additions and 8 deletions

View File

@@ -0,0 +1,515 @@
"""
Tests for decnet/web/attacker_worker.py
Covers:
- _rebuild(): CorrelationEngine integration, traversal detection, upsert calls
- _extract_commands(): command harvesting from raw log rows
- _build_record(): record assembly from engine events + bounties
- _first_contact_deckies(): ordering for single-decky attackers
- attacker_profile_worker(): cancellation and error handling
"""
from __future__ import annotations
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from decnet.logging.syslog_formatter import SEVERITY_INFO, format_rfc5424
from decnet.web.attacker_worker import (
_build_record,
_extract_commands,
_first_contact_deckies,
_rebuild,
attacker_profile_worker,
)
from decnet.correlation.parser import LogEvent
# ─── Helpers ──────────────────────────────────────────────────────────────────
_TS1 = "2026-04-04T10:00:00+00:00"
_TS2 = "2026-04-04T10:05:00+00:00"
_TS3 = "2026-04-04T10:10:00+00:00"
_DT1 = datetime.fromisoformat(_TS1)
_DT2 = datetime.fromisoformat(_TS2)
_DT3 = datetime.fromisoformat(_TS3)
def _make_raw_line(
service: str = "ssh",
hostname: str = "decky-01",
event_type: str = "connection",
src_ip: str = "1.2.3.4",
timestamp: str = _TS1,
**extra: str,
) -> str:
return format_rfc5424(
service=service,
hostname=hostname,
event_type=event_type,
severity=SEVERITY_INFO,
timestamp=datetime.fromisoformat(timestamp),
src_ip=src_ip,
**extra,
)
def _make_log_row(
raw_line: str = "",
attacker_ip: str = "1.2.3.4",
service: str = "ssh",
event_type: str = "connection",
decky: str = "decky-01",
timestamp: datetime = _DT1,
fields: str = "{}",
) -> dict:
if not raw_line:
raw_line = _make_raw_line(
service=service,
hostname=decky,
event_type=event_type,
src_ip=attacker_ip,
timestamp=timestamp.isoformat(),
)
return {
"id": 1,
"raw_line": raw_line,
"attacker_ip": attacker_ip,
"service": service,
"event_type": event_type,
"decky": decky,
"timestamp": timestamp,
"fields": fields,
}
def _make_repo(logs=None, bounties=None):
repo = MagicMock()
repo.get_all_logs_raw = AsyncMock(return_value=logs or [])
repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {})
repo.upsert_attacker = AsyncMock()
return repo
def _make_log_event(
ip: str,
decky: str,
service: str = "ssh",
event_type: str = "connection",
timestamp: datetime = _DT1,
) -> LogEvent:
return LogEvent(
timestamp=timestamp,
decky=decky,
service=service,
event_type=event_type,
attacker_ip=ip,
fields={},
raw="",
)
# ─── _first_contact_deckies ───────────────────────────────────────────────────
class TestFirstContactDeckies:
def test_single_decky(self):
events = [_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1)]
assert _first_contact_deckies(events) == ["decky-01"]
def test_multiple_deckies_ordered_by_first_contact(self):
events = [
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
]
assert _first_contact_deckies(events) == ["decky-01", "decky-02"]
def test_revisit_does_not_duplicate(self):
events = [
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT3), # revisit
]
result = _first_contact_deckies(events)
assert result == ["decky-01", "decky-02"]
assert result.count("decky-01") == 1
# ─── _extract_commands ────────────────────────────────────────────────────────
class TestExtractCommands:
def _row(self, ip, event_type, fields):
return _make_log_row(
attacker_ip=ip,
event_type=event_type,
service="ssh",
decky="decky-01",
fields=json.dumps(fields),
)
def test_extracts_command_field(self):
rows = [self._row("1.1.1.1", "command", {"command": "id"})]
result = _extract_commands(rows, "1.1.1.1")
assert len(result) == 1
assert result[0]["command"] == "id"
assert result[0]["service"] == "ssh"
assert result[0]["decky"] == "decky-01"
def test_extracts_query_field(self):
rows = [self._row("2.2.2.2", "query", {"query": "SELECT * FROM users"})]
result = _extract_commands(rows, "2.2.2.2")
assert len(result) == 1
assert result[0]["command"] == "SELECT * FROM users"
def test_extracts_input_field(self):
rows = [self._row("3.3.3.3", "input", {"input": "ls -la"})]
result = _extract_commands(rows, "3.3.3.3")
assert len(result) == 1
assert result[0]["command"] == "ls -la"
def test_non_command_event_type_ignored(self):
rows = [self._row("1.1.1.1", "connection", {"command": "id"})]
result = _extract_commands(rows, "1.1.1.1")
assert result == []
def test_wrong_ip_ignored(self):
rows = [self._row("9.9.9.9", "command", {"command": "whoami"})]
result = _extract_commands(rows, "1.1.1.1")
assert result == []
def test_no_command_field_skipped(self):
rows = [self._row("1.1.1.1", "command", {"other": "stuff"})]
result = _extract_commands(rows, "1.1.1.1")
assert result == []
def test_invalid_json_fields_skipped(self):
row = _make_log_row(
attacker_ip="1.1.1.1",
event_type="command",
fields="not valid json",
)
result = _extract_commands([row], "1.1.1.1")
assert result == []
def test_multiple_commands_all_extracted(self):
rows = [
self._row("5.5.5.5", "command", {"command": "id"}),
self._row("5.5.5.5", "command", {"command": "uname -a"}),
]
result = _extract_commands(rows, "5.5.5.5")
assert len(result) == 2
cmds = {r["command"] for r in result}
assert cmds == {"id", "uname -a"}
def test_timestamp_serialized_to_string(self):
rows = [self._row("1.1.1.1", "command", {"command": "pwd"})]
result = _extract_commands(rows, "1.1.1.1")
assert isinstance(result[0]["timestamp"], str)
# ─── _build_record ────────────────────────────────────────────────────────────
class TestBuildRecord:
def _events(self, ip="1.1.1.1"):
return [
_make_log_event(ip, "decky-01", "ssh", "conn", _DT1),
_make_log_event(ip, "decky-01", "http", "req", _DT2),
]
def test_basic_fields(self):
events = self._events()
record = _build_record("1.1.1.1", events, None, [], [])
assert record["ip"] == "1.1.1.1"
assert record["event_count"] == 2
assert record["service_count"] == 2
assert record["decky_count"] == 1
def test_first_last_seen(self):
events = self._events()
record = _build_record("1.1.1.1", events, None, [], [])
assert record["first_seen"] == _DT1
assert record["last_seen"] == _DT2
def test_services_json_sorted(self):
events = self._events()
record = _build_record("1.1.1.1", events, None, [], [])
services = json.loads(record["services"])
assert sorted(services) == services
def test_no_traversal(self):
events = self._events()
record = _build_record("1.1.1.1", events, None, [], [])
assert record["is_traversal"] is False
assert record["traversal_path"] is None
def test_with_traversal(self):
from decnet.correlation.graph import AttackerTraversal, TraversalHop
hops = [
TraversalHop(_DT1, "decky-01", "ssh", "conn"),
TraversalHop(_DT2, "decky-02", "http", "req"),
]
t = AttackerTraversal("1.1.1.1", hops)
events = [
_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1),
_make_log_event("1.1.1.1", "decky-02", timestamp=_DT2),
]
record = _build_record("1.1.1.1", events, t, [], [])
assert record["is_traversal"] is True
assert record["traversal_path"] == "decky-01 → decky-02"
deckies = json.loads(record["deckies"])
assert deckies == ["decky-01", "decky-02"]
def test_bounty_counts(self):
events = self._events()
bounties = [
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
{"bounty_type": "credential", "attacker_ip": "1.1.1.1"},
{"bounty_type": "fingerprint", "attacker_ip": "1.1.1.1"},
]
record = _build_record("1.1.1.1", events, None, bounties, [])
assert record["bounty_count"] == 3
assert record["credential_count"] == 2
fps = json.loads(record["fingerprints"])
assert len(fps) == 1
assert fps[0]["bounty_type"] == "fingerprint"
def test_commands_serialized(self):
events = self._events()
cmds = [{"service": "ssh", "decky": "decky-01", "command": "id", "timestamp": "2026-04-04T10:00:00"}]
record = _build_record("1.1.1.1", events, None, [], cmds)
parsed = json.loads(record["commands"])
assert len(parsed) == 1
assert parsed[0]["command"] == "id"
def test_updated_at_is_utc_datetime(self):
events = self._events()
record = _build_record("1.1.1.1", events, None, [], [])
assert isinstance(record["updated_at"], datetime)
assert record["updated_at"].tzinfo is not None
# ─── _rebuild ─────────────────────────────────────────────────────────────────
class TestRebuild:
@pytest.mark.asyncio
async def test_empty_logs_no_upsert(self):
repo = _make_repo(logs=[])
await _rebuild(repo)
repo.upsert_attacker.assert_not_awaited()
@pytest.mark.asyncio
async def test_single_attacker_upserted(self):
raw = _make_raw_line("ssh", "decky-01", "connection", "10.0.0.1", _TS1)
row = _make_log_row(raw_line=raw, attacker_ip="10.0.0.1")
repo = _make_repo(logs=[row])
await _rebuild(repo)
repo.upsert_attacker.assert_awaited_once()
record = repo.upsert_attacker.call_args[0][0]
assert record["ip"] == "10.0.0.1"
assert record["event_count"] == 1
@pytest.mark.asyncio
async def test_multiple_attackers_all_upserted(self):
rows = [
_make_log_row(
raw_line=_make_raw_line("ssh", "decky-01", "conn", ip, _TS1),
attacker_ip=ip,
)
for ip in ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
]
repo = _make_repo(logs=rows)
await _rebuild(repo)
assert repo.upsert_attacker.await_count == 3
upserted_ips = {c[0][0]["ip"] for c in repo.upsert_attacker.call_args_list}
assert upserted_ips == {"1.1.1.1", "2.2.2.2", "3.3.3.3"}
@pytest.mark.asyncio
async def test_traversal_detected_across_two_deckies(self):
rows = [
_make_log_row(
raw_line=_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1),
attacker_ip="5.5.5.5", decky="decky-01",
),
_make_log_row(
raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2),
attacker_ip="5.5.5.5", decky="decky-02",
),
]
repo = _make_repo(logs=rows)
await _rebuild(repo)
record = repo.upsert_attacker.call_args[0][0]
assert record["is_traversal"] is True
assert "decky-01" in record["traversal_path"]
assert "decky-02" in record["traversal_path"]
@pytest.mark.asyncio
async def test_single_decky_not_traversal(self):
rows = [
_make_log_row(
raw_line=_make_raw_line("ssh", "decky-01", "conn", "7.7.7.7", _TS1),
attacker_ip="7.7.7.7",
),
_make_log_row(
raw_line=_make_raw_line("http", "decky-01", "req", "7.7.7.7", _TS2),
attacker_ip="7.7.7.7",
),
]
repo = _make_repo(logs=rows)
await _rebuild(repo)
record = repo.upsert_attacker.call_args[0][0]
assert record["is_traversal"] is False
@pytest.mark.asyncio
async def test_bounties_merged_into_record(self):
raw = _make_raw_line("ssh", "decky-01", "conn", "8.8.8.8", _TS1)
repo = _make_repo(
logs=[_make_log_row(raw_line=raw, attacker_ip="8.8.8.8")],
bounties={"8.8.8.8": [
{"bounty_type": "credential", "attacker_ip": "8.8.8.8", "payload": {}},
{"bounty_type": "fingerprint", "attacker_ip": "8.8.8.8", "payload": {"ja3": "abc"}},
]},
)
await _rebuild(repo)
record = repo.upsert_attacker.call_args[0][0]
assert record["bounty_count"] == 2
assert record["credential_count"] == 1
fps = json.loads(record["fingerprints"])
assert len(fps) == 1
@pytest.mark.asyncio
async def test_commands_extracted_during_rebuild(self):
raw = _make_raw_line("ssh", "decky-01", "command", "9.9.9.9", _TS1)
row = _make_log_row(
raw_line=raw,
attacker_ip="9.9.9.9",
event_type="command",
fields=json.dumps({"command": "cat /etc/passwd"}),
)
repo = _make_repo(logs=[row])
await _rebuild(repo)
record = repo.upsert_attacker.call_args[0][0]
commands = json.loads(record["commands"])
assert len(commands) == 1
assert commands[0]["command"] == "cat /etc/passwd"
# ─── attacker_profile_worker ──────────────────────────────────────────────────
class TestAttackerProfileWorker:
@pytest.mark.asyncio
async def test_worker_cancels_cleanly(self):
repo = _make_repo()
task = asyncio.create_task(attacker_profile_worker(repo))
await asyncio.sleep(0)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_worker_handles_rebuild_error_without_crashing(self):
repo = _make_repo()
_call_count = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
async def bad_rebuild(_repo):
raise RuntimeError("DB exploded")
with patch("decnet.web.attacker_worker.asyncio.sleep", side_effect=fake_sleep):
with patch("decnet.web.attacker_worker._rebuild", side_effect=bad_rebuild):
with pytest.raises(asyncio.CancelledError):
await attacker_profile_worker(repo)
@pytest.mark.asyncio
async def test_worker_calls_rebuild_after_sleep(self):
repo = _make_repo()
_call_count = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
rebuild_calls = []
async def mock_rebuild(_repo):
rebuild_calls.append(True)
with patch("decnet.web.attacker_worker.asyncio.sleep", side_effect=fake_sleep):
with patch("decnet.web.attacker_worker._rebuild", side_effect=mock_rebuild):
with pytest.raises(asyncio.CancelledError):
await attacker_profile_worker(repo)
assert len(rebuild_calls) >= 1
# ─── JA3 bounty extraction from ingester ─────────────────────────────────────
class TestJA3BountyExtraction:
@pytest.mark.asyncio
async def test_ja3_bounty_extracted_from_sniffer_event(self):
from decnet.web.ingester import _extract_bounty
repo = MagicMock()
repo.add_bounty = AsyncMock()
log_data = {
"decky": "decky-01",
"service": "sniffer",
"attacker_ip": "10.0.0.5",
"event_type": "tls_client_hello",
"fields": {
"ja3": "abc123def456abc123def456abc12345",
"ja3s": None,
"tls_version": "TLS 1.3",
"sni": "example.com",
"alpn": "h2",
"dst_port": "443",
"raw_ciphers": "4865-4866",
"raw_extensions": "0-23-65281",
},
}
await _extract_bounty(repo, log_data)
repo.add_bounty.assert_awaited_once()
bounty = repo.add_bounty.call_args[0][0]
assert bounty["bounty_type"] == "fingerprint"
assert bounty["payload"]["fingerprint_type"] == "ja3"
assert bounty["payload"]["ja3"] == "abc123def456abc123def456abc12345"
assert bounty["payload"]["tls_version"] == "TLS 1.3"
assert bounty["payload"]["sni"] == "example.com"
@pytest.mark.asyncio
async def test_non_sniffer_service_with_ja3_field_ignored(self):
from decnet.web.ingester import _extract_bounty
repo = MagicMock()
repo.add_bounty = AsyncMock()
log_data = {
"service": "http",
"attacker_ip": "10.0.0.6",
"event_type": "request",
"fields": {"ja3": "somehash"},
}
await _extract_bounty(repo, log_data)
# Credential/UA checks run, but JA3 should not fire for non-sniffer
calls = [c[0][0]["bounty_type"] for c in repo.add_bounty.call_args_list]
assert "ja3" not in str(calls)
@pytest.mark.asyncio
async def test_sniffer_without_ja3_no_bounty(self):
from decnet.web.ingester import _extract_bounty
repo = MagicMock()
repo.add_bounty = AsyncMock()
log_data = {
"service": "sniffer",
"attacker_ip": "10.0.0.7",
"event_type": "startup",
"fields": {"msg": "started"},
}
await _extract_bounty(repo, log_data)
repo.add_bounty.assert_not_awaited()

View File

@@ -21,6 +21,11 @@ class DummyRepo(BaseRepository):
async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw)
async def get_state(self, k): await super().get_state(k)
async def set_state(self, k, v): await super().set_state(k, v)
async def get_all_logs_raw(self): await super().get_all_logs_raw()
async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip()
async def upsert_attacker(self, d): await super().upsert_attacker(d)
async def get_attackers(self, **kw): await super().get_attackers(**kw)
async def get_total_attackers(self, **kw): await super().get_total_attackers(**kw)
@pytest.mark.asyncio
async def test_base_repo_coverage():
@@ -41,3 +46,8 @@ async def test_base_repo_coverage():
await dr.get_total_bounties()
await dr.get_state("k")
await dr.set_state("k", "v")
await dr.get_all_logs_raw()
await dr.get_all_bounties_by_ip()
await dr.upsert_attacker({})
await dr.get_attackers()
await dr.get_total_attackers()

437
tests/test_sniffer_ja3.py Normal file
View File

@@ -0,0 +1,437 @@
"""
Unit tests for the JA3/JA3S parsing logic in templates/sniffer/server.py.
Imports the parser functions directly via sys.path manipulation, with
decnet_logging mocked out (it's a container-side stub at template build time).
"""
from __future__ import annotations
import hashlib
import struct
import sys
import types
from pathlib import Path
from unittest.mock import MagicMock
import pytest
# ─── Import sniffer module with mocked decnet_logging ─────────────────────────
_SNIFFER_DIR = str(Path(__file__).parent.parent / "templates" / "sniffer")
def _load_sniffer():
"""Load templates/sniffer/server.py with decnet_logging stubbed out."""
# Stub the decnet_logging module that server.py imports
_stub = types.ModuleType("decnet_logging")
_stub.SEVERITY_INFO = 6
_stub.SEVERITY_WARNING = 4
_stub.syslog_line = MagicMock(return_value="<134>1 fake")
_stub.write_syslog_file = MagicMock()
sys.modules.setdefault("decnet_logging", _stub)
if _SNIFFER_DIR not in sys.path:
sys.path.insert(0, _SNIFFER_DIR)
import importlib
if "server" in sys.modules:
return sys.modules["server"]
import server as _srv
return _srv
_srv = _load_sniffer()
_parse_client_hello = _srv._parse_client_hello
_parse_server_hello = _srv._parse_server_hello
_ja3 = _srv._ja3
_ja3s = _srv._ja3s
_is_grease = _srv._is_grease
_filter_grease = _srv._filter_grease
_tls_version_str = _srv._tls_version_str
# ─── TLS byte builder helpers ─────────────────────────────────────────────────
def _build_client_hello(
version: int = 0x0303,
cipher_suites: list[int] | None = None,
extensions_bytes: bytes = b"",
) -> bytes:
"""Build a minimal valid TLS ClientHello byte sequence."""
if cipher_suites is None:
cipher_suites = [0x002F, 0x0035] # AES-128-SHA, AES-256-SHA
random_bytes = b"\xAB" * 32
session_id = b"\x00" # no session id
cs_bytes = b"".join(struct.pack("!H", c) for c in cipher_suites)
cs_len = struct.pack("!H", len(cs_bytes))
compression = b"\x01\x00" # 1 method: null
if extensions_bytes:
ext_block = struct.pack("!H", len(extensions_bytes)) + extensions_bytes
else:
ext_block = b"\x00\x00"
body = (
struct.pack("!H", version)
+ random_bytes
+ session_id
+ cs_len
+ cs_bytes
+ compression
+ ext_block
)
hs_header = b"\x01" + struct.pack("!I", len(body))[1:] # type + 3-byte len
record_payload = hs_header + body
record = b"\x16\x03\x01" + struct.pack("!H", len(record_payload)) + record_payload
return record
def _build_extension(ext_type: int, data: bytes) -> bytes:
return struct.pack("!HH", ext_type, len(data)) + data
def _build_sni_extension(hostname: str) -> bytes:
name_bytes = hostname.encode()
# server_name: type(1) + len(2) + name
entry = b"\x00" + struct.pack("!H", len(name_bytes)) + name_bytes
# server_name_list: len(2) + entries
lst = struct.pack("!H", len(entry)) + entry
return _build_extension(0x0000, lst)
def _build_supported_groups_extension(groups: list[int]) -> bytes:
grp_bytes = b"".join(struct.pack("!H", g) for g in groups)
data = struct.pack("!H", len(grp_bytes)) + grp_bytes
return _build_extension(0x000A, data)
def _build_ec_point_formats_extension(formats: list[int]) -> bytes:
pf = bytes(formats)
data = bytes([len(pf)]) + pf
return _build_extension(0x000B, data)
def _build_alpn_extension(protocols: list[str]) -> bytes:
proto_bytes = b""
for p in protocols:
pb = p.encode()
proto_bytes += bytes([len(pb)]) + pb
data = struct.pack("!H", len(proto_bytes)) + proto_bytes
return _build_extension(0x0010, data)
def _build_server_hello(
version: int = 0x0303,
cipher_suite: int = 0x002F,
extensions_bytes: bytes = b"",
) -> bytes:
random_bytes = b"\xCD" * 32
session_id = b"\x00"
compression = b"\x00"
if extensions_bytes:
ext_block = struct.pack("!H", len(extensions_bytes)) + extensions_bytes
else:
ext_block = b"\x00\x00"
body = (
struct.pack("!H", version)
+ random_bytes
+ session_id
+ struct.pack("!H", cipher_suite)
+ compression
+ ext_block
)
hs_header = b"\x02" + struct.pack("!I", len(body))[1:]
record_payload = hs_header + body
return b"\x16\x03\x01" + struct.pack("!H", len(record_payload)) + record_payload
# ─── GREASE tests ─────────────────────────────────────────────────────────────
class TestGrease:
def test_known_grease_values_detected(self):
for v in [0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A,
0x6A6A, 0x7A7A, 0x8A8A, 0x9A9A, 0xAAAA, 0xBABA,
0xCACA, 0xDADA, 0xEAEA, 0xFAFA]:
assert _is_grease(v), f"0x{v:04x} should be GREASE"
def test_non_grease_values_not_detected(self):
for v in [0x002F, 0x0035, 0x1301, 0x000A, 0xFFFF]:
assert not _is_grease(v), f"0x{v:04x} should not be GREASE"
def test_filter_grease_removes_grease(self):
values = [0x0A0A, 0x002F, 0x1A1A, 0x0035]
result = _filter_grease(values)
assert result == [0x002F, 0x0035]
def test_filter_grease_preserves_all_non_grease(self):
values = [0x002F, 0x0035, 0x1301]
assert _filter_grease(values) == values
# ─── ClientHello parsing tests ────────────────────────────────────────────────
class TestParseClientHello:
def test_minimal_client_hello_parsed(self):
data = _build_client_hello()
result = _parse_client_hello(data)
assert result is not None
assert result["tls_version"] == 0x0303
assert result["cipher_suites"] == [0x002F, 0x0035]
assert result["extensions"] == []
assert result["supported_groups"] == []
assert result["ec_point_formats"] == []
assert result["sni"] == ""
assert result["alpn"] == []
def test_wrong_record_type_returns_none(self):
data = _build_client_hello()
bad = b"\x14" + data[1:] # change record type to ChangeCipherSpec
assert _parse_client_hello(bad) is None
def test_wrong_handshake_type_returns_none(self):
data = _build_client_hello()
# Byte at offset 5 is the handshake type
bad = data[:5] + b"\x02" + data[6:] # ServerHello type
assert _parse_client_hello(bad) is None
def test_too_short_returns_none(self):
assert _parse_client_hello(b"\x16\x03\x01") is None
assert _parse_client_hello(b"") is None
def test_non_tls_returns_none(self):
assert _parse_client_hello(b"GET / HTTP/1.1\r\n") is None
def test_grease_cipher_suites_filtered(self):
data = _build_client_hello(cipher_suites=[0x0A0A, 0x002F, 0x1A1A, 0x0035])
result = _parse_client_hello(data)
assert result is not None
assert 0x0A0A not in result["cipher_suites"]
assert 0x1A1A not in result["cipher_suites"]
assert result["cipher_suites"] == [0x002F, 0x0035]
def test_sni_extension_extracted(self):
ext = _build_sni_extension("example.com")
data = _build_client_hello(extensions_bytes=ext)
result = _parse_client_hello(data)
assert result is not None
assert result["sni"] == "example.com"
def test_supported_groups_extracted(self):
ext = _build_supported_groups_extension([0x001D, 0x0017, 0x0018])
data = _build_client_hello(extensions_bytes=ext)
result = _parse_client_hello(data)
assert result is not None
assert result["supported_groups"] == [0x001D, 0x0017, 0x0018]
def test_grease_in_supported_groups_filtered(self):
ext = _build_supported_groups_extension([0x0A0A, 0x001D])
data = _build_client_hello(extensions_bytes=ext)
result = _parse_client_hello(data)
assert result is not None
assert 0x0A0A not in result["supported_groups"]
assert 0x001D in result["supported_groups"]
def test_ec_point_formats_extracted(self):
ext = _build_ec_point_formats_extension([0x00, 0x01])
data = _build_client_hello(extensions_bytes=ext)
result = _parse_client_hello(data)
assert result is not None
assert result["ec_point_formats"] == [0x00, 0x01]
def test_alpn_extension_extracted(self):
ext = _build_alpn_extension(["h2", "http/1.1"])
data = _build_client_hello(extensions_bytes=ext)
result = _parse_client_hello(data)
assert result is not None
assert result["alpn"] == ["h2", "http/1.1"]
def test_multiple_extensions_extracted(self):
sni = _build_sni_extension("target.local")
grps = _build_supported_groups_extension([0x001D])
combined = sni + grps
data = _build_client_hello(extensions_bytes=combined)
result = _parse_client_hello(data)
assert result is not None
assert result["sni"] == "target.local"
assert 0x001D in result["supported_groups"]
# Extension type IDs recorded (SNI=0, supported_groups=10)
assert 0x0000 in result["extensions"]
assert 0x000A in result["extensions"]
# ─── ServerHello parsing tests ────────────────────────────────────────────────
class TestParseServerHello:
def test_minimal_server_hello_parsed(self):
data = _build_server_hello()
result = _parse_server_hello(data)
assert result is not None
assert result["tls_version"] == 0x0303
assert result["cipher_suite"] == 0x002F
assert result["extensions"] == []
def test_wrong_record_type_returns_none(self):
data = _build_server_hello()
bad = b"\x15" + data[1:]
assert _parse_server_hello(bad) is None
def test_wrong_handshake_type_returns_none(self):
data = _build_server_hello()
bad = data[:5] + b"\x01" + data[6:] # ClientHello type
assert _parse_server_hello(bad) is None
def test_too_short_returns_none(self):
assert _parse_server_hello(b"") is None
def test_server_hello_extension_types_recorded(self):
# Build a ServerHello with a generic extension (type=0xFF01)
ext_data = _build_extension(0xFF01, b"\x00")
data = _build_server_hello(extensions_bytes=ext_data)
result = _parse_server_hello(data)
assert result is not None
assert 0xFF01 in result["extensions"]
def test_grease_extension_in_server_hello_filtered(self):
ext_data = _build_extension(0x0A0A, b"\x00")
data = _build_server_hello(extensions_bytes=ext_data)
result = _parse_server_hello(data)
assert result is not None
assert 0x0A0A not in result["extensions"]
# ─── JA3 hash tests ───────────────────────────────────────────────────────────
class TestJA3:
def test_ja3_returns_32_char_hex(self):
data = _build_client_hello()
ch = _parse_client_hello(data)
_, ja3_hash = _ja3(ch)
assert len(ja3_hash) == 32
assert all(c in "0123456789abcdef" for c in ja3_hash)
def test_ja3_known_hash(self):
# Minimal ClientHello: TLS 1.2, ciphers [47, 53], no extensions
ch = {
"tls_version": 0x0303, # 771
"cipher_suites": [0x002F, 0x0035], # 47, 53
"extensions": [],
"supported_groups": [],
"ec_point_formats": [],
"sni": "",
"alpn": [],
}
ja3_str, ja3_hash = _ja3(ch)
assert ja3_str == "771,47-53,,,"
expected = hashlib.md5(b"771,47-53,,,").hexdigest()
assert ja3_hash == expected
def test_ja3_same_input_same_hash(self):
data = _build_client_hello()
ch = _parse_client_hello(data)
_, h1 = _ja3(ch)
_, h2 = _ja3(ch)
assert h1 == h2
def test_ja3_different_ciphers_different_hash(self):
ch1 = {"tls_version": 0x0303, "cipher_suites": [47], "extensions": [],
"supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []}
ch2 = {"tls_version": 0x0303, "cipher_suites": [53], "extensions": [],
"supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []}
_, h1 = _ja3(ch1)
_, h2 = _ja3(ch2)
assert h1 != h2
def test_ja3_empty_lists_produce_valid_string(self):
ch = {"tls_version": 0x0303, "cipher_suites": [], "extensions": [],
"supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []}
ja3_str, ja3_hash = _ja3(ch)
assert ja3_str == "771,,,,"
assert len(ja3_hash) == 32
# ─── JA3S hash tests ──────────────────────────────────────────────────────────
class TestJA3S:
def test_ja3s_returns_32_char_hex(self):
data = _build_server_hello()
sh = _parse_server_hello(data)
_, ja3s_hash = _ja3s(sh)
assert len(ja3s_hash) == 32
assert all(c in "0123456789abcdef" for c in ja3s_hash)
def test_ja3s_known_hash(self):
sh = {"tls_version": 0x0303, "cipher_suite": 0x002F, "extensions": []}
ja3s_str, ja3s_hash = _ja3s(sh)
assert ja3s_str == "771,47,"
expected = hashlib.md5(b"771,47,").hexdigest()
assert ja3s_hash == expected
def test_ja3s_different_cipher_different_hash(self):
sh1 = {"tls_version": 0x0303, "cipher_suite": 0x002F, "extensions": []}
sh2 = {"tls_version": 0x0303, "cipher_suite": 0x0035, "extensions": []}
_, h1 = _ja3s(sh1)
_, h2 = _ja3s(sh2)
assert h1 != h2
# ─── TLS version string tests ─────────────────────────────────────────────────
class TestTLSVersionStr:
def test_tls12(self):
assert _tls_version_str(0x0303) == "TLS 1.2"
def test_tls13(self):
assert _tls_version_str(0x0304) == "TLS 1.3"
def test_tls11(self):
assert _tls_version_str(0x0302) == "TLS 1.1"
def test_tls10(self):
assert _tls_version_str(0x0301) == "TLS 1.0"
def test_unknown_version(self):
result = _tls_version_str(0xABCD)
assert "0xabcd" in result.lower()
# ─── Full round-trip: parse bytes → JA3/JA3S ──────────────────────────────────
class TestRoundTrip:
def test_client_hello_bytes_to_ja3(self):
ciphers = [0x1301, 0x1302, 0x002F]
sni_ext = _build_sni_extension("attacker.c2.com")
data = _build_client_hello(cipher_suites=ciphers, extensions_bytes=sni_ext)
ch = _parse_client_hello(data)
assert ch is not None
ja3_str, ja3_hash = _ja3(ch)
assert "4865-4866-47" in ja3_str # ciphers: 0x1301=4865, 0x1302=4866, 0x002F=47
assert len(ja3_hash) == 32
assert ch["sni"] == "attacker.c2.com"
def test_server_hello_bytes_to_ja3s(self):
data = _build_server_hello(cipher_suite=0x1301)
sh = _parse_server_hello(data)
assert sh is not None
ja3s_str, ja3s_hash = _ja3s(sh)
assert "4865" in ja3s_str # 0x1301 = 4865
assert len(ja3s_hash) == 32
def test_grease_client_hello_filtered_before_hash(self):
"""GREASE ciphers must be stripped before JA3 is computed."""
ciphers_with_grease = [0x0A0A, 0x002F, 0xFAFA, 0x0035]
data = _build_client_hello(cipher_suites=ciphers_with_grease)
ch = _parse_client_hello(data)
_, ja3_hash = _ja3(ch)
# Reference: build without GREASE
ciphers_clean = [0x002F, 0x0035]
data_clean = _build_client_hello(cipher_suites=ciphers_clean)
ch_clean = _parse_client_hello(data_clean)
_, ja3_hash_clean = _ja3(ch_clean)
assert ja3_hash == ja3_hash_clean

View File

@@ -128,8 +128,9 @@ class TestLifespan:
with patch("decnet.web.api.repo", mock_repo):
with patch("decnet.web.api.log_ingestion_worker", return_value=asyncio.sleep(0)):
with patch("decnet.web.api.log_collector_worker", return_value=asyncio.sleep(0)):
async with lifespan(mock_app):
mock_repo.initialize.assert_awaited_once()
with patch("decnet.web.api.attacker_profile_worker", return_value=asyncio.sleep(0)):
async with lifespan(mock_app):
mock_repo.initialize.assert_awaited_once()
@pytest.mark.asyncio
async def test_lifespan_db_retry(self):
@@ -150,5 +151,6 @@ class TestLifespan:
with patch("decnet.web.api.asyncio.sleep", new_callable=AsyncMock):
with patch("decnet.web.api.log_ingestion_worker", return_value=asyncio.sleep(0)):
with patch("decnet.web.api.log_collector_worker", return_value=asyncio.sleep(0)):
async with lifespan(mock_app):
assert _call_count == 3
with patch("decnet.web.api.attacker_profile_worker", return_value=asyncio.sleep(0)):
async with lifespan(mock_app):
assert _call_count == 3