diff --git a/decnet/swarm/client.py b/decnet/swarm/client.py index ad67123b..0c9a4c4d 100644 --- a/decnet/swarm/client.py +++ b/decnet/swarm/client.py @@ -16,7 +16,10 @@ The ``host`` is a SwarmHost dict returned by the repository. """ from __future__ import annotations +import asyncio +import hashlib import pathlib +import socket import ssl from dataclasses import dataclass from typing import Any, Optional @@ -29,6 +32,24 @@ from decnet.swarm import pki log = get_logger("swarm.client") + +class FingerprintMismatchError(RuntimeError): + """Raised when the worker presents a cert whose SHA-256 fingerprint + does not match ``SwarmHost.client_cert_fingerprint``. + + Existence of this error class is the contract that lets the deployer + distinguish "wrong worker on the wire" (security event, fail loud) + from generic transport errors (retryable, mark slice failed).""" + + def __init__(self, host: str, expected: str, actual: str) -> None: + super().__init__( + f"agent {host}: cert fingerprint mismatch " + f"(expected={expected[:16]}…, got={actual[:16]}…)" + ) + self.host = host + self.expected = expected + self.actual = actual + # How long a single HTTP operation can take. Deploy is the long pole — # docker compose up pulls images, builds contexts, etc. Tune via env in a # later iteration if the default proves too short. @@ -99,6 +120,8 @@ class AgentClient: self._port = int(host.get("agent_port") or 8765) self._host_uuid = host.get("uuid") self._host_name = host.get("name") + fp = host.get("client_cert_fingerprint") + self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None else: if address is None or agent_port is None: raise ValueError( @@ -108,6 +131,7 @@ class AgentClient: self._port = int(agent_port) self._host_uuid = None self._host_name = None + self._expected_fingerprint = None self._identity = identity or ensure_master_identity() self._verify_hostname = verify_hostname @@ -135,8 +159,52 @@ class AgentClient: timeout=timeout, ) + def _fetch_peer_fingerprint(self) -> str: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_cert_chain( + str(self._identity.cert_path), str(self._identity.key_path) + ) + ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path)) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.check_hostname = self._verify_hostname + sock = socket.create_connection((self._address, self._port), timeout=10.0) + try: + server_hostname = self._address if self._verify_hostname else None + with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock: + der = ssock.getpeercert(binary_form=True) + finally: + try: + sock.close() + except OSError: + pass + if not der: + raise FingerprintMismatchError( + f"{self._address}:{self._port}", self._expected_fingerprint or "", "" + ) + return hashlib.sha256(der).hexdigest().lower() + + async def _verify_pin(self) -> None: + if not self._expected_fingerprint: + # No pin known for this host (legacy enrollments / explicit address ctor). + # Fall through to CA-only validation. Enrollment writes the fingerprint, + # so any production host added via `swarm enroll` will have one. + return + actual = await asyncio.to_thread(self._fetch_peer_fingerprint) + if actual != self._expected_fingerprint: + raise FingerprintMismatchError( + f"{self._address}:{self._port}", + self._expected_fingerprint, + actual, + ) + async def __aenter__(self) -> "AgentClient": self._client = self._build_client(_TIMEOUT_CONTROL) + try: + await self._verify_pin() + except BaseException: + await self._client.aclose() + self._client = None + raise return self async def __aexit__(self, *exc: Any) -> None: diff --git a/tests/swarm/test_client_agent_roundtrip.py b/tests/swarm/test_client_agent_roundtrip.py index 59216c7c..4b351bac 100644 --- a/tests/swarm/test_client_agent_roundtrip.py +++ b/tests/swarm/test_client_agent_roundtrip.py @@ -99,6 +99,49 @@ async def test_client_health_roundtrip(tmp_path: pathlib.Path) -> None: thread.join(timeout=5) +@pytest.mark.asyncio +async def test_fingerprint_pin_accepts_matching_cert(tmp_path: pathlib.Path) -> None: + """AgentClient with the correct expected fingerprint connects normally.""" + port = _free_port() + server, thread, master_id = _start_agent(tmp_path, port) + try: + worker_cert_pem = (tmp_path / "agent" / "worker.crt").read_bytes() + expected = pki.fingerprint(worker_cert_pem) + host = { + "uuid": "h1", + "name": "worker-test", + "address": "127.0.0.1", + "agent_port": port, + "client_cert_fingerprint": expected, + } + async with swarm_client.AgentClient(host=host, identity=master_id) as agent: + assert await agent.health() == {"status": "ok"} + finally: + server.should_exit = True + thread.join(timeout=5) + + +@pytest.mark.asyncio +async def test_fingerprint_pin_rejects_mismatch(tmp_path: pathlib.Path) -> None: + """A wrong expected fingerprint must raise FingerprintMismatchError.""" + port = _free_port() + server, thread, master_id = _start_agent(tmp_path, port) + try: + host = { + "uuid": "h1", + "name": "worker-test", + "address": "127.0.0.1", + "agent_port": port, + "client_cert_fingerprint": "0" * 64, + } + with pytest.raises(swarm_client.FingerprintMismatchError): + async with swarm_client.AgentClient(host=host, identity=master_id): + pass + finally: + server.should_exit = True + thread.join(timeout=5) + + @pytest.mark.asyncio async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None: """A client whose cert was issued by a DIFFERENT CA must be rejected."""