refactor(prober): ActiveProbe ABC + ActiveProbeMeta registry

Replace _jarm_phase / _hassh_phase / _tcpfp_phase boilerplate (3×~50
lines of identical port-iteration logic) with a metaclass-registered ABC.
Adding a new port-iterating active probe is now one class + three methods.

- decnet/prober/base.py: ActiveProbeMeta auto-registers subclasses by
  probe_name; ActiveProbe ABC enforces run/syslog_fields/publish_payload
  with env-driven DECNET_PROBE_PORTS_<NAME> port override.
- decnet/prober/probes/{jarm,hassh,tcpfp}.py: concrete probe classes.
- decnet/prober/worker.py: single _run_probe driver replaces the three
  phase functions; _probe_cycle iterates ActiveProbeMeta.all(); drops
  the ports=/ssh_ports=/tcpfp_ports= kwargs from prober_worker.
- IPv6 leak and TLS cert capture stay as special cases (different call
  shapes; intentionally outside the registry).
- tests/prober/test_active_probe_registry.py: registry contents, sort
  order, priority-10 override, ABC contract per probe class.
- tests/prober/test_run_probe_driver.py: dedup, success, None-skip,
  exception, rotation, publish paths for _run_probe.
- tests/prober/test_prober_worker.py: updated patch targets and
  _probe_cycle call sites; port control via monkeypatch.setattr.
This commit is contained in:
2026-05-17 23:16:35 -04:00
parent 3977f06374
commit 916b21b652
9 changed files with 810 additions and 339 deletions

90
decnet/prober/base.py Normal file
View File

@@ -0,0 +1,90 @@
"""
ActiveProbe ABC and metaclass registry for port-iterating active probes.
Adding a new active probe = one class with three methods.
IPv6 leak and TLS cert capture are NOT part of this registry (different
call shapes); they stay as special cases in prober/worker.py.
"""
from __future__ import annotations
import os
from abc import ABCMeta, abstractmethod
from typing import Any
from decnet.correlation.fingerprint_rotation import ProbeType
class ActiveProbeMeta(ABCMeta):
"""Metaclass that auto-registers every ActiveProbe subclass by probe_name."""
_registry: dict[str, type[ActiveProbe]] = {}
def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
) -> ActiveProbeMeta:
cls = super().__new__(mcs, name, bases, namespace)
if bases and getattr(cls, "probe_name", None):
mcs._registry[cls.probe_name] = cls # type: ignore[attr-defined,assignment]
return cls
@classmethod
def all(mcs) -> list[type[ActiveProbe]]:
"""Return registered probes sorted by (priority asc, probe_name asc)."""
return sorted(mcs._registry.values(), key=lambda c: (c.priority, c.probe_name))
class ActiveProbe(metaclass=ActiveProbeMeta):
"""Base class for all port-iterating active probes.
Subclasses declare class-level attributes and implement three methods.
Registration is automatic via ActiveProbeMeta.
Port override: set DECNET_PROBE_PORTS_<NAME_UPPER> (comma-separated) to
override default_ports at runtime without touching the class.
"""
probe_name: str
default_ports: list[int]
event_type: str
rotation_type: ProbeType | None = None
rotation_hash_key: str | None = None
priority: int = 100
def __init__(self) -> None:
env_key = f"DECNET_PROBE_PORTS_{self.probe_name.upper()}"
raw = os.environ.get(env_key, "").strip()
if raw:
try:
self._ports: list[int] = [int(p.strip()) for p in raw.split(",") if p.strip()]
except ValueError:
self._ports = list(self.default_ports)
else:
self._ports = list(self.default_ports)
@property
def ports(self) -> list[int]:
return self._ports
@abstractmethod
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
"""Execute the probe against ip:port.
Return a result dict on success, or None to suppress emission (e.g.
empty JARM hash means the port doesn't speak TLS).
"""
@abstractmethod
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
"""Return (sd_fields, human_msg) for _write_event.
target_ip and target_port are injected by _run_probe; do not include
them in sd_fields.
"""
@abstractmethod
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
"""Return the bus payload dict for attacker.fingerprinted events."""

View File

@@ -0,0 +1,4 @@
# Import all probe modules to trigger ActiveProbeMeta registration.
from decnet.prober.probes.hassh import HasshProbe as HasshProbe
from decnet.prober.probes.jarm import JarmProbe as JarmProbe
from decnet.prober.probes.tcpfp import TcpfpProbe as TcpfpProbe

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from typing import Any
from decnet.prober.base import ActiveProbe
from decnet.prober.hassh import hassh_server
from decnet.telemetry import traced as _traced
DEFAULT_PORTS: list[int] = [22, 2222, 22222, 2022]
class HasshProbe(ActiveProbe):
probe_name = "hassh"
default_ports = DEFAULT_PORTS
event_type = "hassh_fingerprint"
rotation_type = "hassh"
rotation_hash_key = "hassh_server"
priority = 100
@_traced("prober.hassh_probe")
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
return hassh_server(ip, port, timeout=timeout)
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
fields = {
"hassh_server_hash": result["hassh_server"],
"ssh_banner": result["banner"],
"kex_algorithms": result["kex_algorithms"],
"encryption_s2c": result["encryption_s2c"],
"mac_s2c": result["mac_s2c"],
"compression_s2c": result["compression_s2c"],
}
return fields, f"HASSH {ip}:{port} = {result['hassh_server']}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {
"attacker_ip": ip,
"port": port,
"hassh_server": result["hassh_server"],
"ssh_banner": result["banner"],
}

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from typing import Any
from decnet.prober.base import ActiveProbe
from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash
from decnet.telemetry import traced as _traced
DEFAULT_PORTS: list[int] = [443, 8443, 8080, 4443, 50050, 2222, 993, 995, 8888, 9001]
class JarmProbe(ActiveProbe):
probe_name = "jarm"
default_ports = DEFAULT_PORTS
event_type = "jarm_fingerprint"
rotation_type = "jarm"
rotation_hash_key = "jarm_hash"
priority = 100
@_traced("prober.jarm_probe")
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
h = jarm_hash(ip, port, timeout=timeout)
if h == JARM_EMPTY_HASH:
return None
return {"jarm_hash": h}
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
h = result["jarm_hash"]
return {"jarm_hash": h}, f"JARM {ip}:{port} = {h}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {"attacker_ip": ip, "port": port, "jarm_hash": result["jarm_hash"]}

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from typing import Any
from decnet.prober.base import ActiveProbe
from decnet.prober.tcpfp import tcp_fingerprint
from decnet.telemetry import traced as _traced
DEFAULT_PORTS: list[int] = [22, 80, 443, 8080, 8443, 445, 3389]
class TcpfpProbe(ActiveProbe):
probe_name = "tcpfp"
default_ports = DEFAULT_PORTS
event_type = "tcpfp_fingerprint"
rotation_type = "tcpfp"
rotation_hash_key = "tcpfp_hash"
priority = 100
@_traced("prober.tcpfp_probe")
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
return tcp_fingerprint(ip, port, timeout=timeout)
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
fields = {
"tcpfp_hash": result["tcpfp_hash"],
"tcpfp_raw": result["tcpfp_raw"],
"ttl": str(result["ttl"]),
"window_size": str(result["window_size"]),
"df_bit": str(result["df_bit"]),
"mss": str(result["mss"]),
"window_scale": str(result["window_scale"]),
"sack_ok": str(result["sack_ok"]),
"timestamp": str(result["timestamp"]),
"options_order": result["options_order"],
"tos": str(result["tos"]),
"dscp": str(result["dscp"]),
"ecn": str(result["ecn"]),
"server_isn": str(result["server_isn"]),
}
return fields, f"TCPFP {ip}:{port} = {result['tcpfp_hash']}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {
"attacker_ip": ip,
"port": port,
"tcpfp_hash": result["tcpfp_hash"],
"ttl": result["ttl"],
"mss": result["mss"],
}

View File

@@ -43,9 +43,8 @@ from decnet.correlation.fingerprint_rotation import (
record_fingerprint, record_fingerprint,
) )
from decnet.logging import get_logger from decnet.logging import get_logger
from decnet.prober.hassh import hassh_server from decnet.prober.base import ActiveProbe, ActiveProbeMeta
from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash import decnet.prober.probes as _probes # noqa: F401 — triggers metaclass registration
from decnet.prober.tcpfp import tcp_fingerprint
from decnet.prober.tlscert import fetch_leaf_cert from decnet.prober.tlscert import fetch_leaf_cert
from decnet.telemetry import traced as _traced from decnet.telemetry import traced as _traced
@@ -66,20 +65,6 @@ def _build_sync_engine() -> Engine:
db_path = os.environ.get("DECNET_DB_PATH", str(_ROOT / "decnet.db")) db_path = os.environ.get("DECNET_DB_PATH", str(_ROOT / "decnet.db"))
return get_sync_engine(db_path) return get_sync_engine(db_path)
# ─── Default ports per probe type ───────────────────────────────────────────
# JARM: common C2 callback / TLS server ports
DEFAULT_PROBE_PORTS: list[int] = [
443, 8443, 8080, 4443, 50050, 2222, 993, 995, 8888, 9001,
]
# HASSHServer: common SSH server ports
DEFAULT_SSH_PORTS: list[int] = [22, 2222, 22222, 2022]
# TCP/IP stack: probe on ports commonly open on attacker machines.
# Wide spread gives the best chance of a SYN-ACK for TTL/fingerprint extraction.
DEFAULT_TCPFP_PORTS: list[int] = [22, 80, 443, 8080, 8443, 445, 3389]
# ─── RFC 5424 formatting (inline, mirrors templates/*/decnet_logging.py) ───── # ─── RFC 5424 formatting (inline, mirrors templates/*/decnet_logging.py) ─────
_FACILITY_LOCAL0 = 16 _FACILITY_LOCAL0 = 16
@@ -259,94 +244,49 @@ ProbePublishFn = Callable[[str, dict[str, Any]], None]
# performs the rotation-detection upsert + derived-event emission for the # performs the rotation-detection upsert + derived-event emission for the
# DEBT-032 substrate-fingerprint flow. Optional; when None the prober # DEBT-032 substrate-fingerprint flow. Optional; when None the prober
# behaves exactly as before (raw fingerprint emit only, no rotation # behaves exactly as before (raw fingerprint emit only, no rotation
# detection). Construction lives at worker startup so phase functions # detection). Construction lives at worker startup so the probe driver
# don't have to know about the DB engine. # doesn't have to know about the DB engine.
RotationRecorderFn = Callable[[str, int, "ProbeType", str], None] RotationRecorderFn = Callable[[str, int, "ProbeType", str], None]
@_traced("prober.probe_cycle") @_traced("prober.run_probe")
def _probe_cycle( def _run_probe(
targets: set[str], probe: ActiveProbe,
probed: dict[str, dict[str, set[int]]],
jarm_ports: list[int],
ssh_ports: list[int],
tcpfp_ports: list[int],
log_path: Path,
json_path: Path,
timeout: float = 5.0,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""
Probe all known attacker IPs with JARM, HASSH, and TCP/IP fingerprinting.
Args:
targets: set of attacker IPs to probe
probed: dict mapping IP -> {probe_type -> set of ports already probed}
jarm_ports: TLS ports for JARM fingerprinting
ssh_ports: SSH ports for HASSHServer fingerprinting
tcpfp_ports: ports for TCP/IP stack fingerprinting
log_path: RFC 5424 log file
json_path: JSON log file
timeout: per-probe TCP timeout
"""
for ip in sorted(targets):
ip_probed = probed.setdefault(ip, {})
# Phase 1: JARM (TLS fingerprinting)
_jarm_phase(ip, ip_probed, jarm_ports, log_path, json_path, timeout, publish_fn, record_rotation)
# Phase 2: HASSHServer (SSH fingerprinting)
_hassh_phase(ip, ip_probed, ssh_ports, log_path, json_path, timeout, publish_fn, record_rotation)
# Phase 3: TCP/IP stack fingerprinting
_tcpfp_phase(ip, ip_probed, tcpfp_ports, log_path, json_path, timeout, publish_fn, record_rotation)
# Phase 4: IPv6 link-local leak (active ICMPv6 solicitation; on-link only)
_ipv6_leak_phase(ip, ip_probed, log_path, json_path, timeout, publish_fn)
@_traced("prober.jarm_phase")
def _jarm_phase(
ip: str, ip: str,
ip_probed: dict[str, set[int]], ip_probed: dict[str, set[int]],
ports: list[int],
log_path: Path, log_path: Path,
json_path: Path, json_path: Path,
timeout: float, timeout: float,
publish_fn: ProbePublishFn | None = None, publish_fn: ProbePublishFn | None,
record_rotation: RotationRecorderFn | None = None, record_rotation: RotationRecorderFn | None,
) -> None: ) -> None:
"""JARM-fingerprint an IP on the given TLS ports.""" """Generic driver for any port-iterating ActiveProbe."""
done = ip_probed.setdefault("jarm", set()) done = ip_probed.setdefault(probe.probe_name, set())
for port in ports: for port in probe.ports:
if port in done: if port in done:
continue continue
try: try:
h = jarm_hash(ip, port, timeout=timeout) result = probe.run(ip, port, timeout)
done.add(port) done.add(port)
if h == JARM_EMPTY_HASH: if result is None:
continue continue
fields, msg = probe.syslog_fields(ip, port, result)
_write_event( _write_event(
log_path, json_path, log_path, json_path,
"jarm_fingerprint", probe.event_type,
target_ip=ip, target_ip=ip,
target_port=str(port), target_port=str(port),
jarm_hash=h, msg=msg,
msg=f"JARM {ip}:{port} = {h}", **fields,
) )
logger.info("prober: JARM %s:%d = %s", ip, port, h) logger.info("prober: %s %s:%d ok", probe.probe_name, ip, port)
if record_rotation is not None: if record_rotation is not None and probe.rotation_type and probe.rotation_hash_key:
record_rotation(ip, port, "jarm", h) record_rotation(ip, port, probe.rotation_type, result[probe.rotation_hash_key])
if publish_fn is not None: if publish_fn is not None:
publish_fn( publish_fn(probe.probe_name, probe.publish_payload(ip, port, result))
"jarm", if probe.probe_name == "jarm":
{"attacker_ip": ip, "port": port, "jarm_hash": h}, # A non-empty JARM hash proves TLS; attempt a real cert capture.
) _capture_tls_cert(ip, port, log_path, json_path, timeout, publish_fn)
# Cert capture: a non-empty JARM hash proves the port speaks
# TLS, so a follow-up real handshake is worth attempting.
# Failures are silent — the next probe target must not stall.
_capture_tls_cert(ip, port, log_path, json_path, timeout, publish_fn)
except Exception as exc: except Exception as exc:
done.add(port) done.add(port)
_write_event( _write_event(
@@ -356,9 +296,34 @@ def _jarm_phase(
target_ip=ip, target_ip=ip,
target_port=str(port), target_port=str(port),
error=str(exc), error=str(exc),
msg=f"JARM probe failed for {ip}:{port}: {exc}", msg=f"{probe.probe_name} probe failed for {ip}:{port}: {exc}",
) )
logger.warning("prober: JARM probe failed %s:%d: %s", ip, port, exc) logger.warning("prober: %s probe failed %s:%d: %s", probe.probe_name, ip, port, exc)
@_traced("prober.probe_cycle")
def _probe_cycle(
targets: set[str],
probed: dict[str, dict[str, set[int]]],
log_path: Path,
json_path: Path,
timeout: float = 5.0,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""Probe all known attacker IPs via every registered ActiveProbe.
Probes run in (priority, probe_name) order per ActiveProbeMeta.all().
IPv6 leak runs last — it is not port-iterating and stays a special case.
"""
for ip in sorted(targets):
ip_probed = probed.setdefault(ip, {})
for probe_cls in ActiveProbeMeta.all():
_run_probe(probe_cls(), ip, ip_probed, log_path, json_path,
timeout, publish_fn, record_rotation)
_ipv6_leak_phase(ip, ip_probed, log_path, json_path, timeout, publish_fn)
@_traced("prober.tls_cert_capture") @_traced("prober.tls_cert_capture")
@@ -415,137 +380,6 @@ def _capture_tls_cert(
) )
@_traced("prober.hassh_phase")
def _hassh_phase(
ip: str,
ip_probed: dict[str, set[int]],
ports: list[int],
log_path: Path,
json_path: Path,
timeout: float,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""HASSHServer-fingerprint an IP on the given SSH ports."""
done = ip_probed.setdefault("hassh", set())
for port in ports:
if port in done:
continue
try:
result = hassh_server(ip, port, timeout=timeout)
done.add(port)
if result is None:
continue
_write_event(
log_path, json_path,
"hassh_fingerprint",
target_ip=ip,
target_port=str(port),
hassh_server_hash=result["hassh_server"],
ssh_banner=result["banner"],
kex_algorithms=result["kex_algorithms"],
encryption_s2c=result["encryption_s2c"],
mac_s2c=result["mac_s2c"],
compression_s2c=result["compression_s2c"],
msg=f"HASSH {ip}:{port} = {result['hassh_server']}",
)
logger.info("prober: HASSH %s:%d = %s", ip, port, result["hassh_server"])
if record_rotation is not None:
record_rotation(ip, port, "hassh", result["hassh_server"])
if publish_fn is not None:
publish_fn(
"hassh",
{
"attacker_ip": ip,
"port": port,
"hassh_server": result["hassh_server"],
"ssh_banner": result["banner"],
},
)
except Exception as exc:
done.add(port)
_write_event(
log_path, json_path,
"prober_error",
severity=_SEVERITY_WARNING,
target_ip=ip,
target_port=str(port),
error=str(exc),
msg=f"HASSH probe failed for {ip}:{port}: {exc}",
)
logger.warning("prober: HASSH probe failed %s:%d: %s", ip, port, exc)
@_traced("prober.tcpfp_phase")
def _tcpfp_phase(
ip: str,
ip_probed: dict[str, set[int]],
ports: list[int],
log_path: Path,
json_path: Path,
timeout: float,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""TCP/IP stack fingerprint an IP on the given ports."""
done = ip_probed.setdefault("tcpfp", set())
for port in ports:
if port in done:
continue
try:
result = tcp_fingerprint(ip, port, timeout=timeout)
done.add(port)
if result is None:
continue
_write_event(
log_path, json_path,
"tcpfp_fingerprint",
target_ip=ip,
target_port=str(port),
tcpfp_hash=result["tcpfp_hash"],
tcpfp_raw=result["tcpfp_raw"],
ttl=str(result["ttl"]),
window_size=str(result["window_size"]),
df_bit=str(result["df_bit"]),
mss=str(result["mss"]),
window_scale=str(result["window_scale"]),
sack_ok=str(result["sack_ok"]),
timestamp=str(result["timestamp"]),
options_order=result["options_order"],
tos=str(result["tos"]),
dscp=str(result["dscp"]),
ecn=str(result["ecn"]),
server_isn=str(result["server_isn"]),
msg=f"TCPFP {ip}:{port} = {result['tcpfp_hash']}",
)
logger.info("prober: TCPFP %s:%d = %s", ip, port, result["tcpfp_hash"])
if record_rotation is not None:
record_rotation(ip, port, "tcpfp", result["tcpfp_hash"])
if publish_fn is not None:
publish_fn(
"tcpfp",
{
"attacker_ip": ip,
"port": port,
"tcpfp_hash": result["tcpfp_hash"],
"ttl": result["ttl"],
"mss": result["mss"],
},
)
except Exception as exc:
done.add(port)
_write_event(
log_path, json_path,
"prober_error",
severity=_SEVERITY_WARNING,
target_ip=ip,
target_port=str(port),
error=str(exc),
msg=f"TCPFP probe failed for {ip}:{port}: {exc}",
)
logger.warning("prober: TCPFP probe failed %s:%d: %s", ip, port, exc)
@_traced("prober.ipv6_leak_phase") @_traced("prober.ipv6_leak_phase")
def _ipv6_leak_phase( def _ipv6_leak_phase(
ip: str, ip: str,
@@ -622,49 +456,43 @@ async def prober_worker(
log_file: str, log_file: str,
interval: int = 300, interval: int = 300,
timeout: float = 5.0, timeout: float = 5.0,
ports: list[int] | None = None,
ssh_ports: list[int] | None = None,
tcpfp_ports: list[int] | None = None,
) -> None: ) -> None:
""" """
Main entry point for the standalone prober process. Main entry point for the standalone prober process.
Discovers attacker IPs automatically by tailing the JSON log file, Discovers attacker IPs automatically by tailing the JSON log file,
then fingerprints each IP via JARM, HASSH, and TCP/IP stack probes. then fingerprints each IP via every registered ActiveProbe (JARM,
HASSH, TCP/IP stack) plus the IPv6 leak special case.
Per-probe port lists are taken from each probe's ``default_ports``
attribute. Override at runtime via DECNET_PROBE_PORTS_<NAME_UPPER>
(comma-separated), e.g. DECNET_PROBE_PORTS_JARM="443,8443".
Args: Args:
log_file: base path for log files (RFC 5424 to .log, JSON to .json) log_file: base path for log files (RFC 5424 to .log, JSON to .json)
interval: seconds between probe cycles interval: seconds between probe cycles
timeout: per-probe TCP timeout timeout: per-probe TCP timeout
ports: JARM TLS ports (defaults to DEFAULT_PROBE_PORTS)
ssh_ports: HASSH SSH ports (defaults to DEFAULT_SSH_PORTS)
tcpfp_ports: TCP fingerprint ports (defaults to DEFAULT_TCPFP_PORTS)
""" """
jarm_ports = ports or DEFAULT_PROBE_PORTS
hassh_ports = ssh_ports or DEFAULT_SSH_PORTS
tcp_ports = tcpfp_ports or DEFAULT_TCPFP_PORTS
all_ports_str = (
f"jarm={','.join(str(p) for p in jarm_ports)} "
f"ssh={','.join(str(p) for p in hassh_ports)} "
f"tcpfp={','.join(str(p) for p in tcp_ports)}"
)
log_path = Path(log_file) log_path = Path(log_file)
json_path = log_path.with_suffix(".json") json_path = log_path.with_suffix(".json")
log_path.parent.mkdir(parents=True, exist_ok=True) log_path.parent.mkdir(parents=True, exist_ok=True)
probe_summary = " ".join(
f"{cls.probe_name}={','.join(str(p) for p in cls().ports)}"
for cls in ActiveProbeMeta.all()
)
logger.info( logger.info(
"prober started interval=%ds %s log=%s", "prober started interval=%ds %s log=%s",
interval, all_ports_str, log_path, interval, probe_summary, log_path,
) )
_write_event( _write_event(
log_path, json_path, log_path, json_path,
"prober_startup", "prober_startup",
interval=str(interval), interval=str(interval),
probe_ports=all_ports_str, probe_ports=probe_summary,
msg=f"DECNET-PROBER started, interval {interval}s, {all_ports_str}", msg=f"DECNET-PROBER started, interval {interval}s, {probe_summary}",
) )
known_attackers: set[str] = set() known_attackers: set[str] = set()
@@ -776,7 +604,6 @@ async def prober_worker(
if known_attackers: if known_attackers:
await asyncio.to_thread( await asyncio.to_thread(
_probe_cycle, known_attackers, probed, _probe_cycle, known_attackers, probed,
jarm_ports, hassh_ports, tcp_ports,
log_path, json_path, timeout, log_path, json_path, timeout,
_publish_attacker, _publish_attacker,
record_rotation, record_rotation,

View File

@@ -0,0 +1,80 @@
"""Tests for ActiveProbeMeta registry and ActiveProbe ABC contract."""
from __future__ import annotations
from typing import Any
import pytest
from decnet.prober.base import ActiveProbe, ActiveProbeMeta
import decnet.prober.probes as _probes # noqa: F401 — ensure probes are registered
@pytest.fixture(autouse=True)
def _restore_registry():
"""Snapshot and restore the registry around each test so throwaway probes don't leak."""
snapshot = dict(ActiveProbeMeta._registry)
yield
ActiveProbeMeta._registry.clear()
ActiveProbeMeta._registry.update(snapshot)
class TestRegistryContents:
def test_all_three_probes_registered(self):
names = {cls.probe_name for cls in ActiveProbeMeta.all()}
assert names == {"jarm", "hassh", "tcpfp"}
def test_sorted_by_priority_then_name(self):
order = [cls.probe_name for cls in ActiveProbeMeta.all()]
assert order == ["hassh", "jarm", "tcpfp"] # all priority=100, alphabetical
def test_priority10_probe_sorts_first(self):
class _FastProbe(ActiveProbe):
probe_name = "_fast_test_probe"
default_ports = [9999]
event_type = "_fast_event"
priority = 10
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
return None
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
return {}, ""
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {}
order = [cls.probe_name for cls in ActiveProbeMeta.all()]
assert order[0] == "_fast_test_probe"
assert set(order[1:]) == {"hassh", "jarm", "tcpfp"}
def test_base_class_not_registered(self):
assert "ActiveProbe" not in ActiveProbeMeta._registry
assert None not in ActiveProbeMeta._registry.values()
class TestProbeABCContract:
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_instantiable(self, probe_cls: type[ActiveProbe]):
instance = probe_cls()
assert isinstance(instance, ActiveProbe)
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_has_required_class_attrs(self, probe_cls: type[ActiveProbe]):
assert isinstance(probe_cls.probe_name, str) and probe_cls.probe_name
assert isinstance(probe_cls.default_ports, list) and probe_cls.default_ports
assert isinstance(probe_cls.event_type, str) and probe_cls.event_type
assert isinstance(probe_cls.priority, int)
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_ports_property_reflects_default(self, probe_cls: type[ActiveProbe]):
instance = probe_cls()
assert instance.ports == probe_cls.default_ports
@pytest.mark.parametrize("probe_cls", list(ActiveProbeMeta.all()))
def test_implements_abstract_methods(self, probe_cls: type[ActiveProbe]):
assert callable(getattr(probe_cls, "run"))
assert callable(getattr(probe_cls, "syslog_fields"))
assert callable(getattr(probe_cls, "publish_payload"))

View File

@@ -12,10 +12,10 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from decnet.prober.jarm import JARM_EMPTY_HASH from decnet.prober.jarm import JARM_EMPTY_HASH
from decnet.prober.probes.hassh import HasshProbe
from decnet.prober.probes.jarm import JarmProbe
from decnet.prober.probes.tcpfp import TcpfpProbe
from decnet.prober.worker import ( from decnet.prober.worker import (
DEFAULT_PROBE_PORTS,
DEFAULT_SSH_PORTS,
DEFAULT_TCPFP_PORTS,
_discover_attackers, _discover_attackers,
_probe_cycle, _probe_cycle,
_write_event, _write_event,
@@ -109,13 +109,18 @@ class TestDiscoverAttackers:
class TestProbeCycleJARM: class TestProbeCycleJARM:
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert", return_value=None) @patch("decnet.prober.worker.fetch_leaf_cert", return_value=None)
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_probes_new_ips(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_probes_new_ips(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_cert: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock,
tmp_path: Path): mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 # fake 62-char hash mock_jarm.return_value = "c0c" * 10 + "a" * 32 # fake 62-char hash
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -125,19 +130,24 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_jarm.call_count == 2 # two ports assert mock_jarm.call_count == 2 # two ports
assert 443 in probed["10.0.0.1"]["jarm"] assert 443 in probed["10.0.0.1"]["jarm"]
assert 8443 in probed["10.0.0.1"]["jarm"] assert 8443 in probed["10.0.0.1"]["jarm"]
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert", return_value=None) @patch("decnet.prober.worker.fetch_leaf_cert", return_value=None)
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_skips_already_probed_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_skips_already_probed_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_cert: MagicMock, mock_tcpfp: MagicMock, mock_cert: MagicMock,
tmp_path: Path): mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -147,17 +157,22 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"jarm": {443}}} probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"jarm": {443}}}
_probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
# Should only probe 8443 (443 already done) # Should only probe 8443 (443 already done)
assert mock_jarm.call_count == 1 assert mock_jarm.call_count == 1
mock_jarm.assert_called_once_with("10.0.0.1", 8443, timeout=1.0) mock_jarm.assert_called_once_with("10.0.0.1", 8443, timeout=1.0)
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_empty_hash_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_empty_hash_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -167,18 +182,23 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443], [], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 443 in probed["10.0.0.1"]["jarm"] assert 443 in probed["10.0.0.1"]["jarm"]
if json_path.exists(): if json_path.exists():
content = json_path.read_text() content = json_path.read_text()
assert "jarm_fingerprint" not in content assert "jarm_fingerprint" not in content
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_exception_marks_port_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_exception_marks_port_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.side_effect = OSError("Connection refused") mock_jarm.side_effect = OSError("Connection refused")
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -188,15 +208,20 @@ class TestProbeCycleJARM:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443], [], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 443 in probed["10.0.0.1"]["jarm"] assert 443 in probed["10.0.0.1"]["jarm"]
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_skips_ip_with_all_ports_done(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_skips_ip_with_all_ports_done(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
log_path = tmp_path / "decnet.log" log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json" json_path = tmp_path / "decnet.json"
@@ -205,7 +230,7 @@ class TestProbeCycleJARM:
"10.0.0.1": {"jarm": {443, 8443}, "hassh": set(), "tcpfp": set()}, "10.0.0.1": {"jarm": {443, 8443}, "hassh": set(), "tcpfp": set()},
} }
_probe_cycle(targets, probed, [443, 8443], [], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_jarm.call_count == 0 assert mock_jarm.call_count == 0
@@ -214,11 +239,16 @@ class TestProbeCycleJARM:
class TestProbeCycleHASSH: class TestProbeCycleHASSH:
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_probes_ssh_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_probes_ssh_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22, 2222])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = { mock_hassh.return_value = {
"hassh_server": "a" * 32, "hassh_server": "a" * 32,
@@ -235,17 +265,22 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22, 2222], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_hassh.call_count == 2 assert mock_hassh.call_count == 2
assert 22 in probed["10.0.0.1"]["hassh"] assert 22 in probed["10.0.0.1"]["hassh"]
assert 2222 in probed["10.0.0.1"]["hassh"] assert 2222 in probed["10.0.0.1"]["hassh"]
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_writes_event(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_hassh_writes_event(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = { mock_hassh.return_value = {
"hassh_server": "b" * 32, "hassh_server": "b" * 32,
@@ -262,7 +297,7 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert json_path.exists() assert json_path.exists()
content = json_path.read_text() content = json_path.read_text()
@@ -271,11 +306,16 @@ class TestProbeCycleHASSH:
assert record["fields"]["hassh_server_hash"] == "b" * 32 assert record["fields"]["hassh_server_hash"] == "b" * 32
assert record["fields"]["ssh_banner"] == "SSH-2.0-Paramiko_3.0" assert record["fields"]["ssh_banner"] == "SSH-2.0-Paramiko_3.0"
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_hassh_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None # No SSH server mock_hassh.return_value = None # No SSH server
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -285,18 +325,23 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 22 in probed["10.0.0.1"]["hassh"] assert 22 in probed["10.0.0.1"]["hassh"]
if json_path.exists(): if json_path.exists():
content = json_path.read_text() content = json_path.read_text()
assert "hassh_fingerprint" not in content assert "hassh_fingerprint" not in content
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_skips_already_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_hassh_skips_already_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22, 2222])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
log_path = tmp_path / "decnet.log" log_path = tmp_path / "decnet.log"
@@ -305,16 +350,21 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"hassh": {22}}} probed: dict[str, dict[str, set[int]]] = {"10.0.0.1": {"hassh": {22}}}
_probe_cycle(targets, probed, [], [22, 2222], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_hassh.call_count == 1 # only 2222 assert mock_hassh.call_count == 1 # only 2222
mock_hassh.assert_called_once_with("10.0.0.1", 2222, timeout=1.0) mock_hassh.assert_called_once_with("10.0.0.1", 2222, timeout=1.0)
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_hassh_exception_marks_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_hassh_exception_marks_probed(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.side_effect = OSError("Connection refused") mock_hassh.side_effect = OSError("Connection refused")
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -324,7 +374,7 @@ class TestProbeCycleHASSH:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [22], [], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 22 in probed["10.0.0.1"]["hassh"] assert 22 in probed["10.0.0.1"]["hassh"]
@@ -333,11 +383,16 @@ class TestProbeCycleHASSH:
class TestProbeCycleTCPFP: class TestProbeCycleTCPFP:
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_probes_tcpfp_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_probes_tcpfp_ports(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [80, 443])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = { mock_tcpfp.return_value = {
@@ -354,17 +409,22 @@ class TestProbeCycleTCPFP:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [], [80, 443], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_tcpfp.call_count == 2 assert mock_tcpfp.call_count == 2
assert 80 in probed["10.0.0.1"]["tcpfp"] assert 80 in probed["10.0.0.1"]["tcpfp"]
assert 443 in probed["10.0.0.1"]["tcpfp"] assert 443 in probed["10.0.0.1"]["tcpfp"]
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_tcpfp_writes_event_with_all_fields(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_tcpfp_writes_event_with_all_fields(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [443])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = { mock_tcpfp.return_value = {
@@ -381,7 +441,7 @@ class TestProbeCycleTCPFP:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [], [443], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
content = json_path.read_text() content = json_path.read_text()
assert "tcpfp_fingerprint" in content assert "tcpfp_fingerprint" in content
@@ -391,11 +451,16 @@ class TestProbeCycleTCPFP:
assert record["fields"]["window_size"] == "8192" assert record["fields"]["window_size"] == "8192"
assert record["fields"]["options_order"] == "M,N,W,N,N,S" assert record["fields"]["options_order"] == "M,N,W,N,N,S"
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_tcpfp_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_tcpfp_none_not_logged(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [443])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -405,7 +470,7 @@ class TestProbeCycleTCPFP:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [], [], [443], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert 443 in probed["10.0.0.1"]["tcpfp"] assert 443 in probed["10.0.0.1"]["tcpfp"]
if json_path.exists(): if json_path.exists():
@@ -417,12 +482,17 @@ class TestProbeCycleTCPFP:
class TestProbeTypeIsolation: class TestProbeTypeIsolation:
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_jarm_does_not_mark_hassh(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_jarm_does_not_mark_hassh(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""JARM probing port 2222 should not mark HASSH port 2222 as done.""" """JARM probing port 2222 should not mark HASSH port 2222 as done."""
monkeypatch.setattr(JarmProbe, "default_ports", [2222])
monkeypatch.setattr(HasshProbe, "default_ports", [2222])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -432,8 +502,7 @@ class TestProbeTypeIsolation:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
# Probe with JARM on 2222 and HASSH on 2222 _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
_probe_cycle(targets, probed, [2222], [2222], [], log_path, json_path, timeout=1.0)
# Both should be called # Both should be called
assert mock_jarm.call_count == 1 assert mock_jarm.call_count == 1
@@ -441,11 +510,16 @@ class TestProbeTypeIsolation:
assert 2222 in probed["10.0.0.1"]["jarm"] assert 2222 in probed["10.0.0.1"]["jarm"]
assert 2222 in probed["10.0.0.1"]["hassh"] assert 2222 in probed["10.0.0.1"]["hassh"]
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.probes.jarm.jarm_hash")
def test_all_three_probes_run(self, mock_jarm: MagicMock, mock_hassh: MagicMock, def test_all_three_probes_run(self, mock_jarm: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, tmp_path: Path): mock_tcpfp: MagicMock, mock_ipv6: MagicMock,
tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [22])
monkeypatch.setattr(TcpfpProbe, "default_ports", [80])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -455,7 +529,7 @@ class TestProbeTypeIsolation:
targets = {"10.0.0.1"} targets = {"10.0.0.1"}
probed: dict[str, dict[str, set[int]]] = {} probed: dict[str, dict[str, set[int]]] = {}
_probe_cycle(targets, probed, [443], [22], [80], log_path, json_path, timeout=1.0) _probe_cycle(targets, probed, log_path, json_path, timeout=1.0)
assert mock_jarm.call_count == 1 assert mock_jarm.call_count == 1
assert mock_hassh.call_count == 1 assert mock_hassh.call_count == 1
@@ -490,20 +564,26 @@ class TestWriteEvent:
class TestProbeCycleTLSCert: class TestProbeCycleTLSCert:
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert") @patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_event_emitted_after_successful_jarm( def test_cert_event_emitted_after_successful_jarm(
self, self,
mock_jarm: MagicMock, mock_jarm: MagicMock,
mock_hassh: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_tcpfp: MagicMock,
mock_cert: MagicMock, mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path, tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
): ):
"""A non-empty JARM hash should trigger a follow-up cert fetch and """A non-empty JARM hash should trigger a follow-up cert fetch and
write a tls_certificate event with all parsed fields.""" write a tls_certificate event with all parsed fields."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -519,7 +599,7 @@ class TestProbeCycleTLSCert:
log_path = tmp_path / "decnet.log" log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json" json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0) _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0) mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0)
records = [ records = [
@@ -539,69 +619,87 @@ class TestProbeCycleTLSCert:
assert f["sans"] == "evil.example.com,c2.example.com" assert f["sans"] == "evil.example.com,c2.example.com"
assert f["cert_sha256"] == "ab" * 32 assert f["cert_sha256"] == "ab" * 32
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert") @patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_fetch_skipped_on_empty_jarm( def test_cert_fetch_skipped_on_empty_jarm(
self, self,
mock_jarm: MagicMock, mock_jarm: MagicMock,
mock_hassh: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_tcpfp: MagicMock,
mock_cert: MagicMock, mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path, tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
): ):
"""JARM_EMPTY_HASH means the port doesn't speak TLS; skip cert fetch.""" """JARM_EMPTY_HASH means the port doesn't speak TLS; skip cert fetch."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = JARM_EMPTY_HASH mock_jarm.return_value = JARM_EMPTY_HASH
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
log_path = tmp_path / "decnet.log" log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json" json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0) _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
mock_cert.assert_not_called() mock_cert.assert_not_called()
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert", return_value=None) @patch("decnet.prober.worker.fetch_leaf_cert", return_value=None)
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_fetch_failure_silent( def test_cert_fetch_failure_silent(
self, self,
mock_jarm: MagicMock, mock_jarm: MagicMock,
mock_hassh: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_tcpfp: MagicMock,
mock_cert: MagicMock, mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path, tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
): ):
"""fetch_leaf_cert returning None must not write a cert event.""" """fetch_leaf_cert returning None must not write a cert event."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
log_path = tmp_path / "decnet.log" log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json" json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443], [], [], log_path, json_path, timeout=1.0) _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0) mock_cert.assert_called_once_with("10.0.0.1", 443, timeout=1.0)
if json_path.exists(): if json_path.exists():
content = json_path.read_text() content = json_path.read_text()
assert "tls_certificate" not in content assert "tls_certificate" not in content
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert") @patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_fetch_crash_does_not_break_phase( def test_cert_fetch_crash_does_not_break_phase(
self, self,
mock_jarm: MagicMock, mock_jarm: MagicMock,
mock_hassh: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_tcpfp: MagicMock,
mock_cert: MagicMock, mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path, tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
): ):
"""If fetch_leaf_cert throws despite its contract, the JARM phase """If fetch_leaf_cert throws despite its contract, the JARM phase
must keep moving to the next port without crashing.""" must keep moving to the next port without crashing."""
monkeypatch.setattr(JarmProbe, "default_ports", [443, 8443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -609,25 +707,30 @@ class TestProbeCycleTLSCert:
log_path = tmp_path / "decnet.log" log_path = tmp_path / "decnet.log"
json_path = tmp_path / "decnet.json" json_path = tmp_path / "decnet.json"
_probe_cycle({"10.0.0.1"}, {}, [443, 8443], [], [], log_path, json_path, timeout=1.0) _probe_cycle({"10.0.0.1"}, {}, log_path, json_path, timeout=1.0)
# Both ports still marked probed despite the cert-side crash. # Both ports still marked probed despite the cert-side crash.
from decnet.prober.worker import _probe_cycle as _ # re-import safety
assert mock_cert.call_count == 2 assert mock_cert.call_count == 2
@patch("decnet.prober.worker._ipv6_leak_phase")
@patch("decnet.prober.worker.fetch_leaf_cert") @patch("decnet.prober.worker.fetch_leaf_cert")
@patch("decnet.prober.worker.tcp_fingerprint") @patch("decnet.prober.probes.tcpfp.tcp_fingerprint")
@patch("decnet.prober.worker.hassh_server") @patch("decnet.prober.probes.hassh.hassh_server")
@patch("decnet.prober.worker.jarm_hash") @patch("decnet.prober.probes.jarm.jarm_hash")
def test_cert_publish_fn_called( def test_cert_publish_fn_called(
self, self,
mock_jarm: MagicMock, mock_jarm: MagicMock,
mock_hassh: MagicMock, mock_hassh: MagicMock,
mock_tcpfp: MagicMock, mock_tcpfp: MagicMock,
mock_cert: MagicMock, mock_cert: MagicMock,
mock_ipv6: MagicMock,
tmp_path: Path, tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
): ):
"""publish_fn must receive a 'tls_certificate' event when capture succeeds.""" """publish_fn must receive a 'tls_certificate' event when capture succeeds."""
monkeypatch.setattr(JarmProbe, "default_ports", [443])
monkeypatch.setattr(HasshProbe, "default_ports", [])
monkeypatch.setattr(TcpfpProbe, "default_ports", [])
mock_jarm.return_value = "c0c" * 10 + "a" * 32 mock_jarm.return_value = "c0c" * 10 + "a" * 32
mock_hassh.return_value = None mock_hassh.return_value = None
mock_tcpfp.return_value = None mock_tcpfp.return_value = None
@@ -646,7 +749,7 @@ class TestProbeCycleTLSCert:
published.append((kind, payload)) published.append((kind, payload))
_probe_cycle( _probe_cycle(
{"10.0.0.1"}, {}, [443], [], [], {"10.0.0.1"}, {},
tmp_path / "decnet.log", tmp_path / "decnet.json", tmp_path / "decnet.log", tmp_path / "decnet.json",
timeout=1.0, publish_fn=publish, timeout=1.0, publish_fn=publish,
) )

View File

@@ -0,0 +1,244 @@
"""Unit tests for the _run_probe generic driver."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from decnet.prober.base import ActiveProbe, ActiveProbeMeta
from decnet.prober.worker import _run_probe
@pytest.fixture(autouse=True)
def _restore_registry():
snapshot = dict(ActiveProbeMeta._registry)
yield
ActiveProbeMeta._registry.clear()
ActiveProbeMeta._registry.update(snapshot)
def _make_probe(
probe_name: str = "test_probe",
default_ports: list[int] | None = None,
run_return: dict[str, Any] | None = None,
run_side_effect: Exception | None = None,
rotation_type: str | None = "test",
rotation_hash_key: str | None = "hash",
) -> ActiveProbe:
"""Build a concrete ActiveProbe subclass for testing and return an instance."""
_pname = probe_name
_ports = default_ports or [1234]
_result = run_return
_exc = run_side_effect
_rtype = rotation_type
_rkey = rotation_hash_key
class _TestProbe(ActiveProbe):
probe_name = _pname # type: ignore[assignment]
default_ports = _ports # type: ignore[assignment]
event_type = f"{_pname}_event"
rotation_type = _rtype # type: ignore[assignment]
rotation_hash_key = _rkey
priority = 100
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
if _exc is not None:
raise _exc
return _result
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
return {"hash": result.get("hash", "")}, f"{_pname} {ip}:{port}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {"attacker_ip": ip, "port": port, "hash": result.get("hash", "")}
return _TestProbe()
class TestRunProbeDedup:
def test_skips_already_probed_port(self, tmp_path: Path):
probe = _make_probe(default_ports=[80, 443], run_return={"hash": "abc"})
ip_probed: dict[str, set[int]] = {"test_probe": {80}}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert 80 in ip_probed["test_probe"] # was already there
assert 443 in ip_probed["test_probe"] # newly probed
def test_initializes_done_set_if_missing(self, tmp_path: Path):
probe = _make_probe(default_ports=[22], run_return=None)
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert "test_probe" in ip_probed
assert 22 in ip_probed["test_probe"]
class TestRunProbeSuccessPath:
def test_writes_event_on_success(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "deadbeef"})
ip_probed: dict[str, set[int]] = {}
json_path = tmp_path / "events.json"
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path,
timeout=1.0, publish_fn=None, record_rotation=None)
assert json_path.exists()
record = json.loads(json_path.read_text().strip())
assert record["event_type"] == "test_probe_event"
assert record["fields"]["target_ip"] == "1.2.3.4"
assert record["fields"]["target_port"] == "443"
assert record["fields"]["hash"] == "deadbeef"
def test_calls_publish_fn_on_success(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "cafebabe"})
published: list[tuple[str, dict]] = []
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=lambda k, v: published.append((k, v)),
record_rotation=None)
assert len(published) == 1
assert published[0][0] == "test_probe"
assert published[0][1]["attacker_ip"] == "1.2.3.4"
assert published[0][1]["hash"] == "cafebabe"
def test_calls_record_rotation_when_configured(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "rotateme"},
rotation_type="test", rotation_hash_key="hash")
mock_rotation = MagicMock()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=mock_rotation)
mock_rotation.assert_called_once_with("1.2.3.4", 443, "test", "rotateme")
def test_skips_rotation_when_rotation_type_none(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "x"},
rotation_type=None, rotation_hash_key=None)
mock_rotation = MagicMock()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=mock_rotation)
mock_rotation.assert_not_called()
def test_skips_rotation_when_rotation_hash_key_none(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return={"hash": "x"},
rotation_type="test", rotation_hash_key=None)
mock_rotation = MagicMock()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=mock_rotation)
mock_rotation.assert_not_called()
class TestRunProbeNoneResult:
def test_none_suppresses_event(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return=None)
ip_probed: dict[str, set[int]] = {}
json_path = tmp_path / "events.json"
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path,
timeout=1.0, publish_fn=None, record_rotation=None)
assert 443 in ip_probed["test_probe"]
assert not json_path.exists()
def test_none_suppresses_publish(self, tmp_path: Path):
probe = _make_probe(default_ports=[443], run_return=None)
published: list = []
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=lambda k, v: published.append((k, v)),
record_rotation=None)
assert len(published) == 0
class TestRunProbeExceptionPath:
def test_exception_marks_port_done(self, tmp_path: Path):
probe = _make_probe(default_ports=[443],
run_side_effect=OSError("Connection refused"))
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert 443 in ip_probed["test_probe"]
def test_exception_writes_prober_error_event(self, tmp_path: Path):
probe = _make_probe(default_ports=[443],
run_side_effect=OSError("refused"))
ip_probed: dict[str, set[int]] = {}
json_path = tmp_path / "events.json"
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "events.log", json_path,
timeout=1.0, publish_fn=None, record_rotation=None)
assert json_path.exists()
record = json.loads(json_path.read_text().strip())
assert record["event_type"] == "prober_error"
assert record["fields"]["target_ip"] == "1.2.3.4"
assert "refused" in record["fields"]["error"]
def test_exception_does_not_publish(self, tmp_path: Path):
probe = _make_probe(default_ports=[443],
run_side_effect=RuntimeError("boom"))
published: list = []
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=lambda k, v: published.append((k, v)),
record_rotation=None)
assert len(published) == 0
def test_continues_remaining_ports_after_exception(self, tmp_path: Path):
call_count = 0
class _CountProbe(ActiveProbe):
probe_name = "_count_probe"
default_ports = [80, 443, 8080]
event_type = "_count_event"
priority = 100
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
nonlocal call_count
call_count += 1
if port == 443:
raise OSError("refused")
return None
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
return {}, ""
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {}
probe = _CountProbe()
ip_probed: dict[str, set[int]] = {}
_run_probe(probe, "1.2.3.4", ip_probed, tmp_path / "a.log", tmp_path / "a.json",
timeout=1.0, publish_fn=None, record_rotation=None)
assert call_count == 3 # all three ports attempted
assert {80, 443, 8080} == ip_probed["_count_probe"]