"""Master-side HTTP client that talks to a worker's DECNET agent. All traffic is mTLS: the master presents a cert issued by its own CA (which workers trust) and the master validates the worker's cert against the same CA. In practice the "client cert" the master shows is just another cert signed by itself — the master is both the CA and the sole control-plane client. Usage: async with AgentClient(host) as agent: await agent.deploy(config) status = await agent.status() The ``host`` is a SwarmHost dict returned by the repository. """ from __future__ import annotations import asyncio import hashlib import pathlib import socket import ssl from dataclasses import dataclass from typing import Any, Optional import httpx from decnet.config import DecnetConfig from decnet.logging import get_logger from decnet.swarm import pki log = get_logger("swarm.client") class FingerprintMismatchError(RuntimeError): """Raised when the worker presents a cert whose SHA-256 fingerprint does not match ``SwarmHost.client_cert_fingerprint``. Existence of this error class is the contract that lets the deployer distinguish "wrong worker on the wire" (security event, fail loud) from generic transport errors (retryable, mark slice failed).""" def __init__(self, host: str, expected: str, actual: str) -> None: super().__init__( f"agent {host}: cert fingerprint mismatch " f"(expected={expected[:16]}…, got={actual[:16]}…)" ) self.host = host self.expected = expected self.actual = actual # How long a single HTTP operation can take. Deploy is the long pole — # docker compose up pulls images, builds contexts, etc. Tune via env in a # later iteration if the default proves too short. _TIMEOUT_DEPLOY = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=5.0) _TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=15.0, write=5.0, pool=5.0) # Topology apply pulls images + runs compose on the agent — same ball-park # as a fleet deploy. Teardown is faster but still long enough we can't # reuse the control timeout. _TIMEOUT_TOPOLOGY_APPLY = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=5.0) _TIMEOUT_TOPOLOGY_TEARDOWN = httpx.Timeout(connect=10.0, read=300.0, write=30.0, pool=5.0) @dataclass(frozen=True) class MasterIdentity: """Paths to the master's own mTLS client bundle. The master uses ONE master-client cert to talk to every worker. It is signed by the DECNET CA (same CA that signs worker certs). Stored under ``~/.decnet/ca/master/`` by ``ensure_master_identity``. """ key_path: pathlib.Path cert_path: pathlib.Path ca_cert_path: pathlib.Path def ensure_master_identity( ca_dir: pathlib.Path = pki.DEFAULT_CA_DIR, ) -> MasterIdentity: """Create (or load) the master's own client cert. Called once by the swarm controller on startup and by the CLI before any master→worker call. Idempotent. """ ca = pki.ensure_ca(ca_dir) master_dir = ca_dir / "master" bundle = pki.load_worker_bundle(master_dir) if bundle is None: issued = pki.issue_worker_cert(ca, "decnet-master", ["127.0.0.1", "decnet-master"]) pki.write_worker_bundle(issued, master_dir) return MasterIdentity( key_path=master_dir / "worker.key", cert_path=master_dir / "worker.crt", ca_cert_path=master_dir / "ca.crt", ) class AgentClient: """Thin async wrapper around the worker agent's HTTP API.""" def __init__( self, host: dict[str, Any] | None = None, *, address: Optional[str] = None, agent_port: Optional[int] = None, identity: Optional[MasterIdentity] = None, verify_hostname: Optional[bool] = None, ): """Either pass a SwarmHost dict, or explicit address/port. ``verify_hostname`` defers to ``DECNET_VERIFY_HOSTNAME`` when the caller doesn't pass an explicit value — production deploys flip the env var on so the worker's cert SAN must match the address the master connects to, on top of the existing CA + fingerprint pin. Defaults to False so dev/test enrollments with mismatched SANs keep working unchanged. """ if verify_hostname is None: from decnet.env import DECNET_VERIFY_HOSTNAME verify_hostname = DECNET_VERIFY_HOSTNAME if host is not None: self._address = host["address"] self._port = int(host.get("agent_port") or 8765) self._host_uuid = host.get("uuid") self._host_name = host.get("name") fp = host.get("client_cert_fingerprint") self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None else: if address is None or agent_port is None: raise ValueError( "AgentClient requires either a host dict or address+agent_port" ) self._address = address self._port = int(agent_port) self._host_uuid = None self._host_name = None self._expected_fingerprint = None self._identity = identity or ensure_master_identity() self._verify_hostname = verify_hostname self._client: Optional[httpx.AsyncClient] = None # --------------------------------------------------------------- lifecycle def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient: # Build the SSL context manually — httpx.create_ssl_context layers on # purpose/ALPN/default-CA logic that doesn't compose with private-CA # mTLS in all combinations. A bare SSLContext is predictable. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_cert_chain( str(self._identity.cert_path), str(self._identity.key_path) ) ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path)) ctx.verify_mode = ssl.CERT_REQUIRED # Pin by CA + cert chain, not by DNS — workers enroll with arbitrary # SANs (IPs, hostnames) and we don't want to force operators to keep # those in sync with whatever URL the master happens to use. ctx.check_hostname = self._verify_hostname return httpx.AsyncClient( base_url=f"https://{self._address}:{self._port}", verify=ctx, timeout=timeout, ) def _fetch_peer_fingerprint(self) -> str: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_cert_chain( str(self._identity.cert_path), str(self._identity.key_path) ) ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path)) ctx.verify_mode = ssl.CERT_REQUIRED ctx.check_hostname = self._verify_hostname sock = socket.create_connection((self._address, self._port), timeout=10.0) try: server_hostname = self._address if self._verify_hostname else None with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock: der = ssock.getpeercert(binary_form=True) finally: try: sock.close() except OSError: pass if not der: raise FingerprintMismatchError( f"{self._address}:{self._port}", self._expected_fingerprint or "", "" ) return hashlib.sha256(der).hexdigest().lower() async def _verify_pin(self) -> None: if not self._expected_fingerprint: # No pin known for this host (legacy enrollments / explicit address ctor). # Fall through to CA-only validation. Enrollment writes the fingerprint, # so any production host added via `swarm enroll` will have one. return actual = await asyncio.to_thread(self._fetch_peer_fingerprint) if actual != self._expected_fingerprint: raise FingerprintMismatchError( f"{self._address}:{self._port}", self._expected_fingerprint, actual, ) async def __aenter__(self) -> "AgentClient": self._client = self._build_client(_TIMEOUT_CONTROL) try: await self._verify_pin() except BaseException: await self._client.aclose() self._client = None raise return self async def __aexit__(self, *exc: Any) -> None: if self._client: await self._client.aclose() self._client = None def _require_client(self) -> httpx.AsyncClient: if self._client is None: raise RuntimeError("AgentClient used outside `async with` block") return self._client # ----------------------------------------------------------------- RPCs async def health(self) -> dict[str, Any]: resp = await self._require_client().get("/health") resp.raise_for_status() return resp.json() async def status(self) -> dict[str, Any]: resp = await self._require_client().get("/status") resp.raise_for_status() return resp.json() async def deploy( self, config: DecnetConfig, *, dry_run: bool = False, no_cache: bool = False, ) -> dict[str, Any]: body = { "config": config.model_dump(mode="json"), "dry_run": dry_run, "no_cache": no_cache, } # Swap in a long-deploy timeout for this call only. old = self._require_client().timeout self._require_client().timeout = _TIMEOUT_DEPLOY try: resp = await self._require_client().post("/deploy", json=body) finally: self._require_client().timeout = old resp.raise_for_status() return resp.json() async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]: resp = await self._require_client().post( "/teardown", json={"decky_id": decky_id} ) resp.raise_for_status() return resp.json() async def self_destruct(self) -> dict[str, Any]: """Trigger the worker to stop services and wipe its install.""" resp = await self._require_client().post("/self-destruct") resp.raise_for_status() return resp.json() # ------------------------------------------------------------ topology async def apply_topology( self, hydrated: dict[str, Any], version_hash: str, ) -> dict[str, Any]: """Push a hydrated topology to the agent for local materialisation. The agent independently computes ``canonical_hash(hydrated)`` and returns 400 if it disagrees with *version_hash* — that's how we catch serialisation drift before half-creating bridges. """ old = self._require_client().timeout self._require_client().timeout = _TIMEOUT_TOPOLOGY_APPLY try: resp = await self._require_client().post( "/topology/apply", json={"hydrated": hydrated, "version_hash": version_hash}, ) finally: self._require_client().timeout = old resp.raise_for_status() return resp.json() async def teardown_topology(self, topology_id: str) -> dict[str, Any]: """Ask the agent to dismantle the named topology.""" old = self._require_client().timeout self._require_client().timeout = _TIMEOUT_TOPOLOGY_TEARDOWN try: resp = await self._require_client().post( "/topology/teardown", json={"topology_id": topology_id}, ) finally: self._require_client().timeout = old resp.raise_for_status() return resp.json() async def get_topology_state(self) -> dict[str, Any]: """Snapshot of the agent's applied topology + live docker state.""" resp = await self._require_client().get("/topology/state") resp.raise_for_status() return resp.json() # -------------------------------------------------------------- diagnostics def __repr__(self) -> str: return ( f"AgentClient(name={self._host_name!r}, " f"address={self._address!r}, port={self._port})" )