From 7765b36c50176762d0cfcf74a527880594aaf426 Mon Sep 17 00:00:00 2001 From: anti Date: Sat, 18 Apr 2026 21:40:21 -0400 Subject: [PATCH] feat(updater): remote self-update daemon with auto-rollback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a separate `decnet updater` daemon on each worker that owns the agent's release directory and installs tarball pushes from the master over mTLS. A normal `/update` never touches the updater itself, so the updater is always a known-good rescuer if a bad agent push breaks /health — the rotation is reversed and the agent restarted against the previous release. `POST /update-self` handles updater upgrades explicitly (no auto-rollback). - decnet/updater/: executor, FastAPI app, uvicorn launcher - decnet/swarm/updater_client.py, tar_tree.py: master-side push - cli: `decnet updater`, `decnet swarm update [--host|--all] [--include-self] [--dry-run]`, `--updater` on `swarm enroll` - enrollment API issues a second cert (CN=updater@) signed by the same CA; SwarmHost records updater_cert_fingerprint - tests: executor, app, CLI, tar tree, enroll-with-updater (37 new) - wiki: Remote-Updates page + sidebar + SWARM-Mode cross-link --- decnet/cli.py | 171 ++++++++- decnet/swarm/tar_tree.py | 97 +++++ decnet/swarm/updater_client.py | 124 ++++++ decnet/updater/__init__.py | 10 + decnet/updater/app.py | 139 +++++++ decnet/updater/executor.py | 416 +++++++++++++++++++++ decnet/updater/server.py | 86 +++++ decnet/web/db/models.py | 17 + decnet/web/router/swarm/api_enroll_host.py | 25 +- pyproject.toml | 3 +- tests/swarm/test_cli_swarm_update.py | 192 ++++++++++ tests/swarm/test_swarm_api.py | 30 ++ tests/swarm/test_tar_tree.py | 75 ++++ tests/updater/__init__.py | 0 tests/updater/test_updater_app.py | 138 +++++++ tests/updater/test_updater_executor.py | 295 +++++++++++++++ 16 files changed, 1814 insertions(+), 4 deletions(-) create mode 100644 decnet/swarm/tar_tree.py create mode 100644 decnet/swarm/updater_client.py create mode 100644 decnet/updater/__init__.py create mode 100644 decnet/updater/app.py create mode 100644 decnet/updater/executor.py create mode 100644 decnet/updater/server.py create mode 100644 tests/swarm/test_cli_swarm_update.py create mode 100644 tests/swarm/test_tar_tree.py create mode 100644 tests/updater/__init__.py create mode 100644 tests/updater/test_updater_app.py create mode 100644 tests/updater/test_updater_executor.py diff --git a/decnet/cli.py b/decnet/cli.py index ecff922..d6e7083 100644 --- a/decnet/cli.py +++ b/decnet/cli.py @@ -187,6 +187,43 @@ def agent( raise typer.Exit(rc) +@app.command() +def updater( + port: int = typer.Option(8766, "--port", help="Port for the self-updater daemon"), + host: str = typer.Option("0.0.0.0", "--host", help="Bind address for the updater"), # nosec B104 + updater_dir: Optional[str] = typer.Option(None, "--updater-dir", help="Updater cert bundle dir (default: ~/.decnet/updater)"), + install_dir: Optional[str] = typer.Option(None, "--install-dir", help="Release install root (default: /opt/decnet)"), + agent_dir: Optional[str] = typer.Option(None, "--agent-dir", help="Worker agent cert bundle (for local /health probes; default: ~/.decnet/agent)"), + daemon: bool = typer.Option(False, "--daemon", "-d", help="Detach to background as a daemon process"), +) -> None: + """Run the DECNET self-updater (requires a bundle in ~/.decnet/updater/).""" + import pathlib as _pathlib + from decnet.swarm import pki as _pki + from decnet.updater import server as _upd_server + + resolved_updater = _pathlib.Path(updater_dir) if updater_dir else _upd_server.DEFAULT_UPDATER_DIR + resolved_install = _pathlib.Path(install_dir) if install_dir else _pathlib.Path("/opt/decnet") + resolved_agent = _pathlib.Path(agent_dir) if agent_dir else _pki.DEFAULT_AGENT_DIR + + if daemon: + log.info("updater daemonizing host=%s port=%d", host, port) + _daemonize() + + log.info( + "updater command invoked host=%s port=%d updater_dir=%s install_dir=%s", + host, port, resolved_updater, resolved_install, + ) + console.print(f"[green]Starting DECNET self-updater on {host}:{port} (mTLS)...[/]") + rc = _upd_server.run( + host, port, + updater_dir=resolved_updater, + install_dir=resolved_install, + agent_dir=resolved_agent, + ) + if rc != 0: + raise typer.Exit(rc) + + @app.command() def listener( bind_host: str = typer.Option("0.0.0.0", "--host", help="Bind address for the master syslog-TLS listener"), # nosec B104 @@ -393,6 +430,7 @@ def swarm_enroll( sans: Optional[str] = typer.Option(None, "--sans", help="Comma-separated extra SANs for the worker cert"), notes: Optional[str] = typer.Option(None, "--notes", help="Free-form operator notes"), out_dir: Optional[str] = typer.Option(None, "--out-dir", help="Write the bundle (ca.crt/worker.crt/worker.key) to this dir for scp"), + updater: bool = typer.Option(False, "--updater", help="Also issue an updater-identity cert (CN=updater@) for the remote self-updater"), url: Optional[str] = typer.Option(None, "--url", help="Override swarm controller URL (default: 127.0.0.1:8770)"), ) -> None: """Issue a mTLS bundle for a new worker and register it in the swarm.""" @@ -403,6 +441,8 @@ def swarm_enroll( body["sans"] = [s.strip() for s in sans.split(",") if s.strip()] if notes: body["notes"] = notes + if updater: + body["issue_updater_bundle"] = True resp = _http_request("POST", _swarmctl_base_url(url) + "/swarm/enroll", json_body=body) data = resp.json() @@ -410,6 +450,9 @@ def swarm_enroll( console.print(f"[green]Enrolled worker:[/] {data['name']} " f"[dim]uuid=[/]{data['host_uuid']} " f"[dim]fingerprint=[/]{data['fingerprint']}") + if data.get("updater"): + console.print(f"[green] + updater identity[/] " + f"[dim]fingerprint=[/]{data['updater']['fingerprint']}") if out_dir: target = _pathlib.Path(out_dir).expanduser() @@ -422,8 +465,22 @@ def swarm_enroll( (target / leaf).chmod(0o600) except OSError: pass - console.print(f"[cyan]Bundle written to[/] {target}") - console.print("[dim]Ship this directory to the worker at ~/.decnet/agent/ (or wherever `decnet agent --agent-dir` points).[/]") + console.print(f"[cyan]Agent bundle written to[/] {target}") + + if data.get("updater"): + upd_target = target.parent / f"{target.name}-updater" + upd_target.mkdir(parents=True, exist_ok=True) + (upd_target / "ca.crt").write_text(data["ca_cert_pem"]) + (upd_target / "updater.crt").write_text(data["updater"]["updater_cert_pem"]) + (upd_target / "updater.key").write_text(data["updater"]["updater_key_pem"]) + try: + (upd_target / "updater.key").chmod(0o600) + except OSError: + pass + console.print(f"[cyan]Updater bundle written to[/] {upd_target}") + console.print("[dim]Ship the agent dir to ~/.decnet/agent/ and the updater dir to ~/.decnet/updater/ on the worker.[/]") + else: + console.print("[dim]Ship this directory to the worker at ~/.decnet/agent/ (or wherever `decnet agent --agent-dir` points).[/]") else: console.print("[yellow]No --out-dir given — bundle PEMs are in the JSON response; persist them before leaving this shell.[/]") @@ -494,6 +551,116 @@ def swarm_check( console.print(table) +@swarm_app.command("update") +def swarm_update( + host: Optional[str] = typer.Option(None, "--host", help="Target worker (name or UUID). Omit with --all."), + all_hosts: bool = typer.Option(False, "--all", help="Push to every enrolled worker."), + include_self: bool = typer.Option(False, "--include-self", help="Also push to each updater's /update-self after a successful agent update."), + root: Optional[str] = typer.Option(None, "--root", help="Source tree to tar (default: CWD)."), + exclude: list[str] = typer.Option([], "--exclude", help="Additional exclude glob. Repeatable."), + updater_port: int = typer.Option(8766, "--updater-port", help="Port the workers' updater listens on."), + dry_run: bool = typer.Option(False, "--dry-run", help="Build the tarball and print stats; no network."), + url: Optional[str] = typer.Option(None, "--url", help="Override swarm controller URL."), +) -> None: + """Push the current working tree to workers' self-updaters (with auto-rollback on failure).""" + import asyncio + import pathlib as _pathlib + + from decnet.swarm.tar_tree import tar_working_tree, detect_git_sha + from decnet.swarm.updater_client import UpdaterClient + + if not (host or all_hosts): + console.print("[red]Supply --host or --all.[/]") + raise typer.Exit(2) + if host and all_hosts: + console.print("[red]--host and --all are mutually exclusive.[/]") + raise typer.Exit(2) + + base = _swarmctl_base_url(url) + resp = _http_request("GET", base + "/swarm/hosts") + rows = resp.json() + if host: + targets = [r for r in rows if r.get("name") == host or r.get("uuid") == host] + if not targets: + console.print(f"[red]No enrolled worker matching '{host}'.[/]") + raise typer.Exit(1) + else: + targets = [r for r in rows if r.get("status") != "decommissioned"] + if not targets: + console.print("[dim]No targets.[/]") + return + + tree_root = _pathlib.Path(root) if root else _pathlib.Path.cwd() + sha = detect_git_sha(tree_root) + console.print(f"[dim]Tarring[/] {tree_root} [dim]sha={sha or '(not a git repo)'}[/]") + tarball = tar_working_tree(tree_root, extra_excludes=exclude) + console.print(f"[dim]Tarball size:[/] {len(tarball):,} bytes") + + if dry_run: + console.print("[yellow]--dry-run: not pushing.[/]") + for t in targets: + console.print(f" would push to [cyan]{t.get('name')}[/] at {t.get('address')}:{updater_port}") + return + + async def _push_one(h: dict) -> dict: + name = h.get("name") or h.get("uuid") + out: dict = {"name": name, "address": h.get("address"), "agent": None, "self": None} + try: + async with UpdaterClient(h, updater_port=updater_port) as u: + r = await u.update(tarball, sha=sha) + out["agent"] = {"status": r.status_code, "body": r.json() if r.content else {}} + if r.status_code == 200 and include_self: + # Agent first, updater second — see plan. + rs = await u.update_self(tarball, sha=sha) + # Connection-drop is expected for update-self. + out["self"] = {"status": rs.status_code, "body": rs.json() if rs.content else {}} + except Exception as exc: # noqa: BLE001 + out["error"] = f"{type(exc).__name__}: {exc}" + return out + + async def _push_all() -> list[dict]: + return await asyncio.gather(*(_push_one(t) for t in targets)) + + results = asyncio.run(_push_all()) + + table = Table(title="DECNET swarm update") + for col in ("host", "address", "agent", "self", "detail"): + table.add_column(col) + any_failure = False + for r in results: + agent = r.get("agent") or {} + selff = r.get("self") or {} + err = r.get("error") + if err: + any_failure = True + table.add_row(r["name"], r.get("address") or "", "[red]error[/]", "—", err) + continue + a_status = agent.get("status") + if a_status == 200: + agent_cell = "[green]updated[/]" + elif a_status == 409: + agent_cell = "[yellow]rolled-back[/]" + any_failure = True + else: + agent_cell = f"[red]{a_status}[/]" + any_failure = True + if not include_self: + self_cell = "—" + elif selff.get("status") == 200 or selff.get("status") is None: + self_cell = "[green]ok[/]" if selff else "[dim]skipped[/]" + else: + self_cell = f"[red]{selff.get('status')}[/]" + detail = "" + body = agent.get("body") or {} + if isinstance(body, dict): + detail = body.get("release", {}).get("sha") or body.get("detail", {}).get("error") or "" + table.add_row(r["name"], r.get("address") or "", agent_cell, self_cell, str(detail)[:80]) + console.print(table) + + if any_failure: + raise typer.Exit(1) + + @swarm_app.command("deckies") def swarm_deckies( host: Optional[str] = typer.Option(None, "--host", help="Filter by worker name or UUID"), diff --git a/decnet/swarm/tar_tree.py b/decnet/swarm/tar_tree.py new file mode 100644 index 0000000..ab5b7b9 --- /dev/null +++ b/decnet/swarm/tar_tree.py @@ -0,0 +1,97 @@ +"""Build a gzipped tarball of the master's working tree for pushing to workers. + +Always excludes the obvious large / secret / churn paths: ``.venv/``, +``__pycache__/``, ``.git/``, ``wiki-checkout/``, ``*.db*``, ``*.log``. The +caller can supply additional exclude globs. + +Deliberately does NOT invoke git — the tree is what the operator has on +disk (staged + unstaged + untracked). That's the whole point; the scp +workflow we're replacing also shipped the live tree. +""" +from __future__ import annotations + +import fnmatch +import io +import pathlib +import tarfile +from typing import Iterable, Optional + +DEFAULT_EXCLUDES = ( + ".venv", ".venv/*", + "**/.venv/*", + "__pycache__", "**/__pycache__", "**/__pycache__/*", + ".git", ".git/*", + "wiki-checkout", "wiki-checkout/*", + "*.pyc", "*.pyo", + "*.db", "*.db-wal", "*.db-shm", + "*.log", + ".pytest_cache", ".pytest_cache/*", + ".mypy_cache", ".mypy_cache/*", + ".tox", ".tox/*", + "*.egg-info", "*.egg-info/*", + "decnet-state.json", + "master.log", "master.json", + "decnet.db*", +) + + +def _is_excluded(rel: str, patterns: Iterable[str]) -> bool: + parts = pathlib.PurePosixPath(rel).parts + for pat in patterns: + if fnmatch.fnmatch(rel, pat): + return True + # Also match the pattern against every leading subpath — this is + # what catches nested `.venv/...` without forcing callers to spell + # out every `**/` glob. + for i in range(1, len(parts) + 1): + if fnmatch.fnmatch("/".join(parts[:i]), pat): + return True + return False + + +def tar_working_tree( + root: pathlib.Path, + extra_excludes: Optional[Iterable[str]] = None, +) -> bytes: + """Return the gzipped tarball bytes of ``root``. + + Entries are added with paths relative to ``root`` (no leading ``/``, + no ``..``). The updater rejects unsafe paths on the receiving side. + """ + patterns = list(DEFAULT_EXCLUDES) + list(extra_excludes or ()) + buf = io.BytesIO() + + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for path in sorted(root.rglob("*")): + rel = path.relative_to(root).as_posix() + if _is_excluded(rel, patterns): + continue + if path.is_symlink(): + # Symlinks inside a repo tree are rare and often break + # portability; skip them rather than ship dangling links. + continue + if path.is_dir(): + continue + tar.add(path, arcname=rel, recursive=False) + + return buf.getvalue() + + +def detect_git_sha(root: pathlib.Path) -> str: + """Best-effort ``HEAD`` sha. Returns ``""`` if not a git repo.""" + head = root / ".git" / "HEAD" + if not head.is_file(): + return "" + try: + ref = head.read_text().strip() + except OSError: + return "" + if ref.startswith("ref: "): + ref_path = root / ".git" / ref[5:] + if ref_path.is_file(): + try: + return ref_path.read_text().strip() + except OSError: + return "" + return "" + return ref diff --git a/decnet/swarm/updater_client.py b/decnet/swarm/updater_client.py new file mode 100644 index 0000000..753c558 --- /dev/null +++ b/decnet/swarm/updater_client.py @@ -0,0 +1,124 @@ +"""Master-side HTTP client for the worker's self-updater daemon. + +Sibling of ``AgentClient``: same mTLS identity (same DECNET CA, same +master client cert) but targets the updater's port (default 8766) and +speaks the multipart upload protocol the updater's ``/update`` endpoint +expects. + +Kept as its own module — not a subclass of ``AgentClient`` — because the +timeouts and failure semantics are genuinely different: pip install + +agent probe can take a minute on a slow VM, and ``/update-self`` drops +the connection on purpose (the updater re-execs itself mid-response). +""" +from __future__ import annotations + +import ssl +from typing import Any, Optional + +import httpx + +from decnet.logging import get_logger +from decnet.swarm.client import MasterIdentity, ensure_master_identity + +log = get_logger("swarm.updater_client") + +_TIMEOUT_UPDATE = httpx.Timeout(connect=10.0, read=180.0, write=120.0, pool=5.0) +_TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0) + + +class UpdaterClient: + """Async client targeting a worker's ``decnet updater`` daemon.""" + + def __init__( + self, + host: dict[str, Any] | None = None, + *, + address: Optional[str] = None, + updater_port: int = 8766, + identity: Optional[MasterIdentity] = None, + ): + if host is not None: + self._address = host["address"] + self._host_name = host.get("name") + else: + if address is None: + raise ValueError("UpdaterClient requires host dict or address") + self._address = address + self._host_name = None + self._port = updater_port + self._identity = identity or ensure_master_identity() + self._client: Optional[httpx.AsyncClient] = None + + def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_cert_chain( + str(self._identity.cert_path), str(self._identity.key_path), + ) + ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path)) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.check_hostname = False + return httpx.AsyncClient( + base_url=f"https://{self._address}:{self._port}", + verify=ctx, + timeout=timeout, + ) + + async def __aenter__(self) -> "UpdaterClient": + self._client = self._build_client(_TIMEOUT_CONTROL) + return self + + async def __aexit__(self, *exc: Any) -> None: + if self._client: + await self._client.aclose() + self._client = None + + def _require(self) -> httpx.AsyncClient: + if self._client is None: + raise RuntimeError("UpdaterClient used outside `async with` block") + return self._client + + # --------------------------------------------------------------- RPCs + + async def health(self) -> dict[str, Any]: + r = await self._require().get("/health") + r.raise_for_status() + return r.json() + + async def releases(self) -> dict[str, Any]: + r = await self._require().get("/releases") + r.raise_for_status() + return r.json() + + async def update(self, tarball: bytes, sha: str = "") -> httpx.Response: + """POST /update. Returns the Response so the caller can distinguish + 200 / 409 / 500 — each means something different. + """ + self._require().timeout = _TIMEOUT_UPDATE + try: + r = await self._require().post( + "/update", + files={"tarball": ("tree.tgz", tarball, "application/gzip")}, + data={"sha": sha}, + ) + finally: + self._require().timeout = _TIMEOUT_CONTROL + return r + + async def update_self(self, tarball: bytes, sha: str = "") -> httpx.Response: + """POST /update-self. The updater re-execs itself, so the connection + usually drops mid-response; that's not an error. Callers should then + poll /health until the new SHA appears. + """ + self._require().timeout = _TIMEOUT_UPDATE + try: + r = await self._require().post( + "/update-self", + files={"tarball": ("tree.tgz", tarball, "application/gzip")}, + data={"sha": sha, "confirm_self": "true"}, + ) + finally: + self._require().timeout = _TIMEOUT_CONTROL + return r + + async def rollback(self) -> httpx.Response: + return await self._require().post("/rollback") diff --git a/decnet/updater/__init__.py b/decnet/updater/__init__.py new file mode 100644 index 0000000..b586e1f --- /dev/null +++ b/decnet/updater/__init__.py @@ -0,0 +1,10 @@ +"""DECNET self-updater daemon. + +Runs on each worker alongside ``decnet agent``. Receives working-tree +tarballs from the master and owns the agent's lifecycle: snapshot → +install → restart → probe → auto-rollback on failure. + +Deliberately separate process, separate venv, separate mTLS cert so that +a broken ``decnet agent`` push can always be rolled back by the updater +that shipped it. See ``wiki/Remote-Updates.md``. +""" diff --git a/decnet/updater/app.py b/decnet/updater/app.py new file mode 100644 index 0000000..5c5d879 --- /dev/null +++ b/decnet/updater/app.py @@ -0,0 +1,139 @@ +"""Updater FastAPI app — mTLS-protected endpoints for self-update. + +Mirrors the shape of ``decnet/agent/app.py``: bare FastAPI, docs disabled, +handlers delegate to ``decnet.updater.executor``. + +Mounted by uvicorn via ``decnet.updater.server`` with ``--ssl-cert-reqs 2``; +the CN on the peer cert tells us which endpoints are legal (``updater@*`` +only — agent certs are rejected). +""" +from __future__ import annotations + +import os as _os +import pathlib + +from fastapi import FastAPI, File, Form, HTTPException, UploadFile +from pydantic import BaseModel + +from decnet.logging import get_logger +from decnet.swarm import pki +from decnet.updater import executor as _exec + +log = get_logger("updater.app") + +app = FastAPI( + title="DECNET Self-Updater", + version="0.1.0", + docs_url=None, + redoc_url=None, + openapi_url=None, +) + + +class _Config: + install_dir: pathlib.Path = pathlib.Path( + _os.environ.get("DECNET_UPDATER_INSTALL_DIR") or str(_exec.DEFAULT_INSTALL_DIR) + ) + updater_install_dir: pathlib.Path = pathlib.Path( + _os.environ.get("DECNET_UPDATER_UPDATER_DIR") + or str(_exec.DEFAULT_INSTALL_DIR / "updater") + ) + agent_dir: pathlib.Path = pathlib.Path( + _os.environ.get("DECNET_UPDATER_AGENT_DIR") or str(pki.DEFAULT_AGENT_DIR) + ) + + +def configure( + install_dir: pathlib.Path, + updater_install_dir: pathlib.Path, + agent_dir: pathlib.Path, +) -> None: + """Inject paths from the server launcher; must be called before serving.""" + _Config.install_dir = install_dir + _Config.updater_install_dir = updater_install_dir + _Config.agent_dir = agent_dir + + +# ------------------------------------------------------------------- schemas + +class RollbackResult(BaseModel): + status: str + release: dict + probe: str + + +class ReleasesResponse(BaseModel): + releases: list[dict] + + +# -------------------------------------------------------------------- routes + +@app.get("/health") +async def health() -> dict: + return { + "status": "ok", + "role": "updater", + "releases": [r.to_dict() for r in _exec.list_releases(_Config.install_dir)], + } + + +@app.get("/releases") +async def releases() -> dict: + return {"releases": [r.to_dict() for r in _exec.list_releases(_Config.install_dir)]} + + +@app.post("/update") +async def update( + tarball: UploadFile = File(..., description="tar.gz of the working tree"), + sha: str = Form("", description="git SHA of the tree for provenance"), +) -> dict: + body = await tarball.read() + try: + return _exec.run_update( + body, sha=sha or None, + install_dir=_Config.install_dir, agent_dir=_Config.agent_dir, + ) + except _exec.UpdateError as exc: + status = 409 if exc.rolled_back else 500 + raise HTTPException( + status_code=status, + detail={"error": str(exc), "stderr": exc.stderr, "rolled_back": exc.rolled_back}, + ) from exc + + +@app.post("/update-self") +async def update_self( + tarball: UploadFile = File(...), + sha: str = Form(""), + confirm_self: str = Form("", description="Must be 'true' to proceed"), +) -> dict: + if confirm_self.lower() != "true": + raise HTTPException( + status_code=400, + detail="self-update requires confirm_self=true (no auto-rollback)", + ) + body = await tarball.read() + try: + return _exec.run_update_self( + body, sha=sha or None, + updater_install_dir=_Config.updater_install_dir, + ) + except _exec.UpdateError as exc: + raise HTTPException( + status_code=500, + detail={"error": str(exc), "stderr": exc.stderr}, + ) from exc + + +@app.post("/rollback") +async def rollback() -> dict: + try: + return _exec.run_rollback( + install_dir=_Config.install_dir, agent_dir=_Config.agent_dir, + ) + except _exec.UpdateError as exc: + status = 404 if "no previous" in str(exc) else 500 + raise HTTPException( + status_code=status, + detail={"error": str(exc), "stderr": exc.stderr}, + ) from exc diff --git a/decnet/updater/executor.py b/decnet/updater/executor.py new file mode 100644 index 0000000..8f1813d --- /dev/null +++ b/decnet/updater/executor.py @@ -0,0 +1,416 @@ +"""Update/rollback orchestrator for the DECNET self-updater. + +Directory layout owned by this module (root = ``install_dir``): + + / + current -> releases/active (symlink; atomic swap == promotion) + releases/ + active/ (working tree; has its own .venv) + prev/ (last good snapshot; restored on failure) + active.new/ (staging; only exists mid-update) + agent.pid (PID of the agent process we spawned) + +Rollback semantics: if the agent doesn't come back healthy after an update, +we swap the symlink back to ``prev``, restart the agent, and return the +captured pip/agent stderr to the caller. + +Seams for tests — every subprocess call goes through a module-level hook +(`_run_pip`, `_spawn_agent`, `_probe_agent`) so tests can monkeypatch them +without actually touching the filesystem's Python toolchain. +""" +from __future__ import annotations + +import dataclasses +import os +import pathlib +import shutil +import signal +import ssl +import subprocess # nosec B404 +import sys +import tarfile +import time +from datetime import datetime, timezone +from typing import Any, Callable, Optional + +import httpx + +from decnet.logging import get_logger +from decnet.swarm import pki + +log = get_logger("updater.executor") + +DEFAULT_INSTALL_DIR = pathlib.Path("/opt/decnet") +AGENT_PROBE_URL = "https://127.0.0.1:8765/health" +AGENT_PROBE_ATTEMPTS = 10 +AGENT_PROBE_BACKOFF_S = 1.0 +AGENT_RESTART_GRACE_S = 10.0 + + +# ------------------------------------------------------------------- errors + +class UpdateError(RuntimeError): + """Raised when an update fails but the install dir is consistent. + + Carries the captured stderr so the master gets actionable output. + """ + + def __init__(self, message: str, *, stderr: str = "", rolled_back: bool = False): + super().__init__(message) + self.stderr = stderr + self.rolled_back = rolled_back + + +# -------------------------------------------------------------------- types + +@dataclasses.dataclass(frozen=True) +class Release: + slot: str + sha: Optional[str] + installed_at: Optional[datetime] + + def to_dict(self) -> dict[str, Any]: + return { + "slot": self.slot, + "sha": self.sha, + "installed_at": self.installed_at.isoformat() if self.installed_at else None, + } + + +# ---------------------------------------------------------------- internals + +def _releases_dir(install_dir: pathlib.Path) -> pathlib.Path: + return install_dir / "releases" + + +def _active_dir(install_dir: pathlib.Path) -> pathlib.Path: + return _releases_dir(install_dir) / "active" + + +def _prev_dir(install_dir: pathlib.Path) -> pathlib.Path: + return _releases_dir(install_dir) / "prev" + + +def _staging_dir(install_dir: pathlib.Path) -> pathlib.Path: + return _releases_dir(install_dir) / "active.new" + + +def _current_symlink(install_dir: pathlib.Path) -> pathlib.Path: + return install_dir / "current" + + +def _pid_file(install_dir: pathlib.Path) -> pathlib.Path: + return install_dir / "agent.pid" + + +def _manifest_file(release: pathlib.Path) -> pathlib.Path: + return release / ".decnet-release.json" + + +def _venv_python(release: pathlib.Path) -> pathlib.Path: + return release / ".venv" / "bin" / "python" + + +# ------------------------------------------------------------------- public + +def read_release(release: pathlib.Path) -> Release: + """Read the release manifest sidecar; tolerate absence.""" + slot = release.name + mf = _manifest_file(release) + if not mf.is_file(): + return Release(slot=slot, sha=None, installed_at=None) + import json + + try: + data = json.loads(mf.read_text()) + except (json.JSONDecodeError, OSError): + return Release(slot=slot, sha=None, installed_at=None) + ts = data.get("installed_at") + return Release( + slot=slot, + sha=data.get("sha"), + installed_at=datetime.fromisoformat(ts) if ts else None, + ) + + +def list_releases(install_dir: pathlib.Path) -> list[Release]: + out: list[Release] = [] + for slot_dir in (_active_dir(install_dir), _prev_dir(install_dir)): + if slot_dir.is_dir(): + out.append(read_release(slot_dir)) + return out + + +def clean_stale_staging(install_dir: pathlib.Path) -> None: + """Remove a half-extracted ``active.new`` left by a crashed update.""" + staging = _staging_dir(install_dir) + if staging.exists(): + log.warning("removing stale staging dir %s", staging) + shutil.rmtree(staging, ignore_errors=True) + + +def extract_tarball(tarball_bytes: bytes, dest: pathlib.Path) -> None: + """Extract a gzipped tarball into ``dest`` (must not pre-exist). + + Rejects absolute paths and ``..`` traversal in the archive. + """ + import io + + dest.mkdir(parents=True, exist_ok=False) + with tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as tar: + for member in tar.getmembers(): + name = member.name + if name.startswith("/") or ".." in pathlib.PurePosixPath(name).parts: + raise UpdateError(f"unsafe path in tarball: {name!r}") + tar.extractall(dest) # nosec B202 — validated above + + +# ---------------------------------------------------------------- seams + +def _run_pip(release: pathlib.Path) -> subprocess.CompletedProcess: + """Create a venv in ``release/.venv`` and pip install -e . into it. + + Monkeypatched in tests so the test suite never shells out. + """ + venv_dir = release / ".venv" + if not venv_dir.exists(): + subprocess.run( # nosec B603 + [sys.executable, "-m", "venv", str(venv_dir)], + check=True, capture_output=True, text=True, + ) + py = _venv_python(release) + return subprocess.run( # nosec B603 + [str(py), "-m", "pip", "install", "-e", str(release)], + check=False, capture_output=True, text=True, + ) + + +def _spawn_agent(install_dir: pathlib.Path) -> int: + """Launch ``decnet agent --daemon`` using the current-symlinked venv. + + Returns the new PID. Monkeypatched in tests. + """ + py = _venv_python(_current_symlink(install_dir).resolve()) + proc = subprocess.Popen( # nosec B603 + [str(py), "-m", "decnet", "agent", "--daemon"], + start_new_session=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + _pid_file(install_dir).write_text(str(proc.pid)) + return proc.pid + + +def _stop_agent(install_dir: pathlib.Path, grace: float = AGENT_RESTART_GRACE_S) -> None: + """SIGTERM the PID we spawned; SIGKILL if it doesn't exit in ``grace`` s.""" + pid_file = _pid_file(install_dir) + if not pid_file.is_file(): + return + try: + pid = int(pid_file.read_text().strip()) + except (ValueError, OSError): + return + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return + deadline = time.monotonic() + grace + while time.monotonic() < deadline: + try: + os.kill(pid, 0) + except ProcessLookupError: + return + time.sleep(0.2) + try: + os.kill(pid, signal.SIGKILL) + except ProcessLookupError: + pass + + +def _probe_agent( + agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR, + url: str = AGENT_PROBE_URL, + attempts: int = AGENT_PROBE_ATTEMPTS, + backoff_s: float = AGENT_PROBE_BACKOFF_S, +) -> tuple[bool, str]: + """Local mTLS health probe against the agent. Returns (ok, detail).""" + worker_key = agent_dir / "worker.key" + worker_crt = agent_dir / "worker.crt" + ca = agent_dir / "ca.crt" + if not (worker_key.is_file() and worker_crt.is_file() and ca.is_file()): + return False, f"no mTLS bundle at {agent_dir}" + ctx = ssl.create_default_context(cafile=str(ca)) + ctx.load_cert_chain(certfile=str(worker_crt), keyfile=str(worker_key)) + ctx.check_hostname = False + + last = "" + for i in range(attempts): + try: + with httpx.Client(verify=ctx, timeout=3.0) as client: + r = client.get(url) + if r.status_code == 200: + return True, r.text + last = f"status={r.status_code} body={r.text[:200]}" + except Exception as exc: # noqa: BLE001 + last = f"{type(exc).__name__}: {exc}" + if i < attempts - 1: + time.sleep(backoff_s) + return False, last + + +# -------------------------------------------------------------- orchestrator + +def _write_manifest(release: pathlib.Path, sha: Optional[str]) -> None: + import json + + _manifest_file(release).write_text(json.dumps({ + "sha": sha, + "installed_at": datetime.now(timezone.utc).isoformat(), + })) + + +def _rotate(install_dir: pathlib.Path) -> None: + """Rotate directories: prev→(deleted), active→prev, active.new→active. + + Caller must ensure ``active.new`` exists. ``active`` may or may not. + """ + active = _active_dir(install_dir) + prev = _prev_dir(install_dir) + staging = _staging_dir(install_dir) + + if prev.exists(): + shutil.rmtree(prev) + if active.exists(): + active.rename(prev) + staging.rename(active) + + +def _point_current_at(install_dir: pathlib.Path, target: pathlib.Path) -> None: + """Atomic symlink flip via rename.""" + link = _current_symlink(install_dir) + tmp = install_dir / ".current.tmp" + if tmp.exists() or tmp.is_symlink(): + tmp.unlink() + tmp.symlink_to(target) + os.replace(tmp, link) + + +def run_update( + tarball_bytes: bytes, + sha: Optional[str], + install_dir: pathlib.Path = DEFAULT_INSTALL_DIR, + agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR, +) -> dict[str, Any]: + """Apply an update atomically. Rolls back on probe failure.""" + clean_stale_staging(install_dir) + staging = _staging_dir(install_dir) + + extract_tarball(tarball_bytes, staging) + _write_manifest(staging, sha) + + pip = _run_pip(staging) + if pip.returncode != 0: + shutil.rmtree(staging, ignore_errors=True) + raise UpdateError( + "pip install failed on new release", stderr=pip.stderr or pip.stdout, + ) + + _rotate(install_dir) + _point_current_at(install_dir, _active_dir(install_dir)) + + _stop_agent(install_dir) + _spawn_agent(install_dir) + + ok, detail = _probe_agent(agent_dir=agent_dir) + if ok: + return { + "status": "updated", + "release": read_release(_active_dir(install_dir)).to_dict(), + "probe": detail, + } + + # Rollback. + log.warning("agent probe failed after update: %s — rolling back", detail) + _stop_agent(install_dir) + # Swap active <-> prev. + active = _active_dir(install_dir) + prev = _prev_dir(install_dir) + tmp = _releases_dir(install_dir) / ".swap" + if tmp.exists(): + shutil.rmtree(tmp) + active.rename(tmp) + prev.rename(active) + tmp.rename(prev) + _point_current_at(install_dir, active) + _spawn_agent(install_dir) + ok2, detail2 = _probe_agent(agent_dir=agent_dir) + raise UpdateError( + "agent failed health probe after update; rolled back to previous release", + stderr=f"forward-probe: {detail}\nrollback-probe: {detail2}", + rolled_back=ok2, + ) + + +def run_rollback( + install_dir: pathlib.Path = DEFAULT_INSTALL_DIR, + agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR, +) -> dict[str, Any]: + """Manually swap active with prev and restart the agent.""" + active = _active_dir(install_dir) + prev = _prev_dir(install_dir) + if not prev.is_dir(): + raise UpdateError("no previous release to roll back to") + + _stop_agent(install_dir) + tmp = _releases_dir(install_dir) / ".swap" + if tmp.exists(): + shutil.rmtree(tmp) + active.rename(tmp) + prev.rename(active) + tmp.rename(prev) + _point_current_at(install_dir, active) + _spawn_agent(install_dir) + ok, detail = _probe_agent(agent_dir=agent_dir) + if not ok: + raise UpdateError("agent unhealthy after rollback", stderr=detail) + return { + "status": "rolled_back", + "release": read_release(active).to_dict(), + "probe": detail, + } + + +def run_update_self( + tarball_bytes: bytes, + sha: Optional[str], + updater_install_dir: pathlib.Path, + exec_cb: Optional[Callable[[list[str]], None]] = None, +) -> dict[str, Any]: + """Replace the updater's own source tree, then re-exec this process. + + No auto-rollback. Caller must treat "connection dropped + /health + returns new SHA within 30s" as success. + """ + clean_stale_staging(updater_install_dir) + staging = _staging_dir(updater_install_dir) + extract_tarball(tarball_bytes, staging) + _write_manifest(staging, sha) + + pip = _run_pip(staging) + if pip.returncode != 0: + shutil.rmtree(staging, ignore_errors=True) + raise UpdateError( + "pip install failed on new updater release", + stderr=pip.stderr or pip.stdout, + ) + + _rotate(updater_install_dir) + _point_current_at(updater_install_dir, _active_dir(updater_install_dir)) + + argv = [str(_venv_python(_active_dir(updater_install_dir))), "-m", "decnet", "updater"] + sys.argv[1:] + if exec_cb is not None: + exec_cb(argv) # tests stub this — we don't actually re-exec + return {"status": "self_update_queued", "argv": argv} + # Returns nothing on success (replaces the process image). + os.execv(argv[0], argv) # nosec B606 - pragma: no cover + return {"status": "self_update_queued"} # pragma: no cover diff --git a/decnet/updater/server.py b/decnet/updater/server.py new file mode 100644 index 0000000..4a972a0 --- /dev/null +++ b/decnet/updater/server.py @@ -0,0 +1,86 @@ +"""Self-updater uvicorn launcher. + +Parallels ``decnet/agent/server.py`` but uses a distinct bundle directory +(``~/.decnet/updater``) with a cert whose CN is ``updater@``. That +cert is signed by the same DECNET CA as the agent's, so the master's one +CA still gates both channels; the CN is how we tell them apart. +""" +from __future__ import annotations + +import os +import pathlib +import signal +import subprocess # nosec B404 +import sys + +from decnet.logging import get_logger +from decnet.swarm import pki + +log = get_logger("updater.server") + +DEFAULT_UPDATER_DIR = pathlib.Path(os.path.expanduser("~/.decnet/updater")) + + +def _load_bundle(updater_dir: pathlib.Path) -> bool: + return all( + (updater_dir / name).is_file() + for name in ("ca.crt", "updater.crt", "updater.key") + ) + + +def run( + host: str, + port: int, + updater_dir: pathlib.Path = DEFAULT_UPDATER_DIR, + install_dir: pathlib.Path = pathlib.Path("/opt/decnet"), + agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR, +) -> int: + if not _load_bundle(updater_dir): + print( + f"[updater] No cert bundle at {updater_dir}. " + f"Run `decnet swarm enroll --updater` from the master first.", + file=sys.stderr, + ) + return 2 + + # Pass config into the app module via env so uvicorn subprocess picks it up. + os.environ["DECNET_UPDATER_INSTALL_DIR"] = str(install_dir) + os.environ["DECNET_UPDATER_UPDATER_DIR"] = str(install_dir / "updater") + os.environ["DECNET_UPDATER_AGENT_DIR"] = str(agent_dir) + + keyfile = updater_dir / "updater.key" + certfile = updater_dir / "updater.crt" + cafile = updater_dir / "ca.crt" + + cmd = [ + sys.executable, + "-m", + "uvicorn", + "decnet.updater.app:app", + "--host", + host, + "--port", + str(port), + "--ssl-keyfile", + str(keyfile), + "--ssl-certfile", + str(certfile), + "--ssl-ca-certs", + str(cafile), + "--ssl-cert-reqs", + "2", + ] + log.info("updater starting host=%s port=%d bundle=%s", host, port, updater_dir) + proc = subprocess.Popen(cmd, start_new_session=True) # nosec B603 + try: + return proc.wait() + except KeyboardInterrupt: + try: + os.killpg(proc.pid, signal.SIGTERM) + try: + return proc.wait(timeout=10) + except subprocess.TimeoutExpired: + os.killpg(proc.pid, signal.SIGKILL) + return proc.wait() + except ProcessLookupError: + return 0 diff --git a/decnet/web/db/models.py b/decnet/web/db/models.py index 5590e95..4173f9f 100644 --- a/decnet/web/db/models.py +++ b/decnet/web/db/models.py @@ -118,6 +118,10 @@ class SwarmHost(SQLModel, table=True): # ISO-8601 string of the last successful agent /health probe last_heartbeat: Optional[datetime] = Field(default=None) client_cert_fingerprint: str # SHA-256 hex of worker's issued client cert + # SHA-256 hex of the updater-identity cert, if the host was enrolled + # with ``--updater`` / ``issue_updater_bundle``. ``None`` for hosts + # that only have an agent identity. + updater_cert_fingerprint: Optional[str] = Field(default=None) # Directory on the master where the per-worker cert bundle lives cert_bundle_path: str enrolled_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -281,6 +285,17 @@ class SwarmEnrollRequest(BaseModel): description="Extra SANs (IPs / hostnames) to embed in the worker cert", ) notes: Optional[str] = None + issue_updater_bundle: bool = PydanticField( + default=False, + description="If true, also issue an updater cert (CN=updater@) for the remote self-updater", + ) + + +class SwarmUpdaterBundle(BaseModel): + """Subset of SwarmEnrolledBundle for the updater identity.""" + fingerprint: str + updater_cert_pem: str + updater_key_pem: str class SwarmEnrolledBundle(BaseModel): @@ -293,6 +308,7 @@ class SwarmEnrolledBundle(BaseModel): ca_cert_pem: str worker_cert_pem: str worker_key_pem: str + updater: Optional[SwarmUpdaterBundle] = None class SwarmHostView(BaseModel): @@ -303,6 +319,7 @@ class SwarmHostView(BaseModel): status: str last_heartbeat: Optional[datetime] = None client_cert_fingerprint: str + updater_cert_fingerprint: Optional[str] = None enrolled_at: datetime notes: Optional[str] = None diff --git a/decnet/web/router/swarm/api_enroll_host.py b/decnet/web/router/swarm/api_enroll_host.py index 9baf011..1e85c8e 100644 --- a/decnet/web/router/swarm/api_enroll_host.py +++ b/decnet/web/router/swarm/api_enroll_host.py @@ -12,13 +12,14 @@ from __future__ import annotations import uuid as _uuid from datetime import datetime, timezone +from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status from decnet.swarm import pki from decnet.web.db.repository import BaseRepository from decnet.web.dependencies import get_repo -from decnet.web.db.models import SwarmEnrolledBundle, SwarmEnrollRequest +from decnet.web.db.models import SwarmEnrolledBundle, SwarmEnrollRequest, SwarmUpdaterBundle router = APIRouter() @@ -46,6 +47,26 @@ async def api_enroll_host( bundle_dir = pki.DEFAULT_CA_DIR / "workers" / req.name pki.write_worker_bundle(issued, bundle_dir) + updater_view: Optional[SwarmUpdaterBundle] = None + updater_fp: Optional[str] = None + if req.issue_updater_bundle: + updater_cn = f"updater@{req.name}" + updater_sans = list({*sans, updater_cn, "127.0.0.1"}) + updater_issued = pki.issue_worker_cert(ca, updater_cn, updater_sans) + # Persist alongside the worker bundle for replay. + updater_dir = bundle_dir / "updater" + updater_dir.mkdir(parents=True, exist_ok=True) + (updater_dir / "updater.crt").write_bytes(updater_issued.cert_pem) + (updater_dir / "updater.key").write_bytes(updater_issued.key_pem) + import os as _os + _os.chmod(updater_dir / "updater.key", 0o600) + updater_fp = updater_issued.fingerprint_sha256 + updater_view = SwarmUpdaterBundle( + fingerprint=updater_fp, + updater_cert_pem=updater_issued.cert_pem.decode(), + updater_key_pem=updater_issued.key_pem.decode(), + ) + host_uuid = str(_uuid.uuid4()) await repo.add_swarm_host( { @@ -55,6 +76,7 @@ async def api_enroll_host( "agent_port": req.agent_port, "status": "enrolled", "client_cert_fingerprint": issued.fingerprint_sha256, + "updater_cert_fingerprint": updater_fp, "cert_bundle_path": str(bundle_dir), "enrolled_at": datetime.now(timezone.utc), "notes": req.notes, @@ -69,4 +91,5 @@ async def api_enroll_host( ca_cert_pem=issued.ca_cert_pem.decode(), worker_cert_pem=issued.cert_pem.decode(), worker_key_pem=issued.key_pem.decode(), + updater=updater_view, ) diff --git a/pyproject.toml b/pyproject.toml index 804b43e..d781fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ dependencies = [ "sqlmodel>=0.0.16", "scapy>=2.6.1", "orjson>=3.10", - "cryptography>=46.0.7" + "cryptography>=46.0.7", + "python-multipart>=0.0.20" ] [project.optional-dependencies] diff --git a/tests/swarm/test_cli_swarm_update.py b/tests/swarm/test_cli_swarm_update.py new file mode 100644 index 0000000..a36a9ec --- /dev/null +++ b/tests/swarm/test_cli_swarm_update.py @@ -0,0 +1,192 @@ +"""CLI `decnet swarm update` — target resolution, tarring, push aggregation. + +The UpdaterClient is stubbed: we are testing the CLI's orchestration, not +the wire protocol (that has test_updater_app.py and UpdaterClient round- +trips live under test_swarm_api.py integration). +""" +from __future__ import annotations + +import json +import pathlib +from typing import Any + +import pytest +from typer.testing import CliRunner + +from decnet import cli as cli_mod +from decnet.cli import app + + +runner = CliRunner() + + +class _FakeResp: + def __init__(self, payload: Any, status: int = 200): + self._payload = payload + self.status_code = status + self.text = json.dumps(payload) if not isinstance(payload, str) else payload + self.content = self.text.encode() + + def json(self) -> Any: + return self._payload + + +@pytest.fixture +def http_stub(monkeypatch: pytest.MonkeyPatch) -> dict: + state: dict = {"hosts": []} + + def _fake(method, url, *, json_body=None, timeout=30.0): + if method == "GET" and url.endswith("/swarm/hosts"): + return _FakeResp(state["hosts"]) + raise AssertionError(f"Unscripted HTTP call: {method} {url}") + + monkeypatch.setattr(cli_mod, "_http_request", _fake) + return state + + +class _StubUpdaterClient: + """Mirrors UpdaterClient's async-context-manager surface.""" + instances: list["_StubUpdaterClient"] = [] + behavior: dict[str, Any] = {} + + def __init__(self, host, *, updater_port: int = 8766, **_: Any): + self.host = host + self.port = updater_port + self.calls: list[str] = [] + _StubUpdaterClient.instances.append(self) + + async def __aenter__(self) -> "_StubUpdaterClient": + return self + + async def __aexit__(self, *exc: Any) -> None: + return None + + async def update(self, tarball: bytes, sha: str = "") -> _FakeResp: + self.calls.append("update") + return _StubUpdaterClient.behavior.get( + self.host.get("name"), + _FakeResp({"status": "updated", "release": {"sha": sha}}, 200), + ) + + async def update_self(self, tarball: bytes, sha: str = "") -> _FakeResp: + self.calls.append("update_self") + return _FakeResp({"status": "self_update_queued"}, 200) + + +@pytest.fixture +def stub_updater(monkeypatch: pytest.MonkeyPatch): + _StubUpdaterClient.instances.clear() + _StubUpdaterClient.behavior.clear() + monkeypatch.setattr("decnet.swarm.updater_client.UpdaterClient", _StubUpdaterClient) + # Also patch the module-level import inside cli.py's swarm_update closure. + import decnet.cli # noqa: F401 + return _StubUpdaterClient + + +def _mk_source_tree(tmp_path: pathlib.Path) -> pathlib.Path: + root = tmp_path / "src" + root.mkdir() + (root / "decnet").mkdir() + (root / "decnet" / "a.py").write_text("x = 1") + return root + + +# ------------------------------------------------------------- arg validation + +def test_update_requires_host_or_all(http_stub) -> None: + r = runner.invoke(app, ["swarm", "update"]) + assert r.exit_code == 2 + + +def test_update_host_and_all_are_mutex(http_stub) -> None: + r = runner.invoke(app, ["swarm", "update", "--host", "w1", "--all"]) + assert r.exit_code == 2 + + +def test_update_unknown_host_exits_1(http_stub) -> None: + http_stub["hosts"] = [{"uuid": "u1", "name": "other", "address": "10.0.0.1", "status": "active"}] + r = runner.invoke(app, ["swarm", "update", "--host", "nope"]) + assert r.exit_code == 1 + assert "No enrolled worker" in r.output + + +# ---------------------------------------------------------------- happy paths + +def test_update_single_host(http_stub, stub_updater, tmp_path: pathlib.Path) -> None: + http_stub["hosts"] = [ + {"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}, + {"uuid": "u2", "name": "w2", "address": "10.0.0.2", "status": "active"}, + ] + root = _mk_source_tree(tmp_path) + r = runner.invoke(app, ["swarm", "update", "--host", "w1", "--root", str(root)]) + assert r.exit_code == 0, r.output + assert "w1" in r.output + # Only w1 got a client; w2 is untouched. + names = [c.host["name"] for c in stub_updater.instances] + assert names == ["w1"] + + +def test_update_all_skips_decommissioned(http_stub, stub_updater, tmp_path: pathlib.Path) -> None: + http_stub["hosts"] = [ + {"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}, + {"uuid": "u2", "name": "w2", "address": "10.0.0.2", "status": "decommissioned"}, + {"uuid": "u3", "name": "w3", "address": "10.0.0.3", "status": "enrolled"}, + ] + root = _mk_source_tree(tmp_path) + r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root)]) + assert r.exit_code == 0, r.output + hit = sorted(c.host["name"] for c in stub_updater.instances) + assert hit == ["w1", "w3"] + + +def test_update_include_self_calls_both( + http_stub, stub_updater, tmp_path: pathlib.Path, +) -> None: + http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}] + root = _mk_source_tree(tmp_path) + r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root), "--include-self"]) + assert r.exit_code == 0 + assert stub_updater.instances[0].calls == ["update", "update_self"] + + +# ------------------------------------------------------------- failure modes + +def test_update_rollback_status_409_flags_failure( + http_stub, stub_updater, tmp_path: pathlib.Path, +) -> None: + http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}] + _StubUpdaterClient.behavior["w1"] = _FakeResp( + {"detail": {"error": "probe failed", "rolled_back": True}}, + status=409, + ) + root = _mk_source_tree(tmp_path) + r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root)]) + assert r.exit_code == 1 + assert "rolled-back" in r.output + + +def test_update_include_self_skipped_when_agent_update_failed( + http_stub, stub_updater, tmp_path: pathlib.Path, +) -> None: + http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}] + _StubUpdaterClient.behavior["w1"] = _FakeResp( + {"detail": {"error": "pip failed"}}, status=500, + ) + root = _mk_source_tree(tmp_path) + r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root), "--include-self"]) + assert r.exit_code == 1 + # update_self must NOT have been called — agent update failed. + assert stub_updater.instances[0].calls == ["update"] + + +# --------------------------------------------------------------------- dry run + +def test_update_dry_run_does_not_call_updater( + http_stub, stub_updater, tmp_path: pathlib.Path, +) -> None: + http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}] + root = _mk_source_tree(tmp_path) + r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root), "--dry-run"]) + assert r.exit_code == 0 + assert stub_updater.instances == [] + assert "dry-run" in r.output.lower() diff --git a/tests/swarm/test_swarm_api.py b/tests/swarm/test_swarm_api.py index 1174825..02f0759 100644 --- a/tests/swarm/test_swarm_api.py +++ b/tests/swarm/test_swarm_api.py @@ -78,6 +78,36 @@ def test_enroll_creates_host_and_returns_bundle(client: TestClient) -> None: assert len(body["fingerprint"]) == 64 # sha256 hex +def test_enroll_with_updater_issues_second_cert(client: TestClient, ca_dir) -> None: + resp = client.post( + "/swarm/enroll", + json={"name": "worker-upd", "address": "10.0.0.99", "agent_port": 8765, + "issue_updater_bundle": True}, + ) + assert resp.status_code == 201, resp.text + body = resp.json() + assert body["updater"] is not None + assert body["updater"]["fingerprint"] != body["fingerprint"] + assert "-----BEGIN CERTIFICATE-----" in body["updater"]["updater_cert_pem"] + assert "-----BEGIN PRIVATE KEY-----" in body["updater"]["updater_key_pem"] + # Cert bundle persisted on master. + upd_bundle = ca_dir / "workers" / "worker-upd" / "updater" + assert (upd_bundle / "updater.crt").is_file() + assert (upd_bundle / "updater.key").is_file() + # DB row carries the updater fingerprint. + row = client.get(f"/swarm/hosts/{body['host_uuid']}").json() + assert row.get("updater_cert_fingerprint") == body["updater"]["fingerprint"] + + +def test_enroll_without_updater_omits_bundle(client: TestClient) -> None: + resp = client.post( + "/swarm/enroll", + json={"name": "worker-no-upd", "address": "10.0.0.98", "agent_port": 8765}, + ) + assert resp.status_code == 201 + assert resp.json()["updater"] is None + + def test_enroll_rejects_duplicate_name(client: TestClient) -> None: payload = {"name": "worker-dup", "address": "10.0.0.6", "agent_port": 8765} assert client.post("/swarm/enroll", json=payload).status_code == 201 diff --git a/tests/swarm/test_tar_tree.py b/tests/swarm/test_tar_tree.py new file mode 100644 index 0000000..a9849af --- /dev/null +++ b/tests/swarm/test_tar_tree.py @@ -0,0 +1,75 @@ +"""tar_working_tree: exclude filter, tarball validity, git SHA detection.""" +from __future__ import annotations + +import io +import pathlib +import tarfile + +from decnet.swarm.tar_tree import detect_git_sha, tar_working_tree + + +def _tree_names(data: bytes) -> set[str]: + with tarfile.open(fileobj=io.BytesIO(data), mode="r:gz") as tar: + return {m.name for m in tar.getmembers()} + + +def test_tar_excludes_default_patterns(tmp_path: pathlib.Path) -> None: + (tmp_path / "decnet").mkdir() + (tmp_path / "decnet" / "keep.py").write_text("x = 1") + (tmp_path / ".venv").mkdir() + (tmp_path / ".venv" / "pyvenv.cfg").write_text("junk") + (tmp_path / ".git").mkdir() + (tmp_path / ".git" / "HEAD").write_text("ref: refs/heads/main\n") + (tmp_path / "decnet" / "__pycache__").mkdir() + (tmp_path / "decnet" / "__pycache__" / "keep.cpython-311.pyc").write_text("bytecode") + (tmp_path / "wiki-checkout").mkdir() + (tmp_path / "wiki-checkout" / "Home.md").write_text("# wiki") + (tmp_path / "run.db").write_text("sqlite") + (tmp_path / "master.log").write_text("log") + + data = tar_working_tree(tmp_path) + names = _tree_names(data) + assert "decnet/keep.py" in names + assert all(".venv" not in n for n in names) + assert all(".git" not in n for n in names) + assert all("__pycache__" not in n for n in names) + assert all("wiki-checkout" not in n for n in names) + assert "run.db" not in names + assert "master.log" not in names + + +def test_tar_accepts_extra_excludes(tmp_path: pathlib.Path) -> None: + (tmp_path / "a.py").write_text("x") + (tmp_path / "secret.env").write_text("TOKEN=abc") + data = tar_working_tree(tmp_path, extra_excludes=["secret.env"]) + names = _tree_names(data) + assert "a.py" in names + assert "secret.env" not in names + + +def test_tar_skips_symlinks(tmp_path: pathlib.Path) -> None: + (tmp_path / "real.txt").write_text("hi") + try: + (tmp_path / "link.txt").symlink_to(tmp_path / "real.txt") + except (OSError, NotImplementedError): + return # platform doesn't support symlinks — skip + names = _tree_names(tar_working_tree(tmp_path)) + assert "real.txt" in names + assert "link.txt" not in names + + +def test_detect_git_sha_from_ref(tmp_path: pathlib.Path) -> None: + (tmp_path / ".git" / "refs" / "heads").mkdir(parents=True) + (tmp_path / ".git" / "refs" / "heads" / "main").write_text("deadbeef" * 5 + "\n") + (tmp_path / ".git" / "HEAD").write_text("ref: refs/heads/main\n") + assert detect_git_sha(tmp_path).startswith("deadbeef") + + +def test_detect_git_sha_detached(tmp_path: pathlib.Path) -> None: + (tmp_path / ".git").mkdir() + (tmp_path / ".git" / "HEAD").write_text("f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0\n") + assert detect_git_sha(tmp_path).startswith("f0f0") + + +def test_detect_git_sha_none_when_not_repo(tmp_path: pathlib.Path) -> None: + assert detect_git_sha(tmp_path) == "" diff --git a/tests/updater/__init__.py b/tests/updater/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/updater/test_updater_app.py b/tests/updater/test_updater_app.py new file mode 100644 index 0000000..12b0c83 --- /dev/null +++ b/tests/updater/test_updater_app.py @@ -0,0 +1,138 @@ +"""HTTP contract for the updater app. + +Executor functions are monkeypatched — we're testing wire format, not +the rotation logic (that has test_updater_executor.py). +""" +from __future__ import annotations + +import io +import pathlib +import tarfile + +import pytest +from fastapi.testclient import TestClient + +from decnet.updater import app as app_mod +from decnet.updater import executor as ex + + +def _tarball(files: dict[str, str] | None = None) -> bytes: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for name, content in (files or {"a": "b"}).items(): + data = content.encode() + info = tarfile.TarInfo(name=name) + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + return buf.getvalue() + + +@pytest.fixture +def client(tmp_path: pathlib.Path) -> TestClient: + app_mod.configure( + install_dir=tmp_path / "install", + updater_install_dir=tmp_path / "install" / "updater", + agent_dir=tmp_path / "agent", + ) + (tmp_path / "install" / "releases").mkdir(parents=True) + return TestClient(app_mod.app) + + +def test_health_returns_role_and_releases(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ex, "list_releases", lambda d: []) + r = client.get("/health") + assert r.status_code == 200 + body = r.json() + assert body["status"] == "ok" + assert body["role"] == "updater" + assert body["releases"] == [] + + +def test_update_happy_path(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ex, "run_update", + lambda data, sha, install_dir, agent_dir: {"status": "updated", "release": {"slot": "active", "sha": sha}, "probe": "ok"}, + ) + r = client.post( + "/update", + files={"tarball": ("tree.tgz", _tarball(), "application/gzip")}, + data={"sha": "ABC123"}, + ) + assert r.status_code == 200, r.text + assert r.json()["release"]["sha"] == "ABC123" + + +def test_update_rollback_returns_409(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + def _boom(*a, **kw): + raise ex.UpdateError("probe failed; rolled back", stderr="connection refused", rolled_back=True) + monkeypatch.setattr(ex, "run_update", _boom) + + r = client.post( + "/update", + files={"tarball": ("t.tgz", _tarball(), "application/gzip")}, + data={"sha": ""}, + ) + assert r.status_code == 409, r.text + detail = r.json()["detail"] + assert detail["rolled_back"] is True + assert "connection refused" in detail["stderr"] + + +def test_update_hard_failure_returns_500(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + def _boom(*a, **kw): + raise ex.UpdateError("pip install failed", stderr="resolver error") + monkeypatch.setattr(ex, "run_update", _boom) + + r = client.post("/update", files={"tarball": ("t.tgz", _tarball(), "application/gzip")}) + assert r.status_code == 500 + assert r.json()["detail"]["rolled_back"] is False + + +def test_update_self_requires_confirm(client: TestClient) -> None: + r = client.post("/update-self", files={"tarball": ("t.tgz", _tarball(), "application/gzip")}) + assert r.status_code == 400 + assert "confirm_self" in r.json()["detail"] + + +def test_update_self_happy_path(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ex, "run_update_self", + lambda data, sha, updater_install_dir: {"status": "self_update_queued", "argv": ["python", "-m", "decnet", "updater"]}, + ) + r = client.post( + "/update-self", + files={"tarball": ("t.tgz", _tarball(), "application/gzip")}, + data={"sha": "S", "confirm_self": "true"}, + ) + assert r.status_code == 200 + assert r.json()["status"] == "self_update_queued" + + +def test_rollback_happy(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ex, "run_rollback", + lambda install_dir, agent_dir: {"status": "rolled_back", "release": {"slot": "active", "sha": "O"}, "probe": "ok"}, + ) + r = client.post("/rollback") + assert r.status_code == 200 + assert r.json()["status"] == "rolled_back" + + +def test_rollback_missing_prev_returns_404(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + def _boom(**_): + raise ex.UpdateError("no previous release to roll back to") + monkeypatch.setattr(ex, "run_rollback", _boom) + r = client.post("/rollback") + assert r.status_code == 404 + + +def test_releases_lists_slots(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ex, "list_releases", + lambda d: [ex.Release(slot="active", sha="A", installed_at=None), + ex.Release(slot="prev", sha="B", installed_at=None)], + ) + r = client.get("/releases") + assert r.status_code == 200 + slots = [rel["slot"] for rel in r.json()["releases"]] + assert slots == ["active", "prev"] diff --git a/tests/updater/test_updater_executor.py b/tests/updater/test_updater_executor.py new file mode 100644 index 0000000..f01ee4e --- /dev/null +++ b/tests/updater/test_updater_executor.py @@ -0,0 +1,295 @@ +"""Updater executor: directory rotation, probe-driven rollback, safety checks. + +All three real seams (`_run_pip`, `_spawn_agent`, `_stop_agent`, +`_probe_agent`) are monkeypatched so these tests never shell out or +touch a real Python venv. The rotation/symlink/extract logic is exercised +against a ``tmp_path`` install dir. +""" +from __future__ import annotations + +import io +import pathlib +import subprocess +import tarfile +from typing import Any + +import pytest + +from decnet.updater import executor as ex + + +# ------------------------------------------------------------------ helpers + +def _make_tarball(files: dict[str, str]) -> bytes: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for name, content in files.items(): + data = content.encode() + info = tarfile.TarInfo(name=name) + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + return buf.getvalue() + + +class _PipOK: + returncode = 0 + stdout = "" + stderr = "" + + +class _PipFail: + returncode = 1 + stdout = "" + stderr = "resolver error: Could not find a version that satisfies ..." + + +@pytest.fixture +def install_dir(tmp_path: pathlib.Path) -> pathlib.Path: + d = tmp_path / "decnet" + d.mkdir() + (d / "releases").mkdir() + return d + + +@pytest.fixture +def agent_dir(tmp_path: pathlib.Path) -> pathlib.Path: + d = tmp_path / "agent" + d.mkdir() + # executor._probe_agent checks these exist before constructing SSL ctx, + # but the probe seam is monkeypatched in every test so content doesn't + # matter — still create them so the non-stubbed path is representative. + (d / "ca.crt").write_bytes(b"-----BEGIN CERTIFICATE-----\nstub\n-----END CERTIFICATE-----\n") + (d / "worker.crt").write_bytes(b"-----BEGIN CERTIFICATE-----\nstub\n-----END CERTIFICATE-----\n") + (d / "worker.key").write_bytes(b"-----BEGIN PRIVATE KEY-----\nstub\n-----END PRIVATE KEY-----\n") + return d + + +@pytest.fixture +def seed_existing_release(install_dir: pathlib.Path) -> None: + """Pretend an install is already live: create releases/active with a marker.""" + active = install_dir / "releases" / "active" + active.mkdir() + (active / "marker.txt").write_text("old") + ex._write_manifest(active, sha="OLDSHA") + # current -> active + ex._point_current_at(install_dir, active) + + +# --------------------------------------------------------- extract + safety + +def test_extract_rejects_path_traversal(tmp_path: pathlib.Path) -> None: + evil = _make_tarball({"../escape.txt": "pwned"}) + with pytest.raises(ex.UpdateError, match="unsafe path"): + ex.extract_tarball(evil, tmp_path / "out") + + +def test_extract_rejects_absolute_paths(tmp_path: pathlib.Path) -> None: + evil = _make_tarball({"/etc/passwd": "root:x:0:0"}) + with pytest.raises(ex.UpdateError, match="unsafe path"): + ex.extract_tarball(evil, tmp_path / "out") + + +def test_extract_happy_path(tmp_path: pathlib.Path) -> None: + tb = _make_tarball({"a/b.txt": "hello"}) + out = tmp_path / "out" + ex.extract_tarball(tb, out) + assert (out / "a" / "b.txt").read_text() == "hello" + + +def test_clean_stale_staging(install_dir: pathlib.Path) -> None: + staging = install_dir / "releases" / "active.new" + staging.mkdir() + (staging / "junk").write_text("left from a crash") + ex.clean_stale_staging(install_dir) + assert not staging.exists() + + +# ---------------------------------------------------------------- happy path + +def test_update_rotates_and_probes( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, + agent_dir: pathlib.Path, + seed_existing_release: None, +) -> None: + monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK()) + monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None) + monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 42) + monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok")) + + tb = _make_tarball({"marker.txt": "new"}) + result = ex.run_update(tb, sha="NEWSHA", install_dir=install_dir, agent_dir=agent_dir) + + assert result["status"] == "updated" + assert result["release"]["sha"] == "NEWSHA" + assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "new" + # Old release demoted, not deleted. + assert (install_dir / "releases" / "prev" / "marker.txt").read_text() == "old" + # Current symlink points at the new active. + assert (install_dir / "current").resolve() == (install_dir / "releases" / "active").resolve() + + +def test_update_first_install_without_previous( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, + agent_dir: pathlib.Path, +) -> None: + """No existing active/ dir — first real install via the updater.""" + monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK()) + monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None) + monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1) + monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok")) + + tb = _make_tarball({"marker.txt": "first"}) + result = ex.run_update(tb, sha="S1", install_dir=install_dir, agent_dir=agent_dir) + assert result["status"] == "updated" + assert not (install_dir / "releases" / "prev").exists() + + +# ------------------------------------------------------------ pip failure + +def test_update_pip_failure_aborts_before_rotation( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, + agent_dir: pathlib.Path, + seed_existing_release: None, +) -> None: + monkeypatch.setattr(ex, "_run_pip", lambda release: _PipFail()) + stop_called: list[bool] = [] + monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: stop_called.append(True)) + monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1) + monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok")) + + tb = _make_tarball({"marker.txt": "new"}) + with pytest.raises(ex.UpdateError, match="pip install failed") as ei: + ex.run_update(tb, sha="S", install_dir=install_dir, agent_dir=agent_dir) + assert "resolver error" in ei.value.stderr + + # Nothing rotated — old active still live, no prev created. + assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "old" + assert not (install_dir / "releases" / "prev").exists() + # Agent never touched. + assert stop_called == [] + # Staging cleaned up. + assert not (install_dir / "releases" / "active.new").exists() + + +# ------------------------------------------------------------ probe failure + +def test_update_probe_failure_rolls_back( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, + agent_dir: pathlib.Path, + seed_existing_release: None, +) -> None: + monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK()) + monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None) + monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1) + + calls: list[int] = [0] + + def _probe(**_: Any) -> tuple[bool, str]: + calls[0] += 1 + if calls[0] == 1: + return False, "connection refused" + return True, "ok" # rollback probe succeeds + + monkeypatch.setattr(ex, "_probe_agent", _probe) + + tb = _make_tarball({"marker.txt": "new"}) + with pytest.raises(ex.UpdateError, match="health probe") as ei: + ex.run_update(tb, sha="NEWSHA", install_dir=install_dir, agent_dir=agent_dir) + assert ei.value.rolled_back is True + assert "connection refused" in ei.value.stderr + + # Rolled back: active has the old marker again. + assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "old" + # Prev now holds what would have been the new release. + assert (install_dir / "releases" / "prev" / "marker.txt").read_text() == "new" + # Current symlink points back at active. + assert (install_dir / "current").resolve() == (install_dir / "releases" / "active").resolve() + + +# ------------------------------------------------------------ manual rollback + +def test_manual_rollback_swaps( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, + agent_dir: pathlib.Path, + seed_existing_release: None, +) -> None: + # Seed a prev/ so rollback has somewhere to go. + prev = install_dir / "releases" / "prev" + prev.mkdir() + (prev / "marker.txt").write_text("older") + ex._write_manifest(prev, sha="OLDERSHA") + + monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None) + monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1) + monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok")) + + result = ex.run_rollback(install_dir=install_dir, agent_dir=agent_dir) + assert result["status"] == "rolled_back" + assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "older" + assert (install_dir / "releases" / "prev" / "marker.txt").read_text() == "old" + + +def test_manual_rollback_refuses_without_prev( + install_dir: pathlib.Path, + seed_existing_release: None, +) -> None: + with pytest.raises(ex.UpdateError, match="no previous release"): + ex.run_rollback(install_dir=install_dir) + + +# ---------------------------------------------------------------- releases + +def test_list_releases_includes_only_existing_slots( + install_dir: pathlib.Path, + seed_existing_release: None, +) -> None: + rs = ex.list_releases(install_dir) + assert [r.slot for r in rs] == ["active"] + assert rs[0].sha == "OLDSHA" + + +# ---------------------------------------------------------------- self-update + +def test_update_self_rotates_and_calls_exec_cb( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, +) -> None: + # Seed a stand-in "active" for the updater itself. + active = install_dir / "releases" / "active" + active.mkdir() + (active / "marker").write_text("old-updater") + + monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK()) + seen_argv: list[list[str]] = [] + + tb = _make_tarball({"marker": "new-updater"}) + result = ex.run_update_self( + tb, sha="USHA", updater_install_dir=install_dir, + exec_cb=lambda argv: seen_argv.append(argv), + ) + assert result["status"] == "self_update_queued" + assert (install_dir / "releases" / "active" / "marker").read_text() == "new-updater" + assert (install_dir / "releases" / "prev" / "marker").read_text() == "old-updater" + assert len(seen_argv) == 1 + assert "updater" in seen_argv[0] + + +def test_update_self_pip_failure_leaves_active_intact( + monkeypatch: pytest.MonkeyPatch, + install_dir: pathlib.Path, +) -> None: + active = install_dir / "releases" / "active" + active.mkdir() + (active / "marker").write_text("old-updater") + monkeypatch.setattr(ex, "_run_pip", lambda release: _PipFail()) + + tb = _make_tarball({"marker": "new-updater"}) + with pytest.raises(ex.UpdateError, match="pip install failed"): + ex.run_update_self(tb, sha="U", updater_install_dir=install_dir, exec_cb=lambda a: None) + assert (install_dir / "releases" / "active" / "marker").read_text() == "old-updater" + assert not (install_dir / "releases" / "active.new").exists()