diff --git a/decnet/services/sniffer.py b/decnet/services/sniffer.py new file mode 100644 index 0000000..6bf9c44 --- /dev/null +++ b/decnet/services/sniffer.py @@ -0,0 +1,40 @@ +from pathlib import Path +from decnet.services.base import BaseService + +TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates" / "sniffer" + + +class SnifferService(BaseService): + """ + Passive network sniffer deployed alongside deckies on the MACVLAN. + + Captures TLS handshakes in promiscuous mode and extracts JA3/JA3S hashes + plus connection metadata. Requires NET_RAW + NET_ADMIN capabilities. + No inbound ports — purely passive. + """ + + name = "sniffer" + ports: list[int] = [] + default_image = "build" + + def compose_fragment( + self, + decky_name: str, + log_target: str | None = None, + service_cfg: dict | None = None, + ) -> dict: + fragment: dict = { + "build": {"context": str(TEMPLATES_DIR)}, + "container_name": f"{decky_name}-sniffer", + "restart": "unless-stopped", + "cap_add": ["NET_RAW", "NET_ADMIN"], + "environment": { + "NODE_NAME": decky_name, + }, + } + if log_target: + fragment["environment"]["LOG_TARGET"] = log_target + return fragment + + def dockerfile_context(self) -> Path | None: + return TEMPLATES_DIR diff --git a/decnet/web/api.py b/decnet/web/api.py index 4eabe79..f1cfbb7 100644 --- a/decnet/web/api.py +++ b/decnet/web/api.py @@ -14,16 +14,18 @@ from decnet.logging import get_logger from decnet.web.dependencies import repo from decnet.collector import log_collector_worker from decnet.web.ingester import log_ingestion_worker +from decnet.web.attacker_worker import attacker_profile_worker from decnet.web.router import api_router log = get_logger("api") ingestion_task: Optional[asyncio.Task[Any]] = None collector_task: Optional[asyncio.Task[Any]] = None +attacker_task: Optional[asyncio.Task[Any]] = None @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - global ingestion_task, collector_task + global ingestion_task, collector_task, attacker_task log.info("API startup initialising database") for attempt in range(1, 6): @@ -51,13 +53,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: log.debug("API startup collector worker started log_file=%s", _log_file) elif not _log_file: log.warning("DECNET_INGEST_LOG_FILE not set — Docker log collection disabled.") + + # Start attacker profile rebuild worker + if attacker_task is None or attacker_task.done(): + attacker_task = asyncio.create_task(attacker_profile_worker(repo)) + log.debug("API startup attacker profile worker started") else: log.info("Contract Test Mode: skipping background worker startup") yield log.info("API shutdown cancelling background tasks") - for task in (ingestion_task, collector_task): + for task in (ingestion_task, collector_task, attacker_task): if task and not task.done(): task.cancel() try: diff --git a/decnet/web/attacker_worker.py b/decnet/web/attacker_worker.py new file mode 100644 index 0000000..7d207fa --- /dev/null +++ b/decnet/web/attacker_worker.py @@ -0,0 +1,176 @@ +""" +Attacker profile builder — background worker. + +Periodically rebuilds the `attackers` table by: + 1. Feeding all stored Log.raw_line values through the CorrelationEngine + (which parses RFC 5424 and tracks per-IP event histories + traversals). + 2. Merging with the Bounty table (fingerprints, credentials). + 3. Extracting commands executed per IP from the structured log fields. + 4. Upserting one Attacker record per observed IP. + +Runs every _REBUILD_INTERVAL seconds. Full rebuild each cycle — simple and +correct at honeypot log volumes. +""" + +from __future__ import annotations + +import asyncio +import json +from datetime import datetime, timezone +from typing import Any + +from decnet.correlation.engine import CorrelationEngine +from decnet.correlation.parser import LogEvent +from decnet.logging import get_logger +from decnet.web.db.repository import BaseRepository + +logger = get_logger("attacker_worker") + +_REBUILD_INTERVAL = 30 # seconds + +# Event types that indicate active command/query execution (not just connection/scan) +_COMMAND_EVENT_TYPES = frozenset({ + "command", "exec", "query", "input", "shell_input", + "execute", "run", "sql_query", "redis_command", +}) + +# Fields that carry the executed command/query text +_COMMAND_FIELDS = ("command", "query", "input", "line", "sql", "cmd") + + +async def attacker_profile_worker(repo: BaseRepository) -> None: + """Periodically rebuilds the Attacker table. Designed to run as an asyncio Task.""" + logger.info("attacker profile worker started interval=%ds", _REBUILD_INTERVAL) + while True: + await asyncio.sleep(_REBUILD_INTERVAL) + try: + await _rebuild(repo) + except Exception as exc: + logger.error("attacker worker: rebuild failed: %s", exc) + + +async def _rebuild(repo: BaseRepository) -> None: + all_logs = await repo.get_all_logs_raw() + if not all_logs: + return + + # Feed raw RFC 5424 lines into the CorrelationEngine + engine = CorrelationEngine() + for row in all_logs: + engine.ingest(row["raw_line"]) + + if not engine._events: + return + + traversal_map = {t.attacker_ip: t for t in engine.traversals(min_deckies=2)} + all_bounties = await repo.get_all_bounties_by_ip() + + count = 0 + for ip, events in engine._events.items(): + traversal = traversal_map.get(ip) + bounties = all_bounties.get(ip, []) + commands = _extract_commands(all_logs, ip) + + record = _build_record(ip, events, traversal, bounties, commands) + await repo.upsert_attacker(record) + count += 1 + + logger.debug("attacker worker: rebuilt %d profiles", count) + + +def _build_record( + ip: str, + events: list[LogEvent], + traversal: Any, + bounties: list[dict[str, Any]], + commands: list[dict[str, Any]], +) -> dict[str, Any]: + services = sorted({e.service for e in events}) + deckies = ( + traversal.deckies + if traversal + else _first_contact_deckies(events) + ) + fingerprints = [b for b in bounties if b.get("bounty_type") == "fingerprint"] + credential_count = sum(1 for b in bounties if b.get("bounty_type") == "credential") + + return { + "ip": ip, + "first_seen": min(e.timestamp for e in events), + "last_seen": max(e.timestamp for e in events), + "event_count": len(events), + "service_count": len(services), + "decky_count": len({e.decky for e in events}), + "services": json.dumps(services), + "deckies": json.dumps(deckies), + "traversal_path": traversal.path if traversal else None, + "is_traversal": traversal is not None, + "bounty_count": len(bounties), + "credential_count": credential_count, + "fingerprints": json.dumps(fingerprints), + "commands": json.dumps(commands), + "updated_at": datetime.now(timezone.utc), + } + + +def _first_contact_deckies(events: list[LogEvent]) -> list[str]: + """Return unique deckies in first-contact order (for non-traversal attackers).""" + seen: list[str] = [] + for e in sorted(events, key=lambda x: x.timestamp): + if e.decky not in seen: + seen.append(e.decky) + return seen + + +def _extract_commands( + all_logs: list[dict[str, Any]], ip: str +) -> list[dict[str, Any]]: + """ + Extract executed commands for a given attacker IP from raw log rows. + + Looks for rows where: + - attacker_ip matches + - event_type is a known command-execution type + - fields JSON contains a command-like key + + Returns a list of {service, decky, command, timestamp} dicts. + """ + commands: list[dict[str, Any]] = [] + for row in all_logs: + if row.get("attacker_ip") != ip: + continue + if row.get("event_type") not in _COMMAND_EVENT_TYPES: + continue + + raw_fields = row.get("fields") + if not raw_fields: + continue + + # fields is stored as a JSON string in the DB row + if isinstance(raw_fields, str): + try: + fields = json.loads(raw_fields) + except (json.JSONDecodeError, ValueError): + continue + else: + fields = raw_fields + + cmd_text: str | None = None + for key in _COMMAND_FIELDS: + val = fields.get(key) + if val: + cmd_text = str(val) + break + + if not cmd_text: + continue + + ts = row.get("timestamp") + commands.append({ + "service": row.get("service", ""), + "decky": row.get("decky", ""), + "command": cmd_text, + "timestamp": ts.isoformat() if isinstance(ts, datetime) else str(ts), + }) + + return commands diff --git a/decnet/web/db/models.py b/decnet/web/db/models.py index 681db23..a8e18d1 100644 --- a/decnet/web/db/models.py +++ b/decnet/web/db/models.py @@ -50,6 +50,27 @@ class State(SQLModel, table=True): key: str = Field(primary_key=True) value: str # Stores JSON serialized DecnetConfig or other state blobs + +class Attacker(SQLModel, table=True): + __tablename__ = "attackers" + ip: str = Field(primary_key=True) + first_seen: datetime = Field(index=True) + last_seen: datetime = Field(index=True) + event_count: int = Field(default=0) + service_count: int = Field(default=0) + decky_count: int = Field(default=0) + services: str = Field(default="[]") # JSON list[str] + deckies: str = Field(default="[]") # JSON list[str], first-contact ordered + traversal_path: Optional[str] = None # "decky-01 → decky-03 → decky-05" + is_traversal: bool = Field(default=False) + bounty_count: int = Field(default=0) + credential_count: int = Field(default=0) + fingerprints: str = Field(default="[]") # JSON list[dict] — bounty fingerprints + commands: str = Field(default="[]") # JSON list[dict] — commands per service/decky + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), index=True + ) + # --- API Request/Response Models (Pydantic) --- class Token(BaseModel): @@ -77,6 +98,12 @@ class BountyResponse(BaseModel): offset: int data: List[dict[str, Any]] +class AttackersResponse(BaseModel): + total: int + limit: int + offset: int + data: List[dict[str, Any]] + class StatsResponse(BaseModel): total_logs: int unique_attackers: int diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index 08a6259..7fcfdaa 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -90,3 +90,34 @@ class BaseRepository(ABC): async def set_state(self, key: str, value: Any) -> None: """Store a specific state entry by key.""" pass + + @abstractmethod + async def get_all_logs_raw(self) -> list[dict[str, Any]]: + """Retrieve all log rows with fields needed by the attacker profile worker.""" + pass + + @abstractmethod + async def get_all_bounties_by_ip(self) -> dict[str, list[dict[str, Any]]]: + """Retrieve all bounty rows grouped by attacker_ip.""" + pass + + @abstractmethod + async def upsert_attacker(self, data: dict[str, Any]) -> None: + """Insert or replace an attacker profile record.""" + pass + + @abstractmethod + async def get_attackers( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + sort_by: str = "recent", + ) -> list[dict[str, Any]]: + """Retrieve paginated attacker profile records.""" + pass + + @abstractmethod + async def get_total_attackers(self, search: Optional[str] = None) -> int: + """Retrieve the total count of attacker profile records, optionally filtered.""" + pass diff --git a/decnet/web/db/sqlite/repository.py b/decnet/web/db/sqlite/repository.py index b4768eb..49606cf 100644 --- a/decnet/web/db/sqlite/repository.py +++ b/decnet/web/db/sqlite/repository.py @@ -12,7 +12,7 @@ from decnet.config import load_state, _ROOT from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD from decnet.web.auth import get_password_hash from decnet.web.db.repository import BaseRepository -from decnet.web.db.models import User, Log, Bounty, State +from decnet.web.db.models import User, Log, Bounty, State, Attacker from decnet.web.db.sqlite.database import get_async_engine @@ -371,3 +371,92 @@ class SQLiteRepository(BaseRepository): session.add(new_state) await session.commit() + + # --------------------------------------------------------------- attackers + + async def get_all_logs_raw(self) -> List[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute( + select( + Log.id, + Log.raw_line, + Log.attacker_ip, + Log.service, + Log.event_type, + Log.decky, + Log.timestamp, + Log.fields, + ) + ) + return [ + { + "id": r.id, + "raw_line": r.raw_line, + "attacker_ip": r.attacker_ip, + "service": r.service, + "event_type": r.event_type, + "decky": r.decky, + "timestamp": r.timestamp, + "fields": r.fields, + } + for r in result.all() + ] + + async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: + from collections import defaultdict + async with self.session_factory() as session: + result = await session.execute( + select(Bounty).order_by(asc(Bounty.timestamp)) + ) + grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) + for item in result.scalars().all(): + d = item.model_dump(mode="json") + try: + d["payload"] = json.loads(d["payload"]) + except (json.JSONDecodeError, TypeError): + pass + grouped[item.attacker_ip].append(d) + return dict(grouped) + + async def upsert_attacker(self, data: dict[str, Any]) -> None: + async with self.session_factory() as session: + result = await session.execute( + select(Attacker).where(Attacker.ip == data["ip"]) + ) + existing = result.scalar_one_or_none() + if existing: + for k, v in data.items(): + setattr(existing, k, v) + session.add(existing) + else: + session.add(Attacker(**data)) + await session.commit() + + async def get_attackers( + self, + limit: int = 50, + offset: int = 0, + search: Optional[str] = None, + sort_by: str = "recent", + ) -> List[dict[str, Any]]: + order = { + "active": desc(Attacker.event_count), + "traversals": desc(Attacker.is_traversal), + }.get(sort_by, desc(Attacker.last_seen)) + + statement = select(Attacker).order_by(order).offset(offset).limit(limit) + if search: + statement = statement.where(Attacker.ip.like(f"%{search}%")) + + async with self.session_factory() as session: + result = await session.execute(statement) + return [a.model_dump(mode="json") for a in result.scalars().all()] + + async def get_total_attackers(self, search: Optional[str] = None) -> int: + statement = select(func.count()).select_from(Attacker) + if search: + statement = statement.where(Attacker.ip.like(f"%{search}%")) + + async with self.session_factory() as session: + result = await session.execute(statement) + return result.scalar() or 0 diff --git a/decnet/web/ingester.py b/decnet/web/ingester.py index 675b418..9427b90 100644 --- a/decnet/web/ingester.py +++ b/decnet/web/ingester.py @@ -130,3 +130,24 @@ async def _extract_bounty(repo: BaseRepository, log_data: dict[str, Any]) -> Non # 4. SSH client banner fingerprint (deferred — requires asyncssh server) # Fires on: service=ssh, event_type=client_banner, fields.client_banner + + # 5. JA3/JA3S TLS fingerprint from sniffer container + _ja3 = _fields.get("ja3") + if _ja3 and log_data.get("service") == "sniffer": + await repo.add_bounty({ + "decky": log_data.get("decky"), + "service": "sniffer", + "attacker_ip": log_data.get("attacker_ip"), + "bounty_type": "fingerprint", + "payload": { + "fingerprint_type": "ja3", + "ja3": _ja3, + "ja3s": _fields.get("ja3s"), + "tls_version": _fields.get("tls_version"), + "sni": _fields.get("sni") or None, + "alpn": _fields.get("alpn") or None, + "dst_port": _fields.get("dst_port"), + "raw_ciphers": _fields.get("raw_ciphers"), + "raw_extensions": _fields.get("raw_extensions"), + }, + }) diff --git a/development/DEVELOPMENT.md b/development/DEVELOPMENT.md index 681068f..76739b5 100644 --- a/development/DEVELOPMENT.md +++ b/development/DEVELOPMENT.md @@ -45,7 +45,7 @@ ## Core / Hardening -- [x] **Attacker fingerprinting** — HTTP User-Agent and VNC client version stored as `fingerprint` bounties. TLS JA3/JA4 and TCP window sizes require pcap (out of scope). SSH client banner deferred pending asyncssh server. +- [~] **Attacker fingerprinting** — HTTP User-Agent, VNC client version stored as `fingerprint` bounties. JA3/JA3S in progress (sniffer container). HASSH, JA4+, TCP stack, JARM planned (see Attacker Intelligence section). - [ ] **Canary tokens** — Embed fake AWS keys and honeydocs into decky filesystems. - [ ] **Tarpit mode** — Slow down attackers by drip-feeding bytes or delaying responses. - [x] **Dynamic decky mutation** — Rotate exposed services or OS fingerprints over time. @@ -84,6 +84,55 @@ - [ ] **Realistic web apps** — Fake WordPress, Grafana, and phpMyAdmin templates. - [ ] **OT/ICS profiles** — Expanded Modbus, DNP3, and BACnet support. +## Attacker Intelligence Collection +*Goal: Build the richest possible attacker profile from passive observation across all 26 services.* + +### TLS/SSL Fingerprinting (via sniffer container) +- [x] **JA3/JA3S** — TLS ClientHello/ServerHello fingerprint hashes +- [ ] **JA4+ family** — JA4, JA4S, JA4H, JA4L (latency/geo estimation via RTT) +- [ ] **JARM** — Active server fingerprint; identifies C2 framework from TLS server behavior +- [ ] **CYU** — Citrix-specific TLS fingerprint +- [ ] **TLS session resumption behavior** — Identifies tooling by how it handles session tickets +- [ ] **Certificate details** — CN, SANs, issuer, validity period, self-signed flag (attacker-run servers) + +### Timing & Behavioral +- [ ] **Inter-packet arrival times** — OS TCP stack fingerprint + beaconing interval detection +- [ ] **TTL values** — Rough OS / hop-distance inference +- [ ] **TCP window size & scaling** — p0f-style OS fingerprinting +- [ ] **Retransmission patterns** — Identify lossy paths / throttled connections +- [ ] **Beacon jitter variance** — Attribute tooling: Cobalt Strike vs. Sliver vs. Havoc have distinct profiles +- [ ] **C2 check-in cadence** — Detect beaconing vs. interactive sessions +- [ ] **Data exfil timing** — Behavioral sequencing relative to recon phase + +### Protocol Fingerprinting +- [ ] **TCP/IP stack** — ISN patterns, DF bit, ToS/DSCP, IP ID sequence (random/incremental/zero) +- [ ] **HASSH / HASSHServer** — SSH KEX algo, cipher, MAC order → tool fingerprint +- [ ] **HTTP/2 fingerprint** — GREASE values, settings frame order, header pseudo-field ordering +- [ ] **QUIC fingerprint** — Connection ID length, transport parameters order +- [ ] **DNS behavior** — Query patterns, recursion flags, EDNS0 options, resolver fingerprint +- [ ] **HTTP header ordering** — Tool-specific capitalization and ordering quirks + +### Network Topology Leakage +- [ ] **X-Forwarded-For mismatches** — Detect VPN/proxy slip vs. actual source IP +- [ ] **ICMP error messages** — Internal IP leakage from misconfigured attacker infra +- [ ] **IPv6 link-local leakage** — IPv6 addrs leaked even over IPv4 VPN (common opsec fail) +- [ ] **mDNS/LLMNR leakage** — Attacker hostname/device info from misconfigured systems + +### Geolocation & Infrastructure +- [ ] **ASN lookup** — Source IP autonomous system number and org name +- [ ] **BGP prefix / RPKI validity** — Route origin legitimacy +- [ ] **PTR records** — rDNS for attacker IPs (catches infra with forgotten reverse DNS) +- [ ] **Latency triangulation** — JA4L RTT estimates for rough geolocation + +### Service-Level Behavioral Profiling +- [ ] **Commands executed** — Full command log per session (SSH, Telnet, FTP, Redis, DB services) +- [ ] **Services actively interacted with** — Distinguish port scans from live exploitation attempts +- [ ] **Tooling attribution** — Byte-sequence signatures from known C2 frameworks in handshakes +- [ ] **Credential reuse patterns** — Same username/password tried across multiple deckies/services +- [ ] **Payload signatures** — Hash and classify uploaded files, shellcode, exploit payloads + +--- + ## Developer Experience - [x] **API Fuzzing** — Property-based testing for all web endpoints. diff --git a/pyproject.toml b/pyproject.toml index fb47df0..ac445d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dev = [ "psycopg2-binary>=2.9.11", "paho-mqtt>=2.1.0", "pymongo>=4.16.0", + "scapy>=2.6.1", ] [project.scripts] diff --git a/templates/sniffer/Dockerfile b/templates/sniffer/Dockerfile new file mode 100644 index 0000000..c6a9702 --- /dev/null +++ b/templates/sniffer/Dockerfile @@ -0,0 +1,12 @@ +ARG BASE_IMAGE=debian:bookworm-slim +FROM ${BASE_IMAGE} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-pip libpcap-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --no-cache-dir --break-system-packages "scapy==2.6.1" + +COPY decnet_logging.py server.py /opt/ + +ENTRYPOINT ["python3", "/opt/server.py"] diff --git a/templates/sniffer/decnet_logging.py b/templates/sniffer/decnet_logging.py new file mode 100644 index 0000000..5a64442 --- /dev/null +++ b/templates/sniffer/decnet_logging.py @@ -0,0 +1 @@ +# Placeholder — replaced by the deployer with the shared base template before docker build. diff --git a/templates/sniffer/server.py b/templates/sniffer/server.py new file mode 100644 index 0000000..53c3b79 --- /dev/null +++ b/templates/sniffer/server.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python3 +""" +DECNET passive TLS sniffer. + +Captures TLS handshakes on the MACVLAN interface (shared network namespace +with the decky base container). Extracts JA3/JA3S fingerprints and connection +metadata, then emits structured RFC 5424 log lines to stdout for the +host-side collector to ingest. + +Requires: NET_RAW + NET_ADMIN capabilities (set in compose fragment). + +JA3 — MD5(SSLVersion,Ciphers,Extensions,EllipticCurves,ECPointFormats) +JA3S — MD5(SSLVersion,Cipher,Extensions) + +GREASE values (RFC 8701) are excluded from all lists before hashing. +""" + +from __future__ import annotations + +import hashlib +import os +import struct +import time +from typing import Any + +from scapy.layers.inet import IP, TCP +from scapy.sendrecv import sniff + +from decnet_logging import SEVERITY_INFO, SEVERITY_WARNING, syslog_line, write_syslog_file + +# ─── Configuration ──────────────────────────────────────────────────────────── + +NODE_NAME: str = os.environ.get("NODE_NAME", "decky-sniffer") +SERVICE_NAME: str = "sniffer" + +# Session TTL in seconds — drop half-open sessions after this +_SESSION_TTL: float = 60.0 + +# GREASE values per RFC 8701 — 0x0A0A, 0x1A1A, 0x2A2A, ..., 0xFAFA +_GREASE: frozenset[int] = frozenset(0x0A0A + i * 0x1010 for i in range(16)) + +# TLS record / handshake type constants +_TLS_RECORD_HANDSHAKE: int = 0x16 +_TLS_HT_CLIENT_HELLO: int = 0x01 +_TLS_HT_SERVER_HELLO: int = 0x02 + +# TLS extension types we extract for metadata +_EXT_SNI: int = 0x0000 +_EXT_SUPPORTED_GROUPS: int = 0x000A +_EXT_EC_POINT_FORMATS: int = 0x000B +_EXT_ALPN: int = 0x0010 +_EXT_SESSION_TICKET: int = 0x0023 + +# ─── Session tracking ───────────────────────────────────────────────────────── + +# Key: (src_ip, src_port, dst_ip, dst_port) — forward 4-tuple from ClientHello +# Value: parsed ClientHello metadata dict +_sessions: dict[tuple[str, int, str, int], dict[str, Any]] = {} +_session_ts: dict[tuple[str, int, str, int], float] = {} + + +# ─── GREASE helpers ─────────────────────────────────────────────────────────── + +def _is_grease(value: int) -> bool: + return value in _GREASE + + +def _filter_grease(values: list[int]) -> list[int]: + return [v for v in values if not _is_grease(v)] + + +# ─── Pure-Python TLS record parser ──────────────────────────────────────────── + +def _parse_client_hello(data: bytes) -> dict[str, Any] | None: + """ + Parse a TLS ClientHello from raw bytes (starting at TLS record header). + Returns a dict of parsed fields, or None if not a valid ClientHello. + """ + try: + if len(data) < 6: + return None + # TLS record header: content_type(1) version(2) length(2) + if data[0] != _TLS_RECORD_HANDSHAKE: + return None + record_len = struct.unpack_from("!H", data, 3)[0] + if len(data) < 5 + record_len: + return None + + # Handshake header: type(1) length(3) + hs = data[5:] + if hs[0] != _TLS_HT_CLIENT_HELLO: + return None + + hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] + body = hs[4: 4 + hs_len] + if len(body) < 34: + return None + + pos = 0 + # ClientHello version (2 bytes) — used for JA3 + tls_version = struct.unpack_from("!H", body, pos)[0] + pos += 2 + + # Random (32 bytes) + pos += 32 + + # Session ID + session_id_len = body[pos] + pos += 1 + session_id_len + + # Cipher Suites + cs_len = struct.unpack_from("!H", body, pos)[0] + pos += 2 + cipher_suites = [ + struct.unpack_from("!H", body, pos + i * 2)[0] + for i in range(cs_len // 2) + ] + pos += cs_len + + # Compression Methods + comp_len = body[pos] + pos += 1 + comp_len + + # Extensions + extensions: list[int] = [] + supported_groups: list[int] = [] + ec_point_formats: list[int] = [] + sni: str = "" + alpn: list[str] = [] + + if pos + 2 <= len(body): + ext_total = struct.unpack_from("!H", body, pos)[0] + pos += 2 + ext_end = pos + ext_total + + while pos + 4 <= ext_end: + ext_type = struct.unpack_from("!H", body, pos)[0] + ext_len = struct.unpack_from("!H", body, pos + 2)[0] + ext_data = body[pos + 4: pos + 4 + ext_len] + pos += 4 + ext_len + + if not _is_grease(ext_type): + extensions.append(ext_type) + + if ext_type == _EXT_SNI and len(ext_data) > 5: + # server_name_list_length(2) type(1) name_length(2) name + sni = ext_data[5:].decode("ascii", errors="replace") + + elif ext_type == _EXT_SUPPORTED_GROUPS and len(ext_data) >= 2: + grp_len = struct.unpack_from("!H", ext_data, 0)[0] + supported_groups = [ + struct.unpack_from("!H", ext_data, 2 + i * 2)[0] + for i in range(grp_len // 2) + ] + + elif ext_type == _EXT_EC_POINT_FORMATS and len(ext_data) >= 1: + pf_len = ext_data[0] + ec_point_formats = list(ext_data[1: 1 + pf_len]) + + elif ext_type == _EXT_ALPN and len(ext_data) >= 2: + proto_list_len = struct.unpack_from("!H", ext_data, 0)[0] + ap = 2 + while ap < 2 + proto_list_len: + plen = ext_data[ap] + alpn.append(ext_data[ap + 1: ap + 1 + plen].decode("ascii", errors="replace")) + ap += 1 + plen + + filtered_ciphers = _filter_grease(cipher_suites) + filtered_groups = _filter_grease(supported_groups) + + return { + "tls_version": tls_version, + "cipher_suites": filtered_ciphers, + "extensions": extensions, + "supported_groups": filtered_groups, + "ec_point_formats": ec_point_formats, + "sni": sni, + "alpn": alpn, + } + + except Exception: + return None + + +def _parse_server_hello(data: bytes) -> dict[str, Any] | None: + """ + Parse a TLS ServerHello from raw bytes. + Returns dict with tls_version, cipher_suite, extensions, or None. + """ + try: + if len(data) < 6 or data[0] != _TLS_RECORD_HANDSHAKE: + return None + + hs = data[5:] + if hs[0] != _TLS_HT_SERVER_HELLO: + return None + + hs_len = struct.unpack_from("!I", b"\x00" + hs[1:4])[0] + body = hs[4: 4 + hs_len] + if len(body) < 35: + return None + + pos = 0 + tls_version = struct.unpack_from("!H", body, pos)[0] + pos += 2 + + # Random (32 bytes) + pos += 32 + + # Session ID + session_id_len = body[pos] + pos += 1 + session_id_len + + if pos + 2 > len(body): + return None + + cipher_suite = struct.unpack_from("!H", body, pos)[0] + pos += 2 + + # Compression method (1 byte) + pos += 1 + + extensions: list[int] = [] + if pos + 2 <= len(body): + ext_total = struct.unpack_from("!H", body, pos)[0] + pos += 2 + ext_end = pos + ext_total + while pos + 4 <= ext_end: + ext_type = struct.unpack_from("!H", body, pos)[0] + ext_len = struct.unpack_from("!H", body, pos + 2)[0] + pos += 4 + ext_len + if not _is_grease(ext_type): + extensions.append(ext_type) + + return { + "tls_version": tls_version, + "cipher_suite": cipher_suite, + "extensions": extensions, + } + + except Exception: + return None + + +# ─── JA3 / JA3S computation ─────────────────────────────────────────────────── + +def _tls_version_str(version: int) -> str: + return { + 0x0301: "TLS 1.0", + 0x0302: "TLS 1.1", + 0x0303: "TLS 1.2", + 0x0304: "TLS 1.3", + 0x0200: "SSL 2.0", + 0x0300: "SSL 3.0", + }.get(version, f"0x{version:04x}") + + +def _ja3(ch: dict[str, Any]) -> tuple[str, str]: + """Return (ja3_string, ja3_hash) for a parsed ClientHello.""" + parts = [ + str(ch["tls_version"]), + "-".join(str(c) for c in ch["cipher_suites"]), + "-".join(str(e) for e in ch["extensions"]), + "-".join(str(g) for g in ch["supported_groups"]), + "-".join(str(p) for p in ch["ec_point_formats"]), + ] + ja3_str = ",".join(parts) + return ja3_str, hashlib.md5(ja3_str.encode()).hexdigest() + + +def _ja3s(sh: dict[str, Any]) -> tuple[str, str]: + """Return (ja3s_string, ja3s_hash) for a parsed ServerHello.""" + parts = [ + str(sh["tls_version"]), + str(sh["cipher_suite"]), + "-".join(str(e) for e in sh["extensions"]), + ] + ja3s_str = ",".join(parts) + return ja3s_str, hashlib.md5(ja3s_str.encode()).hexdigest() + + +# ─── Session cleanup ───────────────────────────────────────────────────────── + +def _cleanup_sessions() -> None: + now = time.monotonic() + stale = [k for k, ts in _session_ts.items() if now - ts > _SESSION_TTL] + for k in stale: + _sessions.pop(k, None) + _session_ts.pop(k, None) + + +# ─── Logging helpers ───────────────────────────────────────────────────────── + +def _log(event_type: str, severity: int = SEVERITY_INFO, **fields: Any) -> None: + line = syslog_line(SERVICE_NAME, NODE_NAME, event_type, severity=severity, **fields) + write_syslog_file(line) + + +# ─── Packet callback ───────────────────────────────────────────────────────── + +def _on_packet(pkt: Any) -> None: + if not (pkt.haslayer(IP) and pkt.haslayer(TCP)): + return + + ip = pkt[IP] + tcp = pkt[TCP] + + payload = bytes(tcp.payload) + if not payload: + return + + src_ip: str = ip.src + dst_ip: str = ip.dst + src_port: int = tcp.sport + dst_port: int = tcp.dport + + # TLS record check + if payload[0] != _TLS_RECORD_HANDSHAKE: + return + + # Attempt ClientHello parse + ch = _parse_client_hello(payload) + if ch is not None: + _cleanup_sessions() + + key = (src_ip, src_port, dst_ip, dst_port) + ja3_str, ja3_hash = _ja3(ch) + + _sessions[key] = { + "ja3": ja3_hash, + "ja3_str": ja3_str, + "tls_version": ch["tls_version"], + "cipher_suites": ch["cipher_suites"], + "extensions": ch["extensions"], + "sni": ch["sni"], + "alpn": ch["alpn"], + } + _session_ts[key] = time.monotonic() + + _log( + "tls_client_hello", + src_ip=src_ip, + src_port=str(src_port), + dst_ip=dst_ip, + dst_port=str(dst_port), + ja3=ja3_hash, + tls_version=_tls_version_str(ch["tls_version"]), + sni=ch["sni"] or "", + alpn=",".join(ch["alpn"]), + raw_ciphers="-".join(str(c) for c in ch["cipher_suites"]), + raw_extensions="-".join(str(e) for e in ch["extensions"]), + ) + return + + # Attempt ServerHello parse + sh = _parse_server_hello(payload) + if sh is not None: + # Reverse 4-tuple to find the matching ClientHello + rev_key = (dst_ip, dst_port, src_ip, src_port) + ch_data = _sessions.pop(rev_key, None) + _session_ts.pop(rev_key, None) + + ja3s_str, ja3s_hash = _ja3s(sh) + + fields: dict[str, Any] = { + "src_ip": dst_ip, # original attacker is now the destination + "src_port": str(dst_port), + "dst_ip": src_ip, + "dst_port": str(src_port), + "ja3s": ja3s_hash, + "tls_version": _tls_version_str(sh["tls_version"]), + } + + if ch_data: + fields["ja3"] = ch_data["ja3"] + fields["sni"] = ch_data["sni"] or "" + fields["alpn"] = ",".join(ch_data["alpn"]) + fields["raw_ciphers"] = "-".join(str(c) for c in ch_data["cipher_suites"]) + fields["raw_extensions"] = "-".join(str(e) for e in ch_data["extensions"]) + + _log("tls_session", severity=SEVERITY_WARNING, **fields) + + +# ─── Entry point ───────────────────────────────────────────────────────────── + +if __name__ == "__main__": + _log("startup", msg=f"sniffer started node={NODE_NAME}") + sniff( + filter="tcp", + prn=_on_packet, + store=False, + ) diff --git a/tests/test_attacker_worker.py b/tests/test_attacker_worker.py new file mode 100644 index 0000000..57f44fe --- /dev/null +++ b/tests/test_attacker_worker.py @@ -0,0 +1,515 @@ +""" +Tests for decnet/web/attacker_worker.py + +Covers: +- _rebuild(): CorrelationEngine integration, traversal detection, upsert calls +- _extract_commands(): command harvesting from raw log rows +- _build_record(): record assembly from engine events + bounties +- _first_contact_deckies(): ordering for single-decky attackers +- attacker_profile_worker(): cancellation and error handling +""" + +from __future__ import annotations + +import asyncio +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from decnet.logging.syslog_formatter import SEVERITY_INFO, format_rfc5424 +from decnet.web.attacker_worker import ( + _build_record, + _extract_commands, + _first_contact_deckies, + _rebuild, + attacker_profile_worker, +) +from decnet.correlation.parser import LogEvent + +# ─── Helpers ────────────────────────────────────────────────────────────────── + +_TS1 = "2026-04-04T10:00:00+00:00" +_TS2 = "2026-04-04T10:05:00+00:00" +_TS3 = "2026-04-04T10:10:00+00:00" + +_DT1 = datetime.fromisoformat(_TS1) +_DT2 = datetime.fromisoformat(_TS2) +_DT3 = datetime.fromisoformat(_TS3) + + +def _make_raw_line( + service: str = "ssh", + hostname: str = "decky-01", + event_type: str = "connection", + src_ip: str = "1.2.3.4", + timestamp: str = _TS1, + **extra: str, +) -> str: + return format_rfc5424( + service=service, + hostname=hostname, + event_type=event_type, + severity=SEVERITY_INFO, + timestamp=datetime.fromisoformat(timestamp), + src_ip=src_ip, + **extra, + ) + + +def _make_log_row( + raw_line: str = "", + attacker_ip: str = "1.2.3.4", + service: str = "ssh", + event_type: str = "connection", + decky: str = "decky-01", + timestamp: datetime = _DT1, + fields: str = "{}", +) -> dict: + if not raw_line: + raw_line = _make_raw_line( + service=service, + hostname=decky, + event_type=event_type, + src_ip=attacker_ip, + timestamp=timestamp.isoformat(), + ) + return { + "id": 1, + "raw_line": raw_line, + "attacker_ip": attacker_ip, + "service": service, + "event_type": event_type, + "decky": decky, + "timestamp": timestamp, + "fields": fields, + } + + +def _make_repo(logs=None, bounties=None): + repo = MagicMock() + repo.get_all_logs_raw = AsyncMock(return_value=logs or []) + repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {}) + repo.upsert_attacker = AsyncMock() + return repo + + +def _make_log_event( + ip: str, + decky: str, + service: str = "ssh", + event_type: str = "connection", + timestamp: datetime = _DT1, +) -> LogEvent: + return LogEvent( + timestamp=timestamp, + decky=decky, + service=service, + event_type=event_type, + attacker_ip=ip, + fields={}, + raw="", + ) + + +# ─── _first_contact_deckies ─────────────────────────────────────────────────── + +class TestFirstContactDeckies: + def test_single_decky(self): + events = [_make_log_event("1.1.1.1", "decky-01", timestamp=_DT1)] + assert _first_contact_deckies(events) == ["decky-01"] + + def test_multiple_deckies_ordered_by_first_contact(self): + events = [ + _make_log_event("1.1.1.1", "decky-02", timestamp=_DT2), + _make_log_event("1.1.1.1", "decky-01", timestamp=_DT1), + ] + assert _first_contact_deckies(events) == ["decky-01", "decky-02"] + + def test_revisit_does_not_duplicate(self): + events = [ + _make_log_event("1.1.1.1", "decky-01", timestamp=_DT1), + _make_log_event("1.1.1.1", "decky-02", timestamp=_DT2), + _make_log_event("1.1.1.1", "decky-01", timestamp=_DT3), # revisit + ] + result = _first_contact_deckies(events) + assert result == ["decky-01", "decky-02"] + assert result.count("decky-01") == 1 + + +# ─── _extract_commands ──────────────────────────────────────────────────────── + +class TestExtractCommands: + def _row(self, ip, event_type, fields): + return _make_log_row( + attacker_ip=ip, + event_type=event_type, + service="ssh", + decky="decky-01", + fields=json.dumps(fields), + ) + + def test_extracts_command_field(self): + rows = [self._row("1.1.1.1", "command", {"command": "id"})] + result = _extract_commands(rows, "1.1.1.1") + assert len(result) == 1 + assert result[0]["command"] == "id" + assert result[0]["service"] == "ssh" + assert result[0]["decky"] == "decky-01" + + def test_extracts_query_field(self): + rows = [self._row("2.2.2.2", "query", {"query": "SELECT * FROM users"})] + result = _extract_commands(rows, "2.2.2.2") + assert len(result) == 1 + assert result[0]["command"] == "SELECT * FROM users" + + def test_extracts_input_field(self): + rows = [self._row("3.3.3.3", "input", {"input": "ls -la"})] + result = _extract_commands(rows, "3.3.3.3") + assert len(result) == 1 + assert result[0]["command"] == "ls -la" + + def test_non_command_event_type_ignored(self): + rows = [self._row("1.1.1.1", "connection", {"command": "id"})] + result = _extract_commands(rows, "1.1.1.1") + assert result == [] + + def test_wrong_ip_ignored(self): + rows = [self._row("9.9.9.9", "command", {"command": "whoami"})] + result = _extract_commands(rows, "1.1.1.1") + assert result == [] + + def test_no_command_field_skipped(self): + rows = [self._row("1.1.1.1", "command", {"other": "stuff"})] + result = _extract_commands(rows, "1.1.1.1") + assert result == [] + + def test_invalid_json_fields_skipped(self): + row = _make_log_row( + attacker_ip="1.1.1.1", + event_type="command", + fields="not valid json", + ) + result = _extract_commands([row], "1.1.1.1") + assert result == [] + + def test_multiple_commands_all_extracted(self): + rows = [ + self._row("5.5.5.5", "command", {"command": "id"}), + self._row("5.5.5.5", "command", {"command": "uname -a"}), + ] + result = _extract_commands(rows, "5.5.5.5") + assert len(result) == 2 + cmds = {r["command"] for r in result} + assert cmds == {"id", "uname -a"} + + def test_timestamp_serialized_to_string(self): + rows = [self._row("1.1.1.1", "command", {"command": "pwd"})] + result = _extract_commands(rows, "1.1.1.1") + assert isinstance(result[0]["timestamp"], str) + + +# ─── _build_record ──────────────────────────────────────────────────────────── + +class TestBuildRecord: + def _events(self, ip="1.1.1.1"): + return [ + _make_log_event(ip, "decky-01", "ssh", "conn", _DT1), + _make_log_event(ip, "decky-01", "http", "req", _DT2), + ] + + def test_basic_fields(self): + events = self._events() + record = _build_record("1.1.1.1", events, None, [], []) + assert record["ip"] == "1.1.1.1" + assert record["event_count"] == 2 + assert record["service_count"] == 2 + assert record["decky_count"] == 1 + + def test_first_last_seen(self): + events = self._events() + record = _build_record("1.1.1.1", events, None, [], []) + assert record["first_seen"] == _DT1 + assert record["last_seen"] == _DT2 + + def test_services_json_sorted(self): + events = self._events() + record = _build_record("1.1.1.1", events, None, [], []) + services = json.loads(record["services"]) + assert sorted(services) == services + + def test_no_traversal(self): + events = self._events() + record = _build_record("1.1.1.1", events, None, [], []) + assert record["is_traversal"] is False + assert record["traversal_path"] is None + + def test_with_traversal(self): + from decnet.correlation.graph import AttackerTraversal, TraversalHop + hops = [ + TraversalHop(_DT1, "decky-01", "ssh", "conn"), + TraversalHop(_DT2, "decky-02", "http", "req"), + ] + t = AttackerTraversal("1.1.1.1", hops) + events = [ + _make_log_event("1.1.1.1", "decky-01", timestamp=_DT1), + _make_log_event("1.1.1.1", "decky-02", timestamp=_DT2), + ] + record = _build_record("1.1.1.1", events, t, [], []) + assert record["is_traversal"] is True + assert record["traversal_path"] == "decky-01 → decky-02" + deckies = json.loads(record["deckies"]) + assert deckies == ["decky-01", "decky-02"] + + def test_bounty_counts(self): + events = self._events() + bounties = [ + {"bounty_type": "credential", "attacker_ip": "1.1.1.1"}, + {"bounty_type": "credential", "attacker_ip": "1.1.1.1"}, + {"bounty_type": "fingerprint", "attacker_ip": "1.1.1.1"}, + ] + record = _build_record("1.1.1.1", events, None, bounties, []) + assert record["bounty_count"] == 3 + assert record["credential_count"] == 2 + fps = json.loads(record["fingerprints"]) + assert len(fps) == 1 + assert fps[0]["bounty_type"] == "fingerprint" + + def test_commands_serialized(self): + events = self._events() + cmds = [{"service": "ssh", "decky": "decky-01", "command": "id", "timestamp": "2026-04-04T10:00:00"}] + record = _build_record("1.1.1.1", events, None, [], cmds) + parsed = json.loads(record["commands"]) + assert len(parsed) == 1 + assert parsed[0]["command"] == "id" + + def test_updated_at_is_utc_datetime(self): + events = self._events() + record = _build_record("1.1.1.1", events, None, [], []) + assert isinstance(record["updated_at"], datetime) + assert record["updated_at"].tzinfo is not None + + +# ─── _rebuild ───────────────────────────────────────────────────────────────── + +class TestRebuild: + @pytest.mark.asyncio + async def test_empty_logs_no_upsert(self): + repo = _make_repo(logs=[]) + await _rebuild(repo) + repo.upsert_attacker.assert_not_awaited() + + @pytest.mark.asyncio + async def test_single_attacker_upserted(self): + raw = _make_raw_line("ssh", "decky-01", "connection", "10.0.0.1", _TS1) + row = _make_log_row(raw_line=raw, attacker_ip="10.0.0.1") + repo = _make_repo(logs=[row]) + await _rebuild(repo) + repo.upsert_attacker.assert_awaited_once() + record = repo.upsert_attacker.call_args[0][0] + assert record["ip"] == "10.0.0.1" + assert record["event_count"] == 1 + + @pytest.mark.asyncio + async def test_multiple_attackers_all_upserted(self): + rows = [ + _make_log_row( + raw_line=_make_raw_line("ssh", "decky-01", "conn", ip, _TS1), + attacker_ip=ip, + ) + for ip in ["1.1.1.1", "2.2.2.2", "3.3.3.3"] + ] + repo = _make_repo(logs=rows) + await _rebuild(repo) + assert repo.upsert_attacker.await_count == 3 + upserted_ips = {c[0][0]["ip"] for c in repo.upsert_attacker.call_args_list} + assert upserted_ips == {"1.1.1.1", "2.2.2.2", "3.3.3.3"} + + @pytest.mark.asyncio + async def test_traversal_detected_across_two_deckies(self): + rows = [ + _make_log_row( + raw_line=_make_raw_line("ssh", "decky-01", "conn", "5.5.5.5", _TS1), + attacker_ip="5.5.5.5", decky="decky-01", + ), + _make_log_row( + raw_line=_make_raw_line("http", "decky-02", "req", "5.5.5.5", _TS2), + attacker_ip="5.5.5.5", decky="decky-02", + ), + ] + repo = _make_repo(logs=rows) + await _rebuild(repo) + record = repo.upsert_attacker.call_args[0][0] + assert record["is_traversal"] is True + assert "decky-01" in record["traversal_path"] + assert "decky-02" in record["traversal_path"] + + @pytest.mark.asyncio + async def test_single_decky_not_traversal(self): + rows = [ + _make_log_row( + raw_line=_make_raw_line("ssh", "decky-01", "conn", "7.7.7.7", _TS1), + attacker_ip="7.7.7.7", + ), + _make_log_row( + raw_line=_make_raw_line("http", "decky-01", "req", "7.7.7.7", _TS2), + attacker_ip="7.7.7.7", + ), + ] + repo = _make_repo(logs=rows) + await _rebuild(repo) + record = repo.upsert_attacker.call_args[0][0] + assert record["is_traversal"] is False + + @pytest.mark.asyncio + async def test_bounties_merged_into_record(self): + raw = _make_raw_line("ssh", "decky-01", "conn", "8.8.8.8", _TS1) + repo = _make_repo( + logs=[_make_log_row(raw_line=raw, attacker_ip="8.8.8.8")], + bounties={"8.8.8.8": [ + {"bounty_type": "credential", "attacker_ip": "8.8.8.8", "payload": {}}, + {"bounty_type": "fingerprint", "attacker_ip": "8.8.8.8", "payload": {"ja3": "abc"}}, + ]}, + ) + await _rebuild(repo) + record = repo.upsert_attacker.call_args[0][0] + assert record["bounty_count"] == 2 + assert record["credential_count"] == 1 + fps = json.loads(record["fingerprints"]) + assert len(fps) == 1 + + @pytest.mark.asyncio + async def test_commands_extracted_during_rebuild(self): + raw = _make_raw_line("ssh", "decky-01", "command", "9.9.9.9", _TS1) + row = _make_log_row( + raw_line=raw, + attacker_ip="9.9.9.9", + event_type="command", + fields=json.dumps({"command": "cat /etc/passwd"}), + ) + repo = _make_repo(logs=[row]) + await _rebuild(repo) + record = repo.upsert_attacker.call_args[0][0] + commands = json.loads(record["commands"]) + assert len(commands) == 1 + assert commands[0]["command"] == "cat /etc/passwd" + + +# ─── attacker_profile_worker ────────────────────────────────────────────────── + +class TestAttackerProfileWorker: + @pytest.mark.asyncio + async def test_worker_cancels_cleanly(self): + repo = _make_repo() + task = asyncio.create_task(attacker_profile_worker(repo)) + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + @pytest.mark.asyncio + async def test_worker_handles_rebuild_error_without_crashing(self): + repo = _make_repo() + _call_count = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + async def bad_rebuild(_repo): + raise RuntimeError("DB exploded") + + with patch("decnet.web.attacker_worker.asyncio.sleep", side_effect=fake_sleep): + with patch("decnet.web.attacker_worker._rebuild", side_effect=bad_rebuild): + with pytest.raises(asyncio.CancelledError): + await attacker_profile_worker(repo) + + @pytest.mark.asyncio + async def test_worker_calls_rebuild_after_sleep(self): + repo = _make_repo() + _call_count = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + rebuild_calls = [] + + async def mock_rebuild(_repo): + rebuild_calls.append(True) + + with patch("decnet.web.attacker_worker.asyncio.sleep", side_effect=fake_sleep): + with patch("decnet.web.attacker_worker._rebuild", side_effect=mock_rebuild): + with pytest.raises(asyncio.CancelledError): + await attacker_profile_worker(repo) + + assert len(rebuild_calls) >= 1 + + +# ─── JA3 bounty extraction from ingester ───────────────────────────────────── + +class TestJA3BountyExtraction: + @pytest.mark.asyncio + async def test_ja3_bounty_extracted_from_sniffer_event(self): + from decnet.web.ingester import _extract_bounty + repo = MagicMock() + repo.add_bounty = AsyncMock() + log_data = { + "decky": "decky-01", + "service": "sniffer", + "attacker_ip": "10.0.0.5", + "event_type": "tls_client_hello", + "fields": { + "ja3": "abc123def456abc123def456abc12345", + "ja3s": None, + "tls_version": "TLS 1.3", + "sni": "example.com", + "alpn": "h2", + "dst_port": "443", + "raw_ciphers": "4865-4866", + "raw_extensions": "0-23-65281", + }, + } + await _extract_bounty(repo, log_data) + repo.add_bounty.assert_awaited_once() + bounty = repo.add_bounty.call_args[0][0] + assert bounty["bounty_type"] == "fingerprint" + assert bounty["payload"]["fingerprint_type"] == "ja3" + assert bounty["payload"]["ja3"] == "abc123def456abc123def456abc12345" + assert bounty["payload"]["tls_version"] == "TLS 1.3" + assert bounty["payload"]["sni"] == "example.com" + + @pytest.mark.asyncio + async def test_non_sniffer_service_with_ja3_field_ignored(self): + from decnet.web.ingester import _extract_bounty + repo = MagicMock() + repo.add_bounty = AsyncMock() + log_data = { + "service": "http", + "attacker_ip": "10.0.0.6", + "event_type": "request", + "fields": {"ja3": "somehash"}, + } + await _extract_bounty(repo, log_data) + # Credential/UA checks run, but JA3 should not fire for non-sniffer + calls = [c[0][0]["bounty_type"] for c in repo.add_bounty.call_args_list] + assert "ja3" not in str(calls) + + @pytest.mark.asyncio + async def test_sniffer_without_ja3_no_bounty(self): + from decnet.web.ingester import _extract_bounty + repo = MagicMock() + repo.add_bounty = AsyncMock() + log_data = { + "service": "sniffer", + "attacker_ip": "10.0.0.7", + "event_type": "startup", + "fields": {"msg": "started"}, + } + await _extract_bounty(repo, log_data) + repo.add_bounty.assert_not_awaited() diff --git a/tests/test_base_repo.py b/tests/test_base_repo.py index efa7787..5ba51db 100644 --- a/tests/test_base_repo.py +++ b/tests/test_base_repo.py @@ -21,6 +21,11 @@ class DummyRepo(BaseRepository): async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw) async def get_state(self, k): await super().get_state(k) async def set_state(self, k, v): await super().set_state(k, v) + async def get_all_logs_raw(self): await super().get_all_logs_raw() + async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip() + async def upsert_attacker(self, d): await super().upsert_attacker(d) + async def get_attackers(self, **kw): await super().get_attackers(**kw) + async def get_total_attackers(self, **kw): await super().get_total_attackers(**kw) @pytest.mark.asyncio async def test_base_repo_coverage(): @@ -41,3 +46,8 @@ async def test_base_repo_coverage(): await dr.get_total_bounties() await dr.get_state("k") await dr.set_state("k", "v") + await dr.get_all_logs_raw() + await dr.get_all_bounties_by_ip() + await dr.upsert_attacker({}) + await dr.get_attackers() + await dr.get_total_attackers() diff --git a/tests/test_sniffer_ja3.py b/tests/test_sniffer_ja3.py new file mode 100644 index 0000000..b0e053b --- /dev/null +++ b/tests/test_sniffer_ja3.py @@ -0,0 +1,437 @@ +""" +Unit tests for the JA3/JA3S parsing logic in templates/sniffer/server.py. + +Imports the parser functions directly via sys.path manipulation, with +decnet_logging mocked out (it's a container-side stub at template build time). +""" + +from __future__ import annotations + +import hashlib +import struct +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +# ─── Import sniffer module with mocked decnet_logging ───────────────────────── + +_SNIFFER_DIR = str(Path(__file__).parent.parent / "templates" / "sniffer") + +def _load_sniffer(): + """Load templates/sniffer/server.py with decnet_logging stubbed out.""" + # Stub the decnet_logging module that server.py imports + _stub = types.ModuleType("decnet_logging") + _stub.SEVERITY_INFO = 6 + _stub.SEVERITY_WARNING = 4 + _stub.syslog_line = MagicMock(return_value="<134>1 fake") + _stub.write_syslog_file = MagicMock() + sys.modules.setdefault("decnet_logging", _stub) + + if _SNIFFER_DIR not in sys.path: + sys.path.insert(0, _SNIFFER_DIR) + + import importlib + if "server" in sys.modules: + return sys.modules["server"] + import server as _srv + return _srv + +_srv = _load_sniffer() + +_parse_client_hello = _srv._parse_client_hello +_parse_server_hello = _srv._parse_server_hello +_ja3 = _srv._ja3 +_ja3s = _srv._ja3s +_is_grease = _srv._is_grease +_filter_grease = _srv._filter_grease +_tls_version_str = _srv._tls_version_str + + +# ─── TLS byte builder helpers ───────────────────────────────────────────────── + +def _build_client_hello( + version: int = 0x0303, + cipher_suites: list[int] | None = None, + extensions_bytes: bytes = b"", +) -> bytes: + """Build a minimal valid TLS ClientHello byte sequence.""" + if cipher_suites is None: + cipher_suites = [0x002F, 0x0035] # AES-128-SHA, AES-256-SHA + + random_bytes = b"\xAB" * 32 + session_id = b"\x00" # no session id + cs_bytes = b"".join(struct.pack("!H", c) for c in cipher_suites) + cs_len = struct.pack("!H", len(cs_bytes)) + compression = b"\x01\x00" # 1 method: null + + if extensions_bytes: + ext_block = struct.pack("!H", len(extensions_bytes)) + extensions_bytes + else: + ext_block = b"\x00\x00" + + body = ( + struct.pack("!H", version) + + random_bytes + + session_id + + cs_len + + cs_bytes + + compression + + ext_block + ) + + hs_header = b"\x01" + struct.pack("!I", len(body))[1:] # type + 3-byte len + record_payload = hs_header + body + record = b"\x16\x03\x01" + struct.pack("!H", len(record_payload)) + record_payload + return record + + +def _build_extension(ext_type: int, data: bytes) -> bytes: + return struct.pack("!HH", ext_type, len(data)) + data + + +def _build_sni_extension(hostname: str) -> bytes: + name_bytes = hostname.encode() + # server_name: type(1) + len(2) + name + entry = b"\x00" + struct.pack("!H", len(name_bytes)) + name_bytes + # server_name_list: len(2) + entries + lst = struct.pack("!H", len(entry)) + entry + return _build_extension(0x0000, lst) + + +def _build_supported_groups_extension(groups: list[int]) -> bytes: + grp_bytes = b"".join(struct.pack("!H", g) for g in groups) + data = struct.pack("!H", len(grp_bytes)) + grp_bytes + return _build_extension(0x000A, data) + + +def _build_ec_point_formats_extension(formats: list[int]) -> bytes: + pf = bytes(formats) + data = bytes([len(pf)]) + pf + return _build_extension(0x000B, data) + + +def _build_alpn_extension(protocols: list[str]) -> bytes: + proto_bytes = b"" + for p in protocols: + pb = p.encode() + proto_bytes += bytes([len(pb)]) + pb + data = struct.pack("!H", len(proto_bytes)) + proto_bytes + return _build_extension(0x0010, data) + + +def _build_server_hello( + version: int = 0x0303, + cipher_suite: int = 0x002F, + extensions_bytes: bytes = b"", +) -> bytes: + random_bytes = b"\xCD" * 32 + session_id = b"\x00" + compression = b"\x00" + + if extensions_bytes: + ext_block = struct.pack("!H", len(extensions_bytes)) + extensions_bytes + else: + ext_block = b"\x00\x00" + + body = ( + struct.pack("!H", version) + + random_bytes + + session_id + + struct.pack("!H", cipher_suite) + + compression + + ext_block + ) + + hs_header = b"\x02" + struct.pack("!I", len(body))[1:] + record_payload = hs_header + body + return b"\x16\x03\x01" + struct.pack("!H", len(record_payload)) + record_payload + + +# ─── GREASE tests ───────────────────────────────────────────────────────────── + +class TestGrease: + def test_known_grease_values_detected(self): + for v in [0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A, + 0x6A6A, 0x7A7A, 0x8A8A, 0x9A9A, 0xAAAA, 0xBABA, + 0xCACA, 0xDADA, 0xEAEA, 0xFAFA]: + assert _is_grease(v), f"0x{v:04x} should be GREASE" + + def test_non_grease_values_not_detected(self): + for v in [0x002F, 0x0035, 0x1301, 0x000A, 0xFFFF]: + assert not _is_grease(v), f"0x{v:04x} should not be GREASE" + + def test_filter_grease_removes_grease(self): + values = [0x0A0A, 0x002F, 0x1A1A, 0x0035] + result = _filter_grease(values) + assert result == [0x002F, 0x0035] + + def test_filter_grease_preserves_all_non_grease(self): + values = [0x002F, 0x0035, 0x1301] + assert _filter_grease(values) == values + + +# ─── ClientHello parsing tests ──────────────────────────────────────────────── + +class TestParseClientHello: + def test_minimal_client_hello_parsed(self): + data = _build_client_hello() + result = _parse_client_hello(data) + assert result is not None + assert result["tls_version"] == 0x0303 + assert result["cipher_suites"] == [0x002F, 0x0035] + assert result["extensions"] == [] + assert result["supported_groups"] == [] + assert result["ec_point_formats"] == [] + assert result["sni"] == "" + assert result["alpn"] == [] + + def test_wrong_record_type_returns_none(self): + data = _build_client_hello() + bad = b"\x14" + data[1:] # change record type to ChangeCipherSpec + assert _parse_client_hello(bad) is None + + def test_wrong_handshake_type_returns_none(self): + data = _build_client_hello() + # Byte at offset 5 is the handshake type + bad = data[:5] + b"\x02" + data[6:] # ServerHello type + assert _parse_client_hello(bad) is None + + def test_too_short_returns_none(self): + assert _parse_client_hello(b"\x16\x03\x01") is None + assert _parse_client_hello(b"") is None + + def test_non_tls_returns_none(self): + assert _parse_client_hello(b"GET / HTTP/1.1\r\n") is None + + def test_grease_cipher_suites_filtered(self): + data = _build_client_hello(cipher_suites=[0x0A0A, 0x002F, 0x1A1A, 0x0035]) + result = _parse_client_hello(data) + assert result is not None + assert 0x0A0A not in result["cipher_suites"] + assert 0x1A1A not in result["cipher_suites"] + assert result["cipher_suites"] == [0x002F, 0x0035] + + def test_sni_extension_extracted(self): + ext = _build_sni_extension("example.com") + data = _build_client_hello(extensions_bytes=ext) + result = _parse_client_hello(data) + assert result is not None + assert result["sni"] == "example.com" + + def test_supported_groups_extracted(self): + ext = _build_supported_groups_extension([0x001D, 0x0017, 0x0018]) + data = _build_client_hello(extensions_bytes=ext) + result = _parse_client_hello(data) + assert result is not None + assert result["supported_groups"] == [0x001D, 0x0017, 0x0018] + + def test_grease_in_supported_groups_filtered(self): + ext = _build_supported_groups_extension([0x0A0A, 0x001D]) + data = _build_client_hello(extensions_bytes=ext) + result = _parse_client_hello(data) + assert result is not None + assert 0x0A0A not in result["supported_groups"] + assert 0x001D in result["supported_groups"] + + def test_ec_point_formats_extracted(self): + ext = _build_ec_point_formats_extension([0x00, 0x01]) + data = _build_client_hello(extensions_bytes=ext) + result = _parse_client_hello(data) + assert result is not None + assert result["ec_point_formats"] == [0x00, 0x01] + + def test_alpn_extension_extracted(self): + ext = _build_alpn_extension(["h2", "http/1.1"]) + data = _build_client_hello(extensions_bytes=ext) + result = _parse_client_hello(data) + assert result is not None + assert result["alpn"] == ["h2", "http/1.1"] + + def test_multiple_extensions_extracted(self): + sni = _build_sni_extension("target.local") + grps = _build_supported_groups_extension([0x001D]) + combined = sni + grps + data = _build_client_hello(extensions_bytes=combined) + result = _parse_client_hello(data) + assert result is not None + assert result["sni"] == "target.local" + assert 0x001D in result["supported_groups"] + # Extension type IDs recorded (SNI=0, supported_groups=10) + assert 0x0000 in result["extensions"] + assert 0x000A in result["extensions"] + + +# ─── ServerHello parsing tests ──────────────────────────────────────────────── + +class TestParseServerHello: + def test_minimal_server_hello_parsed(self): + data = _build_server_hello() + result = _parse_server_hello(data) + assert result is not None + assert result["tls_version"] == 0x0303 + assert result["cipher_suite"] == 0x002F + assert result["extensions"] == [] + + def test_wrong_record_type_returns_none(self): + data = _build_server_hello() + bad = b"\x15" + data[1:] + assert _parse_server_hello(bad) is None + + def test_wrong_handshake_type_returns_none(self): + data = _build_server_hello() + bad = data[:5] + b"\x01" + data[6:] # ClientHello type + assert _parse_server_hello(bad) is None + + def test_too_short_returns_none(self): + assert _parse_server_hello(b"") is None + + def test_server_hello_extension_types_recorded(self): + # Build a ServerHello with a generic extension (type=0xFF01) + ext_data = _build_extension(0xFF01, b"\x00") + data = _build_server_hello(extensions_bytes=ext_data) + result = _parse_server_hello(data) + assert result is not None + assert 0xFF01 in result["extensions"] + + def test_grease_extension_in_server_hello_filtered(self): + ext_data = _build_extension(0x0A0A, b"\x00") + data = _build_server_hello(extensions_bytes=ext_data) + result = _parse_server_hello(data) + assert result is not None + assert 0x0A0A not in result["extensions"] + + +# ─── JA3 hash tests ─────────────────────────────────────────────────────────── + +class TestJA3: + def test_ja3_returns_32_char_hex(self): + data = _build_client_hello() + ch = _parse_client_hello(data) + _, ja3_hash = _ja3(ch) + assert len(ja3_hash) == 32 + assert all(c in "0123456789abcdef" for c in ja3_hash) + + def test_ja3_known_hash(self): + # Minimal ClientHello: TLS 1.2, ciphers [47, 53], no extensions + ch = { + "tls_version": 0x0303, # 771 + "cipher_suites": [0x002F, 0x0035], # 47, 53 + "extensions": [], + "supported_groups": [], + "ec_point_formats": [], + "sni": "", + "alpn": [], + } + ja3_str, ja3_hash = _ja3(ch) + assert ja3_str == "771,47-53,,," + expected = hashlib.md5(b"771,47-53,,,").hexdigest() + assert ja3_hash == expected + + def test_ja3_same_input_same_hash(self): + data = _build_client_hello() + ch = _parse_client_hello(data) + _, h1 = _ja3(ch) + _, h2 = _ja3(ch) + assert h1 == h2 + + def test_ja3_different_ciphers_different_hash(self): + ch1 = {"tls_version": 0x0303, "cipher_suites": [47], "extensions": [], + "supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []} + ch2 = {"tls_version": 0x0303, "cipher_suites": [53], "extensions": [], + "supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []} + _, h1 = _ja3(ch1) + _, h2 = _ja3(ch2) + assert h1 != h2 + + def test_ja3_empty_lists_produce_valid_string(self): + ch = {"tls_version": 0x0303, "cipher_suites": [], "extensions": [], + "supported_groups": [], "ec_point_formats": [], "sni": "", "alpn": []} + ja3_str, ja3_hash = _ja3(ch) + assert ja3_str == "771,,,," + assert len(ja3_hash) == 32 + + +# ─── JA3S hash tests ────────────────────────────────────────────────────────── + +class TestJA3S: + def test_ja3s_returns_32_char_hex(self): + data = _build_server_hello() + sh = _parse_server_hello(data) + _, ja3s_hash = _ja3s(sh) + assert len(ja3s_hash) == 32 + assert all(c in "0123456789abcdef" for c in ja3s_hash) + + def test_ja3s_known_hash(self): + sh = {"tls_version": 0x0303, "cipher_suite": 0x002F, "extensions": []} + ja3s_str, ja3s_hash = _ja3s(sh) + assert ja3s_str == "771,47," + expected = hashlib.md5(b"771,47,").hexdigest() + assert ja3s_hash == expected + + def test_ja3s_different_cipher_different_hash(self): + sh1 = {"tls_version": 0x0303, "cipher_suite": 0x002F, "extensions": []} + sh2 = {"tls_version": 0x0303, "cipher_suite": 0x0035, "extensions": []} + _, h1 = _ja3s(sh1) + _, h2 = _ja3s(sh2) + assert h1 != h2 + + +# ─── TLS version string tests ───────────────────────────────────────────────── + +class TestTLSVersionStr: + def test_tls12(self): + assert _tls_version_str(0x0303) == "TLS 1.2" + + def test_tls13(self): + assert _tls_version_str(0x0304) == "TLS 1.3" + + def test_tls11(self): + assert _tls_version_str(0x0302) == "TLS 1.1" + + def test_tls10(self): + assert _tls_version_str(0x0301) == "TLS 1.0" + + def test_unknown_version(self): + result = _tls_version_str(0xABCD) + assert "0xabcd" in result.lower() + + +# ─── Full round-trip: parse bytes → JA3/JA3S ────────────────────────────────── + +class TestRoundTrip: + def test_client_hello_bytes_to_ja3(self): + ciphers = [0x1301, 0x1302, 0x002F] + sni_ext = _build_sni_extension("attacker.c2.com") + data = _build_client_hello(cipher_suites=ciphers, extensions_bytes=sni_ext) + ch = _parse_client_hello(data) + assert ch is not None + ja3_str, ja3_hash = _ja3(ch) + assert "4865-4866-47" in ja3_str # ciphers: 0x1301=4865, 0x1302=4866, 0x002F=47 + assert len(ja3_hash) == 32 + assert ch["sni"] == "attacker.c2.com" + + def test_server_hello_bytes_to_ja3s(self): + data = _build_server_hello(cipher_suite=0x1301) + sh = _parse_server_hello(data) + assert sh is not None + ja3s_str, ja3s_hash = _ja3s(sh) + assert "4865" in ja3s_str # 0x1301 = 4865 + assert len(ja3s_hash) == 32 + + def test_grease_client_hello_filtered_before_hash(self): + """GREASE ciphers must be stripped before JA3 is computed.""" + ciphers_with_grease = [0x0A0A, 0x002F, 0xFAFA, 0x0035] + data = _build_client_hello(cipher_suites=ciphers_with_grease) + ch = _parse_client_hello(data) + _, ja3_hash = _ja3(ch) + + # Reference: build without GREASE + ciphers_clean = [0x002F, 0x0035] + data_clean = _build_client_hello(cipher_suites=ciphers_clean) + ch_clean = _parse_client_hello(data_clean) + _, ja3_hash_clean = _ja3(ch_clean) + + assert ja3_hash == ja3_hash_clean diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 0879c23..b07a1f8 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -128,8 +128,9 @@ class TestLifespan: with patch("decnet.web.api.repo", mock_repo): with patch("decnet.web.api.log_ingestion_worker", return_value=asyncio.sleep(0)): with patch("decnet.web.api.log_collector_worker", return_value=asyncio.sleep(0)): - async with lifespan(mock_app): - mock_repo.initialize.assert_awaited_once() + with patch("decnet.web.api.attacker_profile_worker", return_value=asyncio.sleep(0)): + async with lifespan(mock_app): + mock_repo.initialize.assert_awaited_once() @pytest.mark.asyncio async def test_lifespan_db_retry(self): @@ -150,5 +151,6 @@ class TestLifespan: with patch("decnet.web.api.asyncio.sleep", new_callable=AsyncMock): with patch("decnet.web.api.log_ingestion_worker", return_value=asyncio.sleep(0)): with patch("decnet.web.api.log_collector_worker", return_value=asyncio.sleep(0)): - async with lifespan(mock_app): - assert _call_count == 3 + with patch("decnet.web.api.attacker_profile_worker", return_value=asyncio.sleep(0)): + async with lifespan(mock_app): + assert _call_count == 3