refactor(prober): ActiveProbe ABC + ActiveProbeMeta registry

Replace _jarm_phase / _hassh_phase / _tcpfp_phase boilerplate (3×~50
lines of identical port-iteration logic) with a metaclass-registered ABC.
Adding a new port-iterating active probe is now one class + three methods.

- decnet/prober/base.py: ActiveProbeMeta auto-registers subclasses by
  probe_name; ActiveProbe ABC enforces run/syslog_fields/publish_payload
  with env-driven DECNET_PROBE_PORTS_<NAME> port override.
- decnet/prober/probes/{jarm,hassh,tcpfp}.py: concrete probe classes.
- decnet/prober/worker.py: single _run_probe driver replaces the three
  phase functions; _probe_cycle iterates ActiveProbeMeta.all(); drops
  the ports=/ssh_ports=/tcpfp_ports= kwargs from prober_worker.
- IPv6 leak and TLS cert capture stay as special cases (different call
  shapes; intentionally outside the registry).
- tests/prober/test_active_probe_registry.py: registry contents, sort
  order, priority-10 override, ABC contract per probe class.
- tests/prober/test_run_probe_driver.py: dedup, success, None-skip,
  exception, rotation, publish paths for _run_probe.
- tests/prober/test_prober_worker.py: updated patch targets and
  _probe_cycle call sites; port control via monkeypatch.setattr.
This commit is contained in:
2026-05-17 23:16:35 -04:00
parent 3977f06374
commit 916b21b652
9 changed files with 810 additions and 339 deletions

View File

@@ -0,0 +1,80 @@
"""Tests for ActiveProbeMeta registry and ActiveProbe ABC contract."""
from __future__ import annotations
from typing import Any
import pytest
from decnet.prober.base import ActiveProbe, ActiveProbeMeta
import decnet.prober.probes as _probes # noqa: F401 — ensure probes are registered
@pytest.fixture(autouse=True)
def _restore_registry():
"""Snapshot and restore the registry around each test so throwaway probes don't leak."""
snapshot = dict(ActiveProbeMeta._registry)
yield
ActiveProbeMeta._registry.clear()
ActiveProbeMeta._registry.update(snapshot)
class TestRegistryContents:
def test_all_three_probes_registered(self):
names = {cls.probe_name for cls in ActiveProbeMeta.all()}
assert names == {"jarm", "hassh", "tcpfp"}
def test_sorted_by_priority_then_name(self):
order = [cls.probe_name for cls in ActiveProbeMeta.all()]
assert order == ["hassh", "jarm", "tcpfp"] # all priority=100, alphabetical
def test_priority10_probe_sorts_first(self):
class _FastProbe(ActiveProbe):
probe_name = "_fast_test_probe"
default_ports = [9999]
event_type = "_fast_event"
priority = 10
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
return None
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
return {}, ""
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {}
order = [cls.probe_name for cls in ActiveProbeMeta.all()]
assert order[0] == "_fast_test_probe"
assert set(order[1:]) == {"hassh", "jarm", "tcpfp"}
def test_base_class_not_registered(self):
assert "ActiveProbe" not in ActiveProbeMeta._registry
assert None not in ActiveProbeMeta._registry.values()
class TestProbeABCContract:
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_instantiable(self, probe_cls: type[ActiveProbe]):
instance = probe_cls()
assert isinstance(instance, ActiveProbe)
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_has_required_class_attrs(self, probe_cls: type[ActiveProbe]):
assert isinstance(probe_cls.probe_name, str) and probe_cls.probe_name
assert isinstance(probe_cls.default_ports, list) and probe_cls.default_ports
assert isinstance(probe_cls.event_type, str) and probe_cls.event_type
assert isinstance(probe_cls.priority, int)
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_ports_property_reflects_default(self, probe_cls: type[ActiveProbe]):
instance = probe_cls()
assert instance.ports == probe_cls.default_ports
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_implements_abstract_methods(self, probe_cls: type[ActiveProbe]):
assert callable(getattr(probe_cls, "run"))
assert callable(getattr(probe_cls, "syslog_fields"))
assert callable(getattr(probe_cls, "publish_payload"))

View File

@@ -12,10 +12,10 @@ from unittest.mock import MagicMock, patch
import pytest
from decnet.prober.jarm import JARM_EMPTY_HASH
from decnet.prober.probes.hassh import HasshProbe
from decnet.prober.probes.jarm import JarmProbe
from decnet.prober.probes.tcpfp import TcpfpProbe
from decnet.prober.worker import (
DEFAULT_PROBE_PORTS,
DEFAULT_SSH_PORTS,
DEFAULT_TCPFP_PORTS,
_discover_attackers,
_probe_cycle,
_write_event,
@@ -109,13 +109,18 @@ class TestDiscoverAttackers:
class TestProbeCycleJARM:
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert", return_value=None)
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_probes_new_ips(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_cert: MagicMock,
tmp_path: Path):
mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 # fake 62-char hash
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -125,19 +130,24 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_jarm.call_count == 2 # two ports
assert 443 in probed["10.0.0.1"]["jarm"]
assert 8443 in probed["10.0.0.1"]["jarm"]
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert", return_value=None)
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_skips_already_probed_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_cert: MagicMock,
tmp_path: Path):
mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -147,17 +157,22 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"jarm": {443}}}
_probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
# Should only probe 8443 (443 already done)
assert mock_jarm.call_count == 1
mock_jarm.assert_called_once_with("10.0.0.1", 8443, timeout=1.0)
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_empty_hash_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -167,18 +182,23 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 443 in probed["10.0.0.1"]["jarm"]
if json_path.exists():
content = json_path.read_text()
assert "jarm_fingerprint" not in content
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_exception_marks_port_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.side_effect = OSError("Connection refused")
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -188,15 +208,20 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 443 in probed["10.0.0.1"]["jarm"]
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_skips_ip_with_all_ports_done(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json"
@@ -205,7 +230,7 @@ class TestProbeCycleJARM:
"10.0.0.1": {"jarm": {443, 8443}, "hassh": set(), "tcpfp": set()},
}
_probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_jarm.call_count == 0
@@ -214,11 +239,16 @@ class TestProbeCycleJARM:
class TestProbeCycleHASSH:
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_probes_ssh_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22, 2222])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = {
"hassh_server": "a" * 32,
@@ -235,17 +265,22 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22, 2222], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_hassh.call_count == 2
assert 22 in probed["10.0.0.1"]["hassh"]
assert 2222 in probed["10.0.0.1"]["hassh"]
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_writes_event(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = {
"hassh_server": "b" * 32,
@@ -262,7 +297,7 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert json_path.exists()
content = json_path.read_text()
@@ -271,11 +306,16 @@ class TestProbeCycleHASSH:
assert record["fields"]["hassh_server_hash"] == "b" * 32
assert record["fields"]["ssh_banner"] == "SSH-2.0-Paramiko_3.0"
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None # No SSH server
mock_tcpfp.return_value = None
@@ -285,18 +325,23 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 22 in probed["10.0.0.1"]["hassh"]
if json_path.exists():
content = json_path.read_text()
assert "hassh_fingerprint" not in content
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_skips_already_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22, 2222])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_tcpfp.return_value = None
log_path = tmp_path / "decnet.log"
@@ -305,16 +350,21 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"hassh": {22}}}
_probe_cycle(targets, probed, [], [22, 2222], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_hassh.call_count == 1 # only 2222
mock_hassh.assert_called_once_with("10.0.0.1", 2222, timeout=1.0)
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_exception_marks_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.side_effect = OSError("Connection refused")
mock_tcpfp.return_value = None
@@ -324,7 +374,7 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 22 in probed["10.0.0.1"]["hassh"]
@@ -333,11 +383,16 @@ class TestProbeCycleHASSH:
class TestProbeCycleTCPFP:
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_probes_tcpfp_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [80, 443])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = {
@@ -354,17 +409,22 @@ class TestProbeCycleTCPFP:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [], [80, 443], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_tcpfp.call_count == 2
assert 80 in probed["10.0.0.1"]["tcpfp"]
assert 443 in probed["10.0.0.1"]["tcpfp"]
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_tcpfp_writes_event_with_all_fields(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [443])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = {
@@ -381,7 +441,7 @@ class TestProbeCycleTCPFP:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [], [443], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
content = json_path.read_text()
assert "tcpfp_fingerprint" in content
@@ -391,11 +451,16 @@ class TestProbeCycleTCPFP:
assert record["fields"]["window_size"] == "8192"
assert record["fields"]["options_order"] == "M,N,W,N,N,S"
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_tcpfp_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [443])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -405,7 +470,7 @@ class TestProbeCycleTCPFP:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [], [443], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 443 in probed["10.0.0.1"]["tcpfp"]
if json_path.exists():
@@ -417,12 +482,17 @@ class TestProbeCycleTCPFP:
class TestProbeTypeIsolation:
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_jarm_does_not_mark_hassh(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""JARM probing port 2222 should not mark HASSH port 2222 as done."""
monkeypatch.setattr(JarmProbe, "default_ports", [2222])
monkeypatch.setattr(HasshProbe, "default_ports", [2222])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -432,8 +502,7 @@ class TestProbeTypeIsolation:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
# Probe with JARM on 2222 and HASSH on 2222
_probe_cycle(targets, probed, [2222], [2222], [], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
# Both should be called
assert mock_jarm.call_count == 1
@@ -441,11 +510,16 @@ class TestProbeTypeIsolation:
assert 2222 in probed["10.0.0.1"]["jarm"]
assert 2222 in probed["10.0.0.1"]["hassh"]
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_all_three_probes_run(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path):
mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [80])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -455,7 +529,7 @@ class TestProbeTypeIsolation:
targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443], [22], [80], log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_jarm.call_count == 1
assert mock_hassh.call_count == 1
@@ -490,20 +564,26 @@ class TestWriteEvent:
class TestProbeCycleTLSCert:
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_event_emitted_after_successful_jarm(
self,
mock_jarm: MagicMock,
mock_hassh: MagicMock,
mock_tcpfp: MagicMock,
mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""A non-empty JARM hash should trigger a follow-up cert fetch and
write a tls_certificate event with all parsed fields."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -519,7 +599,7 @@ class TestProbeCycleTLSCert:
log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0)
records = [
@@ -539,69 +619,87 @@ class TestProbeCycleTLSCert:
assert f["sans"] == "evil.example.com,c2.example.com"
assert f["cert_sha256"] == "ab" * 32
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_fetch_skipped_on_empty_jarm(
self,
mock_jarm: MagicMock,
mock_hassh: MagicMock,
mock_tcpfp: MagicMock,
mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""JARM_EMPTY_HASH means the port doesn't speak TLS; skip cert fetch."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None
mock_tcpfp.return_value = None
log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
mock_cert.assert_not_called()
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert", return_value=None)
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_fetch_failure_silent(
self,
mock_jarm: MagicMock,
mock_hassh: MagicMock,
mock_tcpfp: MagicMock,
mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""fetch_leaf_cert returning None must not write a cert event."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None
mock_tcpfp.return_value = None
log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0)
if json_path.exists():
content = json_path.read_text()
assert "tls_certificate" not in content
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_fetch_crash_does_not_break_phase(
self,
mock_jarm: MagicMock,
mock_hassh: MagicMock,
mock_tcpfp: MagicMock,
mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""If fetch_leaf_cert throws despite its contract, the JARM phase
must keep moving to the next port without crashing."""
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -609,25 +707,30 @@ class TestProbeCycleTLSCert:
log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443, 8443], [], [], log_path, json_path, timeout=1.0)
_probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
# Both ports still marked probed despite the cert-side crash.
from decnet.prober.worker import _probe_cycle as _ # re-import safety
assert mock_cert.call_count == 2
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server")
@patch("decnet.prober.worker.jarm_hash")
@patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_publish_fn_called(
self,
mock_jarm: MagicMock,
mock_hassh: MagicMock,
mock_tcpfp: MagicMock,
mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
):
"""publish_fn must receive a 'tls_certificate' event when capture succeeds."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None
mock_tcpfp.return_value = None
@@ -646,7 +749,7 @@ class TestProbeCycleTLSCert:
published.append((kind, payload))
_probe_cycle(
{"10.0.0.1"}, {}, [443], [], [],
{"10.0.0.1"}, {},
tmp_path / "decnet.log", tmp_path / "decnet.json",
timeout=1.0, publish_fn=publish,
)

View File

@@ -0,0 +1,244 @@
"""Unit tests for the _run_probe generic driver."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from decnet.prober.base import ActiveProbe, ActiveProbeMeta
from decnet.prober.worker import _run_probe
@pytest.fixture(autouse=True)
def _restore_registry():
snapshot = dict(ActiveProbeMeta._registry)
yield
ActiveProbeMeta._registry.clear()
ActiveProbeMeta._registry.update(snapshot)
def _make_probe(
probe_name: str = "test_probe",
default_ports: list[int] | None = None,
run_return: dict[str, Any] | None = None,
run_side_effect: Exception | None = None,
rotation_type: str | None = "test",
rotation_hash_key: str | None = "hash",
) -> ActiveProbe:
"""Build a concrete ActiveProbe subclass for testing and return an instance."""
_pname = probe_name
_ports = default_ports or [1234]
_result = run_return
_exc = run_side_effect
_rtype = rotation_type
_rkey = rotation_hash_key
class _TestProbe(ActiveProbe):
probe_name = _pname # type: ignore[assignment]
default_ports = _ports # type: ignore[assignment]
event_type = f"{_pname}_event"
rotation_type = _rtype # type: ignore[assignment]
rotation_hash_key = _rkey
priority = 100
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
if _exc is not None:
raise _exc
return _result
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
return {"hash": result.get("hash", "")}, f"{_pname} {ip}:{port}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {"attacker_ip": ip, "port": port, "hash": result.get("hash", "")}
return _TestProbe()
class TestRunProbeDedup:
def test_skips_already_probed_port(self, tmp_path: Path):
probe = _make_probe(default_ports=[80, 443], run_return={"hash": "abc"})
ip_probed: dict[str, set[int]] = {"test_probe": {80}}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert 80 in ip_probed["test_probe"] # was already there
assert 443 in ip_probed["test_probe"] # newly probed
def test_initializes_done_set_if_missing(self, tmp_path: Path):
probe = _make_probe(default_ports=[22], run_return=None)
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert "test_probe" in ip_probed
assert 22 in ip_probed["test_probe"]
class TestRunProbeSuccessPath:
def test_writes_event_on_success(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "deadbeef"})
ip_probed: dict[str, set[int]] = {}
json_path = tmp_path / "events.json"
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path,
timeout=1.0, publish_fn=None, record_rotation=None)
assert json_path.exists()
record = json.loads(json_path.read_text().strip())
assert record["event_type"] == "test_probe_event"
assert record["fields"]["target_ip"] == "1.2.3.4"
assert record["fields"]["target_port"] == "443"
assert record["fields"]["hash"] == "deadbeef"
def test_calls_publish_fn_on_success(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "cafebabe"})
published: list[tuple[str, dict]] = []
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=lambda k, v: published.append((k, v)),
record_rotation=None)
assert len(published) == 1
assert published[0][0] == "test_probe"
assert published[0][1]["attacker_ip"] == "1.2.3.4"
assert published[0][1]["hash"] == "cafebabe"
def test_calls_record_rotation_when_configured(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "rotateme"},
rotation_type="test", rotation_hash_key="hash")
mock_rotation = MagicMock()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=mock_rotation)
mock_rotation.assert_called_once_with("1.2.3.4", 443, "test", "rotateme")
def test_skips_rotation_when_rotation_type_none(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "x"},
rotation_type=None, rotation_hash_key=None)
mock_rotation = MagicMock()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=mock_rotation)
mock_rotation.assert_not_called()
def test_skips_rotation_when_rotation_hash_key_none(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "x"},
rotation_type="test", rotation_hash_key=None)
mock_rotation = MagicMock()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=mock_rotation)
mock_rotation.assert_not_called()
class TestRunProbeNoneResult:
def test_none_suppresses_event(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return=None)
ip_probed: dict[str, set[int]] = {}
json_path = tmp_path / "events.json"
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path,
timeout=1.0, publish_fn=None, record_rotation=None)
assert 443 in ip_probed["test_probe"]
assert not json_path.exists()
def test_none_suppresses_publish(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return=None)
published: list = []
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=lambda k, v: published.append((k, v)),
record_rotation=None)
assert len(published) == 0
class TestRunProbeExceptionPath:
def test_exception_marks_port_done(self, tmp_path: Path):
probe = _make_probe(default_ports=[443],
run_side_effect=OSError("Connection refused"))
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert 443 in ip_probed["test_probe"]
def test_exception_writes_prober_error_event(self, tmp_path: Path):
probe = _make_probe(default_ports=[443],
run_side_effect=OSError("refused"))
ip_probed: dict[str, set[int]] = {}
json_path = tmp_path / "events.json"
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path,
timeout=1.0, publish_fn=None, record_rotation=None)
assert json_path.exists()
record = json.loads(json_path.read_text().strip())
assert record["event_type"] == "prober_error"
assert record["fields"]["target_ip"] == "1.2.3.4"
assert "refused" in record["fields"]["error"]
def test_exception_does_not_publish(self, tmp_path: Path):
probe = _make_probe(default_ports=[443],
run_side_effect=RuntimeError("boom"))
published: list = []
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=lambda k, v: published.append((k, v)),
record_rotation=None)
assert len(published) == 0
def test_continues_remaining_ports_after_exception(self, tmp_path: Path):
call_count = 0
class _CountProbe(ActiveProbe):
probe_name = "_count_probe"
default_ports = [80, 443, 8080]
event_type = "_count_event"
priority = 100
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
nonlocal call_count
call_count += 1
if port == 443:
raise OSError("refused")
return None
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
return {}, ""
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {}
probe = _CountProbe()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert call_count == 3 # all three ports attempted
assert {80, 443, 8080} == ip_probed["_count_probe"]