feat(swarm): master AgentClient — mTLS httpx wrapper around worker API

decnet.swarm.client exposes:
- MasterIdentity / ensure_master_identity(): the master's own CA-signed
  client bundle, issued once into ~/.decnet/ca/master/.
- AgentClient: async-context httpx wrapper that talks to a worker agent
  over mTLS. health/status/deploy/teardown methods mirror the agent API.

SSL context is built from a bare ssl.SSLContext(PROTOCOL_TLS_CLIENT)
instead of httpx.create_ssl_context — the latter layers on default-CA
and purpose logic that broke private-CA mTLS. Server cert is pinned by
CA + chain, not DNS (workers enroll with arbitrary SANs).

tests/swarm/test_client_agent_roundtrip.py spins uvicorn in-process
with real certs on disk and verifies:
- A CA-signed master client passes health + status calls.
- An impostor whose cert comes from a different CA cannot connect.
This commit is contained in:
2026-04-18 19:08:36 -04:00
parent 8257bcc031
commit 0c77cdab32
2 changed files with 321 additions and 0 deletions

194
decnet/swarm/client.py Normal file
View File

@@ -0,0 +1,194 @@
"""Master-side HTTP client that talks to a worker's DECNET agent.
All traffic is mTLS: the master presents a cert issued by its own CA (which
workers trust) and the master validates the worker's cert against the same
CA. In practice the "client cert" the master shows is just another cert
signed by itself — the master is both the CA and the sole control-plane
client.
Usage:
async with AgentClient(host) as agent:
await agent.deploy(config)
status = await agent.status()
The ``host`` is a SwarmHost dict returned by the repository.
"""
from __future__ import annotations
import pathlib
import ssl
from dataclasses import dataclass
from typing import Any, Optional
import httpx
from decnet.config import DecnetConfig
from decnet.logging import get_logger
from decnet.swarm import pki
log = get_logger("swarm.client")
# How long a single HTTP operation can take. Deploy is the long pole —
# docker compose up pulls images, builds contexts, etc. Tune via env in a
# later iteration if the default proves too short.
_TIMEOUT_DEPLOY = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=5.0)
_TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=15.0, write=5.0, pool=5.0)
@dataclass(frozen=True)
class MasterIdentity:
"""Paths to the master's own mTLS client bundle.
The master uses ONE master-client cert to talk to every worker. It is
signed by the DECNET CA (same CA that signs worker certs). Stored
under ``~/.decnet/ca/master/`` by ``ensure_master_identity``.
"""
key_path: pathlib.Path
cert_path: pathlib.Path
ca_cert_path: pathlib.Path
def ensure_master_identity(
ca_dir: pathlib.Path = pki.DEFAULT_CA_DIR,
) -> MasterIdentity:
"""Create (or load) the master's own client cert.
Called once by the swarm controller on startup and by the CLI before
any master→worker call. Idempotent.
"""
ca = pki.ensure_ca(ca_dir)
master_dir = ca_dir / "master"
bundle = pki.load_worker_bundle(master_dir)
if bundle is None:
issued = pki.issue_worker_cert(ca, "decnet-master", ["127.0.0.1", "decnet-master"])
pki.write_worker_bundle(issued, master_dir)
return MasterIdentity(
key_path=master_dir / "worker.key",
cert_path=master_dir / "worker.crt",
ca_cert_path=master_dir / "ca.crt",
)
class AgentClient:
"""Thin async wrapper around the worker agent's HTTP API."""
def __init__(
self,
host: dict[str, Any] | None = None,
*,
address: Optional[str] = None,
agent_port: Optional[int] = None,
identity: Optional[MasterIdentity] = None,
verify_hostname: bool = False,
):
"""Either pass a SwarmHost dict, or explicit address/port.
``verify_hostname`` stays False by default because the worker's
cert SAN is populated from the operator-supplied address list, not
from modern TLS hostname-verification semantics. The mTLS client
cert + CA pinning are what authenticate the peer.
"""
if host is not None:
self._address = host["address"]
self._port = int(host.get("agent_port") or 8765)
self._host_uuid = host.get("uuid")
self._host_name = host.get("name")
else:
if address is None or agent_port is None:
raise ValueError(
"AgentClient requires either a host dict or address+agent_port"
)
self._address = address
self._port = int(agent_port)
self._host_uuid = None
self._host_name = None
self._identity = identity or ensure_master_identity()
self._verify_hostname = verify_hostname
self._client: Optional[httpx.AsyncClient] = None
# --------------------------------------------------------------- lifecycle
def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient:
# Build the SSL context manually — httpx.create_ssl_context layers on
# purpose/ALPN/default-CA logic that doesn't compose with private-CA
# mTLS in all combinations. A bare SSLContext is predictable.
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(
str(self._identity.cert_path), str(self._identity.key_path)
)
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
ctx.verify_mode = ssl.CERT_REQUIRED
# Pin by CA + cert chain, not by DNS — workers enroll with arbitrary
# SANs (IPs, hostnames) and we don't want to force operators to keep
# those in sync with whatever URL the master happens to use.
ctx.check_hostname = self._verify_hostname
return httpx.AsyncClient(
base_url=f"https://{self._address}:{self._port}",
verify=ctx,
timeout=timeout,
)
async def __aenter__(self) -> "AgentClient":
self._client = self._build_client(_TIMEOUT_CONTROL)
return self
async def __aexit__(self, *exc: Any) -> None:
if self._client:
await self._client.aclose()
self._client = None
def _require_client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError("AgentClient used outside `async with` block")
return self._client
# ----------------------------------------------------------------- RPCs
async def health(self) -> dict[str, Any]:
resp = await self._require_client().get("/health")
resp.raise_for_status()
return resp.json()
async def status(self) -> dict[str, Any]:
resp = await self._require_client().get("/status")
resp.raise_for_status()
return resp.json()
async def deploy(
self,
config: DecnetConfig,
*,
dry_run: bool = False,
no_cache: bool = False,
) -> dict[str, Any]:
body = {
"config": config.model_dump(mode="json"),
"dry_run": dry_run,
"no_cache": no_cache,
}
# Swap in a long-deploy timeout for this call only.
old = self._require_client().timeout
self._require_client().timeout = _TIMEOUT_DEPLOY
try:
resp = await self._require_client().post("/deploy", json=body)
finally:
self._require_client().timeout = old
resp.raise_for_status()
return resp.json()
async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]:
resp = await self._require_client().post(
"/teardown", json={"decky_id": decky_id}
)
resp.raise_for_status()
return resp.json()
# -------------------------------------------------------------- diagnostics
def __repr__(self) -> str:
return (
f"AgentClient(name={self._host_name!r}, "
f"address={self._address!r}, port={self._port})"
)

View File

@@ -0,0 +1,127 @@
"""End-to-end test: AgentClient talks to a live worker agent over mTLS.
Spins up uvicorn in-process on an ephemeral port with real cert files on
disk. Confirms:
1. The health endpoint works when the client presents a CA-signed cert.
2. An impostor client (cert signed by a different CA) is rejected at TLS
time.
"""
from __future__ import annotations
import asyncio
import pathlib
import socket
import threading
import time
import ssl
import httpx
import pytest
import uvicorn
from decnet.agent.app import app as agent_app
from decnet.swarm import client as swarm_client
from decnet.swarm import pki
def _free_port() -> int:
s = socket.socket()
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
return port
def _start_agent(
tmp_path: pathlib.Path, port: int
) -> tuple[uvicorn.Server, threading.Thread, swarm_client.MasterIdentity]:
"""Provision a CA, sign a worker cert + a master cert, start uvicorn."""
ca_dir = tmp_path / "ca"
pki.ensure_ca(ca_dir)
# Worker bundle
worker_dir = tmp_path / "agent"
pki.write_worker_bundle(
pki.issue_worker_cert(pki.load_ca(ca_dir), "worker-test", ["127.0.0.1"]),
worker_dir,
)
# Master identity (used by AgentClient as a client cert)
master_id = swarm_client.ensure_master_identity(ca_dir)
config = uvicorn.Config(
agent_app,
host="127.0.0.1",
port=port,
log_level="warning",
ssl_keyfile=str(worker_dir / "worker.key"),
ssl_certfile=str(worker_dir / "worker.crt"),
ssl_ca_certs=str(worker_dir / "ca.crt"),
# 2 == ssl.CERT_REQUIRED
ssl_cert_reqs=2,
)
server = uvicorn.Server(config)
def _run() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(server.serve())
loop.close()
thread = threading.Thread(target=_run, daemon=True)
thread.start()
# Wait for server to be listening
deadline = time.time() + 5
while time.time() < deadline:
if server.started:
return server, thread, master_id
time.sleep(0.05)
raise RuntimeError("agent did not start within 5s")
@pytest.mark.asyncio
async def test_client_health_roundtrip(tmp_path: pathlib.Path) -> None:
port = _free_port()
server, thread, master_id = _start_agent(tmp_path, port)
try:
async with swarm_client.AgentClient(
address="127.0.0.1", agent_port=port, identity=master_id
) as agent:
body = await agent.health()
assert body == {"status": "ok"}
snap = await agent.status()
assert "deployed" in snap
finally:
server.should_exit = True
thread.join(timeout=5)
@pytest.mark.asyncio
async def test_impostor_client_cannot_connect(tmp_path: pathlib.Path) -> None:
"""A client whose cert was issued by a DIFFERENT CA must be rejected."""
port = _free_port()
server, thread, _master_id = _start_agent(tmp_path, port)
try:
evil_ca = pki.generate_ca("Evil CA")
evil_dir = tmp_path / "evil"
pki.write_worker_bundle(
pki.issue_worker_cert(evil_ca, "evil-master", ["127.0.0.1"]), evil_dir
)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(str(evil_dir / "worker.crt"), str(evil_dir / "worker.key"))
ctx.load_verify_locations(cafile=str(evil_dir / "ca.crt"))
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = False
async with httpx.AsyncClient(
base_url=f"https://127.0.0.1:{port}", verify=ctx, timeout=5.0
) as ac:
with pytest.raises(
(httpx.ConnectError, httpx.ReadError, httpx.RemoteProtocolError)
):
await ac.get("/health")
finally:
server.should_exit = True
thread.join(timeout=5)