refactor(prober): ActiveProbe ABC + ActiveProbeMeta registry

Replace _jarm_phase / _hassh_phase / _tcpfp_phase boilerplate (3×~50
lines of identical port-iteration logic) with a metaclass-registered ABC.
Adding a new port-iterating active probe is now one class + three methods.

- decnet/prober/base.py: ActiveProbeMeta auto-registers subclasses by
  probe_name; ActiveProbe ABC enforces run/syslog_fields/publish_payload
  with env-driven DECNET_PROBE_PORTS_<NAME> port override.
- decnet/prober/probes/{jarm,hassh,tcpfp}.py: concrete probe classes.
- decnet/prober/worker.py: single _run_probe driver replaces the three
  phase functions; _probe_cycle iterates ActiveProbeMeta.all(); drops
  the ports=/ssh_ports=/tcpfp_ports= kwargs from prober_worker.
- IPv6 leak and TLS cert capture stay as special cases (different call
  shapes; intentionally outside the registry).
- tests/prober/test_active_probe_registry.py: registry contents, sort
  order, priority-10 override, ABC contract per probe class.
- tests/prober/test_run_probe_driver.py: dedup, success, None-skip,
  exception, rotation, publish paths for _run_probe.
- tests/prober/test_prober_worker.py: updated patch targets and
  _probe_cycle call sites; port control via monkeypatch.setattr.
This commit is contained in:
2026-05-17 23:16:35 -04:00
parent 3977f06374
commit 916b21b652
9 changed files with 810 additions and 339 deletions

90
decnet/prober/base.py Normal file
View File

@@ -0,0 +1,90 @@
"""
ActiveProbe ABC and metaclass registry for port-iterating active probes.
Adding a new active probe = one class with three methods.
IPv6 leak and TLS cert capture are NOT part of this registry (different
call shapes); they stay as special cases in prober/worker.py.
"""
from __future__ import annotations
import os
from abc import ABCMeta, abstractmethod
from typing import Any
from decnet.correlation.fingerprint_rotation import ProbeType
class ActiveProbeMeta(ABCMeta):
"""Metaclass that auto-registers every ActiveProbe subclass by probe_name."""
_registry: dict[str, type[ActiveProbe]] = {}
def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
) -> ActiveProbeMeta:
cls = super().__new__(mcs, name, bases, namespace)
if bases and getattr(cls, "probe_name", None):
mcs._registry[cls.probe_name] = cls # type: ignore[attr-defined,assignment]
return cls
@classmethod
def all(mcs) -> list[type[ActiveProbe]]:
"""Return registered probes sorted by (priority asc, probe_name asc)."""
return sorted(mcs._registry.values(), key=lambda c: (c.priority, c.probe_name))
class ActiveProbe(metaclass=ActiveProbeMeta):
"""Base class for all port-iterating active probes.
Subclasses declare class-level attributes and implement three methods.
Registration is automatic via ActiveProbeMeta.
Port override: set DECNET_PROBE_PORTS_<NAME_UPPER> (comma-separated) to
override default_ports at runtime without touching the class.
"""
probe_name: str
default_ports: list[int]
event_type: str
rotation_type: ProbeType | None = None
rotation_hash_key: str | None = None
priority: int = 100
def __init__(self) -> None:
env_key = f"DECNET_PROBE_PORTS_{self.probe_name.upper()}"
raw = os.environ.get(env_key, "").strip()
if raw:
try:
self._ports: list[int] = [int(p.strip()) for p in raw.split(",") if p.strip()]
except ValueError:
self._ports = list(self.default_ports)
else:
self._ports = list(self.default_ports)
@property
def ports(self) -> list[int]:
return self._ports
@abstractmethod
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
"""Execute the probe against ip:port.
Return a result dict on success, or None to suppress emission (e.g.
empty JARM hash means the port doesn't speak TLS).
"""
@abstractmethod
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
"""Return (sd_fields, human_msg) for _write_event.
target_ip and target_port are injected by _run_probe; do not include
them in sd_fields.
"""
@abstractmethod
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
"""Return the bus payload dict for attacker.fingerprinted events."""

View File

@@ -0,0 +1,4 @@
# Import all probe modules to trigger ActiveProbeMeta registration.
from decnet.prober.probes.hassh import HasshProbe as HasshProbe
from decnet.prober.probes.jarm import JarmProbe as JarmProbe
from decnet.prober.probes.tcpfp import TcpfpProbe as TcpfpProbe

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from typing import Any
from decnet.prober.base import ActiveProbe
from decnet.prober.hassh import hassh_server
from decnet.telemetry import traced as _traced
DEFAULT_PORTS: list[int] = [22, 2222, 22222, 2022]
class HasshProbe(ActiveProbe):
probe_name = "hassh"
default_ports = DEFAULT_PORTS
event_type = "hassh_fingerprint"
rotation_type = "hassh"
rotation_hash_key = "hassh_server"
priority = 100
@_traced("prober.hassh_probe")
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
return hassh_server(ip, port, timeout=timeout)
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
fields = {
"hassh_server_hash": result["hassh_server"],
"ssh_banner": result["banner"],
"kex_algorithms": result["kex_algorithms"],
"encryption_s2c": result["encryption_s2c"],
"mac_s2c": result["mac_s2c"],
"compression_s2c": result["compression_s2c"],
}
return fields, f"HASSH {ip}:{port} = {result['hassh_server']}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {
"attacker_ip": ip,
"port": port,
"hassh_server": result["hassh_server"],
"ssh_banner": result["banner"],
}

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from typing import Any
from decnet.prober.base import ActiveProbe
from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash
from decnet.telemetry import traced as _traced
DEFAULT_PORTS: list[int] = [443, 8443, 8080, 4443, 50050, 2222, 993, 995, 8888, 9001]
class JarmProbe(ActiveProbe):
probe_name = "jarm"
default_ports = DEFAULT_PORTS
event_type = "jarm_fingerprint"
rotation_type = "jarm"
rotation_hash_key = "jarm_hash"
priority = 100
@_traced("prober.jarm_probe")
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
h = jarm_hash(ip, port, timeout=timeout)
if h == JARM_EMPTY_HASH:
return None
return {"jarm_hash": h}
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
h = result["jarm_hash"]
return {"jarm_hash": h}, f"JARM {ip}:{port} = {h}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {"attacker_ip": ip, "port": port, "jarm_hash": result["jarm_hash"]}

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from typing import Any
from decnet.prober.base import ActiveProbe
from decnet.prober.tcpfp import tcp_fingerprint
from decnet.telemetry import traced as _traced
DEFAULT_PORTS: list[int] = [22, 80, 443, 8080, 8443, 445, 3389]
class TcpfpProbe(ActiveProbe):
probe_name = "tcpfp"
default_ports = DEFAULT_PORTS
event_type = "tcpfp_fingerprint"
rotation_type = "tcpfp"
rotation_hash_key = "tcpfp_hash"
priority = 100
@_traced("prober.tcpfp_probe")
def run(self, ip: str, port: int, timeout: float) -> dict[str, Any] | None:
return tcp_fingerprint(ip, port, timeout=timeout)
def syslog_fields(self, ip: str, port: int, result: dict[str, Any]) -> tuple[dict[str, Any], str]:
fields = {
"tcpfp_hash": result["tcpfp_hash"],
"tcpfp_raw": result["tcpfp_raw"],
"ttl": str(result["ttl"]),
"window_size": str(result["window_size"]),
"df_bit": str(result["df_bit"]),
"mss": str(result["mss"]),
"window_scale": str(result["window_scale"]),
"sack_ok": str(result["sack_ok"]),
"timestamp": str(result["timestamp"]),
"options_order": result["options_order"],
"tos": str(result["tos"]),
"dscp": str(result["dscp"]),
"ecn": str(result["ecn"]),
"server_isn": str(result["server_isn"]),
}
return fields, f"TCPFP {ip}:{port} = {result['tcpfp_hash']}"
def publish_payload(self, ip: str, port: int, result: dict[str, Any]) -> dict[str, Any]:
return {
"attacker_ip": ip,
"port": port,
"tcpfp_hash": result["tcpfp_hash"],
"ttl": result["ttl"],
"mss": result["mss"],
}

View File

@@ -43,9 +43,8 @@ from decnet.correlation.fingerprint_rotation import (
record_fingerprint,
)
from decnet.logging import get_logger
from decnet.prober.hassh import hassh_server
from decnet.prober.jarm import JARM_EMPTY_HASH, jarm_hash
from decnet.prober.tcpfp import tcp_fingerprint
from decnet.prober.base import ActiveProbe, ActiveProbeMeta
import decnet.prober.probes as _probes # noqa: F401 — triggers metaclass registration
from decnet.prober.tlscert import fetch_leaf_cert
from decnet.telemetry import traced as _traced
@@ -66,20 +65,6 @@ def _build_sync_engine() -> Engine:
db_path = os.environ.get("DECNET_DB_PATH", str(_ROOT / "decnet.db"))
return get_sync_engine(db_path)
# ─── Default ports per probe type ───────────────────────────────────────────
# JARM: common C2 callback / TLS server ports
DEFAULT_PROBE_PORTS: list[int] = [
443, 8443, 8080, 4443, 50050, 2222, 993, 995, 8888, 9001,
]
# HASSHServer: common SSH server ports
DEFAULT_SSH_PORTS: list[int] = [22, 2222, 22222, 2022]
# TCP/IP stack: probe on ports commonly open on attacker machines.
# Wide spread gives the best chance of a SYN-ACK for TTL/fingerprint extraction.
DEFAULT_TCPFP_PORTS: list[int] = [22, 80, 443, 8080, 8443, 445, 3389]
# ─── RFC 5424 formatting (inline, mirrors templates/*/decnet_logging.py) ─────
_FACILITY_LOCAL0 = 16
@@ -259,94 +244,49 @@ ProbePublishFn = Callable[[str, dict[str, Any]], None]
# performs the rotation-detection upsert + derived-event emission for the
# DEBT-032 substrate-fingerprint flow. Optional; when None the prober
# behaves exactly as before (raw fingerprint emit only, no rotation
# detection). Construction lives at worker startup so phase functions
# don't have to know about the DB engine.
# detection). Construction lives at worker startup so the probe driver
# doesn't have to know about the DB engine.
RotationRecorderFn = Callable[[str, int, "ProbeType", str], None]
@_traced("prober.probe_cycle")
def _probe_cycle(
targets: set[str],
probed: dict[str, dict[str, set[int]]],
jarm_ports: list[int],
ssh_ports: list[int],
tcpfp_ports: list[int],
log_path: Path,
json_path: Path,
timeout: float = 5.0,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""
Probe all known attacker IPs with JARM, HASSH, and TCP/IP fingerprinting.
Args:
targets: set of attacker IPs to probe
probed: dict mapping IP -> {probe_type -> set of ports already probed}
jarm_ports: TLS ports for JARM fingerprinting
ssh_ports: SSH ports for HASSHServer fingerprinting
tcpfp_ports: ports for TCP/IP stack fingerprinting
log_path: RFC 5424 log file
json_path: JSON log file
timeout: per-probe TCP timeout
"""
for ip in sorted(targets):
ip_probed = probed.setdefault(ip, {})
# Phase 1: JARM (TLS fingerprinting)
_jarm_phase(ip, ip_probed, jarm_ports, log_path, json_path, timeout, publish_fn, record_rotation)
# Phase 2: HASSHServer (SSH fingerprinting)
_hassh_phase(ip, ip_probed, ssh_ports, log_path, json_path, timeout, publish_fn, record_rotation)
# Phase 3: TCP/IP stack fingerprinting
_tcpfp_phase(ip, ip_probed, tcpfp_ports, log_path, json_path, timeout, publish_fn, record_rotation)
# Phase 4: IPv6 link-local leak (active ICMPv6 solicitation; on-link only)
_ipv6_leak_phase(ip, ip_probed, log_path, json_path, timeout, publish_fn)
@_traced("prober.jarm_phase")
def _jarm_phase(
@_traced("prober.run_probe")
def _run_probe(
probe: ActiveProbe,
ip: str,
ip_probed: dict[str, set[int]],
ports: list[int],
log_path: Path,
json_path: Path,
timeout: float,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
publish_fn: ProbePublishFn | None,
record_rotation: RotationRecorderFn | None,
) -> None:
"""JARM-fingerprint an IP on the given TLS ports."""
done = ip_probed.setdefault("jarm", set())
for port in ports:
"""Generic driver for any port-iterating ActiveProbe."""
done = ip_probed.setdefault(probe.probe_name, set())
for port in probe.ports:
if port in done:
continue
try:
h = jarm_hash(ip, port, timeout=timeout)
result = probe.run(ip, port, timeout)
done.add(port)
if h == JARM_EMPTY_HASH:
if result is None:
continue
fields, msg = probe.syslog_fields(ip, port, result)
_write_event(
log_path, json_path,
"jarm_fingerprint",
probe.event_type,
target_ip=ip,
target_port=str(port),
jarm_hash=h,
msg=f"JARM {ip}:{port} = {h}",
msg=msg,
**fields,
)
logger.info("prober: JARM %s:%d = %s", ip, port, h)
if record_rotation is not None:
record_rotation(ip, port, "jarm", h)
logger.info("prober: %s %s:%d ok", probe.probe_name, ip, port)
if record_rotation is not None and probe.rotation_type and probe.rotation_hash_key:
record_rotation(ip, port, probe.rotation_type, result[probe.rotation_hash_key])
if publish_fn is not None:
publish_fn(
"jarm",
{"attacker_ip": ip, "port": port, "jarm_hash": h},
)
# Cert capture: a non-empty JARM hash proves the port speaks
# TLS, so a follow-up real handshake is worth attempting.
# Failures are silent — the next probe target must not stall.
_capture_tls_cert(ip, port, log_path, json_path, timeout, publish_fn)
publish_fn(probe.probe_name, probe.publish_payload(ip, port, result))
if probe.probe_name == "jarm":
# A non-empty JARM hash proves TLS; attempt a real cert capture.
_capture_tls_cert(ip, port, log_path, json_path, timeout, publish_fn)
except Exception as exc:
done.add(port)
_write_event(
@@ -356,9 +296,34 @@ def _jarm_phase(
target_ip=ip,
target_port=str(port),
error=str(exc),
msg=f"JARM probe failed for {ip}:{port}: {exc}",
msg=f"{probe.probe_name} probe failed for {ip}:{port}: {exc}",
)
logger.warning("prober: JARM probe failed %s:%d: %s", ip, port, exc)
logger.warning("prober: %s probe failed %s:%d: %s", probe.probe_name, ip, port, exc)
@_traced("prober.probe_cycle")
def _probe_cycle(
targets: set[str],
probed: dict[str, dict[str, set[int]]],
log_path: Path,
json_path: Path,
timeout: float = 5.0,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""Probe all known attacker IPs via every registered ActiveProbe.
Probes run in (priority, probe_name) order per ActiveProbeMeta.all().
IPv6 leak runs last — it is not port-iterating and stays a special case.
"""
for ip in sorted(targets):
ip_probed = probed.setdefault(ip, {})
for probe_cls in ActiveProbeMeta.all():
_run_probe(probe_cls(), ip, ip_probed, log_path, json_path,
timeout, publish_fn, record_rotation)
_ipv6_leak_phase(ip, ip_probed, log_path, json_path, timeout, publish_fn)
@_traced("prober.tls_cert_capture")
@@ -415,137 +380,6 @@ def _capture_tls_cert(
)
@_traced("prober.hassh_phase")
def _hassh_phase(
ip: str,
ip_probed: dict[str, set[int]],
ports: list[int],
log_path: Path,
json_path: Path,
timeout: float,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""HASSHServer-fingerprint an IP on the given SSH ports."""
done = ip_probed.setdefault("hassh", set())
for port in ports:
if port in done:
continue
try:
result = hassh_server(ip, port, timeout=timeout)
done.add(port)
if result is None:
continue
_write_event(
log_path, json_path,
"hassh_fingerprint",
target_ip=ip,
target_port=str(port),
hassh_server_hash=result["hassh_server"],
ssh_banner=result["banner"],
kex_algorithms=result["kex_algorithms"],
encryption_s2c=result["encryption_s2c"],
mac_s2c=result["mac_s2c"],
compression_s2c=result["compression_s2c"],
msg=f"HASSH {ip}:{port} = {result['hassh_server']}",
)
logger.info("prober: HASSH %s:%d = %s", ip, port, result["hassh_server"])
if record_rotation is not None:
record_rotation(ip, port, "hassh", result["hassh_server"])
if publish_fn is not None:
publish_fn(
"hassh",
{
"attacker_ip": ip,
"port": port,
"hassh_server": result["hassh_server"],
"ssh_banner": result["banner"],
},
)
except Exception as exc:
done.add(port)
_write_event(
log_path, json_path,
"prober_error",
severity=_SEVERITY_WARNING,
target_ip=ip,
target_port=str(port),
error=str(exc),
msg=f"HASSH probe failed for {ip}:{port}: {exc}",
)
logger.warning("prober: HASSH probe failed %s:%d: %s", ip, port, exc)
@_traced("prober.tcpfp_phase")
def _tcpfp_phase(
ip: str,
ip_probed: dict[str, set[int]],
ports: list[int],
log_path: Path,
json_path: Path,
timeout: float,
publish_fn: ProbePublishFn | None = None,
record_rotation: RotationRecorderFn | None = None,
) -> None:
"""TCP/IP stack fingerprint an IP on the given ports."""
done = ip_probed.setdefault("tcpfp", set())
for port in ports:
if port in done:
continue
try:
result = tcp_fingerprint(ip, port, timeout=timeout)
done.add(port)
if result is None:
continue
_write_event(
log_path, json_path,
"tcpfp_fingerprint",
target_ip=ip,
target_port=str(port),
tcpfp_hash=result["tcpfp_hash"],
tcpfp_raw=result["tcpfp_raw"],
ttl=str(result["ttl"]),
window_size=str(result["window_size"]),
df_bit=str(result["df_bit"]),
mss=str(result["mss"]),
window_scale=str(result["window_scale"]),
sack_ok=str(result["sack_ok"]),
timestamp=str(result["timestamp"]),
options_order=result["options_order"],
tos=str(result["tos"]),
dscp=str(result["dscp"]),
ecn=str(result["ecn"]),
server_isn=str(result["server_isn"]),
msg=f"TCPFP {ip}:{port} = {result['tcpfp_hash']}",
)
logger.info("prober: TCPFP %s:%d = %s", ip, port, result["tcpfp_hash"])
if record_rotation is not None:
record_rotation(ip, port, "tcpfp", result["tcpfp_hash"])
if publish_fn is not None:
publish_fn(
"tcpfp",
{
"attacker_ip": ip,
"port": port,
"tcpfp_hash": result["tcpfp_hash"],
"ttl": result["ttl"],
"mss": result["mss"],
},
)
except Exception as exc:
done.add(port)
_write_event(
log_path, json_path,
"prober_error",
severity=_SEVERITY_WARNING,
target_ip=ip,
target_port=str(port),
error=str(exc),
msg=f"TCPFP probe failed for {ip}:{port}: {exc}",
)
logger.warning("prober: TCPFP probe failed %s:%d: %s", ip, port, exc)
@_traced("prober.ipv6_leak_phase")
def _ipv6_leak_phase(
ip: str,
@@ -622,49 +456,43 @@ async def prober_worker(
log_file: str,
interval: int = 300,
timeout: float = 5.0,
ports: list[int] | None = None,
ssh_ports: list[int] | None = None,
tcpfp_ports: list[int] | None = None,
) -> None:
"""
Main entry point for the standalone prober process.
Discovers attacker IPs automatically by tailing the JSON log file,
then fingerprints each IP via JARM, HASSH, and TCP/IP stack probes.
then fingerprints each IP via every registered ActiveProbe (JARM,
HASSH, TCP/IP stack) plus the IPv6 leak special case.
Per-probe port lists are taken from each probe's ``default_ports``
attribute. Override at runtime via DECNET_PROBE_PORTS_<NAME_UPPER>
(comma-separated), e.g. DECNET_PROBE_PORTS_JARM="443,8443".
Args:
log_file: base path for log files (RFC 5424 to .log, JSON to .json)
interval: seconds between probe cycles
timeout: per-probe TCP timeout
ports: JARM TLS ports (defaults to DEFAULT_PROBE_PORTS)
ssh_ports: HASSH SSH ports (defaults to DEFAULT_SSH_PORTS)
tcpfp_ports: TCP fingerprint ports (defaults to DEFAULT_TCPFP_PORTS)
"""
jarm_ports = ports or DEFAULT_PROBE_PORTS
hassh_ports = ssh_ports or DEFAULT_SSH_PORTS
tcp_ports = tcpfp_ports or DEFAULT_TCPFP_PORTS
all_ports_str = (
f"jarm={','.join(str(p) for p in jarm_ports)} "
f"ssh={','.join(str(p) for p in hassh_ports)} "
f"tcpfp={','.join(str(p) for p in tcp_ports)}"
)
log_path = Path(log_file)
json_path = log_path.with_suffix(".json")
log_path.parent.mkdir(parents=True, exist_ok=True)
probe_summary = " ".join(
f"{cls.probe_name}={','.join(str(p) for p in cls().ports)}"
for cls in ActiveProbeMeta.all()
)
logger.info(
"prober started interval=%ds %s log=%s",
interval, all_ports_str, log_path,
interval, probe_summary, log_path,
)
_write_event(
log_path, json_path,
"prober_startup",
interval=str(interval),
probe_ports=all_ports_str,
msg=f"DECNET-PROBER started, interval {interval}s, {all_ports_str}",
probe_ports=probe_summary,
msg=f"DECNET-PROBER started, interval {interval}s, {probe_summary}",
)
known_attackers: set[str] = set()
@@ -776,7 +604,6 @@ async def prober_worker(
if known_attackers:
await asyncio.to_thread(
_probe_cycle, known_attackers, probed,
jarm_ports, hassh_ports, tcp_ports,
log_path, json_path, timeout,
_publish_attacker,
record_rotation,