feat(api/sse): per-user connection cap + viewer-safe invariant

New decnet/web/sse_limits.py provides sse_connection_slot, an async
context manager that counts live SSE connections per user UUID and
raises 429 when a per-user cap is exceeded (default 5, override via
DECNET_SSE_MAX_PER_USER). Wired into both SSE generators as their
first async with, so the cap check fires before any stream data is
yielded.

The cap must sit inside the generator — StreamingResponse returns
before the generator body runs, so a handler-level wrapper would
release the slot immediately. Put prefetch + slot + loop all under
the one async with.

Also documents F6/I (role leakage) as mitigated-by-construction via
handler docstrings: every event type on both streams wraps data
already reachable via viewer-gated REST, so no per-event filter is
needed until a new event family is introduced. The invariant is
written into the handler docstrings so a future PR can't silently
add admin-only events.

Resolves THREAT_MODEL F6/I and F6/D.
This commit is contained in:
2026-04-24 15:01:20 -04:00
parent df84981954
commit 162f7c1194
7 changed files with 271 additions and 123 deletions

View File

@@ -10,6 +10,7 @@ from decnet.env import DECNET_DEVELOPER
from decnet.logging import get_logger
from decnet.telemetry import traced as _traced, get_tracer as _get_tracer
from decnet.web.dependencies import require_stream_viewer, repo
from decnet.web.sse_limits import sse_connection_slot
log = get_logger("api")
@@ -52,7 +53,8 @@ def _build_trace_links(logs: list[dict]) -> list:
},
401: {"description": "Could not validate credentials"},
403: {"description": "Insufficient permissions"},
422: {"description": "Validation error"}
422: {"description": "Validation error"},
429: {"description": "Per-user SSE connection cap reached"},
},
)
@_traced("api.stream_events")
@@ -65,77 +67,79 @@ async def stream_events(
max_output: Optional[int] = Query(None, alias="maxOutput"),
user: dict = Depends(require_stream_viewer)
) -> StreamingResponse:
# Prefetch the initial snapshot before entering the streaming generator.
# With asyncmy (pure async TCP I/O), the first DB await inside the generator
# fires immediately after the ASGI layer sends the keepalive chunk — the HTTP
# write and the MySQL read compete for asyncio I/O callbacks and the MySQL
# callback can stall. Running these here (normal async context, no streaming)
# avoids that race entirely. aiosqlite is immune because it runs SQLite in a
# thread, decoupled from the event loop's I/O scheduler.
_start_id = last_event_id if last_event_id != 0 else await repo.get_max_log_id()
_initial_stats = await repo.get_stats_summary()
_initial_histogram = await repo.get_log_histogram(
search=search, start_time=start_time, end_time=end_time, interval_minutes=15,
)
# Event types emitted on this stream: logs, stats, histogram.
# All three are viewer-safe — same data is reachable via /logs and
# /stats (viewer-gated REST). Adding a new event family here
# requires a threat-model review for F6/I (role leakage).
async def event_generator() -> AsyncGenerator[str, None]:
last_id = _start_id
stats_interval_sec = 10
loops_since_stats = 0
emitted_chunks = 0
try:
yield ": keepalive\n\n" # flush headers immediately
async with sse_connection_slot(user["uuid"]):
# Prefetch the initial snapshot before the first yield.
# With asyncmy (pure async TCP I/O), a DB await AFTER the first
# yield races with the HTTP write callback; running DB reads
# here (pre-yield, normal coroutine context) avoids that.
# aiosqlite is immune because SQLite runs on a worker thread.
_start_id = last_event_id if last_event_id != 0 else await repo.get_max_log_id()
_initial_stats = await repo.get_stats_summary()
_initial_histogram = await repo.get_log_histogram(
search=search, start_time=start_time, end_time=end_time, interval_minutes=15,
)
last_id = _start_id
stats_interval_sec = 10
loops_since_stats = 0
emitted_chunks = 0
try:
yield ": keepalive\n\n" # flush headers immediately
# Emit pre-fetched initial snapshot — no DB calls in generator until the loop
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': _initial_stats}).decode()}\n\n"
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': _initial_histogram}).decode()}\n\n"
# Emit pre-fetched initial snapshot — no DB calls in generator until the loop
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': _initial_stats}).decode()}\n\n"
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': _initial_histogram}).decode()}\n\n"
while True:
if DECNET_DEVELOPER and max_output is not None:
emitted_chunks += 1
if emitted_chunks > max_output:
log.debug("Developer mode: max_output reached (%d), closing stream", max_output)
while True:
if DECNET_DEVELOPER and max_output is not None:
emitted_chunks += 1
if emitted_chunks > max_output:
log.debug("Developer mode: max_output reached (%d), closing stream", max_output)
break
if await request.is_disconnected():
break
if await request.is_disconnected():
break
new_logs = await repo.get_logs_after_id(
last_id, limit=50, search=search,
start_time=start_time, end_time=end_time,
)
if new_logs:
last_id = max(entry["id"] for entry in new_logs)
# Create a span linking back to the ingestion traces
# stored in each log row, closing the pipeline gap.
_links = _build_trace_links(new_logs)
_tracer = _get_tracer("sse")
with _tracer.start_as_current_span(
"sse.emit_logs", links=_links,
attributes={"log_count": len(new_logs)},
):
yield f"event: message\ndata: {orjson.dumps({'type': 'logs', 'data': new_logs}).decode()}\n\n"
loops_since_stats = stats_interval_sec
if loops_since_stats >= stats_interval_sec:
stats = await repo.get_stats_summary()
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': stats}).decode()}\n\n"
histogram = await repo.get_log_histogram(
search=search, start_time=start_time,
end_time=end_time, interval_minutes=15,
new_logs = await repo.get_logs_after_id(
last_id, limit=50, search=search,
start_time=start_time, end_time=end_time,
)
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': histogram}).decode()}\n\n"
loops_since_stats = 0
if new_logs:
last_id = max(entry["id"] for entry in new_logs)
# Create a span linking back to the ingestion traces
# stored in each log row, closing the pipeline gap.
_links = _build_trace_links(new_logs)
_tracer = _get_tracer("sse")
with _tracer.start_as_current_span(
"sse.emit_logs", links=_links,
attributes={"log_count": len(new_logs)},
):
yield f"event: message\ndata: {orjson.dumps({'type': 'logs', 'data': new_logs}).decode()}\n\n"
loops_since_stats = stats_interval_sec
loops_since_stats += 1
if loops_since_stats >= stats_interval_sec:
stats = await repo.get_stats_summary()
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': stats}).decode()}\n\n"
histogram = await repo.get_log_histogram(
search=search, start_time=start_time,
end_time=end_time, interval_minutes=15,
)
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': histogram}).decode()}\n\n"
loops_since_stats = 0
await asyncio.sleep(1)
except asyncio.CancelledError:
pass
except Exception:
log.exception("SSE stream error for user %s", last_event_id)
yield f"event: error\ndata: {orjson.dumps({'type': 'error', 'message': 'Stream interrupted'}).decode()}\n\n"
loops_since_stats += 1
await asyncio.sleep(1)
except asyncio.CancelledError:
pass
except Exception:
log.exception("SSE stream error for user %s", last_event_id)
yield f"event: error\ndata: {orjson.dumps({'type': 'error', 'message': 'Stream interrupted'}).decode()}\n\n"
return StreamingResponse(
event_generator(),

View File

@@ -26,6 +26,7 @@ from decnet.bus.app import get_app_bus
from decnet.logging import get_logger
from decnet.telemetry import traced as _traced
from decnet.web.dependencies import repo, require_stream_viewer
from decnet.web.sse_limits import sse_connection_slot
from ._guards import get_topology_or_404
@@ -53,14 +54,20 @@ def _format_sse(event_name: str, data: dict) -> str:
401: {"description": "Could not validate credentials"},
403: {"description": "Insufficient permissions"},
404: {"description": "Topology not found"},
429: {"description": "Per-user SSE connection cap reached"},
},
)
@_traced("api.topology.events")
async def api_topology_events(
topology_id: str,
request: Request,
_user: dict = Depends(require_stream_viewer),
user: dict = Depends(require_stream_viewer),
) -> StreamingResponse:
# Event types emitted: snapshot, status, mutation.{enqueued,
# applying,applied,failed}. All wrap bus events whose payload is
# also reachable via viewer-gated REST (GET /topologies/{id},
# GET /topologies/{id}/mutations). Adding a new event family here
# requires a threat-model review for F6/I (role leakage).
topo = await get_topology_or_404(topology_id)
snapshot_status = topo["status"]
in_flight: list[dict] = []
@@ -68,64 +75,65 @@ async def api_topology_events(
in_flight.extend(await repo.list_topology_mutations(topology_id, state=state))
async def generator() -> AsyncGenerator[str, None]:
# Flush headers immediately so the browser's EventSource sees a
# live connection before the first real event arrives.
yield ": keepalive\n\n"
async with sse_connection_slot(user["uuid"]):
# Flush headers immediately so the browser's EventSource sees a
# live connection before the first real event arrives.
yield ": keepalive\n\n"
# One-shot snapshot — pair the current topology status with any
# mutations the mutator is still holding, so the client buffer
# can render an accurate "already in flight" state.
yield _format_sse("snapshot", {
"topology_id": topology_id,
"status": snapshot_status,
"in_flight": in_flight,
})
# One-shot snapshot — pair the current topology status with any
# mutations the mutator is still holding, so the client buffer
# can render an accurate "already in flight" state.
yield _format_sse("snapshot", {
"topology_id": topology_id,
"status": snapshot_status,
"in_flight": in_flight,
})
bus = await get_app_bus()
if bus is None:
# Bus disabled (NullBus) or unreachable. The snapshot is
# still useful; we idle on keepalives so the client stays
# connected and will re-poll on its own timers.
while not await request.is_disconnected():
try:
await asyncio.sleep(_KEEPALIVE_SECS)
except asyncio.CancelledError:
break
yield ": keepalive\n\n"
return
sub = bus.subscribe(f"{_topics.TOPOLOGY}.{topology_id}.>")
try:
async with sub:
sub_iter = sub.__aiter__()
while True:
if await request.is_disconnected():
break
next_task = asyncio.ensure_future(sub_iter.__anext__())
bus = await get_app_bus()
if bus is None:
# Bus disabled (NullBus) or unreachable. The snapshot is
# still useful; we idle on keepalives so the client stays
# connected and will re-poll on its own timers.
while not await request.is_disconnected():
try:
event = await asyncio.wait_for(next_task, timeout=_KEEPALIVE_SECS)
except asyncio.TimeoutError:
next_task.cancel()
yield ": keepalive\n\n"
continue
except StopAsyncIteration:
await asyncio.sleep(_KEEPALIVE_SECS)
except asyncio.CancelledError:
break
# Map the bus event onto an SSE ``event:`` name that
# the frontend can switch on without parsing topics.
yield _format_sse(
_sse_name_for(event.topic),
{
"topic": event.topic,
"type": event.type,
"ts": event.ts,
"payload": event.payload,
},
)
except asyncio.CancelledError:
pass
except Exception:
log.exception("topology events stream crashed topology_id=%s", topology_id)
yield _format_sse("error", {"message": "Stream interrupted"})
yield ": keepalive\n\n"
return
sub = bus.subscribe(f"{_topics.TOPOLOGY}.{topology_id}.>")
try:
async with sub:
sub_iter = sub.__aiter__()
while True:
if await request.is_disconnected():
break
next_task = asyncio.ensure_future(sub_iter.__anext__())
try:
event = await asyncio.wait_for(next_task, timeout=_KEEPALIVE_SECS)
except asyncio.TimeoutError:
next_task.cancel()
yield ": keepalive\n\n"
continue
except StopAsyncIteration:
break
# Map the bus event onto an SSE ``event:`` name that
# the frontend can switch on without parsing topics.
yield _format_sse(
_sse_name_for(event.topic),
{
"topic": event.topic,
"type": event.type,
"ts": event.ts,
"payload": event.payload,
},
)
except asyncio.CancelledError:
pass
except Exception:
log.exception("topology events stream crashed topology_id=%s", topology_id)
yield _format_sse("error", {"message": "Stream interrupted"})
return StreamingResponse(
generator(),