merge testing->tomerge/main #7
171
decnet/cli.py
171
decnet/cli.py
@@ -187,6 +187,43 @@ def agent(
|
|||||||
raise typer.Exit(rc)
|
raise typer.Exit(rc)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def updater(
|
||||||
|
port: int = typer.Option(8766, "--port", help="Port for the self-updater daemon"),
|
||||||
|
host: str = typer.Option("0.0.0.0", "--host", help="Bind address for the updater"), # nosec B104
|
||||||
|
updater_dir: Optional[str] = typer.Option(None, "--updater-dir", help="Updater cert bundle dir (default: ~/.decnet/updater)"),
|
||||||
|
install_dir: Optional[str] = typer.Option(None, "--install-dir", help="Release install root (default: /opt/decnet)"),
|
||||||
|
agent_dir: Optional[str] = typer.Option(None, "--agent-dir", help="Worker agent cert bundle (for local /health probes; default: ~/.decnet/agent)"),
|
||||||
|
daemon: bool = typer.Option(False, "--daemon", "-d", help="Detach to background as a daemon process"),
|
||||||
|
) -> None:
|
||||||
|
"""Run the DECNET self-updater (requires a bundle in ~/.decnet/updater/)."""
|
||||||
|
import pathlib as _pathlib
|
||||||
|
from decnet.swarm import pki as _pki
|
||||||
|
from decnet.updater import server as _upd_server
|
||||||
|
|
||||||
|
resolved_updater = _pathlib.Path(updater_dir) if updater_dir else _upd_server.DEFAULT_UPDATER_DIR
|
||||||
|
resolved_install = _pathlib.Path(install_dir) if install_dir else _pathlib.Path("/opt/decnet")
|
||||||
|
resolved_agent = _pathlib.Path(agent_dir) if agent_dir else _pki.DEFAULT_AGENT_DIR
|
||||||
|
|
||||||
|
if daemon:
|
||||||
|
log.info("updater daemonizing host=%s port=%d", host, port)
|
||||||
|
_daemonize()
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"updater command invoked host=%s port=%d updater_dir=%s install_dir=%s",
|
||||||
|
host, port, resolved_updater, resolved_install,
|
||||||
|
)
|
||||||
|
console.print(f"[green]Starting DECNET self-updater on {host}:{port} (mTLS)...[/]")
|
||||||
|
rc = _upd_server.run(
|
||||||
|
host, port,
|
||||||
|
updater_dir=resolved_updater,
|
||||||
|
install_dir=resolved_install,
|
||||||
|
agent_dir=resolved_agent,
|
||||||
|
)
|
||||||
|
if rc != 0:
|
||||||
|
raise typer.Exit(rc)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def listener(
|
def listener(
|
||||||
bind_host: str = typer.Option("0.0.0.0", "--host", help="Bind address for the master syslog-TLS listener"), # nosec B104
|
bind_host: str = typer.Option("0.0.0.0", "--host", help="Bind address for the master syslog-TLS listener"), # nosec B104
|
||||||
@@ -393,6 +430,7 @@ def swarm_enroll(
|
|||||||
sans: Optional[str] = typer.Option(None, "--sans", help="Comma-separated extra SANs for the worker cert"),
|
sans: Optional[str] = typer.Option(None, "--sans", help="Comma-separated extra SANs for the worker cert"),
|
||||||
notes: Optional[str] = typer.Option(None, "--notes", help="Free-form operator notes"),
|
notes: Optional[str] = typer.Option(None, "--notes", help="Free-form operator notes"),
|
||||||
out_dir: Optional[str] = typer.Option(None, "--out-dir", help="Write the bundle (ca.crt/worker.crt/worker.key) to this dir for scp"),
|
out_dir: Optional[str] = typer.Option(None, "--out-dir", help="Write the bundle (ca.crt/worker.crt/worker.key) to this dir for scp"),
|
||||||
|
updater: bool = typer.Option(False, "--updater", help="Also issue an updater-identity cert (CN=updater@<name>) for the remote self-updater"),
|
||||||
url: Optional[str] = typer.Option(None, "--url", help="Override swarm controller URL (default: 127.0.0.1:8770)"),
|
url: Optional[str] = typer.Option(None, "--url", help="Override swarm controller URL (default: 127.0.0.1:8770)"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Issue a mTLS bundle for a new worker and register it in the swarm."""
|
"""Issue a mTLS bundle for a new worker and register it in the swarm."""
|
||||||
@@ -403,6 +441,8 @@ def swarm_enroll(
|
|||||||
body["sans"] = [s.strip() for s in sans.split(",") if s.strip()]
|
body["sans"] = [s.strip() for s in sans.split(",") if s.strip()]
|
||||||
if notes:
|
if notes:
|
||||||
body["notes"] = notes
|
body["notes"] = notes
|
||||||
|
if updater:
|
||||||
|
body["issue_updater_bundle"] = True
|
||||||
|
|
||||||
resp = _http_request("POST", _swarmctl_base_url(url) + "/swarm/enroll", json_body=body)
|
resp = _http_request("POST", _swarmctl_base_url(url) + "/swarm/enroll", json_body=body)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
@@ -410,6 +450,9 @@ def swarm_enroll(
|
|||||||
console.print(f"[green]Enrolled worker:[/] {data['name']} "
|
console.print(f"[green]Enrolled worker:[/] {data['name']} "
|
||||||
f"[dim]uuid=[/]{data['host_uuid']} "
|
f"[dim]uuid=[/]{data['host_uuid']} "
|
||||||
f"[dim]fingerprint=[/]{data['fingerprint']}")
|
f"[dim]fingerprint=[/]{data['fingerprint']}")
|
||||||
|
if data.get("updater"):
|
||||||
|
console.print(f"[green] + updater identity[/] "
|
||||||
|
f"[dim]fingerprint=[/]{data['updater']['fingerprint']}")
|
||||||
|
|
||||||
if out_dir:
|
if out_dir:
|
||||||
target = _pathlib.Path(out_dir).expanduser()
|
target = _pathlib.Path(out_dir).expanduser()
|
||||||
@@ -422,8 +465,22 @@ def swarm_enroll(
|
|||||||
(target / leaf).chmod(0o600)
|
(target / leaf).chmod(0o600)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
console.print(f"[cyan]Bundle written to[/] {target}")
|
console.print(f"[cyan]Agent bundle written to[/] {target}")
|
||||||
console.print("[dim]Ship this directory to the worker at ~/.decnet/agent/ (or wherever `decnet agent --agent-dir` points).[/]")
|
|
||||||
|
if data.get("updater"):
|
||||||
|
upd_target = target.parent / f"{target.name}-updater"
|
||||||
|
upd_target.mkdir(parents=True, exist_ok=True)
|
||||||
|
(upd_target / "ca.crt").write_text(data["ca_cert_pem"])
|
||||||
|
(upd_target / "updater.crt").write_text(data["updater"]["updater_cert_pem"])
|
||||||
|
(upd_target / "updater.key").write_text(data["updater"]["updater_key_pem"])
|
||||||
|
try:
|
||||||
|
(upd_target / "updater.key").chmod(0o600)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
console.print(f"[cyan]Updater bundle written to[/] {upd_target}")
|
||||||
|
console.print("[dim]Ship the agent dir to ~/.decnet/agent/ and the updater dir to ~/.decnet/updater/ on the worker.[/]")
|
||||||
|
else:
|
||||||
|
console.print("[dim]Ship this directory to the worker at ~/.decnet/agent/ (or wherever `decnet agent --agent-dir` points).[/]")
|
||||||
else:
|
else:
|
||||||
console.print("[yellow]No --out-dir given — bundle PEMs are in the JSON response; persist them before leaving this shell.[/]")
|
console.print("[yellow]No --out-dir given — bundle PEMs are in the JSON response; persist them before leaving this shell.[/]")
|
||||||
|
|
||||||
@@ -494,6 +551,116 @@ def swarm_check(
|
|||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@swarm_app.command("update")
|
||||||
|
def swarm_update(
|
||||||
|
host: Optional[str] = typer.Option(None, "--host", help="Target worker (name or UUID). Omit with --all."),
|
||||||
|
all_hosts: bool = typer.Option(False, "--all", help="Push to every enrolled worker."),
|
||||||
|
include_self: bool = typer.Option(False, "--include-self", help="Also push to each updater's /update-self after a successful agent update."),
|
||||||
|
root: Optional[str] = typer.Option(None, "--root", help="Source tree to tar (default: CWD)."),
|
||||||
|
exclude: list[str] = typer.Option([], "--exclude", help="Additional exclude glob. Repeatable."),
|
||||||
|
updater_port: int = typer.Option(8766, "--updater-port", help="Port the workers' updater listens on."),
|
||||||
|
dry_run: bool = typer.Option(False, "--dry-run", help="Build the tarball and print stats; no network."),
|
||||||
|
url: Optional[str] = typer.Option(None, "--url", help="Override swarm controller URL."),
|
||||||
|
) -> None:
|
||||||
|
"""Push the current working tree to workers' self-updaters (with auto-rollback on failure)."""
|
||||||
|
import asyncio
|
||||||
|
import pathlib as _pathlib
|
||||||
|
|
||||||
|
from decnet.swarm.tar_tree import tar_working_tree, detect_git_sha
|
||||||
|
from decnet.swarm.updater_client import UpdaterClient
|
||||||
|
|
||||||
|
if not (host or all_hosts):
|
||||||
|
console.print("[red]Supply --host <name> or --all.[/]")
|
||||||
|
raise typer.Exit(2)
|
||||||
|
if host and all_hosts:
|
||||||
|
console.print("[red]--host and --all are mutually exclusive.[/]")
|
||||||
|
raise typer.Exit(2)
|
||||||
|
|
||||||
|
base = _swarmctl_base_url(url)
|
||||||
|
resp = _http_request("GET", base + "/swarm/hosts")
|
||||||
|
rows = resp.json()
|
||||||
|
if host:
|
||||||
|
targets = [r for r in rows if r.get("name") == host or r.get("uuid") == host]
|
||||||
|
if not targets:
|
||||||
|
console.print(f"[red]No enrolled worker matching '{host}'.[/]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
else:
|
||||||
|
targets = [r for r in rows if r.get("status") != "decommissioned"]
|
||||||
|
if not targets:
|
||||||
|
console.print("[dim]No targets.[/]")
|
||||||
|
return
|
||||||
|
|
||||||
|
tree_root = _pathlib.Path(root) if root else _pathlib.Path.cwd()
|
||||||
|
sha = detect_git_sha(tree_root)
|
||||||
|
console.print(f"[dim]Tarring[/] {tree_root} [dim]sha={sha or '(not a git repo)'}[/]")
|
||||||
|
tarball = tar_working_tree(tree_root, extra_excludes=exclude)
|
||||||
|
console.print(f"[dim]Tarball size:[/] {len(tarball):,} bytes")
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
console.print("[yellow]--dry-run: not pushing.[/]")
|
||||||
|
for t in targets:
|
||||||
|
console.print(f" would push to [cyan]{t.get('name')}[/] at {t.get('address')}:{updater_port}")
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _push_one(h: dict) -> dict:
|
||||||
|
name = h.get("name") or h.get("uuid")
|
||||||
|
out: dict = {"name": name, "address": h.get("address"), "agent": None, "self": None}
|
||||||
|
try:
|
||||||
|
async with UpdaterClient(h, updater_port=updater_port) as u:
|
||||||
|
r = await u.update(tarball, sha=sha)
|
||||||
|
out["agent"] = {"status": r.status_code, "body": r.json() if r.content else {}}
|
||||||
|
if r.status_code == 200 and include_self:
|
||||||
|
# Agent first, updater second — see plan.
|
||||||
|
rs = await u.update_self(tarball, sha=sha)
|
||||||
|
# Connection-drop is expected for update-self.
|
||||||
|
out["self"] = {"status": rs.status_code, "body": rs.json() if rs.content else {}}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
out["error"] = f"{type(exc).__name__}: {exc}"
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _push_all() -> list[dict]:
|
||||||
|
return await asyncio.gather(*(_push_one(t) for t in targets))
|
||||||
|
|
||||||
|
results = asyncio.run(_push_all())
|
||||||
|
|
||||||
|
table = Table(title="DECNET swarm update")
|
||||||
|
for col in ("host", "address", "agent", "self", "detail"):
|
||||||
|
table.add_column(col)
|
||||||
|
any_failure = False
|
||||||
|
for r in results:
|
||||||
|
agent = r.get("agent") or {}
|
||||||
|
selff = r.get("self") or {}
|
||||||
|
err = r.get("error")
|
||||||
|
if err:
|
||||||
|
any_failure = True
|
||||||
|
table.add_row(r["name"], r.get("address") or "", "[red]error[/]", "—", err)
|
||||||
|
continue
|
||||||
|
a_status = agent.get("status")
|
||||||
|
if a_status == 200:
|
||||||
|
agent_cell = "[green]updated[/]"
|
||||||
|
elif a_status == 409:
|
||||||
|
agent_cell = "[yellow]rolled-back[/]"
|
||||||
|
any_failure = True
|
||||||
|
else:
|
||||||
|
agent_cell = f"[red]{a_status}[/]"
|
||||||
|
any_failure = True
|
||||||
|
if not include_self:
|
||||||
|
self_cell = "—"
|
||||||
|
elif selff.get("status") == 200 or selff.get("status") is None:
|
||||||
|
self_cell = "[green]ok[/]" if selff else "[dim]skipped[/]"
|
||||||
|
else:
|
||||||
|
self_cell = f"[red]{selff.get('status')}[/]"
|
||||||
|
detail = ""
|
||||||
|
body = agent.get("body") or {}
|
||||||
|
if isinstance(body, dict):
|
||||||
|
detail = body.get("release", {}).get("sha") or body.get("detail", {}).get("error") or ""
|
||||||
|
table.add_row(r["name"], r.get("address") or "", agent_cell, self_cell, str(detail)[:80])
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
if any_failure:
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
@swarm_app.command("deckies")
|
@swarm_app.command("deckies")
|
||||||
def swarm_deckies(
|
def swarm_deckies(
|
||||||
host: Optional[str] = typer.Option(None, "--host", help="Filter by worker name or UUID"),
|
host: Optional[str] = typer.Option(None, "--host", help="Filter by worker name or UUID"),
|
||||||
|
|||||||
97
decnet/swarm/tar_tree.py
Normal file
97
decnet/swarm/tar_tree.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Build a gzipped tarball of the master's working tree for pushing to workers.
|
||||||
|
|
||||||
|
Always excludes the obvious large / secret / churn paths: ``.venv/``,
|
||||||
|
``__pycache__/``, ``.git/``, ``wiki-checkout/``, ``*.db*``, ``*.log``. The
|
||||||
|
caller can supply additional exclude globs.
|
||||||
|
|
||||||
|
Deliberately does NOT invoke git — the tree is what the operator has on
|
||||||
|
disk (staged + unstaged + untracked). That's the whole point; the scp
|
||||||
|
workflow we're replacing also shipped the live tree.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import tarfile
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
DEFAULT_EXCLUDES = (
|
||||||
|
".venv", ".venv/*",
|
||||||
|
"**/.venv/*",
|
||||||
|
"__pycache__", "**/__pycache__", "**/__pycache__/*",
|
||||||
|
".git", ".git/*",
|
||||||
|
"wiki-checkout", "wiki-checkout/*",
|
||||||
|
"*.pyc", "*.pyo",
|
||||||
|
"*.db", "*.db-wal", "*.db-shm",
|
||||||
|
"*.log",
|
||||||
|
".pytest_cache", ".pytest_cache/*",
|
||||||
|
".mypy_cache", ".mypy_cache/*",
|
||||||
|
".tox", ".tox/*",
|
||||||
|
"*.egg-info", "*.egg-info/*",
|
||||||
|
"decnet-state.json",
|
||||||
|
"master.log", "master.json",
|
||||||
|
"decnet.db*",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_excluded(rel: str, patterns: Iterable[str]) -> bool:
|
||||||
|
parts = pathlib.PurePosixPath(rel).parts
|
||||||
|
for pat in patterns:
|
||||||
|
if fnmatch.fnmatch(rel, pat):
|
||||||
|
return True
|
||||||
|
# Also match the pattern against every leading subpath — this is
|
||||||
|
# what catches nested `.venv/...` without forcing callers to spell
|
||||||
|
# out every `**/` glob.
|
||||||
|
for i in range(1, len(parts) + 1):
|
||||||
|
if fnmatch.fnmatch("/".join(parts[:i]), pat):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def tar_working_tree(
|
||||||
|
root: pathlib.Path,
|
||||||
|
extra_excludes: Optional[Iterable[str]] = None,
|
||||||
|
) -> bytes:
|
||||||
|
"""Return the gzipped tarball bytes of ``root``.
|
||||||
|
|
||||||
|
Entries are added with paths relative to ``root`` (no leading ``/``,
|
||||||
|
no ``..``). The updater rejects unsafe paths on the receiving side.
|
||||||
|
"""
|
||||||
|
patterns = list(DEFAULT_EXCLUDES) + list(extra_excludes or ())
|
||||||
|
buf = io.BytesIO()
|
||||||
|
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||||
|
for path in sorted(root.rglob("*")):
|
||||||
|
rel = path.relative_to(root).as_posix()
|
||||||
|
if _is_excluded(rel, patterns):
|
||||||
|
continue
|
||||||
|
if path.is_symlink():
|
||||||
|
# Symlinks inside a repo tree are rare and often break
|
||||||
|
# portability; skip them rather than ship dangling links.
|
||||||
|
continue
|
||||||
|
if path.is_dir():
|
||||||
|
continue
|
||||||
|
tar.add(path, arcname=rel, recursive=False)
|
||||||
|
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def detect_git_sha(root: pathlib.Path) -> str:
|
||||||
|
"""Best-effort ``HEAD`` sha. Returns ``""`` if not a git repo."""
|
||||||
|
head = root / ".git" / "HEAD"
|
||||||
|
if not head.is_file():
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
ref = head.read_text().strip()
|
||||||
|
except OSError:
|
||||||
|
return ""
|
||||||
|
if ref.startswith("ref: "):
|
||||||
|
ref_path = root / ".git" / ref[5:]
|
||||||
|
if ref_path.is_file():
|
||||||
|
try:
|
||||||
|
return ref_path.read_text().strip()
|
||||||
|
except OSError:
|
||||||
|
return ""
|
||||||
|
return ""
|
||||||
|
return ref
|
||||||
124
decnet/swarm/updater_client.py
Normal file
124
decnet/swarm/updater_client.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Master-side HTTP client for the worker's self-updater daemon.
|
||||||
|
|
||||||
|
Sibling of ``AgentClient``: same mTLS identity (same DECNET CA, same
|
||||||
|
master client cert) but targets the updater's port (default 8766) and
|
||||||
|
speaks the multipart upload protocol the updater's ``/update`` endpoint
|
||||||
|
expects.
|
||||||
|
|
||||||
|
Kept as its own module — not a subclass of ``AgentClient`` — because the
|
||||||
|
timeouts and failure semantics are genuinely different: pip install +
|
||||||
|
agent probe can take a minute on a slow VM, and ``/update-self`` drops
|
||||||
|
the connection on purpose (the updater re-execs itself mid-response).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ssl
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from decnet.logging import get_logger
|
||||||
|
from decnet.swarm.client import MasterIdentity, ensure_master_identity
|
||||||
|
|
||||||
|
log = get_logger("swarm.updater_client")
|
||||||
|
|
||||||
|
_TIMEOUT_UPDATE = httpx.Timeout(connect=10.0, read=180.0, write=120.0, pool=5.0)
|
||||||
|
_TIMEOUT_CONTROL = httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdaterClient:
|
||||||
|
"""Async client targeting a worker's ``decnet updater`` daemon."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
address: Optional[str] = None,
|
||||||
|
updater_port: int = 8766,
|
||||||
|
identity: Optional[MasterIdentity] = None,
|
||||||
|
):
|
||||||
|
if host is not None:
|
||||||
|
self._address = host["address"]
|
||||||
|
self._host_name = host.get("name")
|
||||||
|
else:
|
||||||
|
if address is None:
|
||||||
|
raise ValueError("UpdaterClient requires host dict or address")
|
||||||
|
self._address = address
|
||||||
|
self._host_name = None
|
||||||
|
self._port = updater_port
|
||||||
|
self._identity = identity or ensure_master_identity()
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
def _build_client(self, timeout: httpx.Timeout) -> httpx.AsyncClient:
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
ctx.load_cert_chain(
|
||||||
|
str(self._identity.cert_path), str(self._identity.key_path),
|
||||||
|
)
|
||||||
|
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||||
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
ctx.check_hostname = False
|
||||||
|
return httpx.AsyncClient(
|
||||||
|
base_url=f"https://{self._address}:{self._port}",
|
||||||
|
verify=ctx,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "UpdaterClient":
|
||||||
|
self._client = self._build_client(_TIMEOUT_CONTROL)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: Any) -> None:
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
def _require(self) -> httpx.AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError("UpdaterClient used outside `async with` block")
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
# --------------------------------------------------------------- RPCs
|
||||||
|
|
||||||
|
async def health(self) -> dict[str, Any]:
|
||||||
|
r = await self._require().get("/health")
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
async def releases(self) -> dict[str, Any]:
|
||||||
|
r = await self._require().get("/releases")
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
async def update(self, tarball: bytes, sha: str = "") -> httpx.Response:
|
||||||
|
"""POST /update. Returns the Response so the caller can distinguish
|
||||||
|
200 / 409 / 500 — each means something different.
|
||||||
|
"""
|
||||||
|
self._require().timeout = _TIMEOUT_UPDATE
|
||||||
|
try:
|
||||||
|
r = await self._require().post(
|
||||||
|
"/update",
|
||||||
|
files={"tarball": ("tree.tgz", tarball, "application/gzip")},
|
||||||
|
data={"sha": sha},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._require().timeout = _TIMEOUT_CONTROL
|
||||||
|
return r
|
||||||
|
|
||||||
|
async def update_self(self, tarball: bytes, sha: str = "") -> httpx.Response:
|
||||||
|
"""POST /update-self. The updater re-execs itself, so the connection
|
||||||
|
usually drops mid-response; that's not an error. Callers should then
|
||||||
|
poll /health until the new SHA appears.
|
||||||
|
"""
|
||||||
|
self._require().timeout = _TIMEOUT_UPDATE
|
||||||
|
try:
|
||||||
|
r = await self._require().post(
|
||||||
|
"/update-self",
|
||||||
|
files={"tarball": ("tree.tgz", tarball, "application/gzip")},
|
||||||
|
data={"sha": sha, "confirm_self": "true"},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._require().timeout = _TIMEOUT_CONTROL
|
||||||
|
return r
|
||||||
|
|
||||||
|
async def rollback(self) -> httpx.Response:
|
||||||
|
return await self._require().post("/rollback")
|
||||||
10
decnet/updater/__init__.py
Normal file
10
decnet/updater/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""DECNET self-updater daemon.
|
||||||
|
|
||||||
|
Runs on each worker alongside ``decnet agent``. Receives working-tree
|
||||||
|
tarballs from the master and owns the agent's lifecycle: snapshot →
|
||||||
|
install → restart → probe → auto-rollback on failure.
|
||||||
|
|
||||||
|
Deliberately separate process, separate venv, separate mTLS cert so that
|
||||||
|
a broken ``decnet agent`` push can always be rolled back by the updater
|
||||||
|
that shipped it. See ``wiki/Remote-Updates.md``.
|
||||||
|
"""
|
||||||
139
decnet/updater/app.py
Normal file
139
decnet/updater/app.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Updater FastAPI app — mTLS-protected endpoints for self-update.
|
||||||
|
|
||||||
|
Mirrors the shape of ``decnet/agent/app.py``: bare FastAPI, docs disabled,
|
||||||
|
handlers delegate to ``decnet.updater.executor``.
|
||||||
|
|
||||||
|
Mounted by uvicorn via ``decnet.updater.server`` with ``--ssl-cert-reqs 2``;
|
||||||
|
the CN on the peer cert tells us which endpoints are legal (``updater@*``
|
||||||
|
only — agent certs are rejected).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os as _os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from decnet.logging import get_logger
|
||||||
|
from decnet.swarm import pki
|
||||||
|
from decnet.updater import executor as _exec
|
||||||
|
|
||||||
|
log = get_logger("updater.app")
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="DECNET Self-Updater",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
openapi_url=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _Config:
|
||||||
|
install_dir: pathlib.Path = pathlib.Path(
|
||||||
|
_os.environ.get("DECNET_UPDATER_INSTALL_DIR") or str(_exec.DEFAULT_INSTALL_DIR)
|
||||||
|
)
|
||||||
|
updater_install_dir: pathlib.Path = pathlib.Path(
|
||||||
|
_os.environ.get("DECNET_UPDATER_UPDATER_DIR")
|
||||||
|
or str(_exec.DEFAULT_INSTALL_DIR / "updater")
|
||||||
|
)
|
||||||
|
agent_dir: pathlib.Path = pathlib.Path(
|
||||||
|
_os.environ.get("DECNET_UPDATER_AGENT_DIR") or str(pki.DEFAULT_AGENT_DIR)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
updater_install_dir: pathlib.Path,
|
||||||
|
agent_dir: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
"""Inject paths from the server launcher; must be called before serving."""
|
||||||
|
_Config.install_dir = install_dir
|
||||||
|
_Config.updater_install_dir = updater_install_dir
|
||||||
|
_Config.agent_dir = agent_dir
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------- schemas
|
||||||
|
|
||||||
|
class RollbackResult(BaseModel):
|
||||||
|
status: str
|
||||||
|
release: dict
|
||||||
|
probe: str
|
||||||
|
|
||||||
|
|
||||||
|
class ReleasesResponse(BaseModel):
|
||||||
|
releases: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------- routes
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict:
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"role": "updater",
|
||||||
|
"releases": [r.to_dict() for r in _exec.list_releases(_Config.install_dir)],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/releases")
|
||||||
|
async def releases() -> dict:
|
||||||
|
return {"releases": [r.to_dict() for r in _exec.list_releases(_Config.install_dir)]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update")
|
||||||
|
async def update(
|
||||||
|
tarball: UploadFile = File(..., description="tar.gz of the working tree"),
|
||||||
|
sha: str = Form("", description="git SHA of the tree for provenance"),
|
||||||
|
) -> dict:
|
||||||
|
body = await tarball.read()
|
||||||
|
try:
|
||||||
|
return _exec.run_update(
|
||||||
|
body, sha=sha or None,
|
||||||
|
install_dir=_Config.install_dir, agent_dir=_Config.agent_dir,
|
||||||
|
)
|
||||||
|
except _exec.UpdateError as exc:
|
||||||
|
status = 409 if exc.rolled_back else 500
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status,
|
||||||
|
detail={"error": str(exc), "stderr": exc.stderr, "rolled_back": exc.rolled_back},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update-self")
|
||||||
|
async def update_self(
|
||||||
|
tarball: UploadFile = File(...),
|
||||||
|
sha: str = Form(""),
|
||||||
|
confirm_self: str = Form("", description="Must be 'true' to proceed"),
|
||||||
|
) -> dict:
|
||||||
|
if confirm_self.lower() != "true":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="self-update requires confirm_self=true (no auto-rollback)",
|
||||||
|
)
|
||||||
|
body = await tarball.read()
|
||||||
|
try:
|
||||||
|
return _exec.run_update_self(
|
||||||
|
body, sha=sha or None,
|
||||||
|
updater_install_dir=_Config.updater_install_dir,
|
||||||
|
)
|
||||||
|
except _exec.UpdateError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": str(exc), "stderr": exc.stderr},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/rollback")
|
||||||
|
async def rollback() -> dict:
|
||||||
|
try:
|
||||||
|
return _exec.run_rollback(
|
||||||
|
install_dir=_Config.install_dir, agent_dir=_Config.agent_dir,
|
||||||
|
)
|
||||||
|
except _exec.UpdateError as exc:
|
||||||
|
status = 404 if "no previous" in str(exc) else 500
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status,
|
||||||
|
detail={"error": str(exc), "stderr": exc.stderr},
|
||||||
|
) from exc
|
||||||
416
decnet/updater/executor.py
Normal file
416
decnet/updater/executor.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""Update/rollback orchestrator for the DECNET self-updater.
|
||||||
|
|
||||||
|
Directory layout owned by this module (root = ``install_dir``):
|
||||||
|
|
||||||
|
<install_dir>/
|
||||||
|
current -> releases/active (symlink; atomic swap == promotion)
|
||||||
|
releases/
|
||||||
|
active/ (working tree; has its own .venv)
|
||||||
|
prev/ (last good snapshot; restored on failure)
|
||||||
|
active.new/ (staging; only exists mid-update)
|
||||||
|
agent.pid (PID of the agent process we spawned)
|
||||||
|
|
||||||
|
Rollback semantics: if the agent doesn't come back healthy after an update,
|
||||||
|
we swap the symlink back to ``prev``, restart the agent, and return the
|
||||||
|
captured pip/agent stderr to the caller.
|
||||||
|
|
||||||
|
Seams for tests — every subprocess call goes through a module-level hook
|
||||||
|
(`_run_pip`, `_spawn_agent`, `_probe_agent`) so tests can monkeypatch them
|
||||||
|
without actually touching the filesystem's Python toolchain.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import ssl
|
||||||
|
import subprocess # nosec B404
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from decnet.logging import get_logger
|
||||||
|
from decnet.swarm import pki
|
||||||
|
|
||||||
|
log = get_logger("updater.executor")
|
||||||
|
|
||||||
|
DEFAULT_INSTALL_DIR = pathlib.Path("/opt/decnet")
|
||||||
|
AGENT_PROBE_URL = "https://127.0.0.1:8765/health"
|
||||||
|
AGENT_PROBE_ATTEMPTS = 10
|
||||||
|
AGENT_PROBE_BACKOFF_S = 1.0
|
||||||
|
AGENT_RESTART_GRACE_S = 10.0
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------- errors
|
||||||
|
|
||||||
|
class UpdateError(RuntimeError):
|
||||||
|
"""Raised when an update fails but the install dir is consistent.
|
||||||
|
|
||||||
|
Carries the captured stderr so the master gets actionable output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, *, stderr: str = "", rolled_back: bool = False):
|
||||||
|
super().__init__(message)
|
||||||
|
self.stderr = stderr
|
||||||
|
self.rolled_back = rolled_back
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------- types
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class Release:
|
||||||
|
slot: str
|
||||||
|
sha: Optional[str]
|
||||||
|
installed_at: Optional[datetime]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"slot": self.slot,
|
||||||
|
"sha": self.sha,
|
||||||
|
"installed_at": self.installed_at.isoformat() if self.installed_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- internals
|
||||||
|
|
||||||
|
def _releases_dir(install_dir: pathlib.Path) -> pathlib.Path:
|
||||||
|
return install_dir / "releases"
|
||||||
|
|
||||||
|
|
||||||
|
def _active_dir(install_dir: pathlib.Path) -> pathlib.Path:
|
||||||
|
return _releases_dir(install_dir) / "active"
|
||||||
|
|
||||||
|
|
||||||
|
def _prev_dir(install_dir: pathlib.Path) -> pathlib.Path:
|
||||||
|
return _releases_dir(install_dir) / "prev"
|
||||||
|
|
||||||
|
|
||||||
|
def _staging_dir(install_dir: pathlib.Path) -> pathlib.Path:
|
||||||
|
return _releases_dir(install_dir) / "active.new"
|
||||||
|
|
||||||
|
|
||||||
|
def _current_symlink(install_dir: pathlib.Path) -> pathlib.Path:
|
||||||
|
return install_dir / "current"
|
||||||
|
|
||||||
|
|
||||||
|
def _pid_file(install_dir: pathlib.Path) -> pathlib.Path:
|
||||||
|
return install_dir / "agent.pid"
|
||||||
|
|
||||||
|
|
||||||
|
def _manifest_file(release: pathlib.Path) -> pathlib.Path:
|
||||||
|
return release / ".decnet-release.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _venv_python(release: pathlib.Path) -> pathlib.Path:
|
||||||
|
return release / ".venv" / "bin" / "python"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------- public
|
||||||
|
|
||||||
|
def read_release(release: pathlib.Path) -> Release:
|
||||||
|
"""Read the release manifest sidecar; tolerate absence."""
|
||||||
|
slot = release.name
|
||||||
|
mf = _manifest_file(release)
|
||||||
|
if not mf.is_file():
|
||||||
|
return Release(slot=slot, sha=None, installed_at=None)
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(mf.read_text())
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
return Release(slot=slot, sha=None, installed_at=None)
|
||||||
|
ts = data.get("installed_at")
|
||||||
|
return Release(
|
||||||
|
slot=slot,
|
||||||
|
sha=data.get("sha"),
|
||||||
|
installed_at=datetime.fromisoformat(ts) if ts else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_releases(install_dir: pathlib.Path) -> list[Release]:
|
||||||
|
out: list[Release] = []
|
||||||
|
for slot_dir in (_active_dir(install_dir), _prev_dir(install_dir)):
|
||||||
|
if slot_dir.is_dir():
|
||||||
|
out.append(read_release(slot_dir))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def clean_stale_staging(install_dir: pathlib.Path) -> None:
|
||||||
|
"""Remove a half-extracted ``active.new`` left by a crashed update."""
|
||||||
|
staging = _staging_dir(install_dir)
|
||||||
|
if staging.exists():
|
||||||
|
log.warning("removing stale staging dir %s", staging)
|
||||||
|
shutil.rmtree(staging, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tarball(tarball_bytes: bytes, dest: pathlib.Path) -> None:
|
||||||
|
"""Extract a gzipped tarball into ``dest`` (must not pre-exist).
|
||||||
|
|
||||||
|
Rejects absolute paths and ``..`` traversal in the archive.
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
|
||||||
|
dest.mkdir(parents=True, exist_ok=False)
|
||||||
|
with tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as tar:
|
||||||
|
for member in tar.getmembers():
|
||||||
|
name = member.name
|
||||||
|
if name.startswith("/") or ".." in pathlib.PurePosixPath(name).parts:
|
||||||
|
raise UpdateError(f"unsafe path in tarball: {name!r}")
|
||||||
|
tar.extractall(dest) # nosec B202 — validated above
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- seams
|
||||||
|
|
||||||
|
def _run_pip(release: pathlib.Path) -> subprocess.CompletedProcess:
|
||||||
|
"""Create a venv in ``release/.venv`` and pip install -e . into it.
|
||||||
|
|
||||||
|
Monkeypatched in tests so the test suite never shells out.
|
||||||
|
"""
|
||||||
|
venv_dir = release / ".venv"
|
||||||
|
if not venv_dir.exists():
|
||||||
|
subprocess.run( # nosec B603
|
||||||
|
[sys.executable, "-m", "venv", str(venv_dir)],
|
||||||
|
check=True, capture_output=True, text=True,
|
||||||
|
)
|
||||||
|
py = _venv_python(release)
|
||||||
|
return subprocess.run( # nosec B603
|
||||||
|
[str(py), "-m", "pip", "install", "-e", str(release)],
|
||||||
|
check=False, capture_output=True, text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_agent(install_dir: pathlib.Path) -> int:
|
||||||
|
"""Launch ``decnet agent --daemon`` using the current-symlinked venv.
|
||||||
|
|
||||||
|
Returns the new PID. Monkeypatched in tests.
|
||||||
|
"""
|
||||||
|
py = _venv_python(_current_symlink(install_dir).resolve())
|
||||||
|
proc = subprocess.Popen( # nosec B603
|
||||||
|
[str(py), "-m", "decnet", "agent", "--daemon"],
|
||||||
|
start_new_session=True,
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
_pid_file(install_dir).write_text(str(proc.pid))
|
||||||
|
return proc.pid
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_agent(install_dir: pathlib.Path, grace: float = AGENT_RESTART_GRACE_S) -> None:
|
||||||
|
"""SIGTERM the PID we spawned; SIGKILL if it doesn't exit in ``grace`` s."""
|
||||||
|
pid_file = _pid_file(install_dir)
|
||||||
|
if not pid_file.is_file():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
pid = int(pid_file.read_text().strip())
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
os.kill(pid, signal.SIGTERM)
|
||||||
|
except ProcessLookupError:
|
||||||
|
return
|
||||||
|
deadline = time.monotonic() + grace
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
os.kill(pid, 0)
|
||||||
|
except ProcessLookupError:
|
||||||
|
return
|
||||||
|
time.sleep(0.2)
|
||||||
|
try:
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_agent(
|
||||||
|
agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR,
|
||||||
|
url: str = AGENT_PROBE_URL,
|
||||||
|
attempts: int = AGENT_PROBE_ATTEMPTS,
|
||||||
|
backoff_s: float = AGENT_PROBE_BACKOFF_S,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""Local mTLS health probe against the agent. Returns (ok, detail)."""
|
||||||
|
worker_key = agent_dir / "worker.key"
|
||||||
|
worker_crt = agent_dir / "worker.crt"
|
||||||
|
ca = agent_dir / "ca.crt"
|
||||||
|
if not (worker_key.is_file() and worker_crt.is_file() and ca.is_file()):
|
||||||
|
return False, f"no mTLS bundle at {agent_dir}"
|
||||||
|
ctx = ssl.create_default_context(cafile=str(ca))
|
||||||
|
ctx.load_cert_chain(certfile=str(worker_crt), keyfile=str(worker_key))
|
||||||
|
ctx.check_hostname = False
|
||||||
|
|
||||||
|
last = ""
|
||||||
|
for i in range(attempts):
|
||||||
|
try:
|
||||||
|
with httpx.Client(verify=ctx, timeout=3.0) as client:
|
||||||
|
r = client.get(url)
|
||||||
|
if r.status_code == 200:
|
||||||
|
return True, r.text
|
||||||
|
last = f"status={r.status_code} body={r.text[:200]}"
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
last = f"{type(exc).__name__}: {exc}"
|
||||||
|
if i < attempts - 1:
|
||||||
|
time.sleep(backoff_s)
|
||||||
|
return False, last
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------- orchestrator
|
||||||
|
|
||||||
|
def _write_manifest(release: pathlib.Path, sha: Optional[str]) -> None:
|
||||||
|
import json
|
||||||
|
|
||||||
|
_manifest_file(release).write_text(json.dumps({
|
||||||
|
"sha": sha,
|
||||||
|
"installed_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def _rotate(install_dir: pathlib.Path) -> None:
|
||||||
|
"""Rotate directories: prev→(deleted), active→prev, active.new→active.
|
||||||
|
|
||||||
|
Caller must ensure ``active.new`` exists. ``active`` may or may not.
|
||||||
|
"""
|
||||||
|
active = _active_dir(install_dir)
|
||||||
|
prev = _prev_dir(install_dir)
|
||||||
|
staging = _staging_dir(install_dir)
|
||||||
|
|
||||||
|
if prev.exists():
|
||||||
|
shutil.rmtree(prev)
|
||||||
|
if active.exists():
|
||||||
|
active.rename(prev)
|
||||||
|
staging.rename(active)
|
||||||
|
|
||||||
|
|
||||||
|
def _point_current_at(install_dir: pathlib.Path, target: pathlib.Path) -> None:
|
||||||
|
"""Atomic symlink flip via rename."""
|
||||||
|
link = _current_symlink(install_dir)
|
||||||
|
tmp = install_dir / ".current.tmp"
|
||||||
|
if tmp.exists() or tmp.is_symlink():
|
||||||
|
tmp.unlink()
|
||||||
|
tmp.symlink_to(target)
|
||||||
|
os.replace(tmp, link)
|
||||||
|
|
||||||
|
|
||||||
|
def run_update(
|
||||||
|
tarball_bytes: bytes,
|
||||||
|
sha: Optional[str],
|
||||||
|
install_dir: pathlib.Path = DEFAULT_INSTALL_DIR,
|
||||||
|
agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Apply an update atomically. Rolls back on probe failure."""
|
||||||
|
clean_stale_staging(install_dir)
|
||||||
|
staging = _staging_dir(install_dir)
|
||||||
|
|
||||||
|
extract_tarball(tarball_bytes, staging)
|
||||||
|
_write_manifest(staging, sha)
|
||||||
|
|
||||||
|
pip = _run_pip(staging)
|
||||||
|
if pip.returncode != 0:
|
||||||
|
shutil.rmtree(staging, ignore_errors=True)
|
||||||
|
raise UpdateError(
|
||||||
|
"pip install failed on new release", stderr=pip.stderr or pip.stdout,
|
||||||
|
)
|
||||||
|
|
||||||
|
_rotate(install_dir)
|
||||||
|
_point_current_at(install_dir, _active_dir(install_dir))
|
||||||
|
|
||||||
|
_stop_agent(install_dir)
|
||||||
|
_spawn_agent(install_dir)
|
||||||
|
|
||||||
|
ok, detail = _probe_agent(agent_dir=agent_dir)
|
||||||
|
if ok:
|
||||||
|
return {
|
||||||
|
"status": "updated",
|
||||||
|
"release": read_release(_active_dir(install_dir)).to_dict(),
|
||||||
|
"probe": detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Rollback.
|
||||||
|
log.warning("agent probe failed after update: %s — rolling back", detail)
|
||||||
|
_stop_agent(install_dir)
|
||||||
|
# Swap active <-> prev.
|
||||||
|
active = _active_dir(install_dir)
|
||||||
|
prev = _prev_dir(install_dir)
|
||||||
|
tmp = _releases_dir(install_dir) / ".swap"
|
||||||
|
if tmp.exists():
|
||||||
|
shutil.rmtree(tmp)
|
||||||
|
active.rename(tmp)
|
||||||
|
prev.rename(active)
|
||||||
|
tmp.rename(prev)
|
||||||
|
_point_current_at(install_dir, active)
|
||||||
|
_spawn_agent(install_dir)
|
||||||
|
ok2, detail2 = _probe_agent(agent_dir=agent_dir)
|
||||||
|
raise UpdateError(
|
||||||
|
"agent failed health probe after update; rolled back to previous release",
|
||||||
|
stderr=f"forward-probe: {detail}\nrollback-probe: {detail2}",
|
||||||
|
rolled_back=ok2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_rollback(
|
||||||
|
install_dir: pathlib.Path = DEFAULT_INSTALL_DIR,
|
||||||
|
agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Manually swap active with prev and restart the agent."""
|
||||||
|
active = _active_dir(install_dir)
|
||||||
|
prev = _prev_dir(install_dir)
|
||||||
|
if not prev.is_dir():
|
||||||
|
raise UpdateError("no previous release to roll back to")
|
||||||
|
|
||||||
|
_stop_agent(install_dir)
|
||||||
|
tmp = _releases_dir(install_dir) / ".swap"
|
||||||
|
if tmp.exists():
|
||||||
|
shutil.rmtree(tmp)
|
||||||
|
active.rename(tmp)
|
||||||
|
prev.rename(active)
|
||||||
|
tmp.rename(prev)
|
||||||
|
_point_current_at(install_dir, active)
|
||||||
|
_spawn_agent(install_dir)
|
||||||
|
ok, detail = _probe_agent(agent_dir=agent_dir)
|
||||||
|
if not ok:
|
||||||
|
raise UpdateError("agent unhealthy after rollback", stderr=detail)
|
||||||
|
return {
|
||||||
|
"status": "rolled_back",
|
||||||
|
"release": read_release(active).to_dict(),
|
||||||
|
"probe": detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_update_self(
|
||||||
|
tarball_bytes: bytes,
|
||||||
|
sha: Optional[str],
|
||||||
|
updater_install_dir: pathlib.Path,
|
||||||
|
exec_cb: Optional[Callable[[list[str]], None]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Replace the updater's own source tree, then re-exec this process.
|
||||||
|
|
||||||
|
No auto-rollback. Caller must treat "connection dropped + /health
|
||||||
|
returns new SHA within 30s" as success.
|
||||||
|
"""
|
||||||
|
clean_stale_staging(updater_install_dir)
|
||||||
|
staging = _staging_dir(updater_install_dir)
|
||||||
|
extract_tarball(tarball_bytes, staging)
|
||||||
|
_write_manifest(staging, sha)
|
||||||
|
|
||||||
|
pip = _run_pip(staging)
|
||||||
|
if pip.returncode != 0:
|
||||||
|
shutil.rmtree(staging, ignore_errors=True)
|
||||||
|
raise UpdateError(
|
||||||
|
"pip install failed on new updater release",
|
||||||
|
stderr=pip.stderr or pip.stdout,
|
||||||
|
)
|
||||||
|
|
||||||
|
_rotate(updater_install_dir)
|
||||||
|
_point_current_at(updater_install_dir, _active_dir(updater_install_dir))
|
||||||
|
|
||||||
|
argv = [str(_venv_python(_active_dir(updater_install_dir))), "-m", "decnet", "updater"] + sys.argv[1:]
|
||||||
|
if exec_cb is not None:
|
||||||
|
exec_cb(argv) # tests stub this — we don't actually re-exec
|
||||||
|
return {"status": "self_update_queued", "argv": argv}
|
||||||
|
# Returns nothing on success (replaces the process image).
|
||||||
|
os.execv(argv[0], argv) # nosec B606 - pragma: no cover
|
||||||
|
return {"status": "self_update_queued"} # pragma: no cover
|
||||||
86
decnet/updater/server.py
Normal file
86
decnet/updater/server.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Self-updater uvicorn launcher.
|
||||||
|
|
||||||
|
Parallels ``decnet/agent/server.py`` but uses a distinct bundle directory
|
||||||
|
(``~/.decnet/updater``) with a cert whose CN is ``updater@<host>``. That
|
||||||
|
cert is signed by the same DECNET CA as the agent's, so the master's one
|
||||||
|
CA still gates both channels; the CN is how we tell them apart.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import signal
|
||||||
|
import subprocess # nosec B404
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from decnet.logging import get_logger
|
||||||
|
from decnet.swarm import pki
|
||||||
|
|
||||||
|
log = get_logger("updater.server")
|
||||||
|
|
||||||
|
DEFAULT_UPDATER_DIR = pathlib.Path(os.path.expanduser("~/.decnet/updater"))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_bundle(updater_dir: pathlib.Path) -> bool:
|
||||||
|
return all(
|
||||||
|
(updater_dir / name).is_file()
|
||||||
|
for name in ("ca.crt", "updater.crt", "updater.key")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
updater_dir: pathlib.Path = DEFAULT_UPDATER_DIR,
|
||||||
|
install_dir: pathlib.Path = pathlib.Path("/opt/decnet"),
|
||||||
|
agent_dir: pathlib.Path = pki.DEFAULT_AGENT_DIR,
|
||||||
|
) -> int:
|
||||||
|
if not _load_bundle(updater_dir):
|
||||||
|
print(
|
||||||
|
f"[updater] No cert bundle at {updater_dir}. "
|
||||||
|
f"Run `decnet swarm enroll --updater` from the master first.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
# Pass config into the app module via env so uvicorn subprocess picks it up.
|
||||||
|
os.environ["DECNET_UPDATER_INSTALL_DIR"] = str(install_dir)
|
||||||
|
os.environ["DECNET_UPDATER_UPDATER_DIR"] = str(install_dir / "updater")
|
||||||
|
os.environ["DECNET_UPDATER_AGENT_DIR"] = str(agent_dir)
|
||||||
|
|
||||||
|
keyfile = updater_dir / "updater.key"
|
||||||
|
certfile = updater_dir / "updater.crt"
|
||||||
|
cafile = updater_dir / "ca.crt"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"uvicorn",
|
||||||
|
"decnet.updater.app:app",
|
||||||
|
"--host",
|
||||||
|
host,
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--ssl-keyfile",
|
||||||
|
str(keyfile),
|
||||||
|
"--ssl-certfile",
|
||||||
|
str(certfile),
|
||||||
|
"--ssl-ca-certs",
|
||||||
|
str(cafile),
|
||||||
|
"--ssl-cert-reqs",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
log.info("updater starting host=%s port=%d bundle=%s", host, port, updater_dir)
|
||||||
|
proc = subprocess.Popen(cmd, start_new_session=True) # nosec B603
|
||||||
|
try:
|
||||||
|
return proc.wait()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
try:
|
||||||
|
os.killpg(proc.pid, signal.SIGTERM)
|
||||||
|
try:
|
||||||
|
return proc.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
os.killpg(proc.pid, signal.SIGKILL)
|
||||||
|
return proc.wait()
|
||||||
|
except ProcessLookupError:
|
||||||
|
return 0
|
||||||
@@ -118,6 +118,10 @@ class SwarmHost(SQLModel, table=True):
|
|||||||
# ISO-8601 string of the last successful agent /health probe
|
# ISO-8601 string of the last successful agent /health probe
|
||||||
last_heartbeat: Optional[datetime] = Field(default=None)
|
last_heartbeat: Optional[datetime] = Field(default=None)
|
||||||
client_cert_fingerprint: str # SHA-256 hex of worker's issued client cert
|
client_cert_fingerprint: str # SHA-256 hex of worker's issued client cert
|
||||||
|
# SHA-256 hex of the updater-identity cert, if the host was enrolled
|
||||||
|
# with ``--updater`` / ``issue_updater_bundle``. ``None`` for hosts
|
||||||
|
# that only have an agent identity.
|
||||||
|
updater_cert_fingerprint: Optional[str] = Field(default=None)
|
||||||
# Directory on the master where the per-worker cert bundle lives
|
# Directory on the master where the per-worker cert bundle lives
|
||||||
cert_bundle_path: str
|
cert_bundle_path: str
|
||||||
enrolled_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
enrolled_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
@@ -281,6 +285,17 @@ class SwarmEnrollRequest(BaseModel):
|
|||||||
description="Extra SANs (IPs / hostnames) to embed in the worker cert",
|
description="Extra SANs (IPs / hostnames) to embed in the worker cert",
|
||||||
)
|
)
|
||||||
notes: Optional[str] = None
|
notes: Optional[str] = None
|
||||||
|
issue_updater_bundle: bool = PydanticField(
|
||||||
|
default=False,
|
||||||
|
description="If true, also issue an updater cert (CN=updater@<name>) for the remote self-updater",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmUpdaterBundle(BaseModel):
|
||||||
|
"""Subset of SwarmEnrolledBundle for the updater identity."""
|
||||||
|
fingerprint: str
|
||||||
|
updater_cert_pem: str
|
||||||
|
updater_key_pem: str
|
||||||
|
|
||||||
|
|
||||||
class SwarmEnrolledBundle(BaseModel):
|
class SwarmEnrolledBundle(BaseModel):
|
||||||
@@ -293,6 +308,7 @@ class SwarmEnrolledBundle(BaseModel):
|
|||||||
ca_cert_pem: str
|
ca_cert_pem: str
|
||||||
worker_cert_pem: str
|
worker_cert_pem: str
|
||||||
worker_key_pem: str
|
worker_key_pem: str
|
||||||
|
updater: Optional[SwarmUpdaterBundle] = None
|
||||||
|
|
||||||
|
|
||||||
class SwarmHostView(BaseModel):
|
class SwarmHostView(BaseModel):
|
||||||
@@ -303,6 +319,7 @@ class SwarmHostView(BaseModel):
|
|||||||
status: str
|
status: str
|
||||||
last_heartbeat: Optional[datetime] = None
|
last_heartbeat: Optional[datetime] = None
|
||||||
client_cert_fingerprint: str
|
client_cert_fingerprint: str
|
||||||
|
updater_cert_fingerprint: Optional[str] = None
|
||||||
enrolled_at: datetime
|
enrolled_at: datetime
|
||||||
notes: Optional[str] = None
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|||||||
@@ -12,13 +12,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
from decnet.swarm import pki
|
from decnet.swarm import pki
|
||||||
from decnet.web.db.repository import BaseRepository
|
from decnet.web.db.repository import BaseRepository
|
||||||
from decnet.web.dependencies import get_repo
|
from decnet.web.dependencies import get_repo
|
||||||
from decnet.web.db.models import SwarmEnrolledBundle, SwarmEnrollRequest
|
from decnet.web.db.models import SwarmEnrolledBundle, SwarmEnrollRequest, SwarmUpdaterBundle
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -46,6 +47,26 @@ async def api_enroll_host(
|
|||||||
bundle_dir = pki.DEFAULT_CA_DIR / "workers" / req.name
|
bundle_dir = pki.DEFAULT_CA_DIR / "workers" / req.name
|
||||||
pki.write_worker_bundle(issued, bundle_dir)
|
pki.write_worker_bundle(issued, bundle_dir)
|
||||||
|
|
||||||
|
updater_view: Optional[SwarmUpdaterBundle] = None
|
||||||
|
updater_fp: Optional[str] = None
|
||||||
|
if req.issue_updater_bundle:
|
||||||
|
updater_cn = f"updater@{req.name}"
|
||||||
|
updater_sans = list({*sans, updater_cn, "127.0.0.1"})
|
||||||
|
updater_issued = pki.issue_worker_cert(ca, updater_cn, updater_sans)
|
||||||
|
# Persist alongside the worker bundle for replay.
|
||||||
|
updater_dir = bundle_dir / "updater"
|
||||||
|
updater_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(updater_dir / "updater.crt").write_bytes(updater_issued.cert_pem)
|
||||||
|
(updater_dir / "updater.key").write_bytes(updater_issued.key_pem)
|
||||||
|
import os as _os
|
||||||
|
_os.chmod(updater_dir / "updater.key", 0o600)
|
||||||
|
updater_fp = updater_issued.fingerprint_sha256
|
||||||
|
updater_view = SwarmUpdaterBundle(
|
||||||
|
fingerprint=updater_fp,
|
||||||
|
updater_cert_pem=updater_issued.cert_pem.decode(),
|
||||||
|
updater_key_pem=updater_issued.key_pem.decode(),
|
||||||
|
)
|
||||||
|
|
||||||
host_uuid = str(_uuid.uuid4())
|
host_uuid = str(_uuid.uuid4())
|
||||||
await repo.add_swarm_host(
|
await repo.add_swarm_host(
|
||||||
{
|
{
|
||||||
@@ -55,6 +76,7 @@ async def api_enroll_host(
|
|||||||
"agent_port": req.agent_port,
|
"agent_port": req.agent_port,
|
||||||
"status": "enrolled",
|
"status": "enrolled",
|
||||||
"client_cert_fingerprint": issued.fingerprint_sha256,
|
"client_cert_fingerprint": issued.fingerprint_sha256,
|
||||||
|
"updater_cert_fingerprint": updater_fp,
|
||||||
"cert_bundle_path": str(bundle_dir),
|
"cert_bundle_path": str(bundle_dir),
|
||||||
"enrolled_at": datetime.now(timezone.utc),
|
"enrolled_at": datetime.now(timezone.utc),
|
||||||
"notes": req.notes,
|
"notes": req.notes,
|
||||||
@@ -69,4 +91,5 @@ async def api_enroll_host(
|
|||||||
ca_cert_pem=issued.ca_cert_pem.decode(),
|
ca_cert_pem=issued.ca_cert_pem.decode(),
|
||||||
worker_cert_pem=issued.cert_pem.decode(),
|
worker_cert_pem=issued.cert_pem.decode(),
|
||||||
worker_key_pem=issued.key_pem.decode(),
|
worker_key_pem=issued.key_pem.decode(),
|
||||||
|
updater=updater_view,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ dependencies = [
|
|||||||
"sqlmodel>=0.0.16",
|
"sqlmodel>=0.0.16",
|
||||||
"scapy>=2.6.1",
|
"scapy>=2.6.1",
|
||||||
"orjson>=3.10",
|
"orjson>=3.10",
|
||||||
"cryptography>=46.0.7"
|
"cryptography>=46.0.7",
|
||||||
|
"python-multipart>=0.0.20"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
192
tests/swarm/test_cli_swarm_update.py
Normal file
192
tests/swarm/test_cli_swarm_update.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""CLI `decnet swarm update` — target resolution, tarring, push aggregation.
|
||||||
|
|
||||||
|
The UpdaterClient is stubbed: we are testing the CLI's orchestration, not
|
||||||
|
the wire protocol (that has test_updater_app.py and UpdaterClient round-
|
||||||
|
trips live under test_swarm_api.py integration).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from decnet import cli as cli_mod
|
||||||
|
from decnet.cli import app
|
||||||
|
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResp:
|
||||||
|
def __init__(self, payload: Any, status: int = 200):
|
||||||
|
self._payload = payload
|
||||||
|
self.status_code = status
|
||||||
|
self.text = json.dumps(payload) if not isinstance(payload, str) else payload
|
||||||
|
self.content = self.text.encode()
|
||||||
|
|
||||||
|
def json(self) -> Any:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def http_stub(monkeypatch: pytest.MonkeyPatch) -> dict:
|
||||||
|
state: dict = {"hosts": []}
|
||||||
|
|
||||||
|
def _fake(method, url, *, json_body=None, timeout=30.0):
|
||||||
|
if method == "GET" and url.endswith("/swarm/hosts"):
|
||||||
|
return _FakeResp(state["hosts"])
|
||||||
|
raise AssertionError(f"Unscripted HTTP call: {method} {url}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli_mod, "_http_request", _fake)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class _StubUpdaterClient:
|
||||||
|
"""Mirrors UpdaterClient's async-context-manager surface."""
|
||||||
|
instances: list["_StubUpdaterClient"] = []
|
||||||
|
behavior: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __init__(self, host, *, updater_port: int = 8766, **_: Any):
|
||||||
|
self.host = host
|
||||||
|
self.port = updater_port
|
||||||
|
self.calls: list[str] = []
|
||||||
|
_StubUpdaterClient.instances.append(self)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "_StubUpdaterClient":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc: Any) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update(self, tarball: bytes, sha: str = "") -> _FakeResp:
|
||||||
|
self.calls.append("update")
|
||||||
|
return _StubUpdaterClient.behavior.get(
|
||||||
|
self.host.get("name"),
|
||||||
|
_FakeResp({"status": "updated", "release": {"sha": sha}}, 200),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_self(self, tarball: bytes, sha: str = "") -> _FakeResp:
|
||||||
|
self.calls.append("update_self")
|
||||||
|
return _FakeResp({"status": "self_update_queued"}, 200)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def stub_updater(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
_StubUpdaterClient.instances.clear()
|
||||||
|
_StubUpdaterClient.behavior.clear()
|
||||||
|
monkeypatch.setattr("decnet.swarm.updater_client.UpdaterClient", _StubUpdaterClient)
|
||||||
|
# Also patch the module-level import inside cli.py's swarm_update closure.
|
||||||
|
import decnet.cli # noqa: F401
|
||||||
|
return _StubUpdaterClient
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_source_tree(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||||
|
root = tmp_path / "src"
|
||||||
|
root.mkdir()
|
||||||
|
(root / "decnet").mkdir()
|
||||||
|
(root / "decnet" / "a.py").write_text("x = 1")
|
||||||
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------- arg validation
|
||||||
|
|
||||||
|
def test_update_requires_host_or_all(http_stub) -> None:
|
||||||
|
r = runner.invoke(app, ["swarm", "update"])
|
||||||
|
assert r.exit_code == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_host_and_all_are_mutex(http_stub) -> None:
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--host", "w1", "--all"])
|
||||||
|
assert r.exit_code == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_unknown_host_exits_1(http_stub) -> None:
|
||||||
|
http_stub["hosts"] = [{"uuid": "u1", "name": "other", "address": "10.0.0.1", "status": "active"}]
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--host", "nope"])
|
||||||
|
assert r.exit_code == 1
|
||||||
|
assert "No enrolled worker" in r.output
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- happy paths
|
||||||
|
|
||||||
|
def test_update_single_host(http_stub, stub_updater, tmp_path: pathlib.Path) -> None:
|
||||||
|
http_stub["hosts"] = [
|
||||||
|
{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"},
|
||||||
|
{"uuid": "u2", "name": "w2", "address": "10.0.0.2", "status": "active"},
|
||||||
|
]
|
||||||
|
root = _mk_source_tree(tmp_path)
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--host", "w1", "--root", str(root)])
|
||||||
|
assert r.exit_code == 0, r.output
|
||||||
|
assert "w1" in r.output
|
||||||
|
# Only w1 got a client; w2 is untouched.
|
||||||
|
names = [c.host["name"] for c in stub_updater.instances]
|
||||||
|
assert names == ["w1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_all_skips_decommissioned(http_stub, stub_updater, tmp_path: pathlib.Path) -> None:
|
||||||
|
http_stub["hosts"] = [
|
||||||
|
{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"},
|
||||||
|
{"uuid": "u2", "name": "w2", "address": "10.0.0.2", "status": "decommissioned"},
|
||||||
|
{"uuid": "u3", "name": "w3", "address": "10.0.0.3", "status": "enrolled"},
|
||||||
|
]
|
||||||
|
root = _mk_source_tree(tmp_path)
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root)])
|
||||||
|
assert r.exit_code == 0, r.output
|
||||||
|
hit = sorted(c.host["name"] for c in stub_updater.instances)
|
||||||
|
assert hit == ["w1", "w3"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_include_self_calls_both(
|
||||||
|
http_stub, stub_updater, tmp_path: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}]
|
||||||
|
root = _mk_source_tree(tmp_path)
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root), "--include-self"])
|
||||||
|
assert r.exit_code == 0
|
||||||
|
assert stub_updater.instances[0].calls == ["update", "update_self"]
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------- failure modes
|
||||||
|
|
||||||
|
def test_update_rollback_status_409_flags_failure(
|
||||||
|
http_stub, stub_updater, tmp_path: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}]
|
||||||
|
_StubUpdaterClient.behavior["w1"] = _FakeResp(
|
||||||
|
{"detail": {"error": "probe failed", "rolled_back": True}},
|
||||||
|
status=409,
|
||||||
|
)
|
||||||
|
root = _mk_source_tree(tmp_path)
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root)])
|
||||||
|
assert r.exit_code == 1
|
||||||
|
assert "rolled-back" in r.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_include_self_skipped_when_agent_update_failed(
|
||||||
|
http_stub, stub_updater, tmp_path: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}]
|
||||||
|
_StubUpdaterClient.behavior["w1"] = _FakeResp(
|
||||||
|
{"detail": {"error": "pip failed"}}, status=500,
|
||||||
|
)
|
||||||
|
root = _mk_source_tree(tmp_path)
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root), "--include-self"])
|
||||||
|
assert r.exit_code == 1
|
||||||
|
# update_self must NOT have been called — agent update failed.
|
||||||
|
assert stub_updater.instances[0].calls == ["update"]
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------- dry run
|
||||||
|
|
||||||
|
def test_update_dry_run_does_not_call_updater(
|
||||||
|
http_stub, stub_updater, tmp_path: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
http_stub["hosts"] = [{"uuid": "u1", "name": "w1", "address": "10.0.0.1", "status": "active"}]
|
||||||
|
root = _mk_source_tree(tmp_path)
|
||||||
|
r = runner.invoke(app, ["swarm", "update", "--all", "--root", str(root), "--dry-run"])
|
||||||
|
assert r.exit_code == 0
|
||||||
|
assert stub_updater.instances == []
|
||||||
|
assert "dry-run" in r.output.lower()
|
||||||
@@ -78,6 +78,36 @@ def test_enroll_creates_host_and_returns_bundle(client: TestClient) -> None:
|
|||||||
assert len(body["fingerprint"]) == 64 # sha256 hex
|
assert len(body["fingerprint"]) == 64 # sha256 hex
|
||||||
|
|
||||||
|
|
||||||
|
def test_enroll_with_updater_issues_second_cert(client: TestClient, ca_dir) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/swarm/enroll",
|
||||||
|
json={"name": "worker-upd", "address": "10.0.0.99", "agent_port": 8765,
|
||||||
|
"issue_updater_bundle": True},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["updater"] is not None
|
||||||
|
assert body["updater"]["fingerprint"] != body["fingerprint"]
|
||||||
|
assert "-----BEGIN CERTIFICATE-----" in body["updater"]["updater_cert_pem"]
|
||||||
|
assert "-----BEGIN PRIVATE KEY-----" in body["updater"]["updater_key_pem"]
|
||||||
|
# Cert bundle persisted on master.
|
||||||
|
upd_bundle = ca_dir / "workers" / "worker-upd" / "updater"
|
||||||
|
assert (upd_bundle / "updater.crt").is_file()
|
||||||
|
assert (upd_bundle / "updater.key").is_file()
|
||||||
|
# DB row carries the updater fingerprint.
|
||||||
|
row = client.get(f"/swarm/hosts/{body['host_uuid']}").json()
|
||||||
|
assert row.get("updater_cert_fingerprint") == body["updater"]["fingerprint"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_enroll_without_updater_omits_bundle(client: TestClient) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/swarm/enroll",
|
||||||
|
json={"name": "worker-no-upd", "address": "10.0.0.98", "agent_port": 8765},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
assert resp.json()["updater"] is None
|
||||||
|
|
||||||
|
|
||||||
def test_enroll_rejects_duplicate_name(client: TestClient) -> None:
|
def test_enroll_rejects_duplicate_name(client: TestClient) -> None:
|
||||||
payload = {"name": "worker-dup", "address": "10.0.0.6", "agent_port": 8765}
|
payload = {"name": "worker-dup", "address": "10.0.0.6", "agent_port": 8765}
|
||||||
assert client.post("/swarm/enroll", json=payload).status_code == 201
|
assert client.post("/swarm/enroll", json=payload).status_code == 201
|
||||||
|
|||||||
75
tests/swarm/test_tar_tree.py
Normal file
75
tests/swarm/test_tar_tree.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""tar_working_tree: exclude filter, tarball validity, git SHA detection."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
from decnet.swarm.tar_tree import detect_git_sha, tar_working_tree
|
||||||
|
|
||||||
|
|
||||||
|
def _tree_names(data: bytes) -> set[str]:
|
||||||
|
with tarfile.open(fileobj=io.BytesIO(data), mode="r:gz") as tar:
|
||||||
|
return {m.name for m in tar.getmembers()}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tar_excludes_default_patterns(tmp_path: pathlib.Path) -> None:
|
||||||
|
(tmp_path / "decnet").mkdir()
|
||||||
|
(tmp_path / "decnet" / "keep.py").write_text("x = 1")
|
||||||
|
(tmp_path / ".venv").mkdir()
|
||||||
|
(tmp_path / ".venv" / "pyvenv.cfg").write_text("junk")
|
||||||
|
(tmp_path / ".git").mkdir()
|
||||||
|
(tmp_path / ".git" / "HEAD").write_text("ref: refs/heads/main\n")
|
||||||
|
(tmp_path / "decnet" / "__pycache__").mkdir()
|
||||||
|
(tmp_path / "decnet" / "__pycache__" / "keep.cpython-311.pyc").write_text("bytecode")
|
||||||
|
(tmp_path / "wiki-checkout").mkdir()
|
||||||
|
(tmp_path / "wiki-checkout" / "Home.md").write_text("# wiki")
|
||||||
|
(tmp_path / "run.db").write_text("sqlite")
|
||||||
|
(tmp_path / "master.log").write_text("log")
|
||||||
|
|
||||||
|
data = tar_working_tree(tmp_path)
|
||||||
|
names = _tree_names(data)
|
||||||
|
assert "decnet/keep.py" in names
|
||||||
|
assert all(".venv" not in n for n in names)
|
||||||
|
assert all(".git" not in n for n in names)
|
||||||
|
assert all("__pycache__" not in n for n in names)
|
||||||
|
assert all("wiki-checkout" not in n for n in names)
|
||||||
|
assert "run.db" not in names
|
||||||
|
assert "master.log" not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_tar_accepts_extra_excludes(tmp_path: pathlib.Path) -> None:
|
||||||
|
(tmp_path / "a.py").write_text("x")
|
||||||
|
(tmp_path / "secret.env").write_text("TOKEN=abc")
|
||||||
|
data = tar_working_tree(tmp_path, extra_excludes=["secret.env"])
|
||||||
|
names = _tree_names(data)
|
||||||
|
assert "a.py" in names
|
||||||
|
assert "secret.env" not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_tar_skips_symlinks(tmp_path: pathlib.Path) -> None:
|
||||||
|
(tmp_path / "real.txt").write_text("hi")
|
||||||
|
try:
|
||||||
|
(tmp_path / "link.txt").symlink_to(tmp_path / "real.txt")
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
return # platform doesn't support symlinks — skip
|
||||||
|
names = _tree_names(tar_working_tree(tmp_path))
|
||||||
|
assert "real.txt" in names
|
||||||
|
assert "link.txt" not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_git_sha_from_ref(tmp_path: pathlib.Path) -> None:
|
||||||
|
(tmp_path / ".git" / "refs" / "heads").mkdir(parents=True)
|
||||||
|
(tmp_path / ".git" / "refs" / "heads" / "main").write_text("deadbeef" * 5 + "\n")
|
||||||
|
(tmp_path / ".git" / "HEAD").write_text("ref: refs/heads/main\n")
|
||||||
|
assert detect_git_sha(tmp_path).startswith("deadbeef")
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_git_sha_detached(tmp_path: pathlib.Path) -> None:
|
||||||
|
(tmp_path / ".git").mkdir()
|
||||||
|
(tmp_path / ".git" / "HEAD").write_text("f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0\n")
|
||||||
|
assert detect_git_sha(tmp_path).startswith("f0f0")
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_git_sha_none_when_not_repo(tmp_path: pathlib.Path) -> None:
|
||||||
|
assert detect_git_sha(tmp_path) == ""
|
||||||
0
tests/updater/__init__.py
Normal file
0
tests/updater/__init__.py
Normal file
138
tests/updater/test_updater_app.py
Normal file
138
tests/updater/test_updater_app.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""HTTP contract for the updater app.
|
||||||
|
|
||||||
|
Executor functions are monkeypatched — we're testing wire format, not
|
||||||
|
the rotation logic (that has test_updater_executor.py).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from decnet.updater import app as app_mod
|
||||||
|
from decnet.updater import executor as ex
|
||||||
|
|
||||||
|
|
||||||
|
def _tarball(files: dict[str, str] | None = None) -> bytes:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||||
|
for name, content in (files or {"a": "b"}).items():
|
||||||
|
data = content.encode()
|
||||||
|
info = tarfile.TarInfo(name=name)
|
||||||
|
info.size = len(data)
|
||||||
|
tar.addfile(info, io.BytesIO(data))
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(tmp_path: pathlib.Path) -> TestClient:
|
||||||
|
app_mod.configure(
|
||||||
|
install_dir=tmp_path / "install",
|
||||||
|
updater_install_dir=tmp_path / "install" / "updater",
|
||||||
|
agent_dir=tmp_path / "agent",
|
||||||
|
)
|
||||||
|
(tmp_path / "install" / "releases").mkdir(parents=True)
|
||||||
|
return TestClient(app_mod.app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_returns_role_and_releases(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(ex, "list_releases", lambda d: [])
|
||||||
|
r = client.get("/health")
|
||||||
|
assert r.status_code == 200
|
||||||
|
body = r.json()
|
||||||
|
assert body["status"] == "ok"
|
||||||
|
assert body["role"] == "updater"
|
||||||
|
assert body["releases"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_happy_path(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ex, "run_update",
|
||||||
|
lambda data, sha, install_dir, agent_dir: {"status": "updated", "release": {"slot": "active", "sha": sha}, "probe": "ok"},
|
||||||
|
)
|
||||||
|
r = client.post(
|
||||||
|
"/update",
|
||||||
|
files={"tarball": ("tree.tgz", _tarball(), "application/gzip")},
|
||||||
|
data={"sha": "ABC123"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 200, r.text
|
||||||
|
assert r.json()["release"]["sha"] == "ABC123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_rollback_returns_409(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _boom(*a, **kw):
|
||||||
|
raise ex.UpdateError("probe failed; rolled back", stderr="connection refused", rolled_back=True)
|
||||||
|
monkeypatch.setattr(ex, "run_update", _boom)
|
||||||
|
|
||||||
|
r = client.post(
|
||||||
|
"/update",
|
||||||
|
files={"tarball": ("t.tgz", _tarball(), "application/gzip")},
|
||||||
|
data={"sha": ""},
|
||||||
|
)
|
||||||
|
assert r.status_code == 409, r.text
|
||||||
|
detail = r.json()["detail"]
|
||||||
|
assert detail["rolled_back"] is True
|
||||||
|
assert "connection refused" in detail["stderr"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_hard_failure_returns_500(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _boom(*a, **kw):
|
||||||
|
raise ex.UpdateError("pip install failed", stderr="resolver error")
|
||||||
|
monkeypatch.setattr(ex, "run_update", _boom)
|
||||||
|
|
||||||
|
r = client.post("/update", files={"tarball": ("t.tgz", _tarball(), "application/gzip")})
|
||||||
|
assert r.status_code == 500
|
||||||
|
assert r.json()["detail"]["rolled_back"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_self_requires_confirm(client: TestClient) -> None:
|
||||||
|
r = client.post("/update-self", files={"tarball": ("t.tgz", _tarball(), "application/gzip")})
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert "confirm_self" in r.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_self_happy_path(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ex, "run_update_self",
|
||||||
|
lambda data, sha, updater_install_dir: {"status": "self_update_queued", "argv": ["python", "-m", "decnet", "updater"]},
|
||||||
|
)
|
||||||
|
r = client.post(
|
||||||
|
"/update-self",
|
||||||
|
files={"tarball": ("t.tgz", _tarball(), "application/gzip")},
|
||||||
|
data={"sha": "S", "confirm_self": "true"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["status"] == "self_update_queued"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rollback_happy(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ex, "run_rollback",
|
||||||
|
lambda install_dir, agent_dir: {"status": "rolled_back", "release": {"slot": "active", "sha": "O"}, "probe": "ok"},
|
||||||
|
)
|
||||||
|
r = client.post("/rollback")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["status"] == "rolled_back"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rollback_missing_prev_returns_404(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _boom(**_):
|
||||||
|
raise ex.UpdateError("no previous release to roll back to")
|
||||||
|
monkeypatch.setattr(ex, "run_rollback", _boom)
|
||||||
|
r = client.post("/rollback")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_releases_lists_slots(client: TestClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ex, "list_releases",
|
||||||
|
lambda d: [ex.Release(slot="active", sha="A", installed_at=None),
|
||||||
|
ex.Release(slot="prev", sha="B", installed_at=None)],
|
||||||
|
)
|
||||||
|
r = client.get("/releases")
|
||||||
|
assert r.status_code == 200
|
||||||
|
slots = [rel["slot"] for rel in r.json()["releases"]]
|
||||||
|
assert slots == ["active", "prev"]
|
||||||
295
tests/updater/test_updater_executor.py
Normal file
295
tests/updater/test_updater_executor.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Updater executor: directory rotation, probe-driven rollback, safety checks.
|
||||||
|
|
||||||
|
All three real seams (`_run_pip`, `_spawn_agent`, `_stop_agent`,
|
||||||
|
`_probe_agent`) are monkeypatched so these tests never shell out or
|
||||||
|
touch a real Python venv. The rotation/symlink/extract logic is exercised
|
||||||
|
against a ``tmp_path`` install dir.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import pathlib
|
||||||
|
import subprocess
|
||||||
|
import tarfile
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from decnet.updater import executor as ex
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ helpers
|
||||||
|
|
||||||
|
def _make_tarball(files: dict[str, str]) -> bytes:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||||
|
for name, content in files.items():
|
||||||
|
data = content.encode()
|
||||||
|
info = tarfile.TarInfo(name=name)
|
||||||
|
info.size = len(data)
|
||||||
|
tar.addfile(info, io.BytesIO(data))
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
class _PipOK:
|
||||||
|
returncode = 0
|
||||||
|
stdout = ""
|
||||||
|
stderr = ""
|
||||||
|
|
||||||
|
|
||||||
|
class _PipFail:
|
||||||
|
returncode = 1
|
||||||
|
stdout = ""
|
||||||
|
stderr = "resolver error: Could not find a version that satisfies ..."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def install_dir(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||||
|
d = tmp_path / "decnet"
|
||||||
|
d.mkdir()
|
||||||
|
(d / "releases").mkdir()
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_dir(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||||
|
d = tmp_path / "agent"
|
||||||
|
d.mkdir()
|
||||||
|
# executor._probe_agent checks these exist before constructing SSL ctx,
|
||||||
|
# but the probe seam is monkeypatched in every test so content doesn't
|
||||||
|
# matter — still create them so the non-stubbed path is representative.
|
||||||
|
(d / "ca.crt").write_bytes(b"-----BEGIN CERTIFICATE-----\nstub\n-----END CERTIFICATE-----\n")
|
||||||
|
(d / "worker.crt").write_bytes(b"-----BEGIN CERTIFICATE-----\nstub\n-----END CERTIFICATE-----\n")
|
||||||
|
(d / "worker.key").write_bytes(b"-----BEGIN PRIVATE KEY-----\nstub\n-----END PRIVATE KEY-----\n")
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def seed_existing_release(install_dir: pathlib.Path) -> None:
|
||||||
|
"""Pretend an install is already live: create releases/active with a marker."""
|
||||||
|
active = install_dir / "releases" / "active"
|
||||||
|
active.mkdir()
|
||||||
|
(active / "marker.txt").write_text("old")
|
||||||
|
ex._write_manifest(active, sha="OLDSHA")
|
||||||
|
# current -> active
|
||||||
|
ex._point_current_at(install_dir, active)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------- extract + safety
|
||||||
|
|
||||||
|
def test_extract_rejects_path_traversal(tmp_path: pathlib.Path) -> None:
|
||||||
|
evil = _make_tarball({"../escape.txt": "pwned"})
|
||||||
|
with pytest.raises(ex.UpdateError, match="unsafe path"):
|
||||||
|
ex.extract_tarball(evil, tmp_path / "out")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_rejects_absolute_paths(tmp_path: pathlib.Path) -> None:
|
||||||
|
evil = _make_tarball({"/etc/passwd": "root:x:0:0"})
|
||||||
|
with pytest.raises(ex.UpdateError, match="unsafe path"):
|
||||||
|
ex.extract_tarball(evil, tmp_path / "out")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_happy_path(tmp_path: pathlib.Path) -> None:
|
||||||
|
tb = _make_tarball({"a/b.txt": "hello"})
|
||||||
|
out = tmp_path / "out"
|
||||||
|
ex.extract_tarball(tb, out)
|
||||||
|
assert (out / "a" / "b.txt").read_text() == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_stale_staging(install_dir: pathlib.Path) -> None:
|
||||||
|
staging = install_dir / "releases" / "active.new"
|
||||||
|
staging.mkdir()
|
||||||
|
(staging / "junk").write_text("left from a crash")
|
||||||
|
ex.clean_stale_staging(install_dir)
|
||||||
|
assert not staging.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- happy path
|
||||||
|
|
||||||
|
def test_update_rotates_and_probes(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
agent_dir: pathlib.Path,
|
||||||
|
seed_existing_release: None,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK())
|
||||||
|
monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None)
|
||||||
|
monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 42)
|
||||||
|
monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok"))
|
||||||
|
|
||||||
|
tb = _make_tarball({"marker.txt": "new"})
|
||||||
|
result = ex.run_update(tb, sha="NEWSHA", install_dir=install_dir, agent_dir=agent_dir)
|
||||||
|
|
||||||
|
assert result["status"] == "updated"
|
||||||
|
assert result["release"]["sha"] == "NEWSHA"
|
||||||
|
assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "new"
|
||||||
|
# Old release demoted, not deleted.
|
||||||
|
assert (install_dir / "releases" / "prev" / "marker.txt").read_text() == "old"
|
||||||
|
# Current symlink points at the new active.
|
||||||
|
assert (install_dir / "current").resolve() == (install_dir / "releases" / "active").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_first_install_without_previous(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
agent_dir: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
"""No existing active/ dir — first real install via the updater."""
|
||||||
|
monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK())
|
||||||
|
monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None)
|
||||||
|
monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1)
|
||||||
|
monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok"))
|
||||||
|
|
||||||
|
tb = _make_tarball({"marker.txt": "first"})
|
||||||
|
result = ex.run_update(tb, sha="S1", install_dir=install_dir, agent_dir=agent_dir)
|
||||||
|
assert result["status"] == "updated"
|
||||||
|
assert not (install_dir / "releases" / "prev").exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------ pip failure
|
||||||
|
|
||||||
|
def test_update_pip_failure_aborts_before_rotation(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
agent_dir: pathlib.Path,
|
||||||
|
seed_existing_release: None,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(ex, "_run_pip", lambda release: _PipFail())
|
||||||
|
stop_called: list[bool] = []
|
||||||
|
monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: stop_called.append(True))
|
||||||
|
monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1)
|
||||||
|
monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok"))
|
||||||
|
|
||||||
|
tb = _make_tarball({"marker.txt": "new"})
|
||||||
|
with pytest.raises(ex.UpdateError, match="pip install failed") as ei:
|
||||||
|
ex.run_update(tb, sha="S", install_dir=install_dir, agent_dir=agent_dir)
|
||||||
|
assert "resolver error" in ei.value.stderr
|
||||||
|
|
||||||
|
# Nothing rotated — old active still live, no prev created.
|
||||||
|
assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "old"
|
||||||
|
assert not (install_dir / "releases" / "prev").exists()
|
||||||
|
# Agent never touched.
|
||||||
|
assert stop_called == []
|
||||||
|
# Staging cleaned up.
|
||||||
|
assert not (install_dir / "releases" / "active.new").exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------ probe failure
|
||||||
|
|
||||||
|
def test_update_probe_failure_rolls_back(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
agent_dir: pathlib.Path,
|
||||||
|
seed_existing_release: None,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK())
|
||||||
|
monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None)
|
||||||
|
monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1)
|
||||||
|
|
||||||
|
calls: list[int] = [0]
|
||||||
|
|
||||||
|
def _probe(**_: Any) -> tuple[bool, str]:
|
||||||
|
calls[0] += 1
|
||||||
|
if calls[0] == 1:
|
||||||
|
return False, "connection refused"
|
||||||
|
return True, "ok" # rollback probe succeeds
|
||||||
|
|
||||||
|
monkeypatch.setattr(ex, "_probe_agent", _probe)
|
||||||
|
|
||||||
|
tb = _make_tarball({"marker.txt": "new"})
|
||||||
|
with pytest.raises(ex.UpdateError, match="health probe") as ei:
|
||||||
|
ex.run_update(tb, sha="NEWSHA", install_dir=install_dir, agent_dir=agent_dir)
|
||||||
|
assert ei.value.rolled_back is True
|
||||||
|
assert "connection refused" in ei.value.stderr
|
||||||
|
|
||||||
|
# Rolled back: active has the old marker again.
|
||||||
|
assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "old"
|
||||||
|
# Prev now holds what would have been the new release.
|
||||||
|
assert (install_dir / "releases" / "prev" / "marker.txt").read_text() == "new"
|
||||||
|
# Current symlink points back at active.
|
||||||
|
assert (install_dir / "current").resolve() == (install_dir / "releases" / "active").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------ manual rollback
|
||||||
|
|
||||||
|
def test_manual_rollback_swaps(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
agent_dir: pathlib.Path,
|
||||||
|
seed_existing_release: None,
|
||||||
|
) -> None:
|
||||||
|
# Seed a prev/ so rollback has somewhere to go.
|
||||||
|
prev = install_dir / "releases" / "prev"
|
||||||
|
prev.mkdir()
|
||||||
|
(prev / "marker.txt").write_text("older")
|
||||||
|
ex._write_manifest(prev, sha="OLDERSHA")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ex, "_stop_agent", lambda *a, **k: None)
|
||||||
|
monkeypatch.setattr(ex, "_spawn_agent", lambda *a, **k: 1)
|
||||||
|
monkeypatch.setattr(ex, "_probe_agent", lambda **_: (True, "ok"))
|
||||||
|
|
||||||
|
result = ex.run_rollback(install_dir=install_dir, agent_dir=agent_dir)
|
||||||
|
assert result["status"] == "rolled_back"
|
||||||
|
assert (install_dir / "releases" / "active" / "marker.txt").read_text() == "older"
|
||||||
|
assert (install_dir / "releases" / "prev" / "marker.txt").read_text() == "old"
|
||||||
|
|
||||||
|
|
||||||
|
def test_manual_rollback_refuses_without_prev(
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
seed_existing_release: None,
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(ex.UpdateError, match="no previous release"):
|
||||||
|
ex.run_rollback(install_dir=install_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- releases
|
||||||
|
|
||||||
|
def test_list_releases_includes_only_existing_slots(
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
seed_existing_release: None,
|
||||||
|
) -> None:
|
||||||
|
rs = ex.list_releases(install_dir)
|
||||||
|
assert [r.slot for r in rs] == ["active"]
|
||||||
|
assert rs[0].sha == "OLDSHA"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------- self-update
|
||||||
|
|
||||||
|
def test_update_self_rotates_and_calls_exec_cb(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
# Seed a stand-in "active" for the updater itself.
|
||||||
|
active = install_dir / "releases" / "active"
|
||||||
|
active.mkdir()
|
||||||
|
(active / "marker").write_text("old-updater")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ex, "_run_pip", lambda release: _PipOK())
|
||||||
|
seen_argv: list[list[str]] = []
|
||||||
|
|
||||||
|
tb = _make_tarball({"marker": "new-updater"})
|
||||||
|
result = ex.run_update_self(
|
||||||
|
tb, sha="USHA", updater_install_dir=install_dir,
|
||||||
|
exec_cb=lambda argv: seen_argv.append(argv),
|
||||||
|
)
|
||||||
|
assert result["status"] == "self_update_queued"
|
||||||
|
assert (install_dir / "releases" / "active" / "marker").read_text() == "new-updater"
|
||||||
|
assert (install_dir / "releases" / "prev" / "marker").read_text() == "old-updater"
|
||||||
|
assert len(seen_argv) == 1
|
||||||
|
assert "updater" in seen_argv[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_self_pip_failure_leaves_active_intact(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
install_dir: pathlib.Path,
|
||||||
|
) -> None:
|
||||||
|
active = install_dir / "releases" / "active"
|
||||||
|
active.mkdir()
|
||||||
|
(active / "marker").write_text("old-updater")
|
||||||
|
monkeypatch.setattr(ex, "_run_pip", lambda release: _PipFail())
|
||||||
|
|
||||||
|
tb = _make_tarball({"marker": "new-updater"})
|
||||||
|
with pytest.raises(ex.UpdateError, match="pip install failed"):
|
||||||
|
ex.run_update_self(tb, sha="U", updater_install_dir=install_dir, exec_cb=lambda a: None)
|
||||||
|
assert (install_dir / "releases" / "active" / "marker").read_text() == "old-updater"
|
||||||
|
assert not (install_dir / "releases" / "active.new").exists()
|
||||||
Reference in New Issue
Block a user