fix: add get_stream_user dependency for SSE endpoint; allow query-string token for EventSource

This commit is contained in:
2026-04-09 19:20:38 -04:00
parent f20e86826d
commit d2a569496d
2 changed files with 31 additions and 2 deletions

View File

@@ -18,6 +18,35 @@ repo = SQLiteRepository(db_path=str(DB_PATH))
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_stream_user(request: Request, token: Optional[str] = None) -> str:
"""Auth dependency for SSE endpoints — accepts Bearer header OR ?token= query param.
EventSource does not support custom headers, so the query-string fallback is intentional here only.
"""
_credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
auth_header = request.headers.get("Authorization")
resolved: str | None = (
auth_header.split(" ", 1)[1]
if auth_header and auth_header.startswith("Bearer ")
else token
)
if not resolved:
raise _credentials_exception
try:
_payload: dict[str, Any] = jwt.decode(resolved, SECRET_KEY, algorithms=[ALGORITHM])
_user_uuid: Optional[str] = _payload.get("uuid")
if _user_uuid is None:
raise _credentials_exception
return _user_uuid
except jwt.PyJWTError:
raise _credentials_exception
async def get_current_user(request: Request) -> str: async def get_current_user(request: Request) -> str:
_credentials_exception = HTTPException( _credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View File

@@ -6,7 +6,7 @@ from typing import AsyncGenerator, Optional
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from decnet.web.dependencies import get_current_user, repo from decnet.web.dependencies import get_stream_user, repo
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -21,7 +21,7 @@ async def stream_events(
search: Optional[str] = None, search: Optional[str] = None,
start_time: Optional[str] = None, start_time: Optional[str] = None,
end_time: Optional[str] = None, end_time: Optional[str] = None,
current_user: str = Depends(get_current_user) current_user: str = Depends(get_stream_user)
) -> StreamingResponse: ) -> StreamingResponse:
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]: