diff --git a/decnet/web/_uvicorn_tls_scope.py b/decnet/web/_uvicorn_tls_scope.py new file mode 100644 index 0000000..f68ad56 --- /dev/null +++ b/decnet/web/_uvicorn_tls_scope.py @@ -0,0 +1,72 @@ +"""Inject the TLS peer cert into ASGI scope — uvicorn ≤ 0.44 does not. + +Uvicorn's h11/httptools HTTP protocols build the ASGI ``scope`` dict +without any ``extensions.tls`` entry, so per-request cert pinning +handlers (like POST /swarm/heartbeat) can't see the client cert that +CERT_REQUIRED already validated at handshake. + +We patch ``RequestResponseCycle.__init__`` on both protocol modules to +read the peer cert off the asyncio transport (which *does* carry it) +and write the DER bytes into +``scope["extensions"]["tls"]["client_cert_chain"]``. This is the same +key the ASGI TLS extension proposal uses, so the application code will +keep working unchanged if a future uvicorn populates it natively. + +Import this module once at app startup time (before uvicorn starts +accepting connections). Idempotent — subsequent imports are no-ops. +""" +from __future__ import annotations + +from typing import Any + + +_PATCHED = False + + +def _wrap_cycle_init(cycle_cls) -> None: + original = cycle_cls.__init__ + + def _patched_init(self, *args: Any, **kwargs: Any) -> None: + original(self, *args, **kwargs) + transport = kwargs.get("transport") or getattr(self, "transport", None) + if transport is None: + return + ssl_obj = transport.get_extra_info("ssl_object") + if ssl_obj is None: + return + try: + der = ssl_obj.getpeercert(binary_form=True) + except Exception: + return + if not der: + return + # scope is a mutable dict uvicorn stores here; Starlette forwards + # it to handlers as request.scope. Use setdefault so we don't clobber + # any future native extension entries from uvicorn itself. + scope = self.scope + extensions = scope.setdefault("extensions", {}) + extensions.setdefault("tls", {"client_cert_chain": [der]}) + + cycle_cls.__init__ = _patched_init + + +def install() -> None: + """Patch uvicorn's HTTP cycle classes. Safe to call multiple times.""" + global _PATCHED + if _PATCHED: + return + try: + from uvicorn.protocols.http import h11_impl + _wrap_cycle_init(h11_impl.RequestResponseCycle) + except Exception: # nosec B110 - optional uvicorn impl may be unavailable + pass + try: + from uvicorn.protocols.http import httptools_impl + _wrap_cycle_init(httptools_impl.RequestResponseCycle) + except Exception: # nosec B110 - optional uvicorn impl may be unavailable + pass + _PATCHED = True + + +# Auto-install on import so simply importing this module patches uvicorn. +install() diff --git a/decnet/web/swarm_api.py b/decnet/web/swarm_api.py index 669252c..43ffeb6 100644 --- a/decnet/web/swarm_api.py +++ b/decnet/web/swarm_api.py @@ -16,6 +16,8 @@ shared DB. """ from __future__ import annotations +from decnet.web import _uvicorn_tls_scope # noqa: F401 # patches uvicorn on import + from contextlib import asynccontextmanager from typing import AsyncGenerator diff --git a/tests/swarm/test_uvicorn_tls_scope.py b/tests/swarm/test_uvicorn_tls_scope.py new file mode 100644 index 0000000..d15f95d --- /dev/null +++ b/tests/swarm/test_uvicorn_tls_scope.py @@ -0,0 +1,77 @@ +"""Regression tests for the uvicorn TLS scope monkey-patch.""" +from __future__ import annotations + +from typing import Any + +import pytest + + +class _FakeSSLObject: + def __init__(self, der: bytes) -> None: + self._der = der + + def getpeercert(self, binary_form: bool = False) -> bytes: + assert binary_form is True + return self._der + + +class _FakeTransport: + def __init__(self, ssl_obj: Any = None) -> None: + self._ssl = ssl_obj + + def get_extra_info(self, key: str) -> Any: + if key == "ssl_object": + return self._ssl + return None + + +def _make_cycle_cls(): + class Cycle: + def __init__(self, scope: dict, transport: Any = None) -> None: + self.scope = scope + self.transport = transport + return Cycle + + +def test_wrap_cycle_injects_cert_into_scope() -> None: + from decnet.web._uvicorn_tls_scope import _wrap_cycle_init + + Cycle = _make_cycle_cls() + _wrap_cycle_init(Cycle) + + scope: dict = {"type": "http"} + transport = _FakeTransport(_FakeSSLObject(b"\x30\x82der")) + Cycle(scope, transport=transport) + + assert scope["extensions"]["tls"]["client_cert_chain"] == [b"\x30\x82der"] + + +def test_wrap_cycle_noop_when_no_ssl() -> None: + from decnet.web._uvicorn_tls_scope import _wrap_cycle_init + + Cycle = _make_cycle_cls() + _wrap_cycle_init(Cycle) + + scope: dict = {"type": "http"} + Cycle(scope, transport=_FakeTransport(ssl_obj=None)) + + assert "extensions" not in scope or "tls" not in scope.get("extensions", {}) + + +def test_wrap_cycle_noop_when_empty_der() -> None: + from decnet.web._uvicorn_tls_scope import _wrap_cycle_init + + Cycle = _make_cycle_cls() + _wrap_cycle_init(Cycle) + + scope: dict = {"type": "http"} + Cycle(scope, transport=_FakeTransport(_FakeSSLObject(b""))) + + assert "extensions" not in scope or "tls" not in scope.get("extensions", {}) + + +def test_install_is_idempotent() -> None: + from decnet.web import _uvicorn_tls_scope as mod + + mod.install() + mod.install() # second call must not double-wrap