merge: testing → main (reconcile 2-week divergence)
This commit is contained in:
7
decnet/swarm/__init__.py
Normal file
7
decnet/swarm/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""DECNET SWARM — multihost deployment subsystem.
|
||||
|
||||
Components:
|
||||
* ``pki`` — X.509 CA + CSR signing used by all swarm mTLS channels
|
||||
* ``client`` — master-side HTTP client that talks to remote workers
|
||||
* ``log_forwarder``— worker-side syslog-over-TLS (RFC 5425) forwarder
|
||||
"""
|
||||
323
decnet/swarm/client.py
Normal file
323
decnet/swarm/client.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Master-side HTTP client that talks to a worker's DECNET agent.
|
||||
|
||||
All traffic is mTLS: the master presents a cert issued by its own CA (which
|
||||
workers trust) and the master validates the worker's cert against the same
|
||||
CA. In practice the "client cert" the master shows is just another cert
|
||||
signed by itself — the master is both the CA and the sole control-plane
|
||||
client.
|
||||
|
||||
Usage:
|
||||
|
||||
async with AgentClient(host) as agent:
|
||||
await agent.deploy(config)
|
||||
status = await agent.status()
|
||||
|
||||
The ``host`` is a SwarmHost dict returned by the repository.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pathlib
|
||||
import socket
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from decnet.config import DecnetConfig
|
||||
from decnet.logging import get_logger
|
||||
from decnet.swarm import pki
|
||||
|
||||
log = get_logger("swarm.client")
|
||||
|
||||
|
||||
class FingerprintMismatchError(RuntimeError):
|
||||
"""Raised when the worker presents a cert whose SHA-256 fingerprint
|
||||
does not match ``SwarmHost.client_cert_fingerprint``.
|
||||
|
||||
Existence of this error class is the contract that lets the deployer
|
||||
distinguish "wrong worker on the wire" (security event, fail loud)
|
||||
from generic transport errors (retryable, mark slice failed)."""
|
||||
|
||||
def __init__(self, host: str, expected: str, actual: str) -> None:
|
||||
super().__init__(
|
||||
f"agent {host}: cert fingerprint mismatch "
|
||||
f"(expected={expected[:16]}…, got={actual[:16]}…)"
|
||||
)
|
||||
self.host = host
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
|
||||
# How long a single HTTP operation can take. Deploy is the long pole —
|
||||
# docker compose up pulls images, builds contexts, etc. Tune via env in a
|
||||
# later iteration if the default proves too short.
|
||||
_TIMEOUT_DEPLOY = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=5.0)
|
||||
_TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=15.0, write=5.0, pool=5.0)
|
||||
# Topology apply pulls images + runs compose on the agent — same ball-park
|
||||
# as a fleet deploy. Teardown is faster but still long enough we can't
|
||||
# reuse the control timeout.
|
||||
_TIMEOUT_TOPOLOGY_APPLY = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=5.0)
|
||||
_TIMEOUT_TOPOLOGY_TEARDOWN = httpx.Timeout(connect=10.0, read=300.0, write=30.0, pool=5.0)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MasterIdentity:
|
||||
"""Paths to the master's own mTLS client bundle.
|
||||
|
||||
The master uses ONE master-client cert to talk to every worker. It is
|
||||
signed by the DECNET CA (same CA that signs worker certs). Stored
|
||||
under ``~/.decnet/ca/master/`` by ``ensure_master_identity``.
|
||||
"""
|
||||
key_path: pathlib.Path
|
||||
cert_path: pathlib.Path
|
||||
ca_cert_path: pathlib.Path
|
||||
|
||||
|
||||
def ensure_master_identity(
|
||||
ca_dir: pathlib.Path = pki.DEFAULT_CA_DIR,
|
||||
) -> MasterIdentity:
|
||||
"""Create (or load) the master's own client cert.
|
||||
|
||||
Called once by the swarm controller on startup and by the CLI before
|
||||
any master→worker call. Idempotent.
|
||||
"""
|
||||
ca = pki.ensure_ca(ca_dir)
|
||||
master_dir = ca_dir / "master"
|
||||
bundle = pki.load_worker_bundle(master_dir)
|
||||
if bundle is None:
|
||||
issued = pki.issue_worker_cert(ca, "decnet-master", ["127.0.0.1", "decnet-master"])
|
||||
pki.write_worker_bundle(issued, master_dir)
|
||||
return MasterIdentity(
|
||||
key_path=master_dir / "worker.key",
|
||||
cert_path=master_dir / "worker.crt",
|
||||
ca_cert_path=master_dir / "ca.crt",
|
||||
)
|
||||
|
||||
|
||||
class AgentClient:
|
||||
"""Thin async wrapper around the worker agent's HTTP API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: dict[str, Any] | None = None,
|
||||
*,
|
||||
address: Optional[str] = None,
|
||||
agent_port: Optional[int] = None,
|
||||
identity: Optional[MasterIdentity] = None,
|
||||
verify_hostname: Optional[bool] = None,
|
||||
):
|
||||
"""Either pass a SwarmHost dict, or explicit address/port.
|
||||
|
||||
``verify_hostname`` defers to ``DECNET_VERIFY_HOSTNAME`` when the
|
||||
caller doesn't pass an explicit value — production deploys flip
|
||||
the env var on so the worker's cert SAN must match the address
|
||||
the master connects to, on top of the existing CA + fingerprint
|
||||
pin. Defaults to False so dev/test enrollments with mismatched
|
||||
SANs keep working unchanged.
|
||||
"""
|
||||
if verify_hostname is None:
|
||||
from decnet.env import DECNET_VERIFY_HOSTNAME
|
||||
verify_hostname = DECNET_VERIFY_HOSTNAME
|
||||
if host is not None:
|
||||
self._address = host["address"]
|
||||
self._port = int(host.get("agent_port") or 8765)
|
||||
self._host_uuid = host.get("uuid")
|
||||
self._host_name = host.get("name")
|
||||
fp = host.get("client_cert_fingerprint")
|
||||
self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None
|
||||
else:
|
||||
if address is None or agent_port is None:
|
||||
raise ValueError(
|
||||
"AgentClient requires either a host dict or address+agent_port"
|
||||
)
|
||||
self._address = address
|
||||
self._port = int(agent_port)
|
||||
self._host_uuid = None
|
||||
self._host_name = None
|
||||
self._expected_fingerprint = None
|
||||
|
||||
self._identity = identity or ensure_master_identity()
|
||||
self._verify_hostname = verify_hostname
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
# --------------------------------------------------------------- lifecycle
|
||||
|
||||
def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient:
|
||||
# Build the SSL context manually — httpx.create_ssl_context layers on
|
||||
# purpose/ALPN/default-CA logic that doesn't compose with private-CA
|
||||
# mTLS in all combinations. A bare SSLContext is predictable.
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(
|
||||
str(self._identity.cert_path), str(self._identity.key_path)
|
||||
)
|
||||
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
# Pin by CA + cert chain, not by DNS — workers enroll with arbitrary
|
||||
# SANs (IPs, hostnames) and we don't want to force operators to keep
|
||||
# those in sync with whatever URL the master happens to use.
|
||||
ctx.check_hostname = self._verify_hostname
|
||||
return httpx.AsyncClient(
|
||||
base_url=f"https://{self._address}:{self._port}",
|
||||
verify=ctx,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def _fetch_peer_fingerprint(self) -> str:
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(
|
||||
str(self._identity.cert_path), str(self._identity.key_path)
|
||||
)
|
||||
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.check_hostname = self._verify_hostname
|
||||
sock = socket.create_connection((self._address, self._port), timeout=10.0)
|
||||
try:
|
||||
server_hostname = self._address if self._verify_hostname else None
|
||||
with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock:
|
||||
der = ssock.getpeercert(binary_form=True)
|
||||
finally:
|
||||
try:
|
||||
sock.close()
|
||||
except OSError:
|
||||
pass
|
||||
if not der:
|
||||
raise FingerprintMismatchError(
|
||||
f"{self._address}:{self._port}", self._expected_fingerprint or "", ""
|
||||
)
|
||||
return hashlib.sha256(der).hexdigest().lower()
|
||||
|
||||
async def _verify_pin(self) -> None:
|
||||
if not self._expected_fingerprint:
|
||||
# No pin known for this host (legacy enrollments / explicit address ctor).
|
||||
# Fall through to CA-only validation. Enrollment writes the fingerprint,
|
||||
# so any production host added via `swarm enroll` will have one.
|
||||
return
|
||||
actual = await asyncio.to_thread(self._fetch_peer_fingerprint)
|
||||
if actual != self._expected_fingerprint:
|
||||
raise FingerprintMismatchError(
|
||||
f"{self._address}:{self._port}",
|
||||
self._expected_fingerprint,
|
||||
actual,
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "AgentClient":
|
||||
self._client = self._build_client(_TIMEOUT_CONTROL)
|
||||
try:
|
||||
await self._verify_pin()
|
||||
except BaseException:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
raise
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: Any) -> None:
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def _require_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError("AgentClient used outside `async with` block")
|
||||
return self._client
|
||||
|
||||
# ----------------------------------------------------------------- RPCs
|
||||
|
||||
async def health(self) -> dict[str, Any]:
|
||||
resp = await self._require_client().get("/health")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def status(self) -> dict[str, Any]:
|
||||
resp = await self._require_client().get("/status")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def deploy(
|
||||
self,
|
||||
config: DecnetConfig,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
no_cache: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
body = {
|
||||
"config": config.model_dump(mode="json"),
|
||||
"dry_run": dry_run,
|
||||
"no_cache": no_cache,
|
||||
}
|
||||
# Swap in a long-deploy timeout for this call only.
|
||||
old = self._require_client().timeout
|
||||
self._require_client().timeout = _TIMEOUT_DEPLOY
|
||||
try:
|
||||
resp = await self._require_client().post("/deploy", json=body)
|
||||
finally:
|
||||
self._require_client().timeout = old
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]:
|
||||
resp = await self._require_client().post(
|
||||
"/teardown", json={"decky_id": decky_id}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def self_destruct(self) -> dict[str, Any]:
|
||||
"""Trigger the worker to stop services and wipe its install."""
|
||||
resp = await self._require_client().post("/self-destruct")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# ------------------------------------------------------------ topology
|
||||
|
||||
async def apply_topology(
|
||||
self,
|
||||
hydrated: dict[str, Any],
|
||||
version_hash: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Push a hydrated topology to the agent for local materialisation.
|
||||
|
||||
The agent independently computes ``canonical_hash(hydrated)`` and
|
||||
returns 400 if it disagrees with *version_hash* — that's how we
|
||||
catch serialisation drift before half-creating bridges.
|
||||
"""
|
||||
old = self._require_client().timeout
|
||||
self._require_client().timeout = _TIMEOUT_TOPOLOGY_APPLY
|
||||
try:
|
||||
resp = await self._require_client().post(
|
||||
"/topology/apply",
|
||||
json={"hydrated": hydrated, "version_hash": version_hash},
|
||||
)
|
||||
finally:
|
||||
self._require_client().timeout = old
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def teardown_topology(self, topology_id: str) -> dict[str, Any]:
|
||||
"""Ask the agent to dismantle the named topology."""
|
||||
old = self._require_client().timeout
|
||||
self._require_client().timeout = _TIMEOUT_TOPOLOGY_TEARDOWN
|
||||
try:
|
||||
resp = await self._require_client().post(
|
||||
"/topology/teardown",
|
||||
json={"topology_id": topology_id},
|
||||
)
|
||||
finally:
|
||||
self._require_client().timeout = old
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def get_topology_state(self) -> dict[str, Any]:
|
||||
"""Snapshot of the agent's applied topology + live docker state."""
|
||||
resp = await self._require_client().get("/topology/state")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# -------------------------------------------------------------- diagnostics
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AgentClient(name={self._host_name!r}, "
|
||||
f"address={self._address!r}, port={self._port})"
|
||||
)
|
||||
318
decnet/swarm/log_forwarder.py
Normal file
318
decnet/swarm/log_forwarder.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Worker-side syslog-over-TLS forwarder (RFC 5425).
|
||||
|
||||
Runs alongside the worker agent. Tails the worker's local RFC 5424 log
|
||||
file (written by the existing docker-collector) and ships each line to
|
||||
the master's listener on TCP 6514 using octet-counted framing over mTLS.
|
||||
Persists the last-forwarded byte offset in a tiny local SQLite so a
|
||||
master crash never causes loss or duplication.
|
||||
|
||||
Design constraints (from the plan, non-negotiable):
|
||||
* transport MUST be TLS — plaintext syslog is never acceptable between
|
||||
hosts; only loopback (decky → worker-local collector) may be plaintext;
|
||||
* mTLS — the listener pins the worker cert against the DECNET CA, so only
|
||||
enrolled workers can push logs;
|
||||
* offset persistence MUST be transactional w.r.t. the send — we only
|
||||
advance the offset after ``writer.drain()`` returns without error.
|
||||
|
||||
The forwarder is intentionally a standalone coroutine, not a worker
|
||||
inside the agent process. That keeps ``decnet agent`` crashes from
|
||||
losing the log tail, and vice versa.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import pathlib
|
||||
import sqlite3
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from decnet.bus.factory import get_bus
|
||||
from decnet.bus.publish import run_health_heartbeat
|
||||
from decnet.logging import get_logger
|
||||
from decnet.swarm import pki
|
||||
|
||||
log = get_logger("swarm.forwarder")
|
||||
|
||||
# RFC 5425 framing: "<octet-count> <syslog-msg>".
|
||||
# The message itself is a standard RFC 5424 line (no trailing newline).
|
||||
_FRAME_SEP = b" "
|
||||
|
||||
_INITIAL_BACKOFF = 1.0
|
||||
_MAX_BACKOFF = 30.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ForwarderConfig:
|
||||
log_path: pathlib.Path # worker's RFC 5424 .log file
|
||||
master_host: str
|
||||
master_port: int = 6514
|
||||
agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR
|
||||
state_db: Optional[pathlib.Path] = None # default: agent_dir / "forwarder.db"
|
||||
# Max unacked bytes to keep in the local buffer when master is down.
|
||||
# We bound the lag to avoid unbounded disk growth on catastrophic master
|
||||
# outage — older lines are surfaced as a warning and dropped by advancing
|
||||
# the offset.
|
||||
max_lag_bytes: int = 128 * 1024 * 1024 # 128 MiB
|
||||
|
||||
|
||||
# ------------------------------------------------------------ offset storage
|
||||
|
||||
|
||||
class _OffsetStore:
|
||||
"""Single-row SQLite offset tracker. Stdlib only — no ORM, no async."""
|
||||
|
||||
def __init__(self, db_path: pathlib.Path) -> None:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._conn = sqlite3.connect(str(db_path))
|
||||
self._conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS forwarder_offset ("
|
||||
" key TEXT PRIMARY KEY, offset INTEGER NOT NULL)"
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get(self, key: str = "default") -> int:
|
||||
row = self._conn.execute(
|
||||
"SELECT offset FROM forwarder_offset WHERE key=?", (key,)
|
||||
).fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
def set(self, offset: int, key: str = "default") -> None:
|
||||
self._conn.execute(
|
||||
"INSERT INTO forwarder_offset(key, offset) VALUES(?, ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET offset=excluded.offset",
|
||||
(key, offset),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
self._conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- TLS setup
|
||||
|
||||
|
||||
def build_worker_ssl_context(agent_dir: pathlib.Path) -> ssl.SSLContext:
|
||||
"""Client-side mTLS context for the forwarder.
|
||||
|
||||
Worker presents its agent bundle (same cert used for the control-plane
|
||||
HTTPS listener). The CA is the DECNET CA; we pin by CA, not hostname,
|
||||
because workers reach masters by operator-supplied address.
|
||||
"""
|
||||
bundle = pki.load_worker_bundle(agent_dir)
|
||||
if bundle is None:
|
||||
raise RuntimeError(
|
||||
f"no worker bundle at {agent_dir} — enroll from the master first"
|
||||
)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(
|
||||
certfile=str(agent_dir / "worker.crt"),
|
||||
keyfile=str(agent_dir / "worker.key"),
|
||||
)
|
||||
ctx.load_verify_locations(cafile=str(agent_dir / "ca.crt"))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.check_hostname = False
|
||||
return ctx
|
||||
|
||||
|
||||
# ----------------------------------------------------------- frame encoding
|
||||
|
||||
|
||||
def encode_frame(line: str) -> bytes:
|
||||
"""RFC 5425 octet-counted framing: ``"<N> <msg>"``.
|
||||
|
||||
``N`` is the byte length of the payload that follows (after the space).
|
||||
"""
|
||||
payload = line.rstrip("\n").encode("utf-8", errors="replace")
|
||||
return f"{len(payload)}".encode("ascii") + _FRAME_SEP + payload
|
||||
|
||||
|
||||
async def read_frame(reader: asyncio.StreamReader) -> Optional[bytes]:
|
||||
"""Read one octet-counted frame. Returns None on clean EOF."""
|
||||
# Read the ASCII length up to the first space. Bound the prefix so a
|
||||
# malicious peer can't force us to buffer unbounded bytes before we know
|
||||
# it's a valid frame.
|
||||
prefix = b""
|
||||
while True:
|
||||
c = await reader.read(1)
|
||||
if not c:
|
||||
return None if not prefix else b""
|
||||
if c == _FRAME_SEP:
|
||||
break
|
||||
if len(prefix) >= 10 or not c.isdigit():
|
||||
# RFC 5425 caps the length prefix at ~10 digits (< 4 GiB payload).
|
||||
raise ValueError(f"invalid octet-count prefix: {prefix!r}")
|
||||
prefix += c
|
||||
n = int(prefix)
|
||||
buf = await reader.readexactly(n)
|
||||
return buf
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- main loop
|
||||
|
||||
|
||||
async def _send_batch(
|
||||
writer: asyncio.StreamWriter,
|
||||
offset: int,
|
||||
lines: list[tuple[int, str]],
|
||||
store: _OffsetStore,
|
||||
) -> int:
|
||||
"""Write every line as a frame, drain, then persist the last offset."""
|
||||
for _, line in lines:
|
||||
writer.write(encode_frame(line))
|
||||
await writer.drain()
|
||||
last_offset = lines[-1][0]
|
||||
store.set(last_offset)
|
||||
return last_offset
|
||||
|
||||
|
||||
async def run_forwarder(
|
||||
cfg: ForwarderConfig,
|
||||
*,
|
||||
poll_interval: float = 0.5,
|
||||
stop_event: Optional[asyncio.Event] = None,
|
||||
) -> None:
|
||||
"""Main forwarder loop. Run as a dedicated task.
|
||||
|
||||
Stops when ``stop_event`` is set (used by tests and clean shutdown).
|
||||
Exceptions trigger exponential backoff but are never fatal — the
|
||||
forwarder is expected to outlive transient master/network failures.
|
||||
"""
|
||||
state_db = cfg.state_db or (cfg.agent_dir / "forwarder.db")
|
||||
store = _OffsetStore(state_db)
|
||||
offset = store.get()
|
||||
backoff = _INITIAL_BACKOFF
|
||||
|
||||
log.info(
|
||||
"forwarder start log=%s master=%s:%d offset=%d",
|
||||
cfg.log_path, cfg.master_host, cfg.master_port, offset,
|
||||
)
|
||||
|
||||
# Host-local bus heartbeat (system.forwarder.health). Peers on the
|
||||
# same host can tail "is the log shipper alive" without hitting the
|
||||
# master. Bus-disabled path is a no-op loop.
|
||||
bus = None
|
||||
try:
|
||||
bus = get_bus(client_name="forwarder")
|
||||
await bus.connect()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("forwarder: bus unavailable, skipping heartbeat: %s", exc)
|
||||
bus = None
|
||||
|
||||
heartbeat_task = asyncio.create_task(
|
||||
run_health_heartbeat(bus, "forwarder"),
|
||||
name="forwarder-bus-heartbeat",
|
||||
)
|
||||
|
||||
try:
|
||||
while stop_event is None or not stop_event.is_set():
|
||||
try:
|
||||
ctx = build_worker_ssl_context(cfg.agent_dir)
|
||||
reader, writer = await asyncio.open_connection(
|
||||
cfg.master_host, cfg.master_port, ssl=ctx
|
||||
)
|
||||
log.info("forwarder connected master=%s:%d", cfg.master_host, cfg.master_port)
|
||||
backoff = _INITIAL_BACKOFF
|
||||
try:
|
||||
offset = await _pump(cfg, store, writer, offset, poll_interval, stop_event)
|
||||
finally:
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception: # nosec B110 — socket cleanup is best-effort
|
||||
pass
|
||||
# Keep reader alive until here to avoid "reader garbage
|
||||
# collected" warnings on some Python builds.
|
||||
del reader
|
||||
except (OSError, ssl.SSLError, ConnectionError) as exc:
|
||||
log.warning(
|
||||
"forwarder disconnected: %s — retrying in %.1fs", exc, backoff
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
_sleep_unless_stopped(backoff, stop_event), timeout=backoff + 1
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
backoff = min(_MAX_BACKOFF, backoff * 2)
|
||||
finally:
|
||||
heartbeat_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError, Exception):
|
||||
await heartbeat_task
|
||||
if bus is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await bus.close()
|
||||
store.close()
|
||||
log.info("forwarder stopped offset=%d", offset)
|
||||
|
||||
|
||||
async def _pump(
|
||||
cfg: ForwarderConfig,
|
||||
store: _OffsetStore,
|
||||
writer: asyncio.StreamWriter,
|
||||
offset: int,
|
||||
poll_interval: float,
|
||||
stop_event: Optional[asyncio.Event],
|
||||
) -> int:
|
||||
"""Read new lines since ``offset`` and ship them until disconnect."""
|
||||
while stop_event is None or not stop_event.is_set():
|
||||
if not cfg.log_path.exists():
|
||||
await _sleep_unless_stopped(poll_interval, stop_event)
|
||||
continue
|
||||
|
||||
stat = cfg.log_path.stat()
|
||||
if stat.st_size < offset:
|
||||
# truncated/rotated — reset.
|
||||
log.warning("forwarder log rotated — resetting offset=0")
|
||||
offset = 0
|
||||
store.set(0)
|
||||
if stat.st_size - offset > cfg.max_lag_bytes:
|
||||
# Catastrophic lag — skip ahead to cap local disk pressure.
|
||||
skip_to = stat.st_size - cfg.max_lag_bytes
|
||||
log.warning(
|
||||
"forwarder lag %d > cap %d — dropping oldest %d bytes",
|
||||
stat.st_size - offset, cfg.max_lag_bytes, skip_to - offset,
|
||||
)
|
||||
offset = skip_to
|
||||
store.set(offset)
|
||||
|
||||
if stat.st_size == offset:
|
||||
await _sleep_unless_stopped(poll_interval, stop_event)
|
||||
continue
|
||||
|
||||
batch: list[tuple[int, str]] = []
|
||||
with open(cfg.log_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
f.seek(offset)
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line or not line.endswith("\n"):
|
||||
break
|
||||
offset_after = f.tell()
|
||||
batch.append((offset_after, line.rstrip("\n")))
|
||||
if len(batch) >= 500:
|
||||
break
|
||||
if batch:
|
||||
offset = await _send_batch(writer, offset, batch, store)
|
||||
return offset
|
||||
|
||||
|
||||
async def _sleep_unless_stopped(
|
||||
seconds: float, stop_event: Optional[asyncio.Event]
|
||||
) -> None:
|
||||
if stop_event is None:
|
||||
await asyncio.sleep(seconds)
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(stop_event.wait(), timeout=seconds)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
# Re-exported for CLI convenience
|
||||
DEFAULT_PORT = 6514
|
||||
|
||||
|
||||
def default_master_host() -> Optional[str]:
|
||||
return os.environ.get("DECNET_SWARM_MASTER_HOST")
|
||||
194
decnet/swarm/log_listener.py
Normal file
194
decnet/swarm/log_listener.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Master-side syslog-over-TLS listener (RFC 5425).
|
||||
|
||||
Accepts mTLS-authenticated worker connections on TCP 6514, reads
|
||||
octet-counted frames, parses each as an RFC 5424 line, and appends it to
|
||||
the master's local ingest log files. The existing log_ingestion_worker
|
||||
tails those files and inserts records into the master repo — worker
|
||||
provenance is embedded in the parsed record's ``source_worker`` field.
|
||||
|
||||
Design:
|
||||
* TLS is mandatory. No plaintext fallback. A peer without a CA-signed
|
||||
cert is rejected at the TLS handshake; nothing gets past the kernel.
|
||||
* The listener never trusts the syslog HOSTNAME field for provenance —
|
||||
that's attacker-supplied from the decky. The authoritative source is
|
||||
the peer cert's CN, which the CA controlled at enrollment.
|
||||
* Dropped connections are fine — the worker's forwarder holds the
|
||||
offset and resumes from the same byte on reconnect.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pathlib
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.x509.oid import NameOID
|
||||
|
||||
from decnet.logging import get_logger
|
||||
from decnet.swarm import pki
|
||||
from decnet.swarm.log_forwarder import read_frame
|
||||
|
||||
log = get_logger("swarm.listener")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListenerConfig:
|
||||
log_path: pathlib.Path # master's RFC 5424 .log (forensic sink)
|
||||
json_path: pathlib.Path # master's .json (ingester tails this)
|
||||
bind_host: str = "0.0.0.0" # nosec B104 — listener must bind publicly
|
||||
bind_port: int = 6514
|
||||
ca_dir: pathlib.Path = pki.DEFAULT_CA_DIR
|
||||
|
||||
|
||||
# --------------------------------------------------------- TLS context
|
||||
|
||||
|
||||
def build_listener_ssl_context(ca_dir: pathlib.Path) -> ssl.SSLContext:
|
||||
"""Server-side mTLS context: master presents its master cert; clients
|
||||
must present a cert signed by the DECNET CA."""
|
||||
master_dir = ca_dir / "master"
|
||||
ca_cert = master_dir / "ca.crt"
|
||||
cert = master_dir / "worker.crt" # master re-uses the 'worker' bundle layout
|
||||
key = master_dir / "worker.key"
|
||||
for p in (ca_cert, cert, key):
|
||||
if not p.exists():
|
||||
raise RuntimeError(
|
||||
f"master identity missing at {master_dir} — call ensure_master_identity first"
|
||||
)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.load_cert_chain(certfile=str(cert), keyfile=str(key))
|
||||
ctx.load_verify_locations(cafile=str(ca_cert))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------- helpers
|
||||
|
||||
|
||||
def peer_cn(ssl_object: Optional[ssl.SSLObject]) -> str:
|
||||
"""Extract the CN from the TLS peer certificate (worker provenance).
|
||||
|
||||
Falls back to ``"unknown"`` on any parse error — we refuse to crash on
|
||||
malformed cert DNs and instead tag the message for later inspection.
|
||||
"""
|
||||
if ssl_object is None:
|
||||
return "unknown"
|
||||
der = ssl_object.getpeercert(binary_form=True)
|
||||
if der is None:
|
||||
return "unknown"
|
||||
try:
|
||||
cert = x509.load_der_x509_certificate(der)
|
||||
attrs = cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
|
||||
return attrs[0].value if attrs else "unknown"
|
||||
except Exception: # nosec B110 — provenance is best-effort
|
||||
return "unknown"
|
||||
|
||||
|
||||
def fingerprint_from_ssl(ssl_object: Optional[ssl.SSLObject]) -> Optional[str]:
|
||||
if ssl_object is None:
|
||||
return None
|
||||
der = ssl_object.getpeercert(binary_form=True)
|
||||
if der is None:
|
||||
return None
|
||||
try:
|
||||
cert = x509.load_der_x509_certificate(der)
|
||||
return pki.fingerprint(cert.public_bytes(serialization.Encoding.PEM))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------- per-connection handler
|
||||
|
||||
|
||||
async def _handle_connection(
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
cfg: ListenerConfig,
|
||||
) -> None:
|
||||
ssl_obj = writer.get_extra_info("ssl_object")
|
||||
cn = peer_cn(ssl_obj)
|
||||
peer = writer.get_extra_info("peername")
|
||||
log.info("listener accepted worker=%s peer=%s", cn, peer)
|
||||
|
||||
# Lazy import to avoid a circular dep if the collector pulls in logger setup.
|
||||
from decnet.collector.worker import parse_rfc5424
|
||||
|
||||
cfg.log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cfg.json_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(cfg.log_path, "a", encoding="utf-8") as lf, open(
|
||||
cfg.json_path, "a", encoding="utf-8"
|
||||
) as jf:
|
||||
while True:
|
||||
try:
|
||||
frame = await read_frame(reader)
|
||||
except asyncio.IncompleteReadError:
|
||||
break
|
||||
except ValueError as exc:
|
||||
log.warning("listener bad frame worker=%s err=%s", cn, exc)
|
||||
break
|
||||
if frame is None:
|
||||
break
|
||||
if not frame:
|
||||
continue
|
||||
line = frame.decode("utf-8", errors="replace")
|
||||
lf.write(line + "\n")
|
||||
lf.flush()
|
||||
parsed = parse_rfc5424(line)
|
||||
if parsed is not None:
|
||||
parsed["source_worker"] = cn
|
||||
jf.write(json.dumps(parsed) + "\n")
|
||||
jf.flush()
|
||||
else:
|
||||
log.debug("listener malformed RFC5424 worker=%s snippet=%r", cn, line[:80])
|
||||
except Exception as exc:
|
||||
log.warning("listener connection error worker=%s err=%s", cn, exc)
|
||||
finally:
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception: # nosec B110 — socket cleanup is best-effort
|
||||
pass
|
||||
log.info("listener closed worker=%s", cn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- server
|
||||
|
||||
|
||||
async def run_listener(
|
||||
cfg: ListenerConfig,
|
||||
*,
|
||||
stop_event: Optional[asyncio.Event] = None,
|
||||
) -> None:
|
||||
ctx = build_listener_ssl_context(cfg.ca_dir)
|
||||
|
||||
async def _client_cb(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
await _handle_connection(reader, writer, cfg)
|
||||
|
||||
server = await asyncio.start_server(
|
||||
_client_cb, host=cfg.bind_host, port=cfg.bind_port, ssl=ctx
|
||||
)
|
||||
sockets = server.sockets or ()
|
||||
log.info(
|
||||
"listener bound host=%s port=%d sockets=%d",
|
||||
cfg.bind_host, cfg.bind_port, len(sockets),
|
||||
)
|
||||
async with server:
|
||||
if stop_event is None:
|
||||
await server.serve_forever()
|
||||
else:
|
||||
serve_task = asyncio.create_task(server.serve_forever())
|
||||
await stop_event.wait()
|
||||
server.close()
|
||||
serve_task.cancel()
|
||||
try:
|
||||
await serve_task
|
||||
except (asyncio.CancelledError, Exception): # nosec B110
|
||||
pass
|
||||
323
decnet/swarm/pki.py
Normal file
323
decnet/swarm/pki.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""DECNET SWARM PKI — self-managed X.509 CA for master↔worker mTLS.
|
||||
|
||||
Used by:
|
||||
* the SWARM controller (master) to issue per-worker server+client certs at
|
||||
enrollment time,
|
||||
* the agent (worker) to present its mTLS identity for both the control-plane
|
||||
HTTPS endpoint and the syslog-over-TLS (RFC 5425) log forwarder,
|
||||
* the master-side syslog-TLS listener to authenticate inbound workers.
|
||||
|
||||
Storage layout (master):
|
||||
|
||||
~/.decnet/ca/
|
||||
ca.key (PEM, 0600 — the CA private key)
|
||||
ca.crt (PEM — self-signed root)
|
||||
workers/<worker-name>/
|
||||
client.crt (issued, signed by CA)
|
||||
|
||||
Worker layout (delivered by /enroll response):
|
||||
|
||||
~/.decnet/agent/
|
||||
ca.crt (master's CA — trust anchor)
|
||||
worker.key (worker's own private key)
|
||||
worker.crt (signed by master CA — used for both TLS
|
||||
server auth *and* syslog client auth)
|
||||
|
||||
The CA is a hard dependency only in swarm mode; unihost installs never
|
||||
touch this module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as _dt
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import os
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import NameOID
|
||||
|
||||
DEFAULT_CA_DIR = pathlib.Path(os.path.expanduser("~/.decnet/ca"))
|
||||
DEFAULT_AGENT_DIR = pathlib.Path(os.path.expanduser("~/.decnet/agent"))
|
||||
DEFAULT_SWARMCTL_DIR = pathlib.Path(os.path.expanduser("~/.decnet/swarmctl"))
|
||||
|
||||
CA_KEY_BITS = 4096
|
||||
WORKER_KEY_BITS = 2048
|
||||
CA_VALIDITY_DAYS = 3650 # 10 years — internal CA
|
||||
WORKER_VALIDITY_DAYS = 825 # max permitted by modern TLS clients
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CABundle:
|
||||
"""The master's CA identity (key is secret, cert is published)."""
|
||||
|
||||
key_pem: bytes
|
||||
cert_pem: bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IssuedCert:
|
||||
"""A signed worker certificate + its private key, handed to the worker
|
||||
exactly once during enrollment.
|
||||
"""
|
||||
|
||||
key_pem: bytes
|
||||
cert_pem: bytes
|
||||
ca_cert_pem: bytes
|
||||
fingerprint_sha256: str # hex, lowercase
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- CA ops
|
||||
|
||||
|
||||
def _pem_private(key: rsa.RSAPrivateKey) -> bytes:
|
||||
return key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
def _pem_cert(cert: x509.Certificate) -> bytes:
|
||||
return cert.public_bytes(serialization.Encoding.PEM)
|
||||
|
||||
|
||||
def generate_ca(common_name: str = "DECNET SWARM Root CA") -> CABundle:
|
||||
"""Generate a fresh self-signed CA. Does not touch disk."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=CA_KEY_BITS)
|
||||
subject = issuer = x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "DECNET"),
|
||||
]
|
||||
)
|
||||
now = _dt.datetime.now(_dt.timezone.utc)
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - _dt.timedelta(minutes=5))
|
||||
.not_valid_after(now + _dt.timedelta(days=CA_VALIDITY_DAYS))
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.sign(private_key=key, algorithm=hashes.SHA256())
|
||||
)
|
||||
return CABundle(key_pem=_pem_private(key), cert_pem=_pem_cert(cert))
|
||||
|
||||
|
||||
def save_ca(bundle: CABundle, ca_dir: pathlib.Path = DEFAULT_CA_DIR) -> None:
|
||||
ca_dir.mkdir(parents=True, exist_ok=True)
|
||||
# 0700 on the dir, 0600 on the key — defence against casual reads.
|
||||
os.chmod(ca_dir, 0o700)
|
||||
key_path = ca_dir / "ca.key"
|
||||
cert_path = ca_dir / "ca.crt"
|
||||
key_path.write_bytes(bundle.key_pem)
|
||||
os.chmod(key_path, 0o600)
|
||||
cert_path.write_bytes(bundle.cert_pem)
|
||||
|
||||
|
||||
def load_ca(ca_dir: pathlib.Path = DEFAULT_CA_DIR) -> CABundle:
|
||||
key_pem = (ca_dir / "ca.key").read_bytes()
|
||||
cert_pem = (ca_dir / "ca.crt").read_bytes()
|
||||
return CABundle(key_pem=key_pem, cert_pem=cert_pem)
|
||||
|
||||
|
||||
def ensure_ca(ca_dir: pathlib.Path = DEFAULT_CA_DIR) -> CABundle:
|
||||
"""Load the CA if present, otherwise generate and persist a new one."""
|
||||
if (ca_dir / "ca.key").exists() and (ca_dir / "ca.crt").exists():
|
||||
return load_ca(ca_dir)
|
||||
bundle = generate_ca()
|
||||
save_ca(bundle, ca_dir)
|
||||
return bundle
|
||||
|
||||
|
||||
# --------------------------------------------------------------- cert issuance
|
||||
|
||||
|
||||
def _parse_san(value: str) -> x509.GeneralName:
|
||||
"""Parse a SAN entry as IP if possible, otherwise DNS."""
|
||||
try:
|
||||
return x509.IPAddress(ipaddress.ip_address(value))
|
||||
except ValueError:
|
||||
return x509.DNSName(value)
|
||||
|
||||
|
||||
def issue_worker_cert(
|
||||
ca: CABundle,
|
||||
worker_name: str,
|
||||
sans: list[str],
|
||||
validity_days: int = WORKER_VALIDITY_DAYS,
|
||||
) -> IssuedCert:
|
||||
"""Sign a freshly-generated worker keypair.
|
||||
|
||||
The cert is usable as BOTH a TLS server (agent's HTTPS endpoint) and a
|
||||
TLS client (syslog-over-TLS upstream to the master) — extended key usage
|
||||
covers both. ``sans`` should include every address/name the master or
|
||||
workers will use to reach this worker — typically the worker's IP plus
|
||||
its hostname.
|
||||
"""
|
||||
ca_key = serialization.load_pem_private_key(ca.key_pem, password=None)
|
||||
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem)
|
||||
|
||||
worker_key = rsa.generate_private_key(public_exponent=65537, key_size=WORKER_KEY_BITS)
|
||||
subject = x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, worker_name),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "DECNET"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "swarm-worker"),
|
||||
]
|
||||
)
|
||||
now = _dt.datetime.now(_dt.timezone.utc)
|
||||
san_entries: list[x509.GeneralName] = [_parse_san(s) for s in sans] if sans else []
|
||||
# Always include the worker-name as a DNS SAN so cert pinning by CN-as-DNS
|
||||
# works even when the operator forgets to pass an explicit SAN list.
|
||||
if not any(
|
||||
isinstance(e, x509.DNSName) and e.value == worker_name for e in san_entries
|
||||
):
|
||||
san_entries.append(x509.DNSName(worker_name))
|
||||
|
||||
builder = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(worker_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now - _dt.timedelta(minutes=5))
|
||||
.not_valid_after(now + _dt.timedelta(days=validity_days))
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=True,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage(
|
||||
[
|
||||
x509.ObjectIdentifier("1.3.6.1.5.5.7.3.1"), # serverAuth
|
||||
x509.ObjectIdentifier("1.3.6.1.5.5.7.3.2"), # clientAuth
|
||||
]
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(x509.SubjectAlternativeName(san_entries), critical=False)
|
||||
)
|
||||
cert = builder.sign(private_key=ca_key, algorithm=hashes.SHA256())
|
||||
cert_pem = _pem_cert(cert)
|
||||
fp = hashlib.sha256(
|
||||
cert.public_bytes(serialization.Encoding.DER)
|
||||
).hexdigest()
|
||||
return IssuedCert(
|
||||
key_pem=_pem_private(worker_key),
|
||||
cert_pem=cert_pem,
|
||||
ca_cert_pem=ca.cert_pem,
|
||||
fingerprint_sha256=fp,
|
||||
)
|
||||
|
||||
|
||||
def write_worker_bundle(
|
||||
issued: IssuedCert,
|
||||
agent_dir: pathlib.Path = DEFAULT_AGENT_DIR,
|
||||
) -> None:
|
||||
"""Persist an issued bundle into the worker's agent directory."""
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(agent_dir, 0o700)
|
||||
(agent_dir / "ca.crt").write_bytes(issued.ca_cert_pem)
|
||||
(agent_dir / "worker.crt").write_bytes(issued.cert_pem)
|
||||
key_path = agent_dir / "worker.key"
|
||||
key_path.write_bytes(issued.key_pem)
|
||||
os.chmod(key_path, 0o600)
|
||||
|
||||
|
||||
def load_worker_bundle(
|
||||
agent_dir: pathlib.Path = DEFAULT_AGENT_DIR,
|
||||
) -> Optional[IssuedCert]:
|
||||
"""Return the worker's bundle if enrolled; ``None`` otherwise."""
|
||||
ca = agent_dir / "ca.crt"
|
||||
crt = agent_dir / "worker.crt"
|
||||
key = agent_dir / "worker.key"
|
||||
if not (ca.exists() and crt.exists() and key.exists()):
|
||||
return None
|
||||
cert_pem = crt.read_bytes()
|
||||
cert = x509.load_pem_x509_certificate(cert_pem)
|
||||
fp = hashlib.sha256(
|
||||
cert.public_bytes(serialization.Encoding.DER)
|
||||
).hexdigest()
|
||||
return IssuedCert(
|
||||
key_pem=key.read_bytes(),
|
||||
cert_pem=cert_pem,
|
||||
ca_cert_pem=ca.read_bytes(),
|
||||
fingerprint_sha256=fp,
|
||||
)
|
||||
|
||||
|
||||
def ensure_swarmctl_cert(
|
||||
bind_host: str,
|
||||
ca_dir: pathlib.Path = DEFAULT_CA_DIR,
|
||||
swarmctl_dir: pathlib.Path = DEFAULT_SWARMCTL_DIR,
|
||||
extra_sans: Optional[list[str]] = None,
|
||||
) -> tuple[pathlib.Path, pathlib.Path, pathlib.Path]:
|
||||
"""Return (cert_path, key_path, ca_path), auto-issuing if missing.
|
||||
|
||||
Uses the existing DECNET CA (ensuring it exists first) so workers
|
||||
whose bundle already includes ``ca.crt`` can verify the swarmctl
|
||||
endpoint without additional trust configuration. Self-signed is
|
||||
intentionally not the default — a cert signed by the same CA the
|
||||
workers already trust is the friction-free path.
|
||||
|
||||
Callers that want BYOC should skip this and pass their own
|
||||
cert/key paths directly to uvicorn.
|
||||
"""
|
||||
swarmctl_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(swarmctl_dir, 0o700)
|
||||
cert_path = swarmctl_dir / "server.crt"
|
||||
key_path = swarmctl_dir / "server.key"
|
||||
ca_cert_path = ca_dir / "ca.crt"
|
||||
|
||||
if cert_path.exists() and key_path.exists() and ca_cert_path.exists():
|
||||
return cert_path, key_path, ca_cert_path
|
||||
|
||||
ca = ensure_ca(ca_dir)
|
||||
sans = list({bind_host, "127.0.0.1", "localhost", *(extra_sans or [])})
|
||||
issued = issue_worker_cert(ca, "swarmctl", sans)
|
||||
cert_path.write_bytes(issued.cert_pem)
|
||||
key_path.write_bytes(issued.key_pem)
|
||||
os.chmod(key_path, 0o600)
|
||||
# ensure_ca already wrote ca.crt under ca_dir, but save_ca is only
|
||||
# called on generate — re-mirror it here to guarantee the path exists.
|
||||
if not ca_cert_path.exists():
|
||||
ca_cert_path.write_bytes(ca.cert_pem)
|
||||
return cert_path, key_path, ca_cert_path
|
||||
|
||||
|
||||
def fingerprint(cert_pem: bytes) -> str:
|
||||
"""SHA-256 hex fingerprint of a cert (DER-encoded)."""
|
||||
cert = x509.load_pem_x509_certificate(cert_pem)
|
||||
return hashlib.sha256(cert.public_bytes(serialization.Encoding.DER)).hexdigest()
|
||||
97
decnet/swarm/tar_tree.py
Normal file
97
decnet/swarm/tar_tree.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Build a gzipped tarball of the master's working tree for pushing to workers.
|
||||
|
||||
Always excludes the obvious large / secret / churn paths: ``.venv/``,
|
||||
``__pycache__/``, ``.git/``, ``wiki-checkout/``, ``*.db*``, ``*.log``. The
|
||||
caller can supply additional exclude globs.
|
||||
|
||||
Deliberately does NOT invoke git — the tree is what the operator has on
|
||||
disk (staged + unstaged + untracked). That's the whole point; the scp
|
||||
workflow we're replacing also shipped the live tree.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import io
|
||||
import pathlib
|
||||
import tarfile
|
||||
from typing import Iterable, Optional
|
||||
|
||||
DEFAULT_EXCLUDES = (
|
||||
".venv", ".venv/*",
|
||||
"**/.venv/*",
|
||||
"__pycache__", "**/__pycache__", "**/__pycache__/*",
|
||||
".git", ".git/*",
|
||||
"wiki-checkout", "wiki-checkout/*",
|
||||
"*.pyc", "*.pyo",
|
||||
"*.db", "*.db-wal", "*.db-shm",
|
||||
"*.log",
|
||||
".pytest_cache", ".pytest_cache/*",
|
||||
".mypy_cache", ".mypy_cache/*",
|
||||
".tox", ".tox/*",
|
||||
"*.egg-info", "*.egg-info/*",
|
||||
"decnet-state.json",
|
||||
"master.log", "master.json",
|
||||
"decnet.db*",
|
||||
)
|
||||
|
||||
|
||||
def _is_excluded(rel: str, patterns: Iterable[str]) -> bool:
|
||||
parts = pathlib.PurePosixPath(rel).parts
|
||||
for pat in patterns:
|
||||
if fnmatch.fnmatch(rel, pat):
|
||||
return True
|
||||
# Also match the pattern against every leading subpath — this is
|
||||
# what catches nested `.venv/...` without forcing callers to spell
|
||||
# out every `**/` glob.
|
||||
for i in range(1, len(parts) + 1):
|
||||
if fnmatch.fnmatch("/".join(parts[:i]), pat):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def tar_working_tree(
|
||||
root: pathlib.Path,
|
||||
extra_excludes: Optional[Iterable[str]] = None,
|
||||
) -> bytes:
|
||||
"""Return the gzipped tarball bytes of ``root``.
|
||||
|
||||
Entries are added with paths relative to ``root`` (no leading ``/``,
|
||||
no ``..``). The updater rejects unsafe paths on the receiving side.
|
||||
"""
|
||||
patterns = list(DEFAULT_EXCLUDES) + list(extra_excludes or ())
|
||||
buf = io.BytesIO()
|
||||
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for path in sorted(root.rglob("*")):
|
||||
rel = path.relative_to(root).as_posix()
|
||||
if _is_excluded(rel, patterns):
|
||||
continue
|
||||
if path.is_symlink():
|
||||
# Symlinks inside a repo tree are rare and often break
|
||||
# portability; skip them rather than ship dangling links.
|
||||
continue
|
||||
if path.is_dir():
|
||||
continue
|
||||
tar.add(path, arcname=rel, recursive=False)
|
||||
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def detect_git_sha(root: pathlib.Path) -> str:
|
||||
"""Best-effort ``HEAD`` sha. Returns ``""`` if not a git repo."""
|
||||
head = root / ".git" / "HEAD"
|
||||
if not head.is_file():
|
||||
return ""
|
||||
try:
|
||||
ref = head.read_text().strip()
|
||||
except OSError:
|
||||
return ""
|
||||
if ref.startswith("ref: "):
|
||||
ref_path = root / ".git" / ref[5:]
|
||||
if ref_path.is_file():
|
||||
try:
|
||||
return ref_path.read_text().strip()
|
||||
except OSError:
|
||||
return ""
|
||||
return ""
|
||||
return ref
|
||||
132
decnet/swarm/updater_client.py
Normal file
132
decnet/swarm/updater_client.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Master-side HTTP client for the worker's self-updater daemon.
|
||||
|
||||
Sibling of ``AgentClient``: same mTLS identity (same DECNET CA, same
|
||||
master client cert) but targets the updater's port (default 8766) and
|
||||
speaks the multipart upload protocol the updater's ``/update`` endpoint
|
||||
expects.
|
||||
|
||||
Kept as its own module — not a subclass of ``AgentClient`` — because the
|
||||
timeouts and failure semantics are genuinely different: pip install +
|
||||
agent probe can take a minute on a slow VM, and ``/update-self`` drops
|
||||
the connection on purpose (the updater re-execs itself mid-response).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from decnet.logging import get_logger
|
||||
from decnet.swarm.client import MasterIdentity, ensure_master_identity
|
||||
|
||||
log = get_logger("swarm.updater_client")
|
||||
|
||||
_TIMEOUT_UPDATE = httpx.Timeout(connect=10.0, read=180.0, write=120.0, pool=5.0)
|
||||
_TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0)
|
||||
|
||||
|
||||
class UpdaterClient:
|
||||
"""Async client targeting a worker's ``decnet updater`` daemon."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: dict[str, Any] | None = None,
|
||||
*,
|
||||
address: Optional[str] = None,
|
||||
updater_port: int = 8766,
|
||||
identity: Optional[MasterIdentity] = None,
|
||||
verify_hostname: Optional[bool] = None,
|
||||
):
|
||||
if verify_hostname is None:
|
||||
from decnet.env import DECNET_VERIFY_HOSTNAME
|
||||
verify_hostname = DECNET_VERIFY_HOSTNAME
|
||||
self._verify_hostname = verify_hostname
|
||||
if host is not None:
|
||||
self._address = host["address"]
|
||||
self._host_name = host.get("name")
|
||||
else:
|
||||
if address is None:
|
||||
raise ValueError("UpdaterClient requires host dict or address")
|
||||
self._address = address
|
||||
self._host_name = None
|
||||
self._port = updater_port
|
||||
self._identity = identity or ensure_master_identity()
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient:
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(
|
||||
str(self._identity.cert_path), str(self._identity.key_path),
|
||||
)
|
||||
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
ctx.check_hostname = self._verify_hostname
|
||||
return httpx.AsyncClient(
|
||||
base_url=f"https://{self._address}:{self._port}",
|
||||
verify=ctx,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "UpdaterClient":
|
||||
self._client = self._build_client(_TIMEOUT_CONTROL)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc: Any) -> None:
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def _require(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError("UpdaterClient used outside `async with` block")
|
||||
return self._client
|
||||
|
||||
# --------------------------------------------------------------- RPCs
|
||||
|
||||
async def health(self) -> dict[str, Any]:
|
||||
r = await self._require().get("/health")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def releases(self) -> dict[str, Any]:
|
||||
r = await self._require().get("/releases")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def update(self, tarball: bytes, sha: str = "") -> httpx.Response:
|
||||
"""POST /update. Returns the Response so the caller can distinguish
|
||||
200 / 409 / 500 — each means something different.
|
||||
"""
|
||||
sha256 = hashlib.sha256(tarball).hexdigest()
|
||||
self._require().timeout = _TIMEOUT_UPDATE
|
||||
try:
|
||||
r = await self._require().post(
|
||||
"/update",
|
||||
files={"tarball": ("tree.tgz", tarball, "application/gzip")},
|
||||
data={"sha": sha, "sha256": sha256},
|
||||
)
|
||||
finally:
|
||||
self._require().timeout = _TIMEOUT_CONTROL
|
||||
return r
|
||||
|
||||
async def update_self(self, tarball: bytes, sha: str = "") -> httpx.Response:
|
||||
"""POST /update-self. The updater re-execs itself, so the connection
|
||||
usually drops mid-response; that's not an error. Callers should then
|
||||
poll /health until the new SHA appears.
|
||||
"""
|
||||
sha256 = hashlib.sha256(tarball).hexdigest()
|
||||
self._require().timeout = _TIMEOUT_UPDATE
|
||||
try:
|
||||
r = await self._require().post(
|
||||
"/update-self",
|
||||
files={"tarball": ("tree.tgz", tarball, "application/gzip")},
|
||||
data={"sha": sha, "sha256": sha256, "confirm_self": "true"},
|
||||
)
|
||||
finally:
|
||||
self._require().timeout = _TIMEOUT_CONTROL
|
||||
return r
|
||||
|
||||
async def rollback(self) -> httpx.Response:
|
||||
return await self._require().post("/rollback")
|
||||
Reference in New Issue
Block a user