feat(swarm): pin worker cert SHA-256 fingerprint per host

AgentClient now verifies the worker's TLS cert fingerprint against
SwarmHost.client_cert_fingerprint at __aenter__ time, on top of CA
validation. Required before fanning master-orchestrated topology
deploys out across multiple swarm hosts: CA pinning alone allows any
cert signed by the master CA, which is too coarse once a single
deploy can target N hosts.

Mismatch raises FingerprintMismatchError so callers can distinguish
"wrong worker on the wire" from a transport hiccup.
This commit is contained in:
2026-04-25 03:01:15 -04:00
parent efdaa87ee2
commit 36031fa10a
2 changed files with 111 additions and 0 deletions

View File

@@ -16,7 +16,10 @@ The ``host`` is a SwarmHost dict returned by the repository.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import hashlib
import pathlib import pathlib
import socket
import ssl import ssl
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
@@ -29,6 +32,24 @@ from decnet.swarm import pki
log = get_logger("swarm.client") log = get_logger("swarm.client")
class FingerprintMismatchError(RuntimeError):
"""Raised when the worker presents a cert whose SHA-256 fingerprint
does not match ``SwarmHost.client_cert_fingerprint``.
Existence of this error class is the contract that lets the deployer
distinguish "wrong worker on the wire" (security event, fail loud)
from generic transport errors (retryable, mark slice failed)."""
def __init__(self, host: str, expected: str, actual: str) -> None:
super().__init__(
f"agent {host}: cert fingerprint mismatch "
f"(expected={expected[:16]}…, got={actual[:16]}…)"
)
self.host = host
self.expected = expected
self.actual = actual
# How long a single HTTP operation can take. Deploy is the long pole — # How long a single HTTP operation can take. Deploy is the long pole —
# docker compose up pulls images, builds contexts, etc. Tune via env in a # docker compose up pulls images, builds contexts, etc. Tune via env in a
# later iteration if the default proves too short. # later iteration if the default proves too short.
@@ -99,6 +120,8 @@ class AgentClient:
self._port = int(host.get("agent_port") or 8765) self._port = int(host.get("agent_port") or 8765)
self._host_uuid = host.get("uuid") self._host_uuid = host.get("uuid")
self._host_name = host.get("name") self._host_name = host.get("name")
fp = host.get("client_cert_fingerprint")
self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None
else: else:
if address is None or agent_port is None: if address is None or agent_port is None:
raise ValueError( raise ValueError(
@@ -108,6 +131,7 @@ class AgentClient:
self._port = int(agent_port) self._port = int(agent_port)
self._host_uuid = None self._host_uuid = None
self._host_name = None self._host_name = None
self._expected_fingerprint = None
self._identity = identity or ensure_master_identity() self._identity = identity or ensure_master_identity()
self._verify_hostname = verify_hostname self._verify_hostname = verify_hostname
@@ -135,8 +159,52 @@ class AgentClient:
timeout=timeout, timeout=timeout,
) )
def _fetch_peer_fingerprint(self) -> str:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(
str(self._identity.cert_path), str(self._identity.key_path)
)
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = self._verify_hostname
sock = socket.create_connection((self._address, self._port), timeout=10.0)
try:
server_hostname = self._address if self._verify_hostname else None
with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock:
der = ssock.getpeercert(binary_form=True)
finally:
try:
sock.close()
except OSError:
pass
if not der:
raise FingerprintMismatchError(
f"{self._address}:{self._port}", self._expected_fingerprint or "", ""
)
return hashlib.sha256(der).hexdigest().lower()
async def _verify_pin(self) -> None:
if not self._expected_fingerprint:
# No pin known for this host (legacy enrollments / explicit address ctor).
# Fall through to CA-only validation. Enrollment writes the fingerprint,
# so any production host added via `swarm enroll` will have one.
return
actual = await asyncio.to_thread(self._fetch_peer_fingerprint)
if actual != self._expected_fingerprint:
raise FingerprintMismatchError(
f"{self._address}:{self._port}",
self._expected_fingerprint,
actual,
)
async def __aenter__(self) -> "AgentClient": async def __aenter__(self) -> "AgentClient":
self._client = self._build_client(_TIMEOUT_CONTROL) self._client = self._build_client(_TIMEOUT_CONTROL)
try:
await self._verify_pin()
except BaseException:
await self._client.aclose()
self._client = None
raise
return self return self
async def __aexit__(self, *exc: Any) -> None: async def __aexit__(self, *exc: Any) -> None:

View File

@@ -99,6 +99,49 @@ async def test_client_health_roundtrip(tmp_path: pathlib.Path) -> None:
thread.join(timeout=5) thread.join(timeout=5)
@pytest.mark.asyncio
async def test_fingerprint_pin_accepts_matching_cert(tmp_path: pathlib.Path) -> None:
"""AgentClient with the correct expected fingerprint connects normally."""
port = _free_port()
server, thread, master_id = _start_agent(tmp_path, port)
try:
worker_cert_pem = (tmp_path / "agent" / "worker.crt").read_bytes()
expected = pki.fingerprint(worker_cert_pem)
host = {
"uuid": "h1",
"name": "worker-test",
"address": "127.0.0.1",
"agent_port": port,
"client_cert_fingerprint": expected,
}
async with swarm_client.AgentClient(host=host, identity=master_id) as agent:
assert await agent.health() == {"status": "ok"}
finally:
server.should_exit = True
thread.join(timeout=5)
@pytest.mark.asyncio
async def test_fingerprint_pin_rejects_mismatch(tmp_path: pathlib.Path) -> None:
"""A wrong expected fingerprint must raise FingerprintMismatchError."""
port = _free_port()
server, thread, master_id = _start_agent(tmp_path, port)
try:
host = {
"uuid": "h1",
"name": "worker-test",
"address": "127.0.0.1",
"agent_port": port,
"client_cert_fingerprint": "0" * 64,
}
with pytest.raises(swarm_client.FingerprintMismatchError):
async with swarm_client.AgentClient(host=host, identity=master_id):
pass
finally:
server.should_exit = True
thread.join(timeout=5)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None: async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None:
"""A client whose cert was issued by a DIFFERENT CA must be rejected.""" """A client whose cert was issued by a DIFFERENT CA must be rejected."""