refactor(swarm): extract _shard_payload helper and promote _dispatch to module-level

This commit is contained in:
2026-04-30 20:25:38 -04:00
parent c648d8b04e
commit e124f9e296

View File

@@ -57,27 +57,33 @@ def _worker_config(
return base.model_copy(update=updates)
async def dispatch_decnet_config(
def _shard_payload(
d: DeckyConfig,
host_uuid: str,
state: str,
error: str | None,
) -> dict[str, Any]:
return {
"decky_name": d.name,
"host_uuid": host_uuid,
"services": json.dumps(d.services),
"decky_config": d.model_dump_json(),
"decky_ip": d.ip,
"state": state,
"last_error": error,
"updated_at": datetime.now(timezone.utc),
}
async def _dispatch(
host_uuid: str,
shard: list[DeckyConfig],
hosts: dict[str, dict[str, Any]],
config: DecnetConfig,
repo: BaseRepository,
dry_run: bool = False,
no_cache: bool = False,
) -> SwarmDeployResponse:
"""Shard ``config`` by ``host_uuid`` and dispatch to each worker in parallel.
Shared between POST /swarm/deploy (explicit swarm call) and the auto-swarm
branch of POST /deckies/deploy.
"""
buckets = _shard_by_host(config)
hosts: dict[str, dict[str, Any]] = {}
for host_uuid in buckets:
row = await repo.get_swarm_host_by_uuid(host_uuid)
if row is None:
raise HTTPException(status_code=404, detail=f"unknown host_uuid: {host_uuid}")
hosts[host_uuid] = row
async def _dispatch(host_uuid: str, shard: list[DeckyConfig]) -> SwarmHostResult:
dry_run: bool,
no_cache: bool,
) -> SwarmHostResult:
host = hosts[host_uuid]
cfg = _worker_config(config, shard, host)
try:
@@ -85,16 +91,7 @@ async def dispatch_decnet_config(
body = await agent.deploy(cfg, dry_run=dry_run, no_cache=no_cache)
for d in shard:
await repo.upsert_decky_shard(
{
"decky_name": d.name,
"host_uuid": host_uuid,
"services": json.dumps(d.services),
"decky_config": d.model_dump_json(),
"decky_ip": d.ip,
"state": "running" if not dry_run else "pending",
"last_error": None,
"updated_at": datetime.now(timezone.utc),
}
_shard_payload(d, host_uuid, "running" if not dry_run else "pending", None)
)
await repo.update_swarm_host(host_uuid, {"status": "active"})
return SwarmHostResult(host_uuid=host_uuid, host_name=host["name"], ok=True, detail=body)
@@ -116,21 +113,36 @@ async def dispatch_decnet_config(
rstate = runtime.get(d.name) or {}
is_up = bool(rstate.get("running"))
await repo.upsert_decky_shard(
{
"decky_name": d.name,
"host_uuid": host_uuid,
"services": json.dumps(d.services),
"decky_config": d.model_dump_json(),
"decky_ip": d.ip,
"state": "running" if is_up else "failed",
"last_error": None if is_up else str(exc)[:512],
"updated_at": datetime.now(timezone.utc),
}
_shard_payload(d, host_uuid, "running" if is_up else "failed", None if is_up else str(exc)[:512])
)
return SwarmHostResult(host_uuid=host_uuid, host_name=host["name"], ok=False, detail=str(exc))
async def dispatch_decnet_config(
config: DecnetConfig,
repo: BaseRepository,
dry_run: bool = False,
no_cache: bool = False,
) -> SwarmDeployResponse:
"""Shard ``config`` by ``host_uuid`` and dispatch to each worker in parallel.
Shared between POST /swarm/deploy (explicit swarm call) and the auto-swarm
branch of POST /deckies/deploy.
"""
buckets = _shard_by_host(config)
hosts: dict[str, dict[str, Any]] = {}
for host_uuid in buckets:
row = await repo.get_swarm_host_by_uuid(host_uuid)
if row is None:
raise HTTPException(status_code=404, detail=f"unknown host_uuid: {host_uuid}")
hosts[host_uuid] = row
results = await asyncio.gather(
*(_dispatch(uuid_, shard) for uuid_, shard in buckets.items())
*(
_dispatch(uuid_, shard, hosts, config, repo, dry_run, no_cache)
for uuid_, shard in buckets.items()
)
)
return SwarmDeployResponse(results=list(results))