feat(swarm): agent→master heartbeat with per-host cert pinning
New POST /swarm/heartbeat on the swarm controller. Workers post every
~30s with the output of executor.status(); the master bumps
SwarmHost.last_heartbeat and re-upserts each DeckyShard with a fresh
DeckyConfig snapshot and runtime-derived state (running/degraded).
Security: CA-signed mTLS alone is not sufficient — a decommissioned
worker's still-valid cert could resurrect ghost shards. The endpoint
extracts the presented peer cert (primary: scope["extensions"]["tls"],
fallback: transport.get_extra_info("ssl_object")) and SHA-256-pins it
to the SwarmHost.client_cert_fingerprint stored for the claimed
host_uuid. Extraction is factored into _extract_peer_fingerprint so
tests can exercise both uvicorn scope shapes and the both-unavailable
fail-closed path without mocking uvicorn's TLS pipeline.
Adds get_swarm_host_by_fingerprint to the repo interface (SQLModel
impl reuses the indexed client_cert_fingerprint column).
This commit is contained in:
@@ -211,6 +211,9 @@ class BaseRepository(ABC):
|
||||
async def get_swarm_host_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_swarm_host_by_fingerprint(self, fingerprint: str) -> Optional[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_swarm_hosts(self, status: Optional[str] = None) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -782,6 +782,14 @@ class SQLModelRepository(BaseRepository):
|
||||
row = result.scalar_one_or_none()
|
||||
return row.model_dump(mode="json") if row else None
|
||||
|
||||
async def get_swarm_host_by_fingerprint(self, fingerprint: str) -> Optional[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(SwarmHost).where(SwarmHost.client_cert_fingerprint == fingerprint)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row.model_dump(mode="json") if row else None
|
||||
|
||||
async def list_swarm_hosts(self, status: Optional[str] = None) -> list[dict[str, Any]]:
|
||||
statement = select(SwarmHost).order_by(asc(SwarmHost.name))
|
||||
if status:
|
||||
|
||||
@@ -15,6 +15,7 @@ from .api_deploy_swarm import router as deploy_swarm_router
|
||||
from .api_teardown_swarm import router as teardown_swarm_router
|
||||
from .api_get_swarm_health import router as get_swarm_health_router
|
||||
from .api_check_hosts import router as check_hosts_router
|
||||
from .api_heartbeat import router as heartbeat_router
|
||||
from .api_list_deckies import router as list_deckies_router
|
||||
|
||||
swarm_router = APIRouter(prefix="/swarm")
|
||||
@@ -33,3 +34,4 @@ swarm_router.include_router(list_deckies_router)
|
||||
# Health
|
||||
swarm_router.include_router(get_swarm_health_router)
|
||||
swarm_router.include_router(check_hosts_router)
|
||||
swarm_router.include_router(heartbeat_router)
|
||||
|
||||
138
decnet/web/router/swarm/api_heartbeat.py
Normal file
138
decnet/web/router/swarm/api_heartbeat.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""POST /swarm/heartbeat — agent→master liveness + decky snapshot refresh.
|
||||
|
||||
Workers call this every ~30 s with the output of ``executor.status()``.
|
||||
The master bumps ``SwarmHost.last_heartbeat`` and re-upserts each
|
||||
``DeckyShard`` with the fresh ``DeckyConfig`` snapshot + runtime-derived
|
||||
state so the dashboard stays current without a master-pull probe.
|
||||
|
||||
Security: CA-signed mTLS is necessary but not sufficient — a
|
||||
decommissioned worker's still-valid cert must not resurrect ghost
|
||||
shards. We pin the presented peer cert's SHA-256 to the
|
||||
``client_cert_fingerprint`` stored for the claimed ``host_uuid``.
|
||||
Mismatch (or decommissioned host) → 403.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from decnet.config import DeckyConfig
|
||||
from decnet.logging import get_logger
|
||||
from decnet.web.db.repository import BaseRepository
|
||||
from decnet.web.dependencies import get_repo
|
||||
|
||||
log = get_logger("swarm.heartbeat")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HeartbeatRequest(BaseModel):
|
||||
host_uuid: str
|
||||
agent_version: Optional[str] = None
|
||||
status: dict[str, Any]
|
||||
|
||||
|
||||
def _extract_peer_fingerprint(scope: dict[str, Any]) -> Optional[str]:
|
||||
"""Pull the peer cert's SHA-256 fingerprint from an ASGI scope.
|
||||
|
||||
Tries two extraction paths because uvicorn has historically stashed
|
||||
the TLS peer cert in different scope keys across versions:
|
||||
|
||||
1. Primary: ``scope["extensions"]["tls"]["client_cert_chain"][0]``
|
||||
(uvicorn ≥ 0.30 ASGI TLS extension).
|
||||
2. Fallback: the transport object's ``ssl_object.getpeercert(binary_form=True)``
|
||||
(older uvicorn builds + some other servers).
|
||||
|
||||
Returns the lowercase hex SHA-256 of the DER-encoded cert, or None
|
||||
when neither path yields bytes. The endpoint fails closed on None.
|
||||
"""
|
||||
peer_der: Optional[bytes] = None
|
||||
source = "none"
|
||||
|
||||
try:
|
||||
chain = scope.get("extensions", {}).get("tls", {}).get("client_cert_chain")
|
||||
if chain:
|
||||
peer_der = chain[0]
|
||||
source = "primary"
|
||||
except Exception:
|
||||
peer_der = None
|
||||
|
||||
if peer_der is None:
|
||||
transport = scope.get("transport")
|
||||
try:
|
||||
ssl_obj = transport.get_extra_info("ssl_object") if transport else None
|
||||
if ssl_obj is not None:
|
||||
peer_der = ssl_obj.getpeercert(binary_form=True)
|
||||
if peer_der:
|
||||
source = "fallback"
|
||||
except Exception:
|
||||
peer_der = None
|
||||
|
||||
if not peer_der:
|
||||
log.debug("heartbeat: peer cert extraction failed via none")
|
||||
return None
|
||||
|
||||
log.debug("heartbeat: peer cert extraction succeeded via %s", source)
|
||||
return hashlib.sha256(peer_der).hexdigest().lower()
|
||||
|
||||
|
||||
async def _verify_peer_matches_host(
|
||||
request: Request, host_uuid: str, repo: BaseRepository
|
||||
) -> dict[str, Any]:
|
||||
host = await repo.get_swarm_host_by_uuid(host_uuid)
|
||||
if host is None:
|
||||
raise HTTPException(status_code=404, detail="unknown host")
|
||||
fp = _extract_peer_fingerprint(request.scope)
|
||||
if fp is None:
|
||||
raise HTTPException(status_code=403, detail="peer cert unavailable")
|
||||
expected = (host.get("client_cert_fingerprint") or "").lower()
|
||||
if not expected or fp != expected:
|
||||
raise HTTPException(status_code=403, detail="cert fingerprint mismatch")
|
||||
return host
|
||||
|
||||
|
||||
@router.post("/heartbeat", status_code=204, tags=["Swarm Health"])
|
||||
async def heartbeat(
|
||||
req: HeartbeatRequest,
|
||||
request: Request,
|
||||
repo: BaseRepository = Depends(get_repo),
|
||||
) -> None:
|
||||
await _verify_peer_matches_host(request, req.host_uuid, repo)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
await repo.update_swarm_host(
|
||||
req.host_uuid,
|
||||
{"status": "active", "last_heartbeat": now},
|
||||
)
|
||||
|
||||
status_body = req.status or {}
|
||||
if not status_body.get("deployed"):
|
||||
return
|
||||
|
||||
runtime = status_body.get("runtime") or {}
|
||||
for decky_dict in status_body.get("deckies") or []:
|
||||
try:
|
||||
d = DeckyConfig(**decky_dict)
|
||||
except Exception:
|
||||
log.exception("heartbeat: skipping malformed decky payload host=%s", req.host_uuid)
|
||||
continue
|
||||
rstate = runtime.get(d.name) or {}
|
||||
is_up = bool(rstate.get("running"))
|
||||
await repo.upsert_decky_shard(
|
||||
{
|
||||
"decky_name": d.name,
|
||||
"host_uuid": req.host_uuid,
|
||||
"services": json.dumps(d.services),
|
||||
"decky_config": d.model_dump_json(),
|
||||
"decky_ip": d.ip,
|
||||
"state": "running" if is_up else "degraded",
|
||||
"last_error": None,
|
||||
"last_seen": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
300
tests/swarm/test_heartbeat.py
Normal file
300
tests/swarm/test_heartbeat.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Tests for POST /swarm/heartbeat — cert pinning + shard snapshot refresh."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pathlib
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from decnet.web.db.factory import get_repository
|
||||
from decnet.web.dependencies import get_repo
|
||||
from decnet.web.router.swarm import api_heartbeat as hb_mod
|
||||
|
||||
|
||||
# ------------------------- shared fixtures (mirror test_swarm_api.py) ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ca_dir(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> pathlib.Path:
|
||||
ca = tmp_path / "ca"
|
||||
from decnet.swarm import pki
|
||||
from decnet.swarm import client as swarm_client
|
||||
from decnet.web.router.swarm import api_enroll_host as enroll_mod
|
||||
|
||||
monkeypatch.setattr(pki, "DEFAULT_CA_DIR", ca)
|
||||
monkeypatch.setattr(swarm_client, "pki", pki)
|
||||
monkeypatch.setattr(enroll_mod, "pki", pki)
|
||||
return ca
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
|
||||
r = get_repository(db_path=str(tmp_path / "hb.db"))
|
||||
import decnet.web.dependencies as deps
|
||||
import decnet.web.swarm_api as swarm_api_mod
|
||||
|
||||
monkeypatch.setattr(deps, "repo", r)
|
||||
monkeypatch.setattr(swarm_api_mod, "repo", r)
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(repo, ca_dir: pathlib.Path):
|
||||
from decnet.web.swarm_api import app
|
||||
|
||||
async def _override() -> Any:
|
||||
return repo
|
||||
|
||||
app.dependency_overrides[get_repo] = _override
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _enroll(client: TestClient, name: str, address: str = "10.0.0.5") -> dict:
|
||||
resp = client.post(
|
||||
"/swarm/enroll",
|
||||
json={"name": name, "address": address, "agent_port": 8765},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _pin_fingerprint(monkeypatch: pytest.MonkeyPatch, fp: str | None) -> None:
|
||||
"""Force ``_extract_peer_fingerprint`` to return ``fp`` inside the
|
||||
endpoint module so we don't need a live TLS peer."""
|
||||
monkeypatch.setattr(hb_mod, "_extract_peer_fingerprint", lambda scope: fp)
|
||||
|
||||
|
||||
def _status_body(deckies: list[dict], runtime: dict[str, dict]) -> dict:
|
||||
return {
|
||||
"deployed": True,
|
||||
"mode": "swarm",
|
||||
"compose_path": "/run/decnet/compose.yml",
|
||||
"deckies": deckies,
|
||||
"runtime": runtime,
|
||||
}
|
||||
|
||||
|
||||
def _decky_payload(name: str = "decky-01", ip: str = "10.0.0.50") -> dict:
|
||||
return {
|
||||
"name": name,
|
||||
"hostname": f"{name}.lan",
|
||||
"distro": "debian-bookworm",
|
||||
"ip": ip,
|
||||
"services": ["ssh"],
|
||||
"base_image": "debian:bookworm-slim",
|
||||
"service_config": {"ssh": {"port": 22}},
|
||||
"mutate_interval": 3600,
|
||||
"last_mutated": 0.0,
|
||||
"archetype": "generic",
|
||||
"host_uuid": None,
|
||||
}
|
||||
|
||||
|
||||
# ------------------------- _extract_peer_fingerprint unit tests ---------
|
||||
|
||||
|
||||
def test_extract_primary_path_returns_fingerprint() -> None:
|
||||
der = b"\x30\x82test-cert-bytes"
|
||||
scope = {"extensions": {"tls": {"client_cert_chain": [der]}}}
|
||||
assert hb_mod._extract_peer_fingerprint(scope) == hashlib.sha256(der).hexdigest()
|
||||
|
||||
|
||||
def test_extract_fallback_path_when_primary_absent() -> None:
|
||||
der = b"\x30\x82fallback-bytes"
|
||||
ssl_obj = MagicMock()
|
||||
ssl_obj.getpeercert.return_value = der
|
||||
transport = MagicMock()
|
||||
transport.get_extra_info.return_value = ssl_obj
|
||||
scope = {"transport": transport}
|
||||
|
||||
fp = hb_mod._extract_peer_fingerprint(scope)
|
||||
assert fp == hashlib.sha256(der).hexdigest()
|
||||
transport.get_extra_info.assert_called_with("ssl_object")
|
||||
ssl_obj.getpeercert.assert_called_with(binary_form=True)
|
||||
|
||||
|
||||
def test_extract_returns_none_when_both_paths_empty() -> None:
|
||||
# No extensions, no transport → fail-closed signal for the endpoint.
|
||||
assert hb_mod._extract_peer_fingerprint({}) is None
|
||||
|
||||
|
||||
def test_extract_returns_none_when_transport_ssl_object_missing() -> None:
|
||||
transport = MagicMock()
|
||||
transport.get_extra_info.return_value = None
|
||||
scope = {"transport": transport}
|
||||
assert hb_mod._extract_peer_fingerprint(scope) is None
|
||||
|
||||
|
||||
# ------------------------- endpoint behaviour --------------------------
|
||||
|
||||
|
||||
def test_heartbeat_happy_path_primary_extraction(
|
||||
client: TestClient, repo, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
host = _enroll(client, "worker-a")
|
||||
_pin_fingerprint(monkeypatch, host["fingerprint"])
|
||||
|
||||
body = {
|
||||
"host_uuid": host["host_uuid"],
|
||||
"agent_version": "1.2.3",
|
||||
"status": _status_body(
|
||||
[_decky_payload("decky-01")],
|
||||
{"decky-01": {"running": True}},
|
||||
),
|
||||
}
|
||||
resp = client.post("/swarm/heartbeat", json=body)
|
||||
assert resp.status_code == 204, resp.text
|
||||
|
||||
async def _verify() -> None:
|
||||
row = await repo.get_swarm_host_by_uuid(host["host_uuid"])
|
||||
assert row["last_heartbeat"] is not None
|
||||
assert row["status"] == "active"
|
||||
shards = await repo.list_decky_shards(host["host_uuid"])
|
||||
assert len(shards) == 1
|
||||
s = shards[0]
|
||||
assert s["decky_name"] == "decky-01"
|
||||
assert s["decky_ip"] == "10.0.0.50"
|
||||
assert s["state"] == "running"
|
||||
assert s["last_seen"] is not None
|
||||
# snapshot flattening from list_decky_shards
|
||||
assert s["hostname"] == "decky-01.lan"
|
||||
assert s["archetype"] == "generic"
|
||||
assert s["service_config"] == {"ssh": {"port": 22}}
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_verify())
|
||||
|
||||
|
||||
def test_heartbeat_fallback_extraction_path_also_accepted(
|
||||
client: TestClient, repo, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# Same endpoint behaviour regardless of which scope path supplied
|
||||
# the fingerprint — this guards against uvicorn-version drift where
|
||||
# only the fallback slot is populated.
|
||||
host = _enroll(client, "worker-b", "10.0.0.6")
|
||||
_pin_fingerprint(monkeypatch, host["fingerprint"])
|
||||
|
||||
resp = client.post(
|
||||
"/swarm/heartbeat",
|
||||
json={
|
||||
"host_uuid": host["host_uuid"],
|
||||
"status": {"deployed": False, "deckies": []},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
|
||||
def test_heartbeat_unknown_host_returns_404(
|
||||
client: TestClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
_pin_fingerprint(monkeypatch, "a" * 64)
|
||||
resp = client.post(
|
||||
"/swarm/heartbeat",
|
||||
json={"host_uuid": "does-not-exist", "status": {"deployed": False}},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_heartbeat_fingerprint_mismatch_returns_403(
|
||||
client: TestClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
host = _enroll(client, "worker-c", "10.0.0.7")
|
||||
_pin_fingerprint(monkeypatch, "b" * 64) # not the host's fingerprint
|
||||
resp = client.post(
|
||||
"/swarm/heartbeat",
|
||||
json={"host_uuid": host["host_uuid"], "status": {"deployed": False}},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_heartbeat_no_peer_cert_fails_closed(
|
||||
client: TestClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# Both extraction paths unavailable → 403, never 200. Fail-closed.
|
||||
host = _enroll(client, "worker-d", "10.0.0.8")
|
||||
_pin_fingerprint(monkeypatch, None)
|
||||
resp = client.post(
|
||||
"/swarm/heartbeat",
|
||||
json={"host_uuid": host["host_uuid"], "status": {"deployed": False}},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "unavailable" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_heartbeat_decommissioned_host_returns_404(
|
||||
client: TestClient, repo, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
# Enrol, capture the fingerprint, delete the host, then replay the
|
||||
# heartbeat. Even though the cert is still CA-signed, the decommission
|
||||
# revoked the host-row so lookup returns None → 404. Prevents ghost
|
||||
# shards from a decommissioned worker.
|
||||
host = _enroll(client, "worker-e", "10.0.0.9")
|
||||
fp = host["fingerprint"]
|
||||
|
||||
async def _delete() -> None:
|
||||
ok = await repo.delete_swarm_host(host["host_uuid"])
|
||||
assert ok
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_delete())
|
||||
|
||||
_pin_fingerprint(monkeypatch, fp)
|
||||
resp = client.post(
|
||||
"/swarm/heartbeat",
|
||||
json={"host_uuid": host["host_uuid"], "status": {"deployed": False}},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_heartbeat_deployed_false_bumps_host_but_writes_no_shards(
|
||||
client: TestClient, repo, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
host = _enroll(client, "worker-f", "10.0.0.10")
|
||||
_pin_fingerprint(monkeypatch, host["fingerprint"])
|
||||
|
||||
resp = client.post(
|
||||
"/swarm/heartbeat",
|
||||
json={
|
||||
"host_uuid": host["host_uuid"],
|
||||
"status": {"deployed": False, "deckies": []},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def _verify() -> None:
|
||||
row = await repo.get_swarm_host_by_uuid(host["host_uuid"])
|
||||
assert row["last_heartbeat"] is not None
|
||||
shards = await repo.list_decky_shards(host["host_uuid"])
|
||||
assert shards == []
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_verify())
|
||||
|
||||
|
||||
def test_heartbeat_decky_missing_from_runtime_is_degraded(
|
||||
client: TestClient, repo, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
host = _enroll(client, "worker-g", "10.0.0.11")
|
||||
_pin_fingerprint(monkeypatch, host["fingerprint"])
|
||||
|
||||
body = {
|
||||
"host_uuid": host["host_uuid"],
|
||||
"status": _status_body(
|
||||
[_decky_payload("decky-01"), _decky_payload("decky-02", "10.0.0.51")],
|
||||
{"decky-01": {"running": True}}, # decky-02 absent
|
||||
),
|
||||
}
|
||||
resp = client.post("/swarm/heartbeat", json=body)
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def _verify() -> None:
|
||||
shards = await repo.list_decky_shards(host["host_uuid"])
|
||||
by = {s["decky_name"]: s for s in shards}
|
||||
assert by["decky-01"]["state"] == "running"
|
||||
assert by["decky-02"]["state"] == "degraded"
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_verify())
|
||||
Reference in New Issue
Block a user