feat(swarm): pin worker cert SHA-256 fingerprint per host
AgentClient now verifies the worker's TLS cert fingerprint against SwarmHost.client_cert_fingerprint at __aenter__ time, on top of CA validation. Required before fanning master-orchestrated topology deploys out across multiple swarm hosts: CA pinning alone allows any cert signed by the master CA, which is too coarse once a single deploy can target N hosts. Mismatch raises FingerprintMismatchError so callers can distinguish "wrong worker on the wire" from a transport hiccup.
This commit is contained in:
@@ -16,7 +16,10 @@ The ``host`` is a SwarmHost dict returned by the repository.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -29,6 +32,24 @@ from decnet.swarm import pki
|
|||||||
|
|
||||||
log = get_logger("swarm.client")
|
log = get_logger("swarm.client")
|
||||||
|
|
||||||
|
|
||||||
|
class FingerprintMismatchError(RuntimeError):
|
||||||
|
"""Raised when the worker presents a cert whose SHA-256 fingerprint
|
||||||
|
does not match ``SwarmHost.client_cert_fingerprint``.
|
||||||
|
|
||||||
|
Existence of this error class is the contract that lets the deployer
|
||||||
|
distinguish "wrong worker on the wire" (security event, fail loud)
|
||||||
|
from generic transport errors (retryable, mark slice failed)."""
|
||||||
|
|
||||||
|
def __init__(self, host: str, expected: str, actual: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"agent {host}: cert fingerprint mismatch "
|
||||||
|
f"(expected={expected[:16]}…, got={actual[:16]}…)"
|
||||||
|
)
|
||||||
|
self.host = host
|
||||||
|
self.expected = expected
|
||||||
|
self.actual = actual
|
||||||
|
|
||||||
# How long a single HTTP operation can take. Deploy is the long pole —
|
# How long a single HTTP operation can take. Deploy is the long pole —
|
||||||
# docker compose up pulls images, builds contexts, etc. Tune via env in a
|
# docker compose up pulls images, builds contexts, etc. Tune via env in a
|
||||||
# later iteration if the default proves too short.
|
# later iteration if the default proves too short.
|
||||||
@@ -99,6 +120,8 @@ class AgentClient:
|
|||||||
self._port = int(host.get("agent_port") or 8765)
|
self._port = int(host.get("agent_port") or 8765)
|
||||||
self._host_uuid = host.get("uuid")
|
self._host_uuid = host.get("uuid")
|
||||||
self._host_name = host.get("name")
|
self._host_name = host.get("name")
|
||||||
|
fp = host.get("client_cert_fingerprint")
|
||||||
|
self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None
|
||||||
else:
|
else:
|
||||||
if address is None or agent_port is None:
|
if address is None or agent_port is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -108,6 +131,7 @@ class AgentClient:
|
|||||||
self._port = int(agent_port)
|
self._port = int(agent_port)
|
||||||
self._host_uuid = None
|
self._host_uuid = None
|
||||||
self._host_name = None
|
self._host_name = None
|
||||||
|
self._expected_fingerprint = None
|
||||||
|
|
||||||
self._identity = identity or ensure_master_identity()
|
self._identity = identity or ensure_master_identity()
|
||||||
self._verify_hostname = verify_hostname
|
self._verify_hostname = verify_hostname
|
||||||
@@ -135,8 +159,52 @@ class AgentClient:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _fetch_peer_fingerprint(self) -> str:
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
ctx.load_cert_chain(
|
||||||
|
str(self._identity.cert_path), str(self._identity.key_path)
|
||||||
|
)
|
||||||
|
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||||
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
ctx.check_hostname = self._verify_hostname
|
||||||
|
sock = socket.create_connection((self._address, self._port), timeout=10.0)
|
||||||
|
try:
|
||||||
|
server_hostname = self._address if self._verify_hostname else None
|
||||||
|
with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock:
|
||||||
|
der = ssock.getpeercert(binary_form=True)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
sock.close()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
if not der:
|
||||||
|
raise FingerprintMismatchError(
|
||||||
|
f"{self._address}:{self._port}", self._expected_fingerprint or "", ""
|
||||||
|
)
|
||||||
|
return hashlib.sha256(der).hexdigest().lower()
|
||||||
|
|
||||||
|
async def _verify_pin(self) -> None:
|
||||||
|
if not self._expected_fingerprint:
|
||||||
|
# No pin known for this host (legacy enrollments / explicit address ctor).
|
||||||
|
# Fall through to CA-only validation. Enrollment writes the fingerprint,
|
||||||
|
# so any production host added via `swarm enroll` will have one.
|
||||||
|
return
|
||||||
|
actual = await asyncio.to_thread(self._fetch_peer_fingerprint)
|
||||||
|
if actual != self._expected_fingerprint:
|
||||||
|
raise FingerprintMismatchError(
|
||||||
|
f"{self._address}:{self._port}",
|
||||||
|
self._expected_fingerprint,
|
||||||
|
actual,
|
||||||
|
)
|
||||||
|
|
||||||
async def __aenter__(self) -> "AgentClient":
|
async def __aenter__(self) -> "AgentClient":
|
||||||
self._client = self._build_client(_TIMEOUT_CONTROL)
|
self._client = self._build_client(_TIMEOUT_CONTROL)
|
||||||
|
try:
|
||||||
|
await self._verify_pin()
|
||||||
|
except BaseException:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
raise
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *exc: Any) -> None:
|
async def __aexit__(self, *exc: Any) -> None:
|
||||||
|
|||||||
@@ -99,6 +99,49 @@ async def test_client_health_roundtrip(tmp_path: pathlib.Path) -> None:
|
|||||||
thread.join(timeout=5)
|
thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fingerprint_pin_accepts_matching_cert(tmp_path: pathlib.Path) -> None:
|
||||||
|
"""AgentClient with the correct expected fingerprint connects normally."""
|
||||||
|
port = _free_port()
|
||||||
|
server, thread, master_id = _start_agent(tmp_path, port)
|
||||||
|
try:
|
||||||
|
worker_cert_pem = (tmp_path / "agent" / "worker.crt").read_bytes()
|
||||||
|
expected = pki.fingerprint(worker_cert_pem)
|
||||||
|
host = {
|
||||||
|
"uuid": "h1",
|
||||||
|
"name": "worker-test",
|
||||||
|
"address": "127.0.0.1",
|
||||||
|
"agent_port": port,
|
||||||
|
"client_cert_fingerprint": expected,
|
||||||
|
}
|
||||||
|
async with swarm_client.AgentClient(host=host, identity=master_id) as agent:
|
||||||
|
assert await agent.health() == {"status": "ok"}
|
||||||
|
finally:
|
||||||
|
server.should_exit = True
|
||||||
|
thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fingerprint_pin_rejects_mismatch(tmp_path: pathlib.Path) -> None:
|
||||||
|
"""A wrong expected fingerprint must raise FingerprintMismatchError."""
|
||||||
|
port = _free_port()
|
||||||
|
server, thread, master_id = _start_agent(tmp_path, port)
|
||||||
|
try:
|
||||||
|
host = {
|
||||||
|
"uuid": "h1",
|
||||||
|
"name": "worker-test",
|
||||||
|
"address": "127.0.0.1",
|
||||||
|
"agent_port": port,
|
||||||
|
"client_cert_fingerprint": "0" * 64,
|
||||||
|
}
|
||||||
|
with pytest.raises(swarm_client.FingerprintMismatchError):
|
||||||
|
async with swarm_client.AgentClient(host=host, identity=master_id):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
server.should_exit = True
|
||||||
|
thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None:
|
async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None:
|
||||||
"""A client whose cert was issued by a DIFFERENT CA must be rejected."""
|
"""A client whose cert was issued by a DIFFERENT CA must be rejected."""
|
||||||
|
|||||||
Reference in New Issue
Block a user