From 148e51011c571976037fb60cb53ca23df1c8f2f2 Mon Sep 17 00:00:00 2001 From: anti Date: Sun, 19 Apr 2026 21:37:15 -0400 Subject: [PATCH] =?UTF-8?q?feat(swarm):=20agent=E2=86=92master=20heartbeat?= =?UTF-8?q?=20with=20per-host=20cert=20pinning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New POST /swarm/heartbeat on the swarm controller. Workers post every ~30s with the output of executor.status(); the master bumps SwarmHost.last_heartbeat and re-upserts each DeckyShard with a fresh DeckyConfig snapshot and runtime-derived state (running/degraded). Security: CA-signed mTLS alone is not sufficient — a decommissioned worker's still-valid cert could resurrect ghost shards. The endpoint extracts the presented peer cert (primary: scope["extensions"]["tls"], fallback: transport.get_extra_info("ssl_object")) and SHA-256-pins it to the SwarmHost.client_cert_fingerprint stored for the claimed host_uuid. Extraction is factored into _extract_peer_fingerprint so tests can exercise both uvicorn scope shapes and the both-unavailable fail-closed path without mocking uvicorn's TLS pipeline. Adds get_swarm_host_by_fingerprint to the repo interface (SQLModel impl reuses the indexed client_cert_fingerprint column). --- decnet/web/db/repository.py | 3 + decnet/web/db/sqlmodel_repo.py | 8 + decnet/web/router/swarm/__init__.py | 2 + decnet/web/router/swarm/api_heartbeat.py | 138 +++++++++++ tests/swarm/test_heartbeat.py | 300 +++++++++++++++++++++++ 5 files changed, 451 insertions(+) create mode 100644 decnet/web/router/swarm/api_heartbeat.py create mode 100644 tests/swarm/test_heartbeat.py diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index f2fdd7a..d0513d4 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -211,6 +211,9 @@ class BaseRepository(ABC): async def get_swarm_host_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: raise NotImplementedError + async def get_swarm_host_by_fingerprint(self, fingerprint: str) -> Optional[dict[str, Any]]: + raise NotImplementedError + async def list_swarm_hosts(self, status: Optional[str] = None) -> list[dict[str, Any]]: raise NotImplementedError diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py index 0bc5183..b5f40f4 100644 --- a/decnet/web/db/sqlmodel_repo.py +++ b/decnet/web/db/sqlmodel_repo.py @@ -782,6 +782,14 @@ class SQLModelRepository(BaseRepository): row = result.scalar_one_or_none() return row.model_dump(mode="json") if row else None + async def get_swarm_host_by_fingerprint(self, fingerprint: str) -> Optional[dict[str, Any]]: + async with self._session() as session: + result = await session.execute( + select(SwarmHost).where(SwarmHost.client_cert_fingerprint == fingerprint) + ) + row = result.scalar_one_or_none() + return row.model_dump(mode="json") if row else None + async def list_swarm_hosts(self, status: Optional[str] = None) -> list[dict[str, Any]]: statement = select(SwarmHost).order_by(asc(SwarmHost.name)) if status: diff --git a/decnet/web/router/swarm/__init__.py b/decnet/web/router/swarm/__init__.py index 2bd3193..a19a52f 100644 --- a/decnet/web/router/swarm/__init__.py +++ b/decnet/web/router/swarm/__init__.py @@ -15,6 +15,7 @@ from .api_deploy_swarm import router as deploy_swarm_router from .api_teardown_swarm import router as teardown_swarm_router from .api_get_swarm_health import router as get_swarm_health_router from .api_check_hosts import router as check_hosts_router +from .api_heartbeat import router as heartbeat_router from .api_list_deckies import router as list_deckies_router swarm_router = APIRouter(prefix="/swarm") @@ -33,3 +34,4 @@ swarm_router.include_router(list_deckies_router) # Health swarm_router.include_router(get_swarm_health_router) swarm_router.include_router(check_hosts_router) +swarm_router.include_router(heartbeat_router) diff --git a/decnet/web/router/swarm/api_heartbeat.py b/decnet/web/router/swarm/api_heartbeat.py new file mode 100644 index 0000000..f49e580 --- /dev/null +++ b/decnet/web/router/swarm/api_heartbeat.py @@ -0,0 +1,138 @@ +"""POST /swarm/heartbeat — agent→master liveness + decky snapshot refresh. + +Workers call this every ~30 s with the output of ``executor.status()``. +The master bumps ``SwarmHost.last_heartbeat`` and re-upserts each +``DeckyShard`` with the fresh ``DeckyConfig`` snapshot + runtime-derived +state so the dashboard stays current without a master-pull probe. + +Security: CA-signed mTLS is necessary but not sufficient — a +decommissioned worker's still-valid cert must not resurrect ghost +shards. We pin the presented peer cert's SHA-256 to the +``client_cert_fingerprint`` stored for the claimed ``host_uuid``. +Mismatch (or decommissioned host) → 403. +""" +from __future__ import annotations + +import hashlib +import json +from datetime import datetime, timezone +from typing import Any, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel + +from decnet.config import DeckyConfig +from decnet.logging import get_logger +from decnet.web.db.repository import BaseRepository +from decnet.web.dependencies import get_repo + +log = get_logger("swarm.heartbeat") + +router = APIRouter() + + +class HeartbeatRequest(BaseModel): + host_uuid: str + agent_version: Optional[str] = None + status: dict[str, Any] + + +def _extract_peer_fingerprint(scope: dict[str, Any]) -> Optional[str]: + """Pull the peer cert's SHA-256 fingerprint from an ASGI scope. + + Tries two extraction paths because uvicorn has historically stashed + the TLS peer cert in different scope keys across versions: + + 1. Primary: ``scope["extensions"]["tls"]["client_cert_chain"][0]`` + (uvicorn ≥ 0.30 ASGI TLS extension). + 2. Fallback: the transport object's ``ssl_object.getpeercert(binary_form=True)`` + (older uvicorn builds + some other servers). + + Returns the lowercase hex SHA-256 of the DER-encoded cert, or None + when neither path yields bytes. The endpoint fails closed on None. + """ + peer_der: Optional[bytes] = None + source = "none" + + try: + chain = scope.get("extensions", {}).get("tls", {}).get("client_cert_chain") + if chain: + peer_der = chain[0] + source = "primary" + except Exception: + peer_der = None + + if peer_der is None: + transport = scope.get("transport") + try: + ssl_obj = transport.get_extra_info("ssl_object") if transport else None + if ssl_obj is not None: + peer_der = ssl_obj.getpeercert(binary_form=True) + if peer_der: + source = "fallback" + except Exception: + peer_der = None + + if not peer_der: + log.debug("heartbeat: peer cert extraction failed via none") + return None + + log.debug("heartbeat: peer cert extraction succeeded via %s", source) + return hashlib.sha256(peer_der).hexdigest().lower() + + +async def _verify_peer_matches_host( + request: Request, host_uuid: str, repo: BaseRepository +) -> dict[str, Any]: + host = await repo.get_swarm_host_by_uuid(host_uuid) + if host is None: + raise HTTPException(status_code=404, detail="unknown host") + fp = _extract_peer_fingerprint(request.scope) + if fp is None: + raise HTTPException(status_code=403, detail="peer cert unavailable") + expected = (host.get("client_cert_fingerprint") or "").lower() + if not expected or fp != expected: + raise HTTPException(status_code=403, detail="cert fingerprint mismatch") + return host + + +@router.post("/heartbeat", status_code=204, tags=["Swarm Health"]) +async def heartbeat( + req: HeartbeatRequest, + request: Request, + repo: BaseRepository = Depends(get_repo), +) -> None: + await _verify_peer_matches_host(request, req.host_uuid, repo) + + now = datetime.now(timezone.utc) + await repo.update_swarm_host( + req.host_uuid, + {"status": "active", "last_heartbeat": now}, + ) + + status_body = req.status or {} + if not status_body.get("deployed"): + return + + runtime = status_body.get("runtime") or {} + for decky_dict in status_body.get("deckies") or []: + try: + d = DeckyConfig(**decky_dict) + except Exception: + log.exception("heartbeat: skipping malformed decky payload host=%s", req.host_uuid) + continue + rstate = runtime.get(d.name) or {} + is_up = bool(rstate.get("running")) + await repo.upsert_decky_shard( + { + "decky_name": d.name, + "host_uuid": req.host_uuid, + "services": json.dumps(d.services), + "decky_config": d.model_dump_json(), + "decky_ip": d.ip, + "state": "running" if is_up else "degraded", + "last_error": None, + "last_seen": now, + "updated_at": now, + } + ) diff --git a/tests/swarm/test_heartbeat.py b/tests/swarm/test_heartbeat.py new file mode 100644 index 0000000..445987a --- /dev/null +++ b/tests/swarm/test_heartbeat.py @@ -0,0 +1,300 @@ +"""Tests for POST /swarm/heartbeat — cert pinning + shard snapshot refresh.""" +from __future__ import annotations + +import asyncio +import hashlib +import pathlib +from typing import Any +from unittest.mock import MagicMock + +import pytest +from fastapi.testclient import TestClient + +from decnet.web.db.factory import get_repository +from decnet.web.dependencies import get_repo +from decnet.web.router.swarm import api_heartbeat as hb_mod + + +# ------------------------- shared fixtures (mirror test_swarm_api.py) --- + + +@pytest.fixture +def ca_dir(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> pathlib.Path: + ca = tmp_path / "ca" + from decnet.swarm import pki + from decnet.swarm import client as swarm_client + from decnet.web.router.swarm import api_enroll_host as enroll_mod + + monkeypatch.setattr(pki, "DEFAULT_CA_DIR", ca) + monkeypatch.setattr(swarm_client, "pki", pki) + monkeypatch.setattr(enroll_mod, "pki", pki) + return ca + + +@pytest.fixture +def repo(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): + r = get_repository(db_path=str(tmp_path / "hb.db")) + import decnet.web.dependencies as deps + import decnet.web.swarm_api as swarm_api_mod + + monkeypatch.setattr(deps, "repo", r) + monkeypatch.setattr(swarm_api_mod, "repo", r) + return r + + +@pytest.fixture +def client(repo, ca_dir: pathlib.Path): + from decnet.web.swarm_api import app + + async def _override() -> Any: + return repo + + app.dependency_overrides[get_repo] = _override + with TestClient(app) as c: + yield c + app.dependency_overrides.clear() + + +def _enroll(client: TestClient, name: str, address: str = "10.0.0.5") -> dict: + resp = client.post( + "/swarm/enroll", + json={"name": name, "address": address, "agent_port": 8765}, + ) + assert resp.status_code == 201, resp.text + return resp.json() + + +def _pin_fingerprint(monkeypatch: pytest.MonkeyPatch, fp: str | None) -> None: + """Force ``_extract_peer_fingerprint`` to return ``fp`` inside the + endpoint module so we don't need a live TLS peer.""" + monkeypatch.setattr(hb_mod, "_extract_peer_fingerprint", lambda scope: fp) + + +def _status_body(deckies: list[dict], runtime: dict[str, dict]) -> dict: + return { + "deployed": True, + "mode": "swarm", + "compose_path": "/run/decnet/compose.yml", + "deckies": deckies, + "runtime": runtime, + } + + +def _decky_payload(name: str = "decky-01", ip: str = "10.0.0.50") -> dict: + return { + "name": name, + "hostname": f"{name}.lan", + "distro": "debian-bookworm", + "ip": ip, + "services": ["ssh"], + "base_image": "debian:bookworm-slim", + "service_config": {"ssh": {"port": 22}}, + "mutate_interval": 3600, + "last_mutated": 0.0, + "archetype": "generic", + "host_uuid": None, + } + + +# ------------------------- _extract_peer_fingerprint unit tests --------- + + +def test_extract_primary_path_returns_fingerprint() -> None: + der = b"\x30\x82test-cert-bytes" + scope = {"extensions": {"tls": {"client_cert_chain": [der]}}} + assert hb_mod._extract_peer_fingerprint(scope) == hashlib.sha256(der).hexdigest() + + +def test_extract_fallback_path_when_primary_absent() -> None: + der = b"\x30\x82fallback-bytes" + ssl_obj = MagicMock() + ssl_obj.getpeercert.return_value = der + transport = MagicMock() + transport.get_extra_info.return_value = ssl_obj + scope = {"transport": transport} + + fp = hb_mod._extract_peer_fingerprint(scope) + assert fp == hashlib.sha256(der).hexdigest() + transport.get_extra_info.assert_called_with("ssl_object") + ssl_obj.getpeercert.assert_called_with(binary_form=True) + + +def test_extract_returns_none_when_both_paths_empty() -> None: + # No extensions, no transport → fail-closed signal for the endpoint. + assert hb_mod._extract_peer_fingerprint({}) is None + + +def test_extract_returns_none_when_transport_ssl_object_missing() -> None: + transport = MagicMock() + transport.get_extra_info.return_value = None + scope = {"transport": transport} + assert hb_mod._extract_peer_fingerprint(scope) is None + + +# ------------------------- endpoint behaviour -------------------------- + + +def test_heartbeat_happy_path_primary_extraction( + client: TestClient, repo, monkeypatch: pytest.MonkeyPatch +) -> None: + host = _enroll(client, "worker-a") + _pin_fingerprint(monkeypatch, host["fingerprint"]) + + body = { + "host_uuid": host["host_uuid"], + "agent_version": "1.2.3", + "status": _status_body( + [_decky_payload("decky-01")], + {"decky-01": {"running": True}}, + ), + } + resp = client.post("/swarm/heartbeat", json=body) + assert resp.status_code == 204, resp.text + + async def _verify() -> None: + row = await repo.get_swarm_host_by_uuid(host["host_uuid"]) + assert row["last_heartbeat"] is not None + assert row["status"] == "active" + shards = await repo.list_decky_shards(host["host_uuid"]) + assert len(shards) == 1 + s = shards[0] + assert s["decky_name"] == "decky-01" + assert s["decky_ip"] == "10.0.0.50" + assert s["state"] == "running" + assert s["last_seen"] is not None + # snapshot flattening from list_decky_shards + assert s["hostname"] == "decky-01.lan" + assert s["archetype"] == "generic" + assert s["service_config"] == {"ssh": {"port": 22}} + + asyncio.get_event_loop().run_until_complete(_verify()) + + +def test_heartbeat_fallback_extraction_path_also_accepted( + client: TestClient, repo, monkeypatch: pytest.MonkeyPatch +) -> None: + # Same endpoint behaviour regardless of which scope path supplied + # the fingerprint — this guards against uvicorn-version drift where + # only the fallback slot is populated. + host = _enroll(client, "worker-b", "10.0.0.6") + _pin_fingerprint(monkeypatch, host["fingerprint"]) + + resp = client.post( + "/swarm/heartbeat", + json={ + "host_uuid": host["host_uuid"], + "status": {"deployed": False, "deckies": []}, + }, + ) + assert resp.status_code == 204 + + +def test_heartbeat_unknown_host_returns_404( + client: TestClient, monkeypatch: pytest.MonkeyPatch +) -> None: + _pin_fingerprint(monkeypatch, "a" * 64) + resp = client.post( + "/swarm/heartbeat", + json={"host_uuid": "does-not-exist", "status": {"deployed": False}}, + ) + assert resp.status_code == 404 + + +def test_heartbeat_fingerprint_mismatch_returns_403( + client: TestClient, monkeypatch: pytest.MonkeyPatch +) -> None: + host = _enroll(client, "worker-c", "10.0.0.7") + _pin_fingerprint(monkeypatch, "b" * 64) # not the host's fingerprint + resp = client.post( + "/swarm/heartbeat", + json={"host_uuid": host["host_uuid"], "status": {"deployed": False}}, + ) + assert resp.status_code == 403 + assert "mismatch" in resp.json()["detail"] + + +def test_heartbeat_no_peer_cert_fails_closed( + client: TestClient, monkeypatch: pytest.MonkeyPatch +) -> None: + # Both extraction paths unavailable → 403, never 200. Fail-closed. + host = _enroll(client, "worker-d", "10.0.0.8") + _pin_fingerprint(monkeypatch, None) + resp = client.post( + "/swarm/heartbeat", + json={"host_uuid": host["host_uuid"], "status": {"deployed": False}}, + ) + assert resp.status_code == 403 + assert "unavailable" in resp.json()["detail"] + + +def test_heartbeat_decommissioned_host_returns_404( + client: TestClient, repo, monkeypatch: pytest.MonkeyPatch +) -> None: + # Enrol, capture the fingerprint, delete the host, then replay the + # heartbeat. Even though the cert is still CA-signed, the decommission + # revoked the host-row so lookup returns None → 404. Prevents ghost + # shards from a decommissioned worker. + host = _enroll(client, "worker-e", "10.0.0.9") + fp = host["fingerprint"] + + async def _delete() -> None: + ok = await repo.delete_swarm_host(host["host_uuid"]) + assert ok + + asyncio.get_event_loop().run_until_complete(_delete()) + + _pin_fingerprint(monkeypatch, fp) + resp = client.post( + "/swarm/heartbeat", + json={"host_uuid": host["host_uuid"], "status": {"deployed": False}}, + ) + assert resp.status_code == 404 + + +def test_heartbeat_deployed_false_bumps_host_but_writes_no_shards( + client: TestClient, repo, monkeypatch: pytest.MonkeyPatch +) -> None: + host = _enroll(client, "worker-f", "10.0.0.10") + _pin_fingerprint(monkeypatch, host["fingerprint"]) + + resp = client.post( + "/swarm/heartbeat", + json={ + "host_uuid": host["host_uuid"], + "status": {"deployed": False, "deckies": []}, + }, + ) + assert resp.status_code == 204 + + async def _verify() -> None: + row = await repo.get_swarm_host_by_uuid(host["host_uuid"]) + assert row["last_heartbeat"] is not None + shards = await repo.list_decky_shards(host["host_uuid"]) + assert shards == [] + + asyncio.get_event_loop().run_until_complete(_verify()) + + +def test_heartbeat_decky_missing_from_runtime_is_degraded( + client: TestClient, repo, monkeypatch: pytest.MonkeyPatch +) -> None: + host = _enroll(client, "worker-g", "10.0.0.11") + _pin_fingerprint(monkeypatch, host["fingerprint"]) + + body = { + "host_uuid": host["host_uuid"], + "status": _status_body( + [_decky_payload("decky-01"), _decky_payload("decky-02", "10.0.0.51")], + {"decky-01": {"running": True}}, # decky-02 absent + ), + } + resp = client.post("/swarm/heartbeat", json=body) + assert resp.status_code == 204 + + async def _verify() -> None: + shards = await repo.list_decky_shards(host["host_uuid"]) + by = {s["decky_name"]: s for s in shards} + assert by["decky-01"]["state"] == "running" + assert by["decky-02"]["state"] == "degraded" + + asyncio.get_event_loop().run_until_complete(_verify())