diff --git a/decnet/prober/base.py b/decnet/prober/base.py new file mode 100644 index 00000000..0da7e9ce --- /dev/null +++ b/decnet/prober/base.py @@ -0,0 +1,90 @@ +""" +ActiveProbe ABC and metaclass registry for port-iterating active probes. + +Adding a new active probe = one class with three methods. +IPv6 leak and TLS cert capture are NOT part of this registry (different +call shapes); they stay as special cases in prober/worker.py. +""" + +from __future__ import annotations + +import os +from abc import ABCMeta, abstractmethod +from typing import Any + +from decnet.correlation.fingerprint_rotation import ProbeType + + +class ActiveProbeMeta(ABCMeta): + """Metaclass that auto-registers every ActiveProbe subclass by probe_name.""" + + _registry: dict[str, type[ActiveProbe]] = {} + + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + ) -> ActiveProbeMeta: + cls = super().__new__(mcs, name, bases, namespace) + if bases and getattr(cls, "probe_name", None): + mcs._registry[cls.probe_name] = cls # type: ignore[attr-defined,assignment] + return cls + + @classmethod + def all(mcs) -> list[type[ActiveProbe]]: + """Return registered probes sorted by (priority asc, probe_name asc).""" + return sorted(mcs._registry.values(), key=lambda c: (c.priority, c.probe_name)) + + +class ActiveProbe(metaclass=ActiveProbeMeta): + """Base class for all port-iterating active probes. + + Subclasses declare class-level attributes and implement three methods. + Registration is automatic via ActiveProbeMeta. + + Port override: set DECNET_PROBE_PORTS_ (comma-separated) to + override default_ports at runtime without touching the class. + """ + + probe_name: str + default_ports: list[int] + event_type: str + rotation_type: ProbeType | None = None + rotation_hash_key: str | None = None + priority: int = 100 + + def __init__(self) -> None: + env_key = f"DECNET_PROBE_PORTS_{self.probe_name.upper()}" + raw = os.environ.get(env_key, "").strip() + if raw: + try: + self._ports: list[int] = [int(p.strip()) for p in raw.split(",") if p.strip()] + except ValueError: + self._ports = list(self.default_ports) + else: + self._ports = list(self.default_ports) + + @property + def ports(self) -> list[int]: + return self._ports + + @abstractmethod + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + """Execute the probe against ip:port. + + Return a result dict on success, or None to suppress emission (e.g. + empty JARM hash means the port doesn't speak TLS). + """ + + @abstractmethod + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + """Return (sd_fields, human_msg) for _write_event. + + target_ip and target_port are injected by _run_probe; do not include + them in sd_fields. + """ + + @abstractmethod + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + """Return the bus payload dict for attacker.fingerprinted events.""" diff --git a/decnet/prober/probes/__init__.py b/decnet/prober/probes/__init__.py new file mode 100644 index 00000000..1e62eaf4 --- /dev/null +++ b/decnet/prober/probes/__init__.py @@ -0,0 +1,4 @@ +# Import all probe modules to trigger ActiveProbeMeta registration. +from decnet.prober.probes.hassh import HasshProbe as HasshProbe +from decnet.prober.probes.jarm import JarmProbe as JarmProbe +from decnet.prober.probes.tcpfp import TcpfpProbe as TcpfpProbe diff --git a/decnet/prober/probes/hassh.py b/decnet/prober/probes/hassh.py new file mode 100644 index 00000000..28b1749e --- /dev/null +++ b/decnet/prober/probes/hassh.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + +from decnet.prober.base import ActiveProbe +from decnet.prober.hassh import hassh_server +from decnet.telemetry import traced as _traced + +DEFAULT_PORTS: list[int] = [22, 2222, 22222, 2022] + + +class HasshProbe(ActiveProbe): + probe_name = "hassh" + default_ports = DEFAULT_PORTS + event_type = "hassh_fingerprint" + rotation_type = "hassh" + rotation_hash_key = "hassh_server" + priority = 100 + + @_traced("prober.hassh_probe") + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + return hassh_server(ip, port, timeout=timeout) + + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + fields = { + "hassh_server_hash": result["hassh_server"], + "ssh_banner": result["banner"], + "kex_algorithms": result["kex_algorithms"], + "encryption_s2c": result["encryption_s2c"], + "mac_s2c": result["mac_s2c"], + "compression_s2c": result["compression_s2c"], + } + return fields, f"HASSH {ip}:{port} = {result['hassh_server']}" + + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + return { + "attacker_ip": ip, + "port": port, + "hassh_server": result["hassh_server"], + "ssh_banner": result["banner"], + } diff --git a/decnet/prober/probes/jarm.py b/decnet/prober/probes/jarm.py new file mode 100644 index 00000000..fc9c28c5 --- /dev/null +++ b/decnet/prober/probes/jarm.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Any + +from decnet.prober.base import ActiveProbe +from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash +from decnet.telemetry import traced as _traced + +DEFAULT_PORTS: list[int] = [443, 8443, 8080, 4443, 50050, 2222, 993, 995, 8888, 9001] + + +class JarmProbe(ActiveProbe): + probe_name = "jarm" + default_ports = DEFAULT_PORTS + event_type = "jarm_fingerprint" + rotation_type = "jarm" + rotation_hash_key = "jarm_hash" + priority = 100 + + @_traced("prober.jarm_probe") + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + h = jarm_hash(ip, port, timeout=timeout) + if h == JARM_EMPTY_HASH: + return None + return {"jarm_hash": h} + + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + h = result["jarm_hash"] + return {"jarm_hash": h}, f"JARM {ip}:{port} = {h}" + + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + return {"attacker_ip": ip, "port": port, "jarm_hash": result["jarm_hash"]} diff --git a/decnet/prober/probes/tcpfp.py b/decnet/prober/probes/tcpfp.py new file mode 100644 index 00000000..69f4b157 --- /dev/null +++ b/decnet/prober/probes/tcpfp.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import Any + +from decnet.prober.base import ActiveProbe +from decnet.prober.tcpfp import tcp_fingerprint +from decnet.telemetry import traced as _traced + +DEFAULT_PORTS: list[int] = [22, 80, 443, 8080, 8443, 445, 3389] + + +class TcpfpProbe(ActiveProbe): + probe_name = "tcpfp" + default_ports = DEFAULT_PORTS + event_type = "tcpfp_fingerprint" + rotation_type = "tcpfp" + rotation_hash_key = "tcpfp_hash" + priority = 100 + + @_traced("prober.tcpfp_probe") + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + return tcp_fingerprint(ip, port, timeout=timeout) + + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + fields = { + "tcpfp_hash": result["tcpfp_hash"], + "tcpfp_raw": result["tcpfp_raw"], + "ttl": str(result["ttl"]), + "window_size": str(result["window_size"]), + "df_bit": str(result["df_bit"]), + "mss": str(result["mss"]), + "window_scale": str(result["window_scale"]), + "sack_ok": str(result["sack_ok"]), + "timestamp": str(result["timestamp"]), + "options_order": result["options_order"], + "tos": str(result["tos"]), + "dscp": str(result["dscp"]), + "ecn": str(result["ecn"]), + "server_isn": str(result["server_isn"]), + } + return fields, f"TCPFP {ip}:{port} = {result['tcpfp_hash']}" + + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + return { + "attacker_ip": ip, + "port": port, + "tcpfp_hash": result["tcpfp_hash"], + "ttl": result["ttl"], + "mss": result["mss"], + } diff --git a/decnet/prober/worker.py b/decnet/prober/worker.py index d070748b..76e7663c 100644 --- a/decnet/prober/worker.py +++ b/decnet/prober/worker.py @@ -43,9 +43,8 @@ from decnet.correlation.fingerprint_rotation import ( record_fingerprint, ) from decnet.logging import get_logger -from decnet.prober.hassh import hassh_server -from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash -from decnet.prober.tcpfp import tcp_fingerprint +from decnet.prober.base import ActiveProbe, ActiveProbeMeta +import decnet.prober.probes as _probes # noqa: F401 — triggers metaclass registration from decnet.prober.tlscert import fetch_leaf_cert from decnet.telemetry import traced as _traced @@ -66,20 +65,6 @@ def _build_sync_engine() -> Engine: db_path = os.environ.get("DECNET_DB_PATH", str(_ROOT / "decnet.db")) return get_sync_engine(db_path) -# ─── Default ports per probe type ─────────────────────────────────────────── - -# JARM: common C2 callback / TLS server ports -DEFAULT_PROBE_PORTS: list[int] = [ - 443, 8443, 8080, 4443, 50050, 2222, 993, 995, 8888, 9001, -] - -# HASSHServer: common SSH server ports -DEFAULT_SSH_PORTS: list[int] = [22, 2222, 22222, 2022] - -# TCP/IP stack: probe on ports commonly open on attacker machines. -# Wide spread gives the best chance of a SYN-ACK for TTL/fingerprint extraction. -DEFAULT_TCPFP_PORTS: list[int] = [22, 80, 443, 8080, 8443, 445, 3389] - # ─── RFC 5424 formatting (inline, mirrors templates/*/decnet_logging.py) ───── _FACILITY_LOCAL0 = 16 @@ -259,94 +244,49 @@ ProbePublishFn = Callable[[str, dict[str, Any]], None] # performs the rotation-detection upsert + derived-event emission for the # DEBT-032 substrate-fingerprint flow. Optional; when None the prober # behaves exactly as before (raw fingerprint emit only, no rotation -# detection). Construction lives at worker startup so phase functions -# don't have to know about the DB engine. +# detection). Construction lives at worker startup so the probe driver +# doesn't have to know about the DB engine. RotationRecorderFn = Callable[[str, int, "ProbeType", str], None] -@_traced("prober.probe_cycle") -def _probe_cycle( - targets: set[str], - probed: dict[str, dict[str, set[int]]], - jarm_ports: list[int], - ssh_ports: list[int], - tcpfp_ports: list[int], - log_path: Path, - json_path: Path, - timeout: float = 5.0, - publish_fn: ProbePublishFn | None = None, - record_rotation: RotationRecorderFn | None = None, -) -> None: - """ - Probe all known attacker IPs with JARM, HASSH, and TCP/IP fingerprinting. - - Args: - targets: set of attacker IPs to probe - probed: dict mapping IP -> {probe_type -> set of ports already probed} - jarm_ports: TLS ports for JARM fingerprinting - ssh_ports: SSH ports for HASSHServer fingerprinting - tcpfp_ports: ports for TCP/IP stack fingerprinting - log_path: RFC 5424 log file - json_path: JSON log file - timeout: per-probe TCP timeout - """ - for ip in sorted(targets): - ip_probed = probed.setdefault(ip, {}) - - # Phase 1: JARM (TLS fingerprinting) - _jarm_phase(ip, ip_probed, jarm_ports, log_path, json_path, timeout, publish_fn, record_rotation) - - # Phase 2: HASSHServer (SSH fingerprinting) - _hassh_phase(ip, ip_probed, ssh_ports, log_path, json_path, timeout, publish_fn, record_rotation) - - # Phase 3: TCP/IP stack fingerprinting - _tcpfp_phase(ip, ip_probed, tcpfp_ports, log_path, json_path, timeout, publish_fn, record_rotation) - - # Phase 4: IPv6 link-local leak (active ICMPv6 solicitation; on-link only) - _ipv6_leak_phase(ip, ip_probed, log_path, json_path, timeout, publish_fn) - - -@_traced("prober.jarm_phase") -def _jarm_phase( +@_traced("prober.run_probe") +def _run_probe( + probe: ActiveProbe, ip: str, ip_probed: dict[str, set[int]], - ports: list[int], log_path: Path, json_path: Path, timeout: float, - publish_fn: ProbePublishFn | None = None, - record_rotation: RotationRecorderFn | None = None, + publish_fn: ProbePublishFn | None, + record_rotation: RotationRecorderFn | None, ) -> None: - """JARM-fingerprint an IP on the given TLS ports.""" - done = ip_probed.setdefault("jarm", set()) - for port in ports: + """Generic driver for any port-iterating ActiveProbe.""" + done = ip_probed.setdefault(probe.probe_name, set()) + for port in probe.ports: if port in done: continue try: - h = jarm_hash(ip, port, timeout=timeout) + result = probe.run(ip, port, timeout) done.add(port) - if h == JARM_EMPTY_HASH: + if result is None: continue + fields, msg = probe.syslog_fields(ip, port, result) _write_event( log_path, json_path, - "jarm_fingerprint", + probe.event_type, target_ip=ip, target_port=str(port), - jarm_hash=h, - msg=f"JARM {ip}:{port} = {h}", + msg=msg, + **fields, ) - logger.info("prober: JARM %s:%d = %s", ip, port, h) - if record_rotation is not None: - record_rotation(ip, port, "jarm", h) + logger.info("prober: %s %s:%d ok", probe.probe_name, ip, port) + if record_rotation is not None and probe.rotation_type and probe.rotation_hash_key: + record_rotation(ip, port, probe.rotation_type, result[probe.rotation_hash_key]) if publish_fn is not None: - publish_fn( - "jarm", - {"attacker_ip": ip, "port": port, "jarm_hash": h}, - ) - # Cert capture: a non-empty JARM hash proves the port speaks - # TLS, so a follow-up real handshake is worth attempting. - # Failures are silent — the next probe target must not stall. - _capture_tls_cert(ip, port, log_path, json_path, timeout, publish_fn) + publish_fn(probe.probe_name, probe.publish_payload(ip, port, result)) + if probe.probe_name == "jarm": + # A non-empty JARM hash proves TLS; attempt a real cert capture. + _capture_tls_cert(ip, port, log_path, json_path, timeout, publish_fn) except Exception as exc: done.add(port) _write_event( @@ -356,9 +296,34 @@ def _jarm_phase( target_ip=ip, target_port=str(port), error=str(exc), - msg=f"JARM probe failed for {ip}:{port}: {exc}", + msg=f"{probe.probe_name} probe failed for {ip}:{port}: {exc}", ) - logger.warning("prober: JARM probe failed %s:%d: %s", ip, port, exc) + logger.warning("prober: %s probe failed %s:%d: %s", probe.probe_name, ip, port, exc) + + +@_traced("prober.probe_cycle") +def _probe_cycle( + targets: set[str], + probed: dict[str, dict[str, set[int]]], + log_path: Path, + json_path: Path, + timeout: float = 5.0, + publish_fn: ProbePublishFn | None = None, + record_rotation: RotationRecorderFn | None = None, +) -> None: + """Probe all known attacker IPs via every registered ActiveProbe. + + Probes run in (priority, probe_name) order per ActiveProbeMeta.all(). + IPv6 leak runs last — it is not port-iterating and stays a special case. + """ + for ip in sorted(targets): + ip_probed = probed.setdefault(ip, {}) + + for probe_cls in ActiveProbeMeta.all(): + _run_probe(probe_cls(), ip, ip_probed, log_path, json_path, + timeout, publish_fn, record_rotation) + + _ipv6_leak_phase(ip, ip_probed, log_path, json_path, timeout, publish_fn) @_traced("prober.tls_cert_capture") @@ -415,137 +380,6 @@ def _capture_tls_cert( ) -@_traced("prober.hassh_phase") -def _hassh_phase( - ip: str, - ip_probed: dict[str, set[int]], - ports: list[int], - log_path: Path, - json_path: Path, - timeout: float, - publish_fn: ProbePublishFn | None = None, - record_rotation: RotationRecorderFn | None = None, -) -> None: - """HASSHServer-fingerprint an IP on the given SSH ports.""" - done = ip_probed.setdefault("hassh", set()) - for port in ports: - if port in done: - continue - try: - result = hassh_server(ip, port, timeout=timeout) - done.add(port) - if result is None: - continue - _write_event( - log_path, json_path, - "hassh_fingerprint", - target_ip=ip, - target_port=str(port), - hassh_server_hash=result["hassh_server"], - ssh_banner=result["banner"], - kex_algorithms=result["kex_algorithms"], - encryption_s2c=result["encryption_s2c"], - mac_s2c=result["mac_s2c"], - compression_s2c=result["compression_s2c"], - msg=f"HASSH {ip}:{port} = {result['hassh_server']}", - ) - logger.info("prober: HASSH %s:%d = %s", ip, port, result["hassh_server"]) - if record_rotation is not None: - record_rotation(ip, port, "hassh", result["hassh_server"]) - if publish_fn is not None: - publish_fn( - "hassh", - { - "attacker_ip": ip, - "port": port, - "hassh_server": result["hassh_server"], - "ssh_banner": result["banner"], - }, - ) - except Exception as exc: - done.add(port) - _write_event( - log_path, json_path, - "prober_error", - severity=_SEVERITY_WARNING, - target_ip=ip, - target_port=str(port), - error=str(exc), - msg=f"HASSH probe failed for {ip}:{port}: {exc}", - ) - logger.warning("prober: HASSH probe failed %s:%d: %s", ip, port, exc) - - -@_traced("prober.tcpfp_phase") -def _tcpfp_phase( - ip: str, - ip_probed: dict[str, set[int]], - ports: list[int], - log_path: Path, - json_path: Path, - timeout: float, - publish_fn: ProbePublishFn | None = None, - record_rotation: RotationRecorderFn | None = None, -) -> None: - """TCP/IP stack fingerprint an IP on the given ports.""" - done = ip_probed.setdefault("tcpfp", set()) - for port in ports: - if port in done: - continue - try: - result = tcp_fingerprint(ip, port, timeout=timeout) - done.add(port) - if result is None: - continue - _write_event( - log_path, json_path, - "tcpfp_fingerprint", - target_ip=ip, - target_port=str(port), - tcpfp_hash=result["tcpfp_hash"], - tcpfp_raw=result["tcpfp_raw"], - ttl=str(result["ttl"]), - window_size=str(result["window_size"]), - df_bit=str(result["df_bit"]), - mss=str(result["mss"]), - window_scale=str(result["window_scale"]), - sack_ok=str(result["sack_ok"]), - timestamp=str(result["timestamp"]), - options_order=result["options_order"], - tos=str(result["tos"]), - dscp=str(result["dscp"]), - ecn=str(result["ecn"]), - server_isn=str(result["server_isn"]), - msg=f"TCPFP {ip}:{port} = {result['tcpfp_hash']}", - ) - logger.info("prober: TCPFP %s:%d = %s", ip, port, result["tcpfp_hash"]) - if record_rotation is not None: - record_rotation(ip, port, "tcpfp", result["tcpfp_hash"]) - if publish_fn is not None: - publish_fn( - "tcpfp", - { - "attacker_ip": ip, - "port": port, - "tcpfp_hash": result["tcpfp_hash"], - "ttl": result["ttl"], - "mss": result["mss"], - }, - ) - except Exception as exc: - done.add(port) - _write_event( - log_path, json_path, - "prober_error", - severity=_SEVERITY_WARNING, - target_ip=ip, - target_port=str(port), - error=str(exc), - msg=f"TCPFP probe failed for {ip}:{port}: {exc}", - ) - logger.warning("prober: TCPFP probe failed %s:%d: %s", ip, port, exc) - - @_traced("prober.ipv6_leak_phase") def _ipv6_leak_phase( ip: str, @@ -622,49 +456,43 @@ async def prober_worker( log_file: str, interval: int = 300, timeout: float = 5.0, - ports: list[int] | None = None, - ssh_ports: list[int] | None = None, - tcpfp_ports: list[int] | None = None, ) -> None: """ Main entry point for the standalone prober process. Discovers attacker IPs automatically by tailing the JSON log file, - then fingerprints each IP via JARM, HASSH, and TCP/IP stack probes. + then fingerprints each IP via every registered ActiveProbe (JARM, + HASSH, TCP/IP stack) plus the IPv6 leak special case. + + Per-probe port lists are taken from each probe's ``default_ports`` + attribute. Override at runtime via DECNET_PROBE_PORTS_ + (comma-separated), e.g. DECNET_PROBE_PORTS_JARM="443,8443". Args: log_file: base path for log files (RFC 5424 to .log, JSON to .json) interval: seconds between probe cycles timeout: per-probe TCP timeout - ports: JARM TLS ports (defaults to DEFAULT_PROBE_PORTS) - ssh_ports: HASSH SSH ports (defaults to DEFAULT_SSH_PORTS) - tcpfp_ports: TCP fingerprint ports (defaults to DEFAULT_TCPFP_PORTS) """ - jarm_ports = ports or DEFAULT_PROBE_PORTS - hassh_ports = ssh_ports or DEFAULT_SSH_PORTS - tcp_ports = tcpfp_ports or DEFAULT_TCPFP_PORTS - - all_ports_str = ( - f"jarm={','.join(str(p) for p in jarm_ports)} " - f"ssh={','.join(str(p) for p in hassh_ports)} " - f"tcpfp={','.join(str(p) for p in tcp_ports)}" - ) - log_path = Path(log_file) json_path = log_path.with_suffix(".json") log_path.parent.mkdir(parents=True, exist_ok=True) + probe_summary = " ".join( + f"{cls.probe_name}={','.join(str(p) for p in cls().ports)}" + for cls in ActiveProbeMeta.all() + ) + logger.info( "prober started interval=%ds %s log=%s", - interval, all_ports_str, log_path, + interval, probe_summary, log_path, ) _write_event( log_path, json_path, "prober_startup", interval=str(interval), - probe_ports=all_ports_str, - msg=f"DECNET-PROBER started, interval {interval}s, {all_ports_str}", + probe_ports=probe_summary, + msg=f"DECNET-PROBER started, interval {interval}s, {probe_summary}", ) known_attackers: set[str] = set() @@ -776,7 +604,6 @@ async def prober_worker( if known_attackers: await asyncio.to_thread( _probe_cycle, known_attackers, probed, - jarm_ports, hassh_ports, tcp_ports, log_path, json_path, timeout, _publish_attacker, record_rotation, diff --git a/tests/prober/test_active_probe_registry.py b/tests/prober/test_active_probe_registry.py new file mode 100644 index 00000000..c2d57741 --- /dev/null +++ b/tests/prober/test_active_probe_registry.py @@ -0,0 +1,80 @@ +"""Tests for ActiveProbeMeta registry and ActiveProbe ABC contract.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from decnet.prober.base import ActiveProbe, ActiveProbeMeta +import decnet.prober.probes as _probes # noqa: F401 — ensure probes are registered + + +@pytest.fixture(autouse=True) +def _restore_registry(): + """Snapshot and restore the registry around each test so throwaway probes don't leak.""" + snapshot = dict(ActiveProbeMeta._registry) + yield + ActiveProbeMeta._registry.clear() + ActiveProbeMeta._registry.update(snapshot) + + +class TestRegistryContents: + + def test_all_three_probes_registered(self): + names = {cls.probe_name for cls in ActiveProbeMeta.all()} + assert names == {"jarm", "hassh", "tcpfp"} + + def test_sorted_by_priority_then_name(self): + order = [cls.probe_name for cls in ActiveProbeMeta.all()] + assert order == ["hassh", "jarm", "tcpfp"] # all priority=100, alphabetical + + def test_priority10_probe_sorts_first(self): + class _FastProbe(ActiveProbe): + probe_name = "_fast_test_probe" + default_ports = [9999] + event_type = "_fast_event" + priority = 10 + + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + return None + + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + return {}, "" + + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + return {} + + order = [cls.probe_name for cls in ActiveProbeMeta.all()] + assert order[0] == "_fast_test_probe" + assert set(order[1:]) == {"hassh", "jarm", "tcpfp"} + + def test_base_class_not_registered(self): + assert "ActiveProbe" not in ActiveProbeMeta._registry + assert None not in ActiveProbeMeta._registry.values() + + +class TestProbeABCContract: + + @pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all())) + def test_instantiable(self, probe_cls: type[ActiveProbe]): + instance = probe_cls() + assert isinstance(instance, ActiveProbe) + + @pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all())) + def test_has_required_class_attrs(self, probe_cls: type[ActiveProbe]): + assert isinstance(probe_cls.probe_name, str) and probe_cls.probe_name + assert isinstance(probe_cls.default_ports, list) and probe_cls.default_ports + assert isinstance(probe_cls.event_type, str) and probe_cls.event_type + assert isinstance(probe_cls.priority, int) + + @pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all())) + def test_ports_property_reflects_default(self, probe_cls: type[ActiveProbe]): + instance = probe_cls() + assert instance.ports == probe_cls.default_ports + + @pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all())) + def test_implements_abstract_methods(self, probe_cls: type[ActiveProbe]): + assert callable(getattr(probe_cls, "run")) + assert callable(getattr(probe_cls, "syslog_fields")) + assert callable(getattr(probe_cls, "publish_payload")) diff --git a/tests/prober/test_prober_worker.py b/tests/prober/test_prober_worker.py index 69b4a73b..480e4ccc 100644 --- a/tests/prober/test_prober_worker.py +++ b/tests/prober/test_prober_worker.py @@ -12,10 +12,10 @@ from unittest.mock import MagicMock, patch import pytest from decnet.prober.jarm import JARM_EMPTY_HASH +from decnet.prober.probes.hassh import HasshProbe +from decnet.prober.probes.jarm import JarmProbe +from decnet.prober.probes.tcpfp import TcpfpProbe from decnet.prober.worker import ( - DEFAULT_PROBE_PORTS, - DEFAULT_SSH_PORTS, - DEFAULT_TCPFP_PORTS, _discover_attackers, _probe_cycle, _write_event, @@ -109,13 +109,18 @@ class TestDiscoverAttackers: class TestProbeCycleJARM: + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert", return_value=None) - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_probes_new_ips(self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, - tmp_path: Path): + mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = "c0c" * 10 + "a" * 32 # fake 62-char hash mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -125,19 +130,24 @@ class TestProbeCycleJARM: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert mock_jarm.call_count == 2 # two ports assert 443 in probed["10.0.0.1"]["jarm"] assert 8443 in probed["10.0.0.1"]["jarm"] + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert", return_value=None) - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_skips_already_probed_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, - tmp_path: Path): + mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -147,17 +157,22 @@ class TestProbeCycleJARM: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"jarm": {443}}} - _probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) # Should only probe 8443 (443 already done) assert mock_jarm.call_count == 1 mock_jarm.assert_called_once_with("10.0.0.1", 8443, timeout=1.0) - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_empty_hash_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -167,18 +182,23 @@ class TestProbeCycleJARM: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert 443 in probed["10.0.0.1"]["jarm"] if json_path.exists(): content = json_path.read_text() assert "jarm_fingerprint" not in content - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_exception_marks_port_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.side_effect = OSError("Connection refused") mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -188,15 +208,20 @@ class TestProbeCycleJARM: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert 443 in probed["10.0.0.1"]["jarm"] - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_skips_ip_with_all_ports_done(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) log_path = tmp_path / "decnet.log" json_path = tmp_path / "decnet.json" @@ -205,7 +230,7 @@ class TestProbeCycleJARM: "10.0.0.1": {"jarm": {443, 8443}, "hassh": set(), "tcpfp": set()}, } - _probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert mock_jarm.call_count == 0 @@ -214,11 +239,16 @@ class TestProbeCycleJARM: class TestProbeCycleHASSH: - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_probes_ssh_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", [22, 2222]) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = { "hassh_server": "a" * 32, @@ -235,17 +265,22 @@ class TestProbeCycleHASSH: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [22, 2222], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert mock_hassh.call_count == 2 assert 22 in probed["10.0.0.1"]["hassh"] assert 2222 in probed["10.0.0.1"]["hassh"] - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_hassh_writes_event(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", [22]) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = { "hassh_server": "b" * 32, @@ -262,7 +297,7 @@ class TestProbeCycleHASSH: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert json_path.exists() content = json_path.read_text() @@ -271,11 +306,16 @@ class TestProbeCycleHASSH: assert record["fields"]["hassh_server_hash"] == "b" * 32 assert record["fields"]["ssh_banner"] == "SSH-2.0-Paramiko_3.0" - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_hassh_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", [22]) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None # No SSH server mock_tcpfp.return_value = None @@ -285,18 +325,23 @@ class TestProbeCycleHASSH: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert 22 in probed["10.0.0.1"]["hassh"] if json_path.exists(): content = json_path.read_text() assert "hassh_fingerprint" not in content - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_hassh_skips_already_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", [22, 2222]) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_tcpfp.return_value = None log_path = tmp_path / "decnet.log" @@ -305,16 +350,21 @@ class TestProbeCycleHASSH: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"hassh": {22}}} - _probe_cycle(targets, probed, [], [22, 2222], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert mock_hassh.call_count == 1 # only 2222 mock_hassh.assert_called_once_with("10.0.0.1", 2222, timeout=1.0) - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_hassh_exception_marks_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", [22]) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.side_effect = OSError("Connection refused") mock_tcpfp.return_value = None @@ -324,7 +374,7 @@ class TestProbeCycleHASSH: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert 22 in probed["10.0.0.1"]["hassh"] @@ -333,11 +383,16 @@ class TestProbeCycleHASSH: class TestProbeCycleTCPFP: - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_probes_tcpfp_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", [80, 443]) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = { @@ -354,17 +409,22 @@ class TestProbeCycleTCPFP: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [], [80, 443], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert mock_tcpfp.call_count == 2 assert 80 in probed["10.0.0.1"]["tcpfp"] assert 443 in probed["10.0.0.1"]["tcpfp"] - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_tcpfp_writes_event_with_all_fields(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", [443]) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = { @@ -381,7 +441,7 @@ class TestProbeCycleTCPFP: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [], [443], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) content = json_path.read_text() assert "tcpfp_fingerprint" in content @@ -391,11 +451,16 @@ class TestProbeCycleTCPFP: assert record["fields"]["window_size"] == "8192" assert record["fields"]["options_order"] == "M,N,W,N,N,S" - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_tcpfp_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", []) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", [443]) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -405,7 +470,7 @@ class TestProbeCycleTCPFP: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [], [], [443], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert 443 in probed["10.0.0.1"]["tcpfp"] if json_path.exists(): @@ -417,12 +482,17 @@ class TestProbeCycleTCPFP: class TestProbeTypeIsolation: - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_jarm_does_not_mark_hassh(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """JARM probing port 2222 should not mark HASSH port 2222 as done.""" + monkeypatch.setattr(JarmProbe, "default_ports", [2222]) + monkeypatch.setattr(HasshProbe, "default_ports", [2222]) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -432,8 +502,7 @@ class TestProbeTypeIsolation: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - # Probe with JARM on 2222 and HASSH on 2222 - _probe_cycle(targets, probed, [2222], [2222], [], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) # Both should be called assert mock_jarm.call_count == 1 @@ -441,11 +510,16 @@ class TestProbeTypeIsolation: assert 2222 in probed["10.0.0.1"]["jarm"] assert 2222 in probed["10.0.0.1"]["hassh"] - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.worker._ipv6_leak_phase") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_all_three_probes_run(self, mock_jarm: MagicMock, mock_hassh: MagicMock, - mock_tcpfp: MagicMock, tmp_path: Path): + mock_tcpfp: MagicMock, mock_ipv6: MagicMock, + tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", [22]) + monkeypatch.setattr(TcpfpProbe, "default_ports", [80]) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -455,7 +529,7 @@ class TestProbeTypeIsolation: targets = {"10.0.0.1"} probed: dict[str, dict[str, set[int]]] = {} - _probe_cycle(targets, probed, [443], [22], [80], log_path, json_path, timeout=1.0) + _probe_cycle(targets, probed, log_path, json_path, timeout=1.0) assert mock_jarm.call_count == 1 assert mock_hassh.call_count == 1 @@ -490,20 +564,26 @@ class TestWriteEvent: class TestProbeCycleTLSCert: + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert") - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_cert_event_emitted_after_successful_jarm( self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, + mock_ipv6: MagicMock, tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, ): """A non-empty JARM hash should trigger a follow-up cert fetch and write a tls_certificate event with all parsed fields.""" + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -519,7 +599,7 @@ class TestProbeCycleTLSCert: log_path = tmp_path / "decnet.log" json_path = tmp_path / "decnet.json" - _probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0) mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0) records = [ @@ -539,69 +619,87 @@ class TestProbeCycleTLSCert: assert f["sans"] == "evil.example.com,c2.example.com" assert f["cert_sha256"] == "ab" * 32 + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert") - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_cert_fetch_skipped_on_empty_jarm( self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, + mock_ipv6: MagicMock, tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, ): """JARM_EMPTY_HASH means the port doesn't speak TLS; skip cert fetch.""" + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = JARM_EMPTY_HASH mock_hassh.return_value = None mock_tcpfp.return_value = None log_path = tmp_path / "decnet.log" json_path = tmp_path / "decnet.json" - _probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0) mock_cert.assert_not_called() + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert", return_value=None) - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_cert_fetch_failure_silent( self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, + mock_ipv6: MagicMock, tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, ): """fetch_leaf_cert returning None must not write a cert event.""" + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_hassh.return_value = None mock_tcpfp.return_value = None log_path = tmp_path / "decnet.log" json_path = tmp_path / "decnet.json" - _probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0) mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0) if json_path.exists(): content = json_path.read_text() assert "tls_certificate" not in content + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert") - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_cert_fetch_crash_does_not_break_phase( self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, + mock_ipv6: MagicMock, tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, ): """If fetch_leaf_cert throws despite its contract, the JARM phase must keep moving to the next port without crashing.""" + monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -609,25 +707,30 @@ class TestProbeCycleTLSCert: log_path = tmp_path / "decnet.log" json_path = tmp_path / "decnet.json" - _probe_cycle({"10.0.0.1"}, {}, [443, 8443], [], [], log_path, json_path, timeout=1.0) + _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0) # Both ports still marked probed despite the cert-side crash. - from decnet.prober.worker import _probe_cycle as _ # re-import safety assert mock_cert.call_count == 2 + @patch("decnet.prober.worker._ipv6_leak_phase") @patch("decnet.prober.worker.fetch_leaf_cert") - @patch("decnet.prober.worker.tcp_fingerprint") - @patch("decnet.prober.worker.hassh_server") - @patch("decnet.prober.worker.jarm_hash") + @patch("decnet.prober.probes.tcpfp.tcp_fingerprint") + @patch("decnet.prober.probes.hassh.hassh_server") + @patch("decnet.prober.probes.jarm.jarm_hash") def test_cert_publish_fn_called( self, mock_jarm: MagicMock, mock_hassh: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock, + mock_ipv6: MagicMock, tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, ): """publish_fn must receive a 'tls_certificate' event when capture succeeds.""" + monkeypatch.setattr(JarmProbe, "default_ports", [443]) + monkeypatch.setattr(HasshProbe, "default_ports", []) + monkeypatch.setattr(TcpfpProbe, "default_ports", []) mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_hassh.return_value = None mock_tcpfp.return_value = None @@ -646,7 +749,7 @@ class TestProbeCycleTLSCert: published.append((kind, payload)) _probe_cycle( - {"10.0.0.1"}, {}, [443], [], [], + {"10.0.0.1"}, {}, tmp_path / "decnet.log", tmp_path / "decnet.json", timeout=1.0, publish_fn=publish, ) diff --git a/tests/prober/test_run_probe_driver.py b/tests/prober/test_run_probe_driver.py new file mode 100644 index 00000000..dc40f45b --- /dev/null +++ b/tests/prober/test_run_probe_driver.py @@ -0,0 +1,244 @@ +"""Unit tests for the _run_probe generic driver.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from decnet.prober.base import ActiveProbe, ActiveProbeMeta +from decnet.prober.worker import _run_probe + + +@pytest.fixture(autouse=True) +def _restore_registry(): + snapshot = dict(ActiveProbeMeta._registry) + yield + ActiveProbeMeta._registry.clear() + ActiveProbeMeta._registry.update(snapshot) + + +def _make_probe( + probe_name: str = "test_probe", + default_ports: list[int] | None = None, + run_return: dict[str, Any] | None = None, + run_side_effect: Exception | None = None, + rotation_type: str | None = "test", + rotation_hash_key: str | None = "hash", +) -> ActiveProbe: + """Build a concrete ActiveProbe subclass for testing and return an instance.""" + + _pname = probe_name + _ports = default_ports or [1234] + _result = run_return + _exc = run_side_effect + _rtype = rotation_type + _rkey = rotation_hash_key + + class _TestProbe(ActiveProbe): + probe_name = _pname # type: ignore[assignment] + default_ports = _ports # type: ignore[assignment] + event_type = f"{_pname}_event" + rotation_type = _rtype # type: ignore[assignment] + rotation_hash_key = _rkey + priority = 100 + + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + if _exc is not None: + raise _exc + return _result + + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + return {"hash": result.get("hash", "")}, f"{_pname} {ip}:{port}" + + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + return {"attacker_ip": ip, "port": port, "hash": result.get("hash", "")} + + return _TestProbe() + + +class TestRunProbeDedup: + + def test_skips_already_probed_port(self, tmp_path: Path): + probe = _make_probe(default_ports=[80, 443], run_return={"hash": "abc"}) + ip_probed: dict[str, set[int]] = {"test_probe": {80}} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=None) + + assert 80 in ip_probed["test_probe"] # was already there + assert 443 in ip_probed["test_probe"] # newly probed + + def test_initializes_done_set_if_missing(self, tmp_path: Path): + probe = _make_probe(default_ports=[22], run_return=None) + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=None) + + assert "test_probe" in ip_probed + assert 22 in ip_probed["test_probe"] + + +class TestRunProbeSuccessPath: + + def test_writes_event_on_success(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return={"hash": "deadbeef"}) + ip_probed: dict[str, set[int]] = {} + json_path = tmp_path / "events.json" + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path, + timeout=1.0, publish_fn=None, record_rotation=None) + + assert json_path.exists() + record = json.loads(json_path.read_text().strip()) + assert record["event_type"] == "test_probe_event" + assert record["fields"]["target_ip"] == "1.2.3.4" + assert record["fields"]["target_port"] == "443" + assert record["fields"]["hash"] == "deadbeef" + + def test_calls_publish_fn_on_success(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return={"hash": "cafebabe"}) + published: list[tuple[str, dict]] = [] + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=lambda k, v: published.append((k, v)), + record_rotation=None) + + assert len(published) == 1 + assert published[0][0] == "test_probe" + assert published[0][1]["attacker_ip"] == "1.2.3.4" + assert published[0][1]["hash"] == "cafebabe" + + def test_calls_record_rotation_when_configured(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return={"hash": "rotateme"}, + rotation_type="test", rotation_hash_key="hash") + mock_rotation = MagicMock() + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=mock_rotation) + + mock_rotation.assert_called_once_with("1.2.3.4", 443, "test", "rotateme") + + def test_skips_rotation_when_rotation_type_none(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return={"hash": "x"}, + rotation_type=None, rotation_hash_key=None) + mock_rotation = MagicMock() + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=mock_rotation) + + mock_rotation.assert_not_called() + + def test_skips_rotation_when_rotation_hash_key_none(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return={"hash": "x"}, + rotation_type="test", rotation_hash_key=None) + mock_rotation = MagicMock() + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=mock_rotation) + + mock_rotation.assert_not_called() + + +class TestRunProbeNoneResult: + + def test_none_suppresses_event(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return=None) + ip_probed: dict[str, set[int]] = {} + json_path = tmp_path / "events.json" + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path, + timeout=1.0, publish_fn=None, record_rotation=None) + + assert 443 in ip_probed["test_probe"] + assert not json_path.exists() + + def test_none_suppresses_publish(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], run_return=None) + published: list = [] + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=lambda k, v: published.append((k, v)), + record_rotation=None) + + assert len(published) == 0 + + +class TestRunProbeExceptionPath: + + def test_exception_marks_port_done(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], + run_side_effect=OSError("Connection refused")) + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=None) + + assert 443 in ip_probed["test_probe"] + + def test_exception_writes_prober_error_event(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], + run_side_effect=OSError("refused")) + ip_probed: dict[str, set[int]] = {} + json_path = tmp_path / "events.json" + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path, + timeout=1.0, publish_fn=None, record_rotation=None) + + assert json_path.exists() + record = json.loads(json_path.read_text().strip()) + assert record["event_type"] == "prober_error" + assert record["fields"]["target_ip"] == "1.2.3.4" + assert "refused" in record["fields"]["error"] + + def test_exception_does_not_publish(self, tmp_path: Path): + probe = _make_probe(default_ports=[443], + run_side_effect=RuntimeError("boom")) + published: list = [] + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=lambda k, v: published.append((k, v)), + record_rotation=None) + + assert len(published) == 0 + + def test_continues_remaining_ports_after_exception(self, tmp_path: Path): + call_count = 0 + + class _CountProbe(ActiveProbe): + probe_name = "_count_probe" + default_ports = [80, 443, 8080] + event_type = "_count_event" + priority = 100 + + def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None: + nonlocal call_count + call_count += 1 + if port == 443: + raise OSError("refused") + return None + + def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]: + return {}, "" + + def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]: + return {} + + probe = _CountProbe() + ip_probed: dict[str, set[int]] = {} + + _run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json", + timeout=1.0, publish_fn=None, record_rotation=None) + + assert call_count == 3 # all three ports attempted + assert {80, 443, 8080} == ip_probed["_count_probe"]