Files
DECNET/decnet/swarm/client.py
anti 0c77cdab32 feat(swarm): master AgentClient — mTLS httpx wrapper around worker API
decnet.swarm.client exposes:
- MasterIdentity / ensure_master_identity(): the master's own CA-signed
  client bundle, issued once into ~/.decnet/ca/master/.
- AgentClient: async-context httpx wrapper that talks to a worker agent
  over mTLS. health/status/deploy/teardown methods mirror the agent API.

SSL context is built from a bare ssl.SSLContext(PROTOCOL_TLS_CLIENT)
instead of httpx.create_ssl_context — the latter layers on default-CA
and purpose logic that broke private-CA mTLS. Server cert is pinned by
CA + chain, not DNS (workers enroll with arbitrary SANs).

tests/swarm/test_client_agent_roundtrip.py spins uvicorn in-process
with real certs on disk and verifies:
- A CA-signed master client passes health + status calls.
- An impostor whose cert comes from a different CA cannot connect.
2026-04-18 19:08:36 -04:00

195 lines
6.8 KiB
Python

"""Master-side HTTP client that talks to a worker's DECNET agent.
All traffic is mTLS: the master presents a cert issued by its own CA (which
workers trust) and the master validates the worker's cert against the same
CA. In practice the "client cert" the master shows is just another cert
signed by itself — the master is both the CA and the sole control-plane
client.
Usage:
async with AgentClient(host) as agent:
await agent.deploy(config)
status = await agent.status()
The ``host`` is a SwarmHost dict returned by the repository.
"""
from __future__ import annotations
import pathlib
import ssl
from dataclasses import dataclass
from typing import Any, Optional
import httpx
from decnet.config import DecnetConfig
from decnet.logging import get_logger
from decnet.swarm import pki
log = get_logger("swarm.client")
# How long a single HTTP operation can take. Deploy is the long pole —
# docker compose up pulls images, builds contexts, etc. Tune via env in a
# later iteration if the default proves too short.
_TIMEOUT_DEPLOY = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=5.0)
_TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=15.0, write=5.0, pool=5.0)
@dataclass(frozen=True)
class MasterIdentity:
"""Paths to the master's own mTLS client bundle.
The master uses ONE master-client cert to talk to every worker. It is
signed by the DECNET CA (same CA that signs worker certs). Stored
under ``~/.decnet/ca/master/`` by ``ensure_master_identity``.
"""
key_path: pathlib.Path
cert_path: pathlib.Path
ca_cert_path: pathlib.Path
def ensure_master_identity(
ca_dir: pathlib.Path = pki.DEFAULT_CA_DIR,
) -> MasterIdentity:
"""Create (or load) the master's own client cert.
Called once by the swarm controller on startup and by the CLI before
any master→worker call. Idempotent.
"""
ca = pki.ensure_ca(ca_dir)
master_dir = ca_dir / "master"
bundle = pki.load_worker_bundle(master_dir)
if bundle is None:
issued = pki.issue_worker_cert(ca, "decnet-master", ["127.0.0.1", "decnet-master"])
pki.write_worker_bundle(issued, master_dir)
return MasterIdentity(
key_path=master_dir / "worker.key",
cert_path=master_dir / "worker.crt",
ca_cert_path=master_dir / "ca.crt",
)
class AgentClient:
"""Thin async wrapper around the worker agent's HTTP API."""
def __init__(
self,
host: dict[str, Any] | None = None,
*,
address: Optional[str] = None,
agent_port: Optional[int] = None,
identity: Optional[MasterIdentity] = None,
verify_hostname: bool = False,
):
"""Either pass a SwarmHost dict, or explicit address/port.
``verify_hostname`` stays False by default because the worker's
cert SAN is populated from the operator-supplied address list, not
from modern TLS hostname-verification semantics. The mTLS client
cert + CA pinning are what authenticate the peer.
"""
if host is not None:
self._address = host["address"]
self._port = int(host.get("agent_port") or 8765)
self._host_uuid = host.get("uuid")
self._host_name = host.get("name")
else:
if address is None or agent_port is None:
raise ValueError(
"AgentClient requires either a host dict or address+agent_port"
)
self._address = address
self._port = int(agent_port)
self._host_uuid = None
self._host_name = None
self._identity = identity or ensure_master_identity()
self._verify_hostname = verify_hostname
self._client: Optional[httpx.AsyncClient] = None
# --------------------------------------------------------------- lifecycle
def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient:
# Build the SSL context manually — httpx.create_ssl_context layers on
# purpose/ALPN/default-CA logic that doesn't compose with private-CA
# mTLS in all combinations. A bare SSLContext is predictable.
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(
str(self._identity.cert_path), str(self._identity.key_path)
)
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
ctx.verify_mode = ssl.CERT_REQUIRED
# Pin by CA + cert chain, not by DNS — workers enroll with arbitrary
# SANs (IPs, hostnames) and we don't want to force operators to keep
# those in sync with whatever URL the master happens to use.
ctx.check_hostname = self._verify_hostname
return httpx.AsyncClient(
base_url=f"https://{self._address}:{self._port}",
verify=ctx,
timeout=timeout,
)
async def __aenter__(self) -> "AgentClient":
self._client = self._build_client(_TIMEOUT_CONTROL)
return self
async def __aexit__(self, *exc: Any) -> None:
if self._client:
await self._client.aclose()
self._client = None
def _require_client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError("AgentClient used outside `async with` block")
return self._client
# ----------------------------------------------------------------- RPCs
async def health(self) -> dict[str, Any]:
resp = await self._require_client().get("/health")
resp.raise_for_status()
return resp.json()
async def status(self) -> dict[str, Any]:
resp = await self._require_client().get("/status")
resp.raise_for_status()
return resp.json()
async def deploy(
self,
config: DecnetConfig,
*,
dry_run: bool = False,
no_cache: bool = False,
) -> dict[str, Any]:
body = {
"config": config.model_dump(mode="json"),
"dry_run": dry_run,
"no_cache": no_cache,
}
# Swap in a long-deploy timeout for this call only.
old = self._require_client().timeout
self._require_client().timeout = _TIMEOUT_DEPLOY
try:
resp = await self._require_client().post("/deploy", json=body)
finally:
self._require_client().timeout = old
resp.raise_for_status()
return resp.json()
async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]:
resp = await self._require_client().post(
"/teardown", json={"decky_id": decky_id}
)
resp.raise_for_status()
return resp.json()
# -------------------------------------------------------------- diagnostics
def __repr__(self) -> str:
return (
f"AgentClient(name={self._host_name!r}, "
f"address={self._address!r}, port={self._port})"
)