feat(api/sse): per-user connection cap + viewer-safe invariant

New decnet/web/sse_limits.py provides sse_connection_slot, an async
context manager that counts live SSE connections per user UUID and
raises 429 when a per-user cap is exceeded (default 5, override via
DECNET_SSE_MAX_PER_USER). Wired into both SSE generators as their
first async with, so the cap check fires before any stream data is
yielded.

The cap must sit inside the generator — StreamingResponse returns
before the generator body runs, so a handler-level wrapper would
release the slot immediately. Put prefetch + slot + loop all under
the one async with.

Also documents F6/I (role leakage) as mitigated-by-construction via
handler docstrings: every event type on both streams wraps data
already reachable via viewer-gated REST, so no per-event filter is
needed until a new event family is introduced. The invariant is
written into the handler docstrings so a future PR can't silently
add admin-only events.

Resolves THREAT_MODEL F6/I and F6/D.
This commit is contained in:
2026-04-24 15:01:20 -04:00
parent df84981954
commit 162f7c1194
7 changed files with 271 additions and 123 deletions

View File

@@ -10,6 +10,7 @@ from decnet.env import DECNET_DEVELOPER
from decnet.logging import get_logger
from decnet.telemetry import traced as _traced, get_tracer as _get_tracer
from decnet.web.dependencies import require_stream_viewer, repo
from decnet.web.sse_limits import sse_connection_slot
log = get_logger("api")
@@ -52,7 +53,8 @@ def _build_trace_links(logs: list[dict]) -> list:
},
401: {"description": "Could not validate credentials"},
403: {"description": "Insufficient permissions"},
422: {"description": "Validation error"}
422: {"description": "Validation error"},
429: {"description": "Per-user SSE connection cap reached"},
},
)
@_traced("api.stream_events")
@@ -65,77 +67,79 @@ async def stream_events(
max_output: Optional[int] = Query(None, alias="maxOutput"),
user: dict = Depends(require_stream_viewer)
) -> StreamingResponse:
# Prefetch the initial snapshot before entering the streaming generator.
# With asyncmy (pure async TCP I/O), the first DB await inside the generator
# fires immediately after the ASGI layer sends the keepalive chunk — the HTTP
# write and the MySQL read compete for asyncio I/O callbacks and the MySQL
# callback can stall. Running these here (normal async context, no streaming)
# avoids that race entirely. aiosqlite is immune because it runs SQLite in a
# thread, decoupled from the event loop's I/O scheduler.
_start_id = last_event_id if last_event_id != 0 else await repo.get_max_log_id()
_initial_stats = await repo.get_stats_summary()
_initial_histogram = await repo.get_log_histogram(
search=search, start_time=start_time, end_time=end_time, interval_minutes=15,
)
# Event types emitted on this stream: logs, stats, histogram.
# All three are viewer-safe — same data is reachable via /logs and
# /stats (viewer-gated REST). Adding a new event family here
# requires a threat-model review for F6/I (role leakage).
async def event_generator() -> AsyncGenerator[str, None]:
last_id = _start_id
stats_interval_sec = 10
loops_since_stats = 0
emitted_chunks = 0
try:
yield ": keepalive\n\n" # flush headers immediately
async with sse_connection_slot(user["uuid"]):
# Prefetch the initial snapshot before the first yield.
# With asyncmy (pure async TCP I/O), a DB await AFTER the first
# yield races with the HTTP write callback; running DB reads
# here (pre-yield, normal coroutine context) avoids that.
# aiosqlite is immune because SQLite runs on a worker thread.
_start_id = last_event_id if last_event_id != 0 else await repo.get_max_log_id()
_initial_stats = await repo.get_stats_summary()
_initial_histogram = await repo.get_log_histogram(
search=search, start_time=start_time, end_time=end_time, interval_minutes=15,
)
last_id = _start_id
stats_interval_sec = 10
loops_since_stats = 0
emitted_chunks = 0
try:
yield ": keepalive\n\n" # flush headers immediately
# Emit pre-fetched initial snapshot — no DB calls in generator until the loop
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': _initial_stats}).decode()}\n\n"
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': _initial_histogram}).decode()}\n\n"
# Emit pre-fetched initial snapshot — no DB calls in generator until the loop
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': _initial_stats}).decode()}\n\n"
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': _initial_histogram}).decode()}\n\n"
while True:
if DECNET_DEVELOPER and max_output is not None:
emitted_chunks += 1
if emitted_chunks > max_output:
log.debug("Developer mode: max_output reached (%d), closing stream", max_output)
while True:
if DECNET_DEVELOPER and max_output is not None:
emitted_chunks += 1
if emitted_chunks > max_output:
log.debug("Developer mode: max_output reached (%d), closing stream", max_output)
break
if await request.is_disconnected():
break
if await request.is_disconnected():
break
new_logs = await repo.get_logs_after_id(
last_id, limit=50, search=search,
start_time=start_time, end_time=end_time,
)
if new_logs:
last_id = max(entry["id"] for entry in new_logs)
# Create a span linking back to the ingestion traces
# stored in each log row, closing the pipeline gap.
_links = _build_trace_links(new_logs)
_tracer = _get_tracer("sse")
with _tracer.start_as_current_span(
"sse.emit_logs", links=_links,
attributes={"log_count": len(new_logs)},
):
yield f"event: message\ndata: {orjson.dumps({'type': 'logs', 'data': new_logs}).decode()}\n\n"
loops_since_stats = stats_interval_sec
if loops_since_stats >= stats_interval_sec:
stats = await repo.get_stats_summary()
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': stats}).decode()}\n\n"
histogram = await repo.get_log_histogram(
search=search, start_time=start_time,
end_time=end_time, interval_minutes=15,
new_logs = await repo.get_logs_after_id(
last_id, limit=50, search=search,
start_time=start_time, end_time=end_time,
)
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': histogram}).decode()}\n\n"
loops_since_stats = 0
if new_logs:
last_id = max(entry["id"] for entry in new_logs)
# Create a span linking back to the ingestion traces
# stored in each log row, closing the pipeline gap.
_links = _build_trace_links(new_logs)
_tracer = _get_tracer("sse")
with _tracer.start_as_current_span(
"sse.emit_logs", links=_links,
attributes={"log_count": len(new_logs)},
):
yield f"event: message\ndata: {orjson.dumps({'type': 'logs', 'data': new_logs}).decode()}\n\n"
loops_since_stats = stats_interval_sec
loops_since_stats += 1
if loops_since_stats >= stats_interval_sec:
stats = await repo.get_stats_summary()
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': stats}).decode()}\n\n"
histogram = await repo.get_log_histogram(
search=search, start_time=start_time,
end_time=end_time, interval_minutes=15,
)
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': histogram}).decode()}\n\n"
loops_since_stats = 0
await asyncio.sleep(1)
except asyncio.CancelledError:
pass
except Exception:
log.exception("SSE stream error for user %s", last_event_id)
yield f"event: error\ndata: {orjson.dumps({'type': 'error', 'message': 'Stream interrupted'}).decode()}\n\n"
loops_since_stats += 1
await asyncio.sleep(1)
except asyncio.CancelledError:
pass
except Exception:
log.exception("SSE stream error for user %s", last_event_id)
yield f"event: error\ndata: {orjson.dumps({'type': 'error', 'message': 'Stream interrupted'}).decode()}\n\n"
return StreamingResponse(
event_generator(),

View File

@@ -26,6 +26,7 @@ from decnet.bus.app import get_app_bus
from decnet.logging import get_logger
from decnet.telemetry import traced as _traced
from decnet.web.dependencies import repo, require_stream_viewer
from decnet.web.sse_limits import sse_connection_slot
from ._guards import get_topology_or_404
@@ -53,14 +54,20 @@ def _format_sse(event_name: str, data: dict) -> str:
401: {"description": "Could not validate credentials"},
403: {"description": "Insufficient permissions"},
404: {"description": "Topology not found"},
429: {"description": "Per-user SSE connection cap reached"},
},
)
@_traced("api.topology.events")
async def api_topology_events(
topology_id: str,
request: Request,
_user: dict = Depends(require_stream_viewer),
user: dict = Depends(require_stream_viewer),
) -> StreamingResponse:
# Event types emitted: snapshot, status, mutation.{enqueued,
# applying,applied,failed}. All wrap bus events whose payload is
# also reachable via viewer-gated REST (GET /topologies/{id},
# GET /topologies/{id}/mutations). Adding a new event family here
# requires a threat-model review for F6/I (role leakage).
topo = await get_topology_or_404(topology_id)
snapshot_status = topo["status"]
in_flight: list[dict] = []
@@ -68,64 +75,65 @@ async def api_topology_events(
in_flight.extend(await repo.list_topology_mutations(topology_id, state=state))
async def generator() -> AsyncGenerator[str, None]:
# Flush headers immediately so the browser's EventSource sees a
# live connection before the first real event arrives.
yield ": keepalive\n\n"
async with sse_connection_slot(user["uuid"]):
# Flush headers immediately so the browser's EventSource sees a
# live connection before the first real event arrives.
yield ": keepalive\n\n"
# One-shot snapshot — pair the current topology status with any
# mutations the mutator is still holding, so the client buffer
# can render an accurate "already in flight" state.
yield _format_sse("snapshot", {
"topology_id": topology_id,
"status": snapshot_status,
"in_flight": in_flight,
})
# One-shot snapshot — pair the current topology status with any
# mutations the mutator is still holding, so the client buffer
# can render an accurate "already in flight" state.
yield _format_sse("snapshot", {
"topology_id": topology_id,
"status": snapshot_status,
"in_flight": in_flight,
})
bus = await get_app_bus()
if bus is None:
# Bus disabled (NullBus) or unreachable. The snapshot is
# still useful; we idle on keepalives so the client stays
# connected and will re-poll on its own timers.
while not await request.is_disconnected():
try:
await asyncio.sleep(_KEEPALIVE_SECS)
except asyncio.CancelledError:
break
yield ": keepalive\n\n"
return
sub = bus.subscribe(f"{_topics.TOPOLOGY}.{topology_id}.>")
try:
async with sub:
sub_iter = sub.__aiter__()
while True:
if await request.is_disconnected():
break
next_task = asyncio.ensure_future(sub_iter.__anext__())
bus = await get_app_bus()
if bus is None:
# Bus disabled (NullBus) or unreachable. The snapshot is
# still useful; we idle on keepalives so the client stays
# connected and will re-poll on its own timers.
while not await request.is_disconnected():
try:
event = await asyncio.wait_for(next_task, timeout=_KEEPALIVE_SECS)
except asyncio.TimeoutError:
next_task.cancel()
yield ": keepalive\n\n"
continue
except StopAsyncIteration:
await asyncio.sleep(_KEEPALIVE_SECS)
except asyncio.CancelledError:
break
# Map the bus event onto an SSE ``event:`` name that
# the frontend can switch on without parsing topics.
yield _format_sse(
_sse_name_for(event.topic),
{
"topic": event.topic,
"type": event.type,
"ts": event.ts,
"payload": event.payload,
},
)
except asyncio.CancelledError:
pass
except Exception:
log.exception("topology events stream crashed topology_id=%s", topology_id)
yield _format_sse("error", {"message": "Stream interrupted"})
yield ": keepalive\n\n"
return
sub = bus.subscribe(f"{_topics.TOPOLOGY}.{topology_id}.>")
try:
async with sub:
sub_iter = sub.__aiter__()
while True:
if await request.is_disconnected():
break
next_task = asyncio.ensure_future(sub_iter.__anext__())
try:
event = await asyncio.wait_for(next_task, timeout=_KEEPALIVE_SECS)
except asyncio.TimeoutError:
next_task.cancel()
yield ": keepalive\n\n"
continue
except StopAsyncIteration:
break
# Map the bus event onto an SSE ``event:`` name that
# the frontend can switch on without parsing topics.
yield _format_sse(
_sse_name_for(event.topic),
{
"topic": event.topic,
"type": event.type,
"ts": event.ts,
"payload": event.payload,
},
)
except asyncio.CancelledError:
pass
except Exception:
log.exception("topology events stream crashed topology_id=%s", topology_id)
yield _format_sse("error", {"message": "Stream interrupted"})
return StreamingResponse(
generator(),

65
decnet/web/sse_limits.py Normal file
View File

@@ -0,0 +1,65 @@
"""Per-user concurrent SSE connection gate.
SSE connections are long-lived — a client that opens one per tab
forever can exhaust API workers. Module-level dict + async lock keeps
the fast path cheap (a dict lookup) while the lock keeps check-and-
increment atomic across concurrent handshakes.
The slot must wrap the generator's own lifetime, not just the handler
call, because StreamingResponse returns before the generator body
runs. Call it as the first statement inside the generator — an
HTTPException raised before the first yield bubbles back to the client
as a normal HTTP response.
"""
from __future__ import annotations
import asyncio
import os
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import HTTPException, status
DEFAULT_CAP = 5
_MAX_PER_USER = int(os.environ.get("DECNET_SSE_MAX_PER_USER", DEFAULT_CAP))
_counts: dict[str, int] = defaultdict(int)
_lock: asyncio.Lock | None = None
def _get_lock() -> asyncio.Lock:
global _lock
if _lock is None:
_lock = asyncio.Lock()
return _lock
def _reset_for_tests() -> None:
"""Clear counters + lock between tests. The lock is rebuilt lazily
so a fixture can reset state without worrying about event-loop
binding from a previous test."""
global _lock
_counts.clear()
_lock = None
def current_count(user_uuid: str) -> int:
"""Snapshot helper — tests and diagnostics only."""
return _counts.get(user_uuid, 0)
@asynccontextmanager
async def sse_connection_slot(user_uuid: str):
async with _get_lock():
if _counts[user_uuid] >= _MAX_PER_USER:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"SSE connection limit ({_MAX_PER_USER}) reached",
)
_counts[user_uuid] += 1
try:
yield
finally:
async with _get_lock():
_counts[user_uuid] -= 1
if _counts[user_uuid] <= 0:
del _counts[user_uuid]