feat(auth): jti claim and token-revocation store

Stateless JWTs had no revocation path: a stolen token stayed valid for
its full 24h even after the victim changed their password, and there was
no logout. This lays the foundation for revoking them.

- User.tokens_valid_from: per-user bulk-revocation cutoff (compared against
  the token's iat). RevokedToken(jti PK, exp): single-token denylist, pruned
  opportunistically on insert so it never outgrows live-but-revoked tokens.
- login() now mints a jti; create_access_token already stamps iat/exp.
- repo.revoke_token / is_token_revoked / set_tokens_valid_from (abstract +
  shared sqlmodel impl + DummyRepo coverage stubs).
- Centralized validate path in dependencies.py: every auth dependency now
  resolves the user and fails closed on (1) missing jti (legacy/pre-deploy
  token -> one forced re-login), (2) iat before the cutoff, (3) a denylisted
  jti. Denylist lookups ride a 10s membership cache mirroring the user cache.
- Contract/fuzz harness seeds its fixed-uuid principal under
  DECNET_CONTRACT_TEST so its minted token resolves to a live admin user.
This commit is contained in:
2026-05-30 18:18:41 -04:00
parent fdb6507c6f
commit 698ecaa322
11 changed files with 392 additions and 39 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
import asyncio
import time
from datetime import datetime, timezone
from typing import Any, Optional
import jwt
@@ -43,13 +44,24 @@ _USERNAME_TTL = 5.0
_username_cache: dict[str, tuple[dict[str, Any], float]] = {}
_username_cache_lock: Optional[asyncio.Lock] = None
# Denylist membership cache for revoked jti lookups. Same 10s envelope as the
# user cache: a token revoked elsewhere stops working within _REVOKED_TTL. In
# this process we drop the stale entry on revoke (see invalidate_token_cache),
# so logout is immediate locally; the TTL only bounds cross-worker staleness.
_REVOKED_TTL = 10.0
_revoked_cache: dict[str, tuple[bool, float]] = {}
_revoked_cache_lock: Optional[asyncio.Lock] = None
def _reset_user_cache() -> None:
global _user_cache, _user_cache_lock, _username_cache, _username_cache_lock
global _revoked_cache, _revoked_cache_lock
_user_cache = {}
_user_cache_lock = None
_username_cache = {}
_username_cache_lock = None
_revoked_cache = {}
_revoked_cache_lock = None
def invalidate_user_cache(user_uuid: Optional[str] = None) -> None:
@@ -66,6 +78,16 @@ def invalidate_user_cache(user_uuid: Optional[str] = None) -> None:
_username_cache.clear()
def invalidate_token_cache(jti: Optional[str] = None) -> None:
"""Drop a single jti (or the whole denylist cache) so the next request
re-reads revocation state from the DB. Called right after ``revoke_token``
so a logged-out token stops working immediately in this process."""
if jti is None:
_revoked_cache.clear()
else:
_revoked_cache.pop(jti, None)
async def get_user_by_username_cached(username: str) -> Optional[dict[str, Any]]:
"""Cached read of get_user_by_username for the login path.
@@ -108,6 +130,24 @@ async def _get_user_cached(user_uuid: str) -> Optional[dict[str, Any]]:
return user
async def _is_revoked_cached(jti: str) -> bool:
global _revoked_cache_lock
entry = _revoked_cache.get(jti)
now = time.monotonic()
if entry is not None and now - entry[1] < _REVOKED_TTL:
return entry[0]
if _revoked_cache_lock is None:
_revoked_cache_lock = asyncio.Lock()
async with _revoked_cache_lock:
entry = _revoked_cache.get(jti)
now = time.monotonic()
if entry is not None and now - entry[1] < _REVOKED_TTL:
return entry[0]
revoked = await repo.is_token_revoked(jti)
_revoked_cache[jti] = (revoked, time.monotonic())
return revoked
_CREDENTIALS_EXCEPTION = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@@ -115,16 +155,53 @@ _CREDENTIALS_EXCEPTION = HTTPException(
)
def _jwt_to_uuid(token: str) -> str:
"""Decode a raw JWT string and return the user UUID, or raise 401."""
def _epoch(value: Any) -> float:
"""Coerce a JWT ``iat`` (int seconds) or a stored datetime to UTC epoch
seconds so the two can be compared regardless of source. Naive datetimes
(SQLite round-trips lose tzinfo) are treated as the UTC we wrote."""
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, datetime):
aware = value.replace(tzinfo=timezone.utc) if value.tzinfo is None else value
return aware.timestamp()
raise _CREDENTIALS_EXCEPTION
def _decode_payload(token: str) -> dict[str, Any]:
"""Decode + signature/expiry-verify a raw JWT, or raise 401."""
try:
payload: dict[str, Any] = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_uuid: Optional[str] = payload.get("uuid")
if user_uuid is None:
raise _CREDENTIALS_EXCEPTION
return user_uuid
except jwt.PyJWTError:
raise _CREDENTIALS_EXCEPTION
if payload.get("uuid") is None:
raise _CREDENTIALS_EXCEPTION
return payload
async def _resolve_token(token: str) -> tuple[str, dict[str, Any]]:
"""Decode a token, load its user, and enforce revocation. Returns
``(user_uuid, user_dict)`` or raises 401. Single chokepoint so every auth
path (header, SSE query param, role gates) shares identical revocation
semantics."""
payload = _decode_payload(token)
user_uuid: str = payload["uuid"]
user = await _get_user_cached(user_uuid)
if not user:
# Unknown / deleted user — also covers the user-delete revocation case.
raise _CREDENTIALS_EXCEPTION
# 1. Legacy tokens minted before jti existed cannot be revoked — fail closed
# so a deploy of this feature forces exactly one re-login.
jti = payload.get("jti")
if not jti:
raise _CREDENTIALS_EXCEPTION
# 2. Bulk cutoff: password/role change moves tokens_valid_from forward.
cutoff = user.get("tokens_valid_from")
if cutoff is not None and _epoch(payload.get("iat", 0)) < _epoch(cutoff):
raise _CREDENTIALS_EXCEPTION
# 3. Single-token denylist (logout).
if await _is_revoked_cached(jti):
raise _CREDENTIALS_EXCEPTION
return user_uuid, user
def _bearer_from_header(request: Request) -> Optional[str]:
@@ -134,6 +211,24 @@ def _bearer_from_header(request: Request) -> Optional[str]:
return None
async def _resolve_request(request: Request) -> tuple[str, dict[str, Any]]:
"""Bearer-header variant of :func:`_resolve_token`."""
token = _bearer_from_header(request)
if not token:
raise _CREDENTIALS_EXCEPTION
return await _resolve_token(token)
def get_token_claims(request: Request) -> dict[str, Any]:
"""Return the validated claims of the presented Bearer token (decode +
signature + revocation checks). Used by logout, which needs the token's own
``jti``/``exp`` to denylist *this* session even for must_change users."""
token = _bearer_from_header(request)
if not token:
raise _CREDENTIALS_EXCEPTION
return _decode_payload(token)
async def get_stream_user(request: Request, token: Optional[str] = None) -> str:
"""Auth dependency for SSE endpoints — accepts Bearer header OR ?token= query param.
EventSource does not support custom headers, so the query-string fallback is intentional here only.
@@ -141,22 +236,16 @@ async def get_stream_user(request: Request, token: Optional[str] = None) -> str:
resolved = _bearer_from_header(request) or token
if not resolved:
raise _CREDENTIALS_EXCEPTION
return _jwt_to_uuid(resolved)
async def _decode_token(request: Request) -> str:
"""Decode and validate a Bearer JWT, returning the user UUID."""
token = _bearer_from_header(request)
if not token:
raise _CREDENTIALS_EXCEPTION
return _jwt_to_uuid(token)
# Decode-only: returns the uuid. Revocation/role enforcement happens in
# require_stream_role (the sole production caller), which runs the full
# _resolve_token path. Kept thin so its decode contract stays unit-testable.
return _decode_payload(resolved)["uuid"]
async def get_current_user(request: Request) -> str:
"""Auth dependency — enforces must_change_password."""
_user_uuid = await _decode_token(request)
_user = await _get_user_cached(_user_uuid)
if _user and _user.get("must_change_password"):
_user_uuid, _user = await _resolve_request(request)
if _user.get("must_change_password"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Password change required before accessing this resource",
@@ -165,10 +254,12 @@ async def get_current_user(request: Request) -> str:
async def get_current_user_unchecked(request: Request) -> str:
"""Auth dependency — skips must_change_password enforcement.
"""Auth dependency — skips must_change_password enforcement (but still
enforces signature, user existence, and revocation).
Use only for endpoints that must remain reachable with the flag set (e.g. change-password).
"""
return await _decode_token(request)
_user_uuid, _user = await _resolve_request(request)
return _user_uuid
# ---------------------------------------------------------------------------
@@ -184,14 +275,7 @@ def require_role(*allowed_roles: str):
endpoints can inspect ``user["uuid"]``, ``user["role"]``, etc.
"""
async def _check(request: Request) -> dict:
user_uuid = await _decode_token(request)
user = await _get_user_cached(user_uuid)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
_user_uuid, user = await _resolve_request(request)
if user.get("must_change_password"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -209,9 +293,11 @@ def require_role(*allowed_roles: str):
def require_stream_role(*allowed_roles: str):
"""Like ``require_role`` but for SSE endpoints that accept a query-param token."""
async def _check(request: Request, token: Optional[str] = None) -> dict:
user_uuid = await get_stream_user(request, token)
user = await _get_user_cached(user_uuid)
if not user or user["role"] not in allowed_roles:
resolved = _bearer_from_header(request) or token
if not resolved:
raise _CREDENTIALS_EXCEPTION
_user_uuid, user = await _resolve_token(resolved)
if user["role"] not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions",