from typing import Any, Optional import jwt from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from decnet.web.auth import ALGORITHM, SECRET_KEY from decnet.web.db.repository import BaseRepository from decnet.web.db.factory import get_repository # Shared repository singleton _repo: Optional[BaseRepository] = None def get_repo() -> BaseRepository: """FastAPI dependency to inject the configured repository.""" global _repo if _repo is None: _repo = get_repository() return _repo repo = get_repo() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") async def get_stream_user(request: Request, token: Optional[str] = None) -> str: """Auth dependency for SSE endpoints — accepts Bearer header OR ?token= query param. EventSource does not support custom headers, so the query-string fallback is intentional here only. """ _credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) auth_header = request.headers.get("Authorization") resolved: str | None = ( auth_header.split(" ", 1)[1] if auth_header and auth_header.startswith("Bearer ") else token ) if not resolved: raise _credentials_exception try: _payload: dict[str, Any] = jwt.decode(resolved, SECRET_KEY, algorithms=[ALGORITHM]) _user_uuid: Optional[str] = _payload.get("uuid") if _user_uuid is None: raise _credentials_exception return _user_uuid except jwt.PyJWTError: raise _credentials_exception async def _decode_token(request: Request) -> str: """Decode and validate a Bearer JWT, returning the user UUID.""" _credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) auth_header = request.headers.get("Authorization") token: str | None = ( auth_header.split(" ", 1)[1] if auth_header and auth_header.startswith("Bearer ") else None ) if not token: raise _credentials_exception try: _payload: dict[str, Any] = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) _user_uuid: Optional[str] = _payload.get("uuid") if _user_uuid is None: raise _credentials_exception return _user_uuid except jwt.PyJWTError: raise _credentials_exception async def get_current_user(request: Request) -> str: """Auth dependency — enforces must_change_password.""" _user_uuid = await _decode_token(request) _user = await repo.get_user_by_uuid(_user_uuid) if _user and _user.get("must_change_password"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Password change required before accessing this resource", ) return _user_uuid async def get_current_user_unchecked(request: Request) -> str: """Auth dependency — skips must_change_password enforcement. Use only for endpoints that must remain reachable with the flag set (e.g. change-password). """ return await _decode_token(request) # --------------------------------------------------------------------------- # Role-based access control # --------------------------------------------------------------------------- def require_role(*allowed_roles: str): """Factory that returns a FastAPI dependency enforcing role membership. The returned dependency chains from ``get_current_user`` (JWT + must_change_password) then verifies the user's role is in *allowed_roles*. Returns the full user dict so endpoints can inspect ``user["uuid"]``, ``user["role"]``, etc. without a second lookup. """ async def _check(current_user: str = Depends(get_current_user)) -> dict: user = await repo.get_user_by_uuid(current_user) if not user or user["role"] not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions", ) return user return _check def require_stream_role(*allowed_roles: str): """Like ``require_role`` but for SSE endpoints that accept a query-param token.""" async def _check(request: Request, token: Optional[str] = None) -> dict: user_uuid = await get_stream_user(request, token) user = await repo.get_user_by_uuid(user_uuid) if not user or user["role"] not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions", ) return user return _check require_admin = require_role("admin") require_viewer = require_role("viewer", "admin") require_stream_viewer = require_stream_role("viewer", "admin")