feat(swarm): pin worker cert SHA-256 fingerprint per host
AgentClient now verifies the worker's TLS cert fingerprint against SwarmHost.client_cert_fingerprint at __aenter__ time, on top of CA validation. Required before fanning master-orchestrated topology deploys out across multiple swarm hosts: CA pinning alone allows any cert signed by the master CA, which is too coarse once a single deploy can target N hosts. Mismatch raises FingerprintMismatchError so callers can distinguish "wrong worker on the wire" from a transport hiccup.
This commit is contained in:
@@ -16,7 +16,10 @@ The ``host`` is a SwarmHost dict returned by the repository.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pathlib
|
||||
import socket
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
@@ -29,6 +32,24 @@ from decnet.swarm import pki
|
||||
|
||||
log = get_logger("swarm.client")
|
||||
|
||||
|
||||
class FingerprintMismatchError(RuntimeError):
|
||||
"""Raised when the worker presents a cert whose SHA-256 fingerprint
|
||||
does not match ``SwarmHost.client_cert_fingerprint``.
|
||||
|
||||
Existence of this error class is the contract that lets the deployer
|
||||
distinguish "wrong worker on the wire" (security event, fail loud)
|
||||
from generic transport errors (retryable, mark slice failed)."""
|
||||
|
||||
def __init__(self, host: str, expected: str, actual: str) -> None:
|
||||
super().__init__(
|
||||
f"agent {host}: cert fingerprint mismatch "
|
||||
f"(expected={expected[:16]}…, got={actual[:16]}…)"
|
||||
)
|
||||
self.host = host
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
|
||||
# How long a single HTTP operation can take. Deploy is the long pole —
|
||||
# docker compose up pulls images, builds contexts, etc. Tune via env in a
|
||||
# later iteration if the default proves too short.
|
||||
@@ -99,6 +120,8 @@ class AgentClient:
|
||||
self._port = int(host.get("agent_port") or 8765)
|
||||
self._host_uuid = host.get("uuid")
|
||||
self._host_name = host.get("name")
|
||||
fp = host.get("client_cert_fingerprint")
|
||||
self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None
|
||||
else:
|
||||
if address is None or agent_port is None:
|
||||
raise ValueError(
|
||||
@@ -108,6 +131,7 @@ class AgentClient:
|
||||
self._port = int(agent_port)
|
||||
self._host_uuid = None
|
||||
self._host_name = None
|
||||
self._expected_fingerprint = None
|
||||
|
||||
self._identity = identity or ensure_master_identity()
|
||||
self._verify_hostname = verify_hostname
|
||||
@@ -135,8 +159,52 @@ class AgentClient:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def _fetch_peer_fingerprint(self) -> str:
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(
|
||||
str(self._identity.cert_path), str(self._identity.key_path)
|
||||
)
|
||||
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.check_hostname = self._verify_hostname
|
||||
sock = socket.create_connection((self._address, self._port), timeout=10.0)
|
||||
try:
|
||||
server_hostname = self._address if self._verify_hostname else None
|
||||
with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock:
|
||||
der = ssock.getpeercert(binary_form=True)
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except OSError:
|
||||
pass
|
||||
if not der:
|
||||
raise FingerprintMismatchError(
|
||||
f"{self._address}:{self._port}", self._expected_fingerprint or "", ""
|
||||
)
|
||||
return hashlib.sha256(der).hexdigest().lower()
|
||||
|
||||
async def _verify_pin(self) -> None:
|
||||
if not self._expected_fingerprint:
|
||||
# No pin known for this host (legacy enrollments / explicit address ctor).
|
||||
# Fall through to CA-only validation. Enrollment writes the fingerprint,
|
||||
# so any production host added via `swarm enroll` will have one.
|
||||
return
|
||||
actual = await asyncio.to_thread(self._fetch_peer_fingerprint)
|
||||
if actual != self._expected_fingerprint:
|
||||
raise FingerprintMismatchError(
|
||||
f"{self._address}:{self._port}",
|
||||
self._expected_fingerprint,
|
||||
actual,
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "AgentClient":
|
||||
self._client = self._build_client(_TIMEOUT_CONTROL)
|
||||
try:
|
||||
await self._verify_pin()
|
||||
except BaseException:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
raise
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: Any) -> None:
|
||||
|
||||
@@ -99,6 +99,49 @@ async def test_client_health_roundtrip(tmp_path: pathlib.Path) -> None:
|
||||
thread.join(timeout=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fingerprint_pin_accepts_matching_cert(tmp_path: pathlib.Path) -> None:
|
||||
"""AgentClient with the correct expected fingerprint connects normally."""
|
||||
port = _free_port()
|
||||
server, thread, master_id = _start_agent(tmp_path, port)
|
||||
try:
|
||||
worker_cert_pem = (tmp_path / "agent" / "worker.crt").read_bytes()
|
||||
expected = pki.fingerprint(worker_cert_pem)
|
||||
host = {
|
||||
"uuid": "h1",
|
||||
"name": "worker-test",
|
||||
"address": "127.0.0.1",
|
||||
"agent_port": port,
|
||||
"client_cert_fingerprint": expected,
|
||||
}
|
||||
async with swarm_client.AgentClient(host=host, identity=master_id) as agent:
|
||||
assert await agent.health() == {"status": "ok"}
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fingerprint_pin_rejects_mismatch(tmp_path: pathlib.Path) -> None:
|
||||
"""A wrong expected fingerprint must raise FingerprintMismatchError."""
|
||||
port = _free_port()
|
||||
server, thread, master_id = _start_agent(tmp_path, port)
|
||||
try:
|
||||
host = {
|
||||
"uuid": "h1",
|
||||
"name": "worker-test",
|
||||
"address": "127.0.0.1",
|
||||
"agent_port": port,
|
||||
"client_cert_fingerprint": "0" * 64,
|
||||
}
|
||||
with pytest.raises(swarm_client.FingerprintMismatchError):
|
||||
async with swarm_client.AgentClient(host=host, identity=master_id):
|
||||
pass
|
||||
finally:
|
||||
server.should_exit = True
|
||||
thread.join(timeout=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None:
|
||||
"""A client whose cert was issued by a DIFFERENT CA must be rejected."""
|
||||
|
||||
Reference in New Issue
Block a user