refactor(swarm): extract _shard_payload helper and promote _dispatch to module-level

This commit is contained in:
2026-04-30 20:25:38 -04:00
parent c648d8b04e
commit e124f9e296

View File

@@ -57,27 +57,33 @@ def _worker_config(
return base.model_copy(update=updates) return base.model_copy(update=updates)
async def dispatch_decnet_config( def _shard_payload(
d: DeckyConfig,
host_uuid: str,
state: str,
error: str | None,
) -> dict[str, Any]:
return {
"decky_name": d.name,
"host_uuid": host_uuid,
"services": json.dumps(d.services),
"decky_config": d.model_dump_json(),
"decky_ip": d.ip,
"state": state,
"last_error": error,
"updated_at": datetime.now(timezone.utc),
}
async def _dispatch(
host_uuid: str,
shard: list[DeckyConfig],
hosts: dict[str, dict[str, Any]],
config: DecnetConfig, config: DecnetConfig,
repo: BaseRepository, repo: BaseRepository,
dry_run: bool = False, dry_run: bool,
no_cache: bool = False, no_cache: bool,
) -> SwarmDeployResponse: ) -> SwarmHostResult:
"""Shard ``config`` by ``host_uuid`` and dispatch to each worker in parallel.
Shared between POST /swarm/deploy (explicit swarm call) and the auto-swarm
branch of POST /deckies/deploy.
"""
buckets = _shard_by_host(config)
hosts: dict[str, dict[str, Any]] = {}
for host_uuid in buckets:
row = await repo.get_swarm_host_by_uuid(host_uuid)
if row is None:
raise HTTPException(status_code=404, detail=f"unknown host_uuid: {host_uuid}")
hosts[host_uuid] = row
async def _dispatch(host_uuid: str, shard: list[DeckyConfig]) -> SwarmHostResult:
host = hosts[host_uuid] host = hosts[host_uuid]
cfg = _worker_config(config, shard, host) cfg = _worker_config(config, shard, host)
try: try:
@@ -85,16 +91,7 @@ async def dispatch_decnet_config(
body = await agent.deploy(cfg, dry_run=dry_run, no_cache=no_cache) body = await agent.deploy(cfg, dry_run=dry_run, no_cache=no_cache)
for d in shard: for d in shard:
await repo.upsert_decky_shard( await repo.upsert_decky_shard(
{ _shard_payload(d, host_uuid, "running" if not dry_run else "pending", None)
"decky_name": d.name,
"host_uuid": host_uuid,
"services": json.dumps(d.services),
"decky_config": d.model_dump_json(),
"decky_ip": d.ip,
"state": "running" if not dry_run else "pending",
"last_error": None,
"updated_at": datetime.now(timezone.utc),
}
) )
await repo.update_swarm_host(host_uuid, {"status": "active"}) await repo.update_swarm_host(host_uuid, {"status": "active"})
return SwarmHostResult(host_uuid=host_uuid, host_name=host["name"], ok=True, detail=body) return SwarmHostResult(host_uuid=host_uuid, host_name=host["name"], ok=True, detail=body)
@@ -116,21 +113,36 @@ async def dispatch_decnet_config(
rstate = runtime.get(d.name) or {} rstate = runtime.get(d.name) or {}
is_up = bool(rstate.get("running")) is_up = bool(rstate.get("running"))
await repo.upsert_decky_shard( await repo.upsert_decky_shard(
{ _shard_payload(d, host_uuid, "running" if is_up else "failed", None if is_up else str(exc)[:512])
"decky_name": d.name,
"host_uuid": host_uuid,
"services": json.dumps(d.services),
"decky_config": d.model_dump_json(),
"decky_ip": d.ip,
"state": "running" if is_up else "failed",
"last_error": None if is_up else str(exc)[:512],
"updated_at": datetime.now(timezone.utc),
}
) )
return SwarmHostResult(host_uuid=host_uuid, host_name=host["name"], ok=False, detail=str(exc)) return SwarmHostResult(host_uuid=host_uuid, host_name=host["name"], ok=False, detail=str(exc))
async def dispatch_decnet_config(
config: DecnetConfig,
repo: BaseRepository,
dry_run: bool = False,
no_cache: bool = False,
) -> SwarmDeployResponse:
"""Shard ``config`` by ``host_uuid`` and dispatch to each worker in parallel.
Shared between POST /swarm/deploy (explicit swarm call) and the auto-swarm
branch of POST /deckies/deploy.
"""
buckets = _shard_by_host(config)
hosts: dict[str, dict[str, Any]] = {}
for host_uuid in buckets:
row = await repo.get_swarm_host_by_uuid(host_uuid)
if row is None:
raise HTTPException(status_code=404, detail=f"unknown host_uuid: {host_uuid}")
hosts[host_uuid] = row
results = await asyncio.gather( results = await asyncio.gather(
*(_dispatch(uuid_, shard) for uuid_, shard in buckets.items()) *(
_dispatch(uuid_, shard, hosts, config, repo, dry_run, no_cache)
for uuid_, shard in buckets.items()
)
) )
return SwarmDeployResponse(results=list(results)) return SwarmDeployResponse(results=list(results))