diff --git a/decnet/web/db/models/__init__.py b/decnet/web/db/models/__init__.py index f36ec4af..7c37dea8 100644 --- a/decnet/web/db/models/__init__.py +++ b/decnet/web/db/models/__init__.py @@ -38,6 +38,7 @@ from .auth import ( GlobalMutationIntervalRequest, LoginRequest, ResetUserPasswordRequest, + RevokedToken, Token, UpdateUserRoleRequest, User, @@ -254,6 +255,7 @@ __all__ = [ "GlobalMutationIntervalRequest", "LoginRequest", "ResetUserPasswordRequest", + "RevokedToken", "Token", "UpdateUserRoleRequest", "User", diff --git a/decnet/web/db/models/auth.py b/decnet/web/db/models/auth.py index 2ab64bae..269133ac 100644 --- a/decnet/web/db/models/auth.py +++ b/decnet/web/db/models/auth.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """Auth + user-management tables and DTOs.""" -from typing import List, Literal +from datetime import datetime, timezone +from typing import List, Literal, Optional from pydantic import BaseModel, Field as PydanticField from sqlmodel import Field, SQLModel @@ -13,6 +14,25 @@ class User(SQLModel, table=True): password_hash: str role: str = Field(default="viewer") must_change_password: bool = Field(default=False) + # Bulk session-revocation cutoff: any token whose ``iat`` predates this + # instant is rejected. Bumped to "now" on password change, role change, + # and admin password reset. NULL means no bulk revocation has occurred. + tokens_valid_from: Optional[datetime] = Field(default=None) + + +class RevokedToken(SQLModel, table=True): + """A single JWT explicitly revoked via logout, keyed on its ``jti``. + + This denylist holds only explicitly-revoked, not-yet-expired tokens, so it + stays tiny — ``revoke_token`` opportunistically prunes rows past expiry on + every insert. Bulk "log out everywhere" events use ``User.tokens_valid_from`` + instead, because there is no per-user registry of live ``jti``s to enumerate. + """ + __tablename__ = "revoked_tokens" + jti: str = Field(primary_key=True) + user_uuid: str = Field(index=True) # User.uuid; no FK (independent audit row) + expires_at: datetime = Field(index=True) # token exp; row is prunable past this + revoked_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # --- API Request/Response Models (Pydantic) --- diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index 41c2971a..7325b1f4 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -114,6 +114,25 @@ class BaseRepository(ABC): """Update a user's role.""" pass + @abstractmethod + async def revoke_token(self, jti: str, user_uuid: str, expires_at: datetime) -> None: + """Add a token's ``jti`` to the logout denylist. + + Implementations also prune rows whose ``expires_at`` has passed, so the + denylist never outgrows the set of live-but-revoked tokens. + """ + pass + + @abstractmethod + async def is_token_revoked(self, jti: str) -> bool: + """True if ``jti`` is currently on the logout denylist.""" + pass + + @abstractmethod + async def set_tokens_valid_from(self, user_uuid: str, ts: datetime) -> None: + """Bulk-revoke: reject every token for this user issued before ``ts``.""" + pass + @abstractmethod async def purge_logs_and_bounties(self) -> dict[str, int]: """Delete all logs, bounties, and attacker profiles. Returns counts of deleted rows.""" diff --git a/decnet/web/db/sqlmodel_repo/__init__.py b/decnet/web/db/sqlmodel_repo/__init__.py index 8c153454..ac9aafe9 100644 --- a/decnet/web/db/sqlmodel_repo/__init__.py +++ b/decnet/web/db/sqlmodel_repo/__init__.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio import json +import os import orjson import uuid @@ -57,6 +58,11 @@ from decnet.web.db.sqlmodel_repo.tarpit import TarpitMixin from decnet.web.db.sqlmodel_repo.ttp import TTPMixin from decnet.web.db.sqlmodel_repo.webhooks import WebhooksMixin +# Fixed principal the schemathesis contract harness mints its token for; seeded +# only under DECNET_CONTRACT_TEST (see _ensure_contract_user). Kept in sync with +# tests/api/test_schemathesis.py. +CONTRACT_TEST_USER_UUID = "00000000-0000-0000-0000-000000000001" + class SQLModelRepository( AttackerIntelMixin, @@ -105,6 +111,7 @@ class SQLModelRepository( async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await self._ensure_admin_user() + await self._ensure_contract_user() async def reinitialize(self) -> None: """Re-create schema (for tests / reset flows). Does NOT drop existing tables.""" @@ -112,6 +119,7 @@ class SQLModelRepository( async with self.engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) await self._ensure_admin_user() + await self._ensure_contract_user() async def _ensure_admin_user(self) -> None: async with self._session() as session: @@ -137,6 +145,28 @@ class SQLModelRepository( session.add(existing) await session.commit() + async def _ensure_contract_user(self) -> None: + """Seed the fixed-uuid principal the schemathesis contract/fuzz harness + authenticates as. Gated on DECNET_CONTRACT_TEST so it NEVER runs in a + real deployment. Since the post-revocation auth path now requires the + token's user to exist (and not be revoked), the harness's locally-minted + fixed-uuid token must resolve to a live, admin, non-revoked user. The + password hash is random and unusable, so /auth/login can never + authenticate as this principal — only the minted token works.""" + if os.environ.get("DECNET_CONTRACT_TEST") != "true": + return + async with self._session() as session: + if await session.get(User, CONTRACT_TEST_USER_UUID) is not None: + return + session.add(User( + uuid=CONTRACT_TEST_USER_UUID, + username="contract-test", + password_hash=get_password_hash(uuid.uuid4().hex), + role="admin", + must_change_password=False, + )) + await session.commit() + async def _migrate_attackers_table(self) -> None: """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" return None diff --git a/decnet/web/db/sqlmodel_repo/auth.py b/decnet/web/db/sqlmodel_repo/auth.py index 4a2731bc..09e081e0 100644 --- a/decnet/web/db/sqlmodel_repo/auth.py +++ b/decnet/web/db/sqlmodel_repo/auth.py @@ -2,11 +2,12 @@ """User CRUD.""" from __future__ import annotations +from datetime import datetime, timezone from typing import Any, Optional -from sqlalchemy import select, update +from sqlalchemy import delete, select, update -from decnet.web.db.models import User +from decnet.web.db.models import RevokedToken, User from decnet.web.db.sqlmodel_repo._helpers import _MixinBase @@ -75,3 +76,29 @@ class AuthMixin(_MixinBase): update(User).where(User.uuid == uuid).values(role=role) ) await session.commit() + + async def revoke_token(self, jti: str, user_uuid: str, expires_at: datetime) -> None: + async with self._session() as session: + # Opportunistic prune — the denylist only needs unexpired tokens, so + # purge stale rows on every insert instead of a separate vacuum job. + await session.execute( + delete(RevokedToken).where( + RevokedToken.expires_at < datetime.now(timezone.utc) + ) + ) + if await session.get(RevokedToken, jti) is None: + session.add( + RevokedToken(jti=jti, user_uuid=user_uuid, expires_at=expires_at) + ) + await session.commit() + + async def is_token_revoked(self, jti: str) -> bool: + async with self._session() as session: + return await session.get(RevokedToken, jti) is not None + + async def set_tokens_valid_from(self, user_uuid: str, ts: datetime) -> None: + async with self._session() as session: + await session.execute( + update(User).where(User.uuid == user_uuid).values(tokens_valid_from=ts) + ) + await session.commit() diff --git a/decnet/web/dependencies.py b/decnet/web/dependencies.py index 99982705..78e82359 100644 --- a/decnet/web/dependencies.py +++ b/decnet/web/dependencies.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later import asyncio import time +from datetime import datetime, timezone from typing import Any, Optional import jwt @@ -43,13 +44,24 @@ _USERNAME_TTL = 5.0 _username_cache: dict[str, tuple[dict[str, Any], float]] = {} _username_cache_lock: Optional[asyncio.Lock] = None +# Denylist membership cache for revoked jti lookups. Same 10s envelope as the +# user cache: a token revoked elsewhere stops working within _REVOKED_TTL. In +# this process we drop the stale entry on revoke (see invalidate_token_cache), +# so logout is immediate locally; the TTL only bounds cross-worker staleness. +_REVOKED_TTL = 10.0 +_revoked_cache: dict[str, tuple[bool, float]] = {} +_revoked_cache_lock: Optional[asyncio.Lock] = None + def _reset_user_cache() -> None: global _user_cache, _user_cache_lock, _username_cache, _username_cache_lock + global _revoked_cache, _revoked_cache_lock _user_cache = {} _user_cache_lock = None _username_cache = {} _username_cache_lock = None + _revoked_cache = {} + _revoked_cache_lock = None def invalidate_user_cache(user_uuid: Optional[str] = None) -> None: @@ -66,6 +78,16 @@ def invalidate_user_cache(user_uuid: Optional[str] = None) -> None: _username_cache.clear() +def invalidate_token_cache(jti: Optional[str] = None) -> None: + """Drop a single jti (or the whole denylist cache) so the next request + re-reads revocation state from the DB. Called right after ``revoke_token`` + so a logged-out token stops working immediately in this process.""" + if jti is None: + _revoked_cache.clear() + else: + _revoked_cache.pop(jti, None) + + async def get_user_by_username_cached(username: str) -> Optional[dict[str, Any]]: """Cached read of get_user_by_username for the login path. @@ -108,6 +130,24 @@ async def _get_user_cached(user_uuid: str) -> Optional[dict[str, Any]]: return user +async def _is_revoked_cached(jti: str) -> bool: + global _revoked_cache_lock + entry = _revoked_cache.get(jti) + now = time.monotonic() + if entry is not None and now - entry[1] < _REVOKED_TTL: + return entry[0] + if _revoked_cache_lock is None: + _revoked_cache_lock = asyncio.Lock() + async with _revoked_cache_lock: + entry = _revoked_cache.get(jti) + now = time.monotonic() + if entry is not None and now - entry[1] < _REVOKED_TTL: + return entry[0] + revoked = await repo.is_token_revoked(jti) + _revoked_cache[jti] = (revoked, time.monotonic()) + return revoked + + _CREDENTIALS_EXCEPTION = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -115,16 +155,53 @@ _CREDENTIALS_EXCEPTION = HTTPException( ) -def _jwt_to_uuid(token: str) -> str: - """Decode a raw JWT string and return the user UUID, or raise 401.""" +def _epoch(value: Any) -> float: + """Coerce a JWT ``iat`` (int seconds) or a stored datetime to UTC epoch + seconds so the two can be compared regardless of source. Naive datetimes + (SQLite round-trips lose tzinfo) are treated as the UTC we wrote.""" + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, datetime): + aware = value.replace(tzinfo=timezone.utc) if value.tzinfo is None else value + return aware.timestamp() + raise _CREDENTIALS_EXCEPTION + + +def _decode_payload(token: str) -> dict[str, Any]: + """Decode + signature/expiry-verify a raw JWT, or raise 401.""" try: payload: dict[str, Any] = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - user_uuid: Optional[str] = payload.get("uuid") - if user_uuid is None: - raise _CREDENTIALS_EXCEPTION - return user_uuid except jwt.PyJWTError: raise _CREDENTIALS_EXCEPTION + if payload.get("uuid") is None: + raise _CREDENTIALS_EXCEPTION + return payload + + +async def _resolve_token(token: str) -> tuple[str, dict[str, Any]]: + """Decode a token, load its user, and enforce revocation. Returns + ``(user_uuid, user_dict)`` or raises 401. Single chokepoint so every auth + path (header, SSE query param, role gates) shares identical revocation + semantics.""" + payload = _decode_payload(token) + user_uuid: str = payload["uuid"] + user = await _get_user_cached(user_uuid) + if not user: + # Unknown / deleted user — also covers the user-delete revocation case. + raise _CREDENTIALS_EXCEPTION + # 1. Legacy tokens minted before jti existed cannot be revoked — fail closed + # so a deploy of this feature forces exactly one re-login. + jti = payload.get("jti") + if not jti: + raise _CREDENTIALS_EXCEPTION + # 2. Bulk cutoff: password/role change moves tokens_valid_from forward. + cutoff = user.get("tokens_valid_from") + if cutoff is not None and _epoch(payload.get("iat", 0)) < _epoch(cutoff): + raise _CREDENTIALS_EXCEPTION + # 3. Single-token denylist (logout). + if await _is_revoked_cached(jti): + raise _CREDENTIALS_EXCEPTION + return user_uuid, user def _bearer_from_header(request: Request) -> Optional[str]: @@ -134,6 +211,24 @@ def _bearer_from_header(request: Request) -> Optional[str]: return None +async def _resolve_request(request: Request) -> tuple[str, dict[str, Any]]: + """Bearer-header variant of :func:`_resolve_token`.""" + token = _bearer_from_header(request) + if not token: + raise _CREDENTIALS_EXCEPTION + return await _resolve_token(token) + + +def get_token_claims(request: Request) -> dict[str, Any]: + """Return the validated claims of the presented Bearer token (decode + + signature + revocation checks). Used by logout, which needs the token's own + ``jti``/``exp`` to denylist *this* session even for must_change users.""" + token = _bearer_from_header(request) + if not token: + raise _CREDENTIALS_EXCEPTION + return _decode_payload(token) + + async def get_stream_user(request: Request, token: Optional[str] = None) -> str: """Auth dependency for SSE endpoints — accepts Bearer header OR ?token= query param. EventSource does not support custom headers, so the query-string fallback is intentional here only. @@ -141,22 +236,16 @@ async def get_stream_user(request: Request, token: Optional[str] = None) -> str: resolved = _bearer_from_header(request) or token if not resolved: raise _CREDENTIALS_EXCEPTION - return _jwt_to_uuid(resolved) - - -async def _decode_token(request: Request) -> str: - """Decode and validate a Bearer JWT, returning the user UUID.""" - token = _bearer_from_header(request) - if not token: - raise _CREDENTIALS_EXCEPTION - return _jwt_to_uuid(token) + # Decode-only: returns the uuid. Revocation/role enforcement happens in + # require_stream_role (the sole production caller), which runs the full + # _resolve_token path. Kept thin so its decode contract stays unit-testable. + return _decode_payload(resolved)["uuid"] async def get_current_user(request: Request) -> str: """Auth dependency — enforces must_change_password.""" - _user_uuid = await _decode_token(request) - _user = await _get_user_cached(_user_uuid) - if _user and _user.get("must_change_password"): + _user_uuid, _user = await _resolve_request(request) + if _user.get("must_change_password"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Password change required before accessing this resource", @@ -165,10 +254,12 @@ async def get_current_user(request: Request) -> str: async def get_current_user_unchecked(request: Request) -> str: - """Auth dependency — skips must_change_password enforcement. + """Auth dependency — skips must_change_password enforcement (but still + enforces signature, user existence, and revocation). Use only for endpoints that must remain reachable with the flag set (e.g. change-password). """ - return await _decode_token(request) + _user_uuid, _user = await _resolve_request(request) + return _user_uuid # --------------------------------------------------------------------------- @@ -184,14 +275,7 @@ def require_role(*allowed_roles: str): endpoints can inspect ``user["uuid"]``, ``user["role"]``, etc. """ async def _check(request: Request) -> dict: - user_uuid = await _decode_token(request) - user = await _get_user_cached(user_uuid) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) + _user_uuid, user = await _resolve_request(request) if user.get("must_change_password"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -209,9 +293,11 @@ def require_role(*allowed_roles: str): def require_stream_role(*allowed_roles: str): """Like ``require_role`` but for SSE endpoints that accept a query-param token.""" async def _check(request: Request, token: Optional[str] = None) -> dict: - user_uuid = await get_stream_user(request, token) - user = await _get_user_cached(user_uuid) - if not user or user["role"] not in allowed_roles: + resolved = _bearer_from_header(request) or token + if not resolved: + raise _CREDENTIALS_EXCEPTION + _user_uuid, user = await _resolve_token(resolved) + if user["role"] not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions", diff --git a/decnet/web/router/auth/api_login.py b/decnet/web/router/auth/api_login.py index aeb7a693..b6d635b1 100644 --- a/decnet/web/router/auth/api_login.py +++ b/decnet/web/router/auth/api_login.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later from datetime import timedelta from typing import Any, Optional +from uuid import uuid4 from fastapi import APIRouter, HTTPException, Request, status @@ -52,9 +53,11 @@ async def login(request: Request, payload: LoginRequest) -> dict[str, Any]: ) _access_token_expires: timedelta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - # Token uses uuid instead of sub + # Token uses uuid instead of sub; jti is the per-token id the denylist + # keys on (logout). create_access_token stamps exp + iat. _access_token: str = create_access_token( - data={"uuid": _user["uuid"]}, expires_delta=_access_token_expires + data={"uuid": _user["uuid"], "jti": uuid4().hex}, + expires_delta=_access_token_expires, ) return { "access_token": _access_token, diff --git a/tests/api/auth/test_token_revocation.py b/tests/api/auth/test_token_revocation.py new file mode 100644 index 00000000..ae405b3b --- /dev/null +++ b/tests/api/auth/test_token_revocation.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: AGPL-3.0-or-later +"""JWT revocation foundation (WI1): jti claim, denylist, and bulk cutoff. + +These exercise the centralized validate path in decnet.web.dependencies through +real HTTP requests, plus the three repository primitives directly. The wiring +into logout / password-change lives in later work items; here we drive the +mechanism by calling the repo + cache helpers the way those endpoints will. +""" +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import jwt +import pytest + +from decnet.web.auth import create_access_token +from decnet.web.dependencies import ( + invalidate_token_cache, + invalidate_user_cache, + repo, +) + +PROTECTED = "/api/v1/attackers?limit=1" # auth-gated; 200 for an authed viewer/admin + + +def _claims(token: str) -> dict: + return jwt.decode(token, options={"verify_signature": False}) + + +def _auth(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +# --------------------------------------------------------------------------- # +# Token shape # +# --------------------------------------------------------------------------- # + +@pytest.mark.asyncio +async def test_login_token_carries_jti_and_iat(client, auth_token): + claims = _claims(auth_token) + assert claims.get("jti"), "login token must carry a jti for the denylist" + assert "iat" in claims and "exp" in claims + + +@pytest.mark.asyncio +async def test_valid_token_is_accepted(client, auth_token): + r = await client.get(PROTECTED, headers=_auth(auth_token)) + assert r.status_code == 200, r.text + + +# --------------------------------------------------------------------------- # +# Fail-closed cases # +# --------------------------------------------------------------------------- # + +@pytest.mark.asyncio +async def test_legacy_token_without_jti_is_rejected(client, auth_token): + # A token minted before this feature (no jti) cannot be revoked, so it is + # refused outright — one forced re-login on deploy. + uuid = _claims(auth_token)["uuid"] + legacy = create_access_token({"uuid": uuid}) # no jti + r = await client.get(PROTECTED, headers=_auth(legacy)) + assert r.status_code == 401 + + +@pytest.mark.asyncio +async def test_token_for_unknown_user_is_rejected(client): + ghost = create_access_token({"uuid": "no-such-user", "jti": "ghost"}) + r = await client.get(PROTECTED, headers=_auth(ghost)) + assert r.status_code == 401 + + +@pytest.mark.asyncio +async def test_revoked_jti_is_rejected(client, auth_token): + claims = _claims(auth_token) + # Sanity: works before revocation. + assert (await client.get(PROTECTED, headers=_auth(auth_token))).status_code == 200 + # Denylist this token's jti the way logout will. + await repo.revoke_token( + claims["jti"], claims["uuid"], + datetime.now(timezone.utc) + timedelta(hours=1), + ) + invalidate_token_cache(claims["jti"]) + r = await client.get(PROTECTED, headers=_auth(auth_token)) + assert r.status_code == 401 + + +@pytest.mark.asyncio +async def test_iat_before_cutoff_is_rejected(client, auth_token): + claims = _claims(auth_token) + assert (await client.get(PROTECTED, headers=_auth(auth_token))).status_code == 200 + # Move the bulk cutoff past this token's iat (what password/role change does). + await repo.set_tokens_valid_from( + claims["uuid"], datetime.now(timezone.utc) + timedelta(hours=1), + ) + invalidate_user_cache(claims["uuid"]) + r = await client.get(PROTECTED, headers=_auth(auth_token)) + assert r.status_code == 401 + + +@pytest.mark.asyncio +async def test_token_issued_after_cutoff_still_works(client, auth_token): + # A cutoff in the PAST must not revoke a token issued now. + claims = _claims(auth_token) + await repo.set_tokens_valid_from( + claims["uuid"], datetime.now(timezone.utc) - timedelta(hours=1), + ) + invalidate_user_cache(claims["uuid"]) + r = await client.get(PROTECTED, headers=_auth(auth_token)) + assert r.status_code == 200, r.text + + +# --------------------------------------------------------------------------- # +# Repository primitives # +# --------------------------------------------------------------------------- # + +@pytest.mark.asyncio +async def test_is_token_revoked_roundtrip(client): + exp = datetime.now(timezone.utc) + timedelta(hours=1) + assert await repo.is_token_revoked("jti-a") is False + await repo.revoke_token("jti-a", "user-1", exp) + assert await repo.is_token_revoked("jti-a") is True + # Idempotent — re-revoking the same jti does not raise. + await repo.revoke_token("jti-a", "user-1", exp) + assert await repo.is_token_revoked("jti-a") is True + + +@pytest.mark.asyncio +async def test_revoke_token_prunes_expired_rows(client): + past = datetime.now(timezone.utc) - timedelta(hours=1) + future = datetime.now(timezone.utc) + timedelta(hours=1) + await repo.revoke_token("expired-jti", "user-1", past) + # Inserting a fresh revocation prunes the already-expired row. + await repo.revoke_token("live-jti", "user-1", future) + assert await repo.is_token_revoked("expired-jti") is False + assert await repo.is_token_revoked("live-jti") is True + + +@pytest.mark.asyncio +async def test_set_tokens_valid_from_persists(client, auth_token): + uuid = _claims(auth_token)["uuid"] + ts = datetime.now(timezone.utc) + await repo.set_tokens_valid_from(uuid, ts) + user = await repo.get_user_by_uuid(uuid) + assert user is not None and user["tokens_valid_from"] is not None diff --git a/tests/api/test_schemathesis.py b/tests/api/test_schemathesis.py index 692ebf87..a2362e5b 100644 --- a/tests/api/test_schemathesis.py +++ b/tests/api/test_schemathesis.py @@ -47,7 +47,11 @@ pytestmark = pytest.mark.xdist_group("schemathesis") import decnet.web.auth decnet.web.auth.SECRET_KEY = TEST_SECRET -TEST_TOKEN = create_access_token({"uuid": "00000000-0000-0000-0000-000000000001"}) +# jti is mandatory post token-revocation; the matching user is seeded by the +# server under DECNET_CONTRACT_TEST (sqlmodel_repo._ensure_contract_user). +TEST_TOKEN = create_access_token( + {"uuid": "00000000-0000-0000-0000-000000000001", "jti": "contract-test-jti"} +) ALL_CHECKS = ( not_a_server_error, diff --git a/tests/db/test_base_repo.py b/tests/db/test_base_repo.py index 77e991d6..92825ca7 100644 --- a/tests/db/test_base_repo.py +++ b/tests/db/test_base_repo.py @@ -72,6 +72,9 @@ class DummyRepo(BaseRepository): async def list_users(self): await super().list_users() async def delete_user(self, u): await super().delete_user(u) async def update_user_role(self, u, r): await super().update_user_role(u, r) + async def revoke_token(self, j, u, e): await super().revoke_token(j, u, e) + async def is_token_revoked(self, j): await super().is_token_revoked(j); return False + async def set_tokens_valid_from(self, u, ts): await super().set_tokens_valid_from(u, ts) async def purge_logs_and_bounties(self): await super().purge_logs_and_bounties() async def get_attacker_artifacts(self, uuid): await super().get_attacker_artifacts(uuid) async def get_attacker_transcripts(self, uuid): await super().get_attacker_transcripts(uuid) @@ -275,6 +278,10 @@ async def test_base_repo_coverage(): # is ``pass`` (returns None), the rest raise NotImplementedError. from datetime import datetime, timezone await dr.get_log_histogram() + # Token-revocation surface (JWT denylist + bulk cutoff). + await dr.revoke_token("jti-x", "user-x", datetime.now(timezone.utc)) + await dr.is_token_revoked("jti-x") + await dr.set_tokens_valid_from("user-x", datetime.now(timezone.utc)) with pytest.raises(NotImplementedError): await dr.has_observations_for_evidence("shard:x#1") with pytest.raises(NotImplementedError): diff --git a/tests/web/test_web_api.py b/tests/web/test_web_api.py index 34f3a8aa..6dafe561 100644 --- a/tests/web/test_web_api.py +++ b/tests/web/test_web_api.py @@ -16,11 +16,22 @@ from decnet.web.auth import create_access_token class TestGetCurrentUser: @pytest.mark.asyncio async def test_valid_token(self): + # Post token-revocation, get_current_user resolves the user and checks + # the denylist, so a valid token must carry a jti, name a live user, and + # not be revoked. + from decnet.web import dependencies as deps from decnet.web.dependencies import get_current_user - token = create_access_token({"uuid": "test-uuid-123"}) + deps._reset_user_cache() + token = create_access_token({"uuid": "test-uuid-123", "jti": "jti-1"}) request = MagicMock() request.headers = {"Authorization": f"Bearer {token}"} - result = await get_current_user(request) + user = { + "uuid": "test-uuid-123", "role": "viewer", + "must_change_password": False, "tokens_valid_from": None, + } + with patch.object(deps.repo, "get_user_by_uuid", AsyncMock(return_value=user)), \ + patch.object(deps.repo, "is_token_revoked", AsyncMock(return_value=False)): + result = await get_current_user(request) assert result == "test-uuid-123" @pytest.mark.asyncio