feat(api/sse): per-user connection cap + viewer-safe invariant
New decnet/web/sse_limits.py provides sse_connection_slot, an async context manager that counts live SSE connections per user UUID and raises 429 when a per-user cap is exceeded (default 5, override via DECNET_SSE_MAX_PER_USER). Wired into both SSE generators as their first async with, so the cap check fires before any stream data is yielded. The cap must sit inside the generator — StreamingResponse returns before the generator body runs, so a handler-level wrapper would release the slot immediately. Put prefetch + slot + loop all under the one async with. Also documents F6/I (role leakage) as mitigated-by-construction via handler docstrings: every event type on both streams wraps data already reachable via viewer-gated REST, so no per-event filter is needed until a new event family is introduced. The invariant is written into the handler docstrings so a future PR can't silently add admin-only events. Resolves THREAT_MODEL F6/I and F6/D.
This commit is contained in:
@@ -10,6 +10,7 @@ from decnet.env import DECNET_DEVELOPER
|
||||
from decnet.logging import get_logger
|
||||
from decnet.telemetry import traced as _traced, get_tracer as _get_tracer
|
||||
from decnet.web.dependencies import require_stream_viewer, repo
|
||||
from decnet.web.sse_limits import sse_connection_slot
|
||||
|
||||
log = get_logger("api")
|
||||
|
||||
@@ -52,7 +53,8 @@ def _build_trace_links(logs: list[dict]) -> list:
|
||||
},
|
||||
401: {"description": "Could not validate credentials"},
|
||||
403: {"description": "Insufficient permissions"},
|
||||
422: {"description": "Validation error"}
|
||||
422: {"description": "Validation error"},
|
||||
429: {"description": "Per-user SSE connection cap reached"},
|
||||
},
|
||||
)
|
||||
@_traced("api.stream_events")
|
||||
@@ -65,77 +67,79 @@ async def stream_events(
|
||||
max_output: Optional[int] = Query(None, alias="maxOutput"),
|
||||
user: dict = Depends(require_stream_viewer)
|
||||
) -> StreamingResponse:
|
||||
|
||||
# Prefetch the initial snapshot before entering the streaming generator.
|
||||
# With asyncmy (pure async TCP I/O), the first DB await inside the generator
|
||||
# fires immediately after the ASGI layer sends the keepalive chunk — the HTTP
|
||||
# write and the MySQL read compete for asyncio I/O callbacks and the MySQL
|
||||
# callback can stall. Running these here (normal async context, no streaming)
|
||||
# avoids that race entirely. aiosqlite is immune because it runs SQLite in a
|
||||
# thread, decoupled from the event loop's I/O scheduler.
|
||||
_start_id = last_event_id if last_event_id != 0 else await repo.get_max_log_id()
|
||||
_initial_stats = await repo.get_stats_summary()
|
||||
_initial_histogram = await repo.get_log_histogram(
|
||||
search=search, start_time=start_time, end_time=end_time, interval_minutes=15,
|
||||
)
|
||||
# Event types emitted on this stream: logs, stats, histogram.
|
||||
# All three are viewer-safe — same data is reachable via /logs and
|
||||
# /stats (viewer-gated REST). Adding a new event family here
|
||||
# requires a threat-model review for F6/I (role leakage).
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
last_id = _start_id
|
||||
stats_interval_sec = 10
|
||||
loops_since_stats = 0
|
||||
emitted_chunks = 0
|
||||
try:
|
||||
yield ": keepalive\n\n" # flush headers immediately
|
||||
async with sse_connection_slot(user["uuid"]):
|
||||
# Prefetch the initial snapshot before the first yield.
|
||||
# With asyncmy (pure async TCP I/O), a DB await AFTER the first
|
||||
# yield races with the HTTP write callback; running DB reads
|
||||
# here (pre-yield, normal coroutine context) avoids that.
|
||||
# aiosqlite is immune because SQLite runs on a worker thread.
|
||||
_start_id = last_event_id if last_event_id != 0 else await repo.get_max_log_id()
|
||||
_initial_stats = await repo.get_stats_summary()
|
||||
_initial_histogram = await repo.get_log_histogram(
|
||||
search=search, start_time=start_time, end_time=end_time, interval_minutes=15,
|
||||
)
|
||||
last_id = _start_id
|
||||
stats_interval_sec = 10
|
||||
loops_since_stats = 0
|
||||
emitted_chunks = 0
|
||||
try:
|
||||
yield ": keepalive\n\n" # flush headers immediately
|
||||
|
||||
# Emit pre-fetched initial snapshot — no DB calls in generator until the loop
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': _initial_stats}).decode()}\n\n"
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': _initial_histogram}).decode()}\n\n"
|
||||
# Emit pre-fetched initial snapshot — no DB calls in generator until the loop
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': _initial_stats}).decode()}\n\n"
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': _initial_histogram}).decode()}\n\n"
|
||||
|
||||
while True:
|
||||
if DECNET_DEVELOPER and max_output is not None:
|
||||
emitted_chunks += 1
|
||||
if emitted_chunks > max_output:
|
||||
log.debug("Developer mode: max_output reached (%d), closing stream", max_output)
|
||||
while True:
|
||||
if DECNET_DEVELOPER and max_output is not None:
|
||||
emitted_chunks += 1
|
||||
if emitted_chunks > max_output:
|
||||
log.debug("Developer mode: max_output reached (%d), closing stream", max_output)
|
||||
break
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
new_logs = await repo.get_logs_after_id(
|
||||
last_id, limit=50, search=search,
|
||||
start_time=start_time, end_time=end_time,
|
||||
)
|
||||
if new_logs:
|
||||
last_id = max(entry["id"] for entry in new_logs)
|
||||
# Create a span linking back to the ingestion traces
|
||||
# stored in each log row, closing the pipeline gap.
|
||||
_links = _build_trace_links(new_logs)
|
||||
_tracer = _get_tracer("sse")
|
||||
with _tracer.start_as_current_span(
|
||||
"sse.emit_logs", links=_links,
|
||||
attributes={"log_count": len(new_logs)},
|
||||
):
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'logs', 'data': new_logs}).decode()}\n\n"
|
||||
loops_since_stats = stats_interval_sec
|
||||
|
||||
if loops_since_stats >= stats_interval_sec:
|
||||
stats = await repo.get_stats_summary()
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': stats}).decode()}\n\n"
|
||||
histogram = await repo.get_log_histogram(
|
||||
search=search, start_time=start_time,
|
||||
end_time=end_time, interval_minutes=15,
|
||||
new_logs = await repo.get_logs_after_id(
|
||||
last_id, limit=50, search=search,
|
||||
start_time=start_time, end_time=end_time,
|
||||
)
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': histogram}).decode()}\n\n"
|
||||
loops_since_stats = 0
|
||||
if new_logs:
|
||||
last_id = max(entry["id"] for entry in new_logs)
|
||||
# Create a span linking back to the ingestion traces
|
||||
# stored in each log row, closing the pipeline gap.
|
||||
_links = _build_trace_links(new_logs)
|
||||
_tracer = _get_tracer("sse")
|
||||
with _tracer.start_as_current_span(
|
||||
"sse.emit_logs", links=_links,
|
||||
attributes={"log_count": len(new_logs)},
|
||||
):
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'logs', 'data': new_logs}).decode()}\n\n"
|
||||
loops_since_stats = stats_interval_sec
|
||||
|
||||
loops_since_stats += 1
|
||||
if loops_since_stats >= stats_interval_sec:
|
||||
stats = await repo.get_stats_summary()
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'stats', 'data': stats}).decode()}\n\n"
|
||||
histogram = await repo.get_log_histogram(
|
||||
search=search, start_time=start_time,
|
||||
end_time=end_time, interval_minutes=15,
|
||||
)
|
||||
yield f"event: message\ndata: {orjson.dumps({'type': 'histogram', 'data': histogram}).decode()}\n\n"
|
||||
loops_since_stats = 0
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception("SSE stream error for user %s", last_event_id)
|
||||
yield f"event: error\ndata: {orjson.dumps({'type': 'error', 'message': 'Stream interrupted'}).decode()}\n\n"
|
||||
loops_since_stats += 1
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception("SSE stream error for user %s", last_event_id)
|
||||
yield f"event: error\ndata: {orjson.dumps({'type': 'error', 'message': 'Stream interrupted'}).decode()}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -26,6 +26,7 @@ from decnet.bus.app import get_app_bus
|
||||
from decnet.logging import get_logger
|
||||
from decnet.telemetry import traced as _traced
|
||||
from decnet.web.dependencies import repo, require_stream_viewer
|
||||
from decnet.web.sse_limits import sse_connection_slot
|
||||
|
||||
from ._guards import get_topology_or_404
|
||||
|
||||
@@ -53,14 +54,20 @@ def _format_sse(event_name: str, data: dict) -> str:
|
||||
401: {"description": "Could not validate credentials"},
|
||||
403: {"description": "Insufficient permissions"},
|
||||
404: {"description": "Topology not found"},
|
||||
429: {"description": "Per-user SSE connection cap reached"},
|
||||
},
|
||||
)
|
||||
@_traced("api.topology.events")
|
||||
async def api_topology_events(
|
||||
topology_id: str,
|
||||
request: Request,
|
||||
_user: dict = Depends(require_stream_viewer),
|
||||
user: dict = Depends(require_stream_viewer),
|
||||
) -> StreamingResponse:
|
||||
# Event types emitted: snapshot, status, mutation.{enqueued,
|
||||
# applying,applied,failed}. All wrap bus events whose payload is
|
||||
# also reachable via viewer-gated REST (GET /topologies/{id},
|
||||
# GET /topologies/{id}/mutations). Adding a new event family here
|
||||
# requires a threat-model review for F6/I (role leakage).
|
||||
topo = await get_topology_or_404(topology_id)
|
||||
snapshot_status = topo["status"]
|
||||
in_flight: list[dict] = []
|
||||
@@ -68,64 +75,65 @@ async def api_topology_events(
|
||||
in_flight.extend(await repo.list_topology_mutations(topology_id, state=state))
|
||||
|
||||
async def generator() -> AsyncGenerator[str, None]:
|
||||
# Flush headers immediately so the browser's EventSource sees a
|
||||
# live connection before the first real event arrives.
|
||||
yield ": keepalive\n\n"
|
||||
async with sse_connection_slot(user["uuid"]):
|
||||
# Flush headers immediately so the browser's EventSource sees a
|
||||
# live connection before the first real event arrives.
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
# One-shot snapshot — pair the current topology status with any
|
||||
# mutations the mutator is still holding, so the client buffer
|
||||
# can render an accurate "already in flight" state.
|
||||
yield _format_sse("snapshot", {
|
||||
"topology_id": topology_id,
|
||||
"status": snapshot_status,
|
||||
"in_flight": in_flight,
|
||||
})
|
||||
# One-shot snapshot — pair the current topology status with any
|
||||
# mutations the mutator is still holding, so the client buffer
|
||||
# can render an accurate "already in flight" state.
|
||||
yield _format_sse("snapshot", {
|
||||
"topology_id": topology_id,
|
||||
"status": snapshot_status,
|
||||
"in_flight": in_flight,
|
||||
})
|
||||
|
||||
bus = await get_app_bus()
|
||||
if bus is None:
|
||||
# Bus disabled (NullBus) or unreachable. The snapshot is
|
||||
# still useful; we idle on keepalives so the client stays
|
||||
# connected and will re-poll on its own timers.
|
||||
while not await request.is_disconnected():
|
||||
try:
|
||||
await asyncio.sleep(_KEEPALIVE_SECS)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
yield ": keepalive\n\n"
|
||||
return
|
||||
|
||||
sub = bus.subscribe(f"{_topics.TOPOLOGY}.{topology_id}.>")
|
||||
try:
|
||||
async with sub:
|
||||
sub_iter = sub.__aiter__()
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
next_task = asyncio.ensure_future(sub_iter.__anext__())
|
||||
bus = await get_app_bus()
|
||||
if bus is None:
|
||||
# Bus disabled (NullBus) or unreachable. The snapshot is
|
||||
# still useful; we idle on keepalives so the client stays
|
||||
# connected and will re-poll on its own timers.
|
||||
while not await request.is_disconnected():
|
||||
try:
|
||||
event = await asyncio.wait_for(next_task, timeout=_KEEPALIVE_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
next_task.cancel()
|
||||
yield ": keepalive\n\n"
|
||||
continue
|
||||
except StopAsyncIteration:
|
||||
await asyncio.sleep(_KEEPALIVE_SECS)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
# Map the bus event onto an SSE ``event:`` name that
|
||||
# the frontend can switch on without parsing topics.
|
||||
yield _format_sse(
|
||||
_sse_name_for(event.topic),
|
||||
{
|
||||
"topic": event.topic,
|
||||
"type": event.type,
|
||||
"ts": event.ts,
|
||||
"payload": event.payload,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception("topology events stream crashed topology_id=%s", topology_id)
|
||||
yield _format_sse("error", {"message": "Stream interrupted"})
|
||||
yield ": keepalive\n\n"
|
||||
return
|
||||
|
||||
sub = bus.subscribe(f"{_topics.TOPOLOGY}.{topology_id}.>")
|
||||
try:
|
||||
async with sub:
|
||||
sub_iter = sub.__aiter__()
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
next_task = asyncio.ensure_future(sub_iter.__anext__())
|
||||
try:
|
||||
event = await asyncio.wait_for(next_task, timeout=_KEEPALIVE_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
next_task.cancel()
|
||||
yield ": keepalive\n\n"
|
||||
continue
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
# Map the bus event onto an SSE ``event:`` name that
|
||||
# the frontend can switch on without parsing topics.
|
||||
yield _format_sse(
|
||||
_sse_name_for(event.topic),
|
||||
{
|
||||
"topic": event.topic,
|
||||
"type": event.type,
|
||||
"ts": event.ts,
|
||||
"payload": event.payload,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
log.exception("topology events stream crashed topology_id=%s", topology_id)
|
||||
yield _format_sse("error", {"message": "Stream interrupted"})
|
||||
|
||||
return StreamingResponse(
|
||||
generator(),
|
||||
|
||||
65
decnet/web/sse_limits.py
Normal file
65
decnet/web/sse_limits.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Per-user concurrent SSE connection gate.
|
||||
|
||||
SSE connections are long-lived — a client that opens one per tab
|
||||
forever can exhaust API workers. Module-level dict + async lock keeps
|
||||
the fast path cheap (a dict lookup) while the lock keeps check-and-
|
||||
increment atomic across concurrent handshakes.
|
||||
|
||||
The slot must wrap the generator's own lifetime, not just the handler
|
||||
call, because StreamingResponse returns before the generator body
|
||||
runs. Call it as the first statement inside the generator — an
|
||||
HTTPException raised before the first yield bubbles back to the client
|
||||
as a normal HTTP response.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
DEFAULT_CAP = 5
|
||||
_MAX_PER_USER = int(os.environ.get("DECNET_SSE_MAX_PER_USER", DEFAULT_CAP))
|
||||
_counts: dict[str, int] = defaultdict(int)
|
||||
_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_lock() -> asyncio.Lock:
|
||||
global _lock
|
||||
if _lock is None:
|
||||
_lock = asyncio.Lock()
|
||||
return _lock
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear counters + lock between tests. The lock is rebuilt lazily
|
||||
so a fixture can reset state without worrying about event-loop
|
||||
binding from a previous test."""
|
||||
global _lock
|
||||
_counts.clear()
|
||||
_lock = None
|
||||
|
||||
|
||||
def current_count(user_uuid: str) -> int:
|
||||
"""Snapshot helper — tests and diagnostics only."""
|
||||
return _counts.get(user_uuid, 0)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_connection_slot(user_uuid: str):
|
||||
async with _get_lock():
|
||||
if _counts[user_uuid] >= _MAX_PER_USER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"SSE connection limit ({_MAX_PER_USER}) reached",
|
||||
)
|
||||
_counts[user_uuid] += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with _get_lock():
|
||||
_counts[user_uuid] -= 1
|
||||
if _counts[user_uuid] <= 0:
|
||||
del _counts[user_uuid]
|
||||
Reference in New Issue
Block a user