fix(security): close MEDIUM ASVS findings — JWT pinning, SSE tickets, SSRF, mTLS pin, rate limits + correctness bugs
Auth (V2.1.1/V3.1.2, V2.1.3, V3.1.1): - Pin JWT iss/aud/typ at mint and require+verify them at decode; revocation (jti denylist + tokens_valid_from) still enforced. - Change-password now requires min_length=12. - SSE auth moves off JWT-in-URL to a single-use 60s opaque ticket (POST /auth/sse-ticket); raw JWT in query no longer authenticates a stream. Removed dead fail-open get_stream_user helper. Egress (V5.1.1, V9.1.1/V14.1.3): - Webhook delivery + CRUD reject SSRF destinations (private/loopback/link-local/ metadata, IPv4-mapped, multi-A-record) via resolved-IP validation, pin to the vetted IP, and never auto-follow redirects. Opt-out via DECNET_WEBHOOK_ALLOW_PRIVATE. - UpdaterClient pins the worker leaf cert SHA-256 against the stored per-host fingerprint (fail closed on missing/mismatch); DECNET_VERIFY_HOSTNAME now defaults True. Hardening (V13.1.3, V4.1.4, V13.1.2): - Rate-limit change-password (5/min), enroll-bundle (10/min), webhook-create (20/min), host-delete (20/min) via the existing slowapi limiter. - Correct false 'global auth middleware' comment; document enroll-bundle proxy trust. Correctness (BUG-7..11): - BUG-7 unbound bus in finally; BUG-8 apply_ceiling clamps to min(base,ceiling); BUG-9 commit before emit; BUG-10 multi-actor rearm for sub-threshold identities; BUG-11 normalize naive timestamps to UTC. Already-closed (no change): V14.1.1, V2.1.2/V3.1.3, V5.1.2. Tests added for every fix; unanimous adversarial review.
This commit is contained in:
@@ -344,11 +344,11 @@ async def tick_multi_actor(
|
|||||||
for entry in candidates:
|
for entry in candidates:
|
||||||
identity_uuid = str(entry["identity_uuid"])
|
identity_uuid = str(entry["identity_uuid"])
|
||||||
primitives: list[str] = sorted(entry.get("primitives") or [])
|
primitives: list[str] = sorted(entry.get("primitives") or [])
|
||||||
seen_now.add(identity_uuid)
|
|
||||||
if len(primitives) < _T.MULTI_ACTOR_MIN_PRIMITIVES:
|
if len(primitives) < _T.MULTI_ACTOR_MIN_PRIMITIVES:
|
||||||
# Repo already filters to >= 2 today; defensive against
|
# Repo already filters to >= 2 today; defensive against
|
||||||
# future schema drift.
|
# future schema drift.
|
||||||
continue
|
continue
|
||||||
|
seen_now.add(identity_uuid)
|
||||||
signature = frozenset(primitives)
|
signature = frozenset(primitives)
|
||||||
if last_fired.get(identity_uuid) == signature:
|
if last_fired.get(identity_uuid) == signature:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -139,13 +139,20 @@ def record_fingerprint(
|
|||||||
"ts": ts.isoformat(),
|
"ts": ts.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if publish_fn is not None:
|
|
||||||
publish_fn(_ROTATED_EVENT_TYPE, payload)
|
|
||||||
if syslog_fn is not None:
|
|
||||||
syslog_fn(_ROTATED_EVENT_TYPE, payload)
|
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if publish_fn is not None:
|
||||||
|
publish_fn(_ROTATED_EVENT_TYPE, payload)
|
||||||
|
if syslog_fn is not None:
|
||||||
|
syslog_fn(_ROTATED_EVENT_TYPE, payload)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
import logging as _logging
|
||||||
|
_logging.getLogger(__name__).warning(
|
||||||
|
"fingerprint_rotation: post-commit emit failed (state already durable)",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
return RotationOutcome(
|
return RotationOutcome(
|
||||||
kind="rotated",
|
kind="rotated",
|
||||||
old_hash=old_hash,
|
old_hash=old_hash,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
# RFC 5424 line structure
|
# RFC 5424 line structure
|
||||||
@@ -159,6 +159,8 @@ def parse_line(line: str) -> LogEvent | None:
|
|||||||
timestamp = datetime.fromisoformat(ts_raw)
|
timestamp = datetime.fromisoformat(ts_raw)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
if timestamp.tzinfo is None:
|
||||||
|
timestamp = timestamp.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
fields = _parse_sd_params(sd_rest)
|
fields = _parse_sd_params(sd_rest)
|
||||||
if sd_rest.startswith("-"):
|
if sd_rest.startswith("-"):
|
||||||
|
|||||||
@@ -204,13 +204,26 @@ _cors_raw: str = os.environ.get("DECNET_CORS_ORIGINS", _cors_default)
|
|||||||
DECNET_CORS_ORIGINS: list[str] = [o.strip() for o in _cors_raw.split(",") if o.strip()]
|
DECNET_CORS_ORIGINS: list[str] = [o.strip() for o in _cors_raw.split(",") if o.strip()]
|
||||||
|
|
||||||
|
|
||||||
# Master→worker mTLS hostname verification. Off by default because legacy
|
# Master→worker mTLS hostname verification. ON by default — the worker's
|
||||||
# enrollments were issued certs with operator-supplied SAN lists that may
|
# cert SAN must match the address the master connects to, on top of CA +
|
||||||
# not match the URL the master uses to connect; set to "true" on a fresh
|
# SHA-256 fingerprint pinning. Operators with legacy enrollments whose
|
||||||
# production deploy where you control enrollment to get TLS hostname checks
|
# operator-supplied SAN lists don't match the connect URL can opt OUT
|
||||||
# on top of CA + fingerprint pinning.
|
# explicitly with DECNET_VERIFY_HOSTNAME=false, but that is a downgrade:
|
||||||
|
# it drops SAN binding and leans entirely on CA + per-host pinning.
|
||||||
DECNET_VERIFY_HOSTNAME: bool = (
|
DECNET_VERIFY_HOSTNAME: bool = (
|
||||||
os.environ.get("DECNET_VERIFY_HOSTNAME", "false").lower() == "true"
|
os.environ.get("DECNET_VERIFY_HOSTNAME", "true").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Webhook egress SSRF guard. By default DECNET refuses to deliver a webhook
|
||||||
|
# to a private (RFC1918), loopback, link-local (incl. 169.254.169.254 cloud
|
||||||
|
# metadata), unspecified, reserved, or multicast destination, and rejects
|
||||||
|
# such URLs at registration time. Operators who genuinely need to target an
|
||||||
|
# internal receiver (e.g. an on-box SIEM) opt IN explicitly by setting
|
||||||
|
# DECNET_WEBHOOK_ALLOW_PRIVATE=true. Fails closed: anything other than the
|
||||||
|
# literal "true" leaves the guard fully enabled.
|
||||||
|
DECNET_WEBHOOK_ALLOW_PRIVATE: bool = (
|
||||||
|
os.environ.get("DECNET_WEBHOOK_ALLOW_PRIVATE", "false").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -498,6 +498,7 @@ async def _run_smtp_probe_listener(
|
|||||||
probe_limit times — if not, forward via the master's real internet
|
probe_limit times — if not, forward via the master's real internet
|
||||||
connection and store a probe_relay bounty with the result.
|
connection and store a probe_relay bounty with the result.
|
||||||
"""
|
"""
|
||||||
|
bus = None
|
||||||
try:
|
try:
|
||||||
bus = get_bus(client_name="orchestrator-probe")
|
bus = get_bus(client_name="orchestrator-probe")
|
||||||
await bus.connect()
|
await bus.connect()
|
||||||
@@ -515,8 +516,9 @@ async def _run_smtp_probe_listener(
|
|||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("smtp probe listener: bus unavailable: %s", exc)
|
logger.warning("smtp probe listener: bus unavailable: %s", exc)
|
||||||
finally:
|
finally:
|
||||||
with contextlib.suppress(Exception):
|
if bus is not None:
|
||||||
await bus.close()
|
with contextlib.suppress(Exception):
|
||||||
|
await bus.close()
|
||||||
|
|
||||||
|
|
||||||
async def _handle_probe_pending(repo: BaseRepository, payload: dict) -> None:
|
async def _handle_probe_pending(repo: BaseRepository, payload: dict) -> None:
|
||||||
|
|||||||
@@ -112,11 +112,11 @@ class AgentClient:
|
|||||||
"""Either pass a SwarmHost dict, or explicit address/port.
|
"""Either pass a SwarmHost dict, or explicit address/port.
|
||||||
|
|
||||||
``verify_hostname`` defers to ``DECNET_VERIFY_HOSTNAME`` when the
|
``verify_hostname`` defers to ``DECNET_VERIFY_HOSTNAME`` when the
|
||||||
caller doesn't pass an explicit value — production deploys flip
|
caller doesn't pass an explicit value — the worker's cert SAN must
|
||||||
the env var on so the worker's cert SAN must match the address
|
match the address the master connects to, on top of the existing CA
|
||||||
the master connects to, on top of the existing CA + fingerprint
|
+ fingerprint pin. Defaults to True; operators opt out explicitly
|
||||||
pin. Defaults to False so dev/test enrollments with mismatched
|
via ``DECNET_VERIFY_HOSTNAME=false`` for dev/test enrollments with
|
||||||
SANs keep working unchanged.
|
mismatched SANs.
|
||||||
"""
|
"""
|
||||||
if verify_hostname is None:
|
if verify_hostname is None:
|
||||||
from decnet.env import DECNET_VERIFY_HOSTNAME
|
from decnet.env import DECNET_VERIFY_HOSTNAME
|
||||||
@@ -155,9 +155,10 @@ class AgentClient:
|
|||||||
)
|
)
|
||||||
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
# Pin by CA + cert chain, not by DNS — workers enroll with arbitrary
|
# Pin by CA + cert chain; hostname verification is on by default
|
||||||
# SANs (IPs, hostnames) and we don't want to force operators to keep
|
# (DECNET_VERIFY_HOSTNAME=true) so the cert SAN must match the
|
||||||
# those in sync with whatever URL the master happens to use.
|
# master's connect address. Operators set the env var to false only
|
||||||
|
# for dev/test enrollments with mismatched SANs.
|
||||||
ctx.check_hostname = self._verify_hostname
|
ctx.check_hostname = self._verify_hostname
|
||||||
return httpx.AsyncClient(
|
return httpx.AsyncClient(
|
||||||
base_url=f"https://{self._address}:{self._port}",
|
base_url=f"https://{self._address}:{self._port}",
|
||||||
|
|||||||
@@ -13,14 +13,20 @@ the connection on purpose (the updater re-execs itself mid-response).
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from decnet.logging import get_logger
|
from decnet.logging import get_logger
|
||||||
from decnet.swarm.client import MasterIdentity, ensure_master_identity
|
from decnet.swarm.client import (
|
||||||
|
FingerprintMismatchError,
|
||||||
|
MasterIdentity,
|
||||||
|
ensure_master_identity,
|
||||||
|
)
|
||||||
|
|
||||||
log = get_logger("swarm.updater_client")
|
log = get_logger("swarm.updater_client")
|
||||||
|
|
||||||
@@ -47,11 +53,19 @@ class UpdaterClient:
|
|||||||
if host is not None:
|
if host is not None:
|
||||||
self._address = host["address"]
|
self._address = host["address"]
|
||||||
self._host_name = host.get("name")
|
self._host_name = host.get("name")
|
||||||
|
# SHA-256 of the worker's UPDATER leaf cert, recorded at enroll
|
||||||
|
# time (api_enroll_host.py writes ``updater_cert_fingerprint``).
|
||||||
|
# This is a distinct identity from the agent cert AgentClient
|
||||||
|
# pins — the updater channel pip-installs code as root, so it
|
||||||
|
# gets its own pin against its own cert.
|
||||||
|
fp = host.get("updater_cert_fingerprint")
|
||||||
|
self._expected_fingerprint = fp.lower() if isinstance(fp, str) else None
|
||||||
else:
|
else:
|
||||||
if address is None:
|
if address is None:
|
||||||
raise ValueError("UpdaterClient requires host dict or address")
|
raise ValueError("UpdaterClient requires host dict or address")
|
||||||
self._address = address
|
self._address = address
|
||||||
self._host_name = None
|
self._host_name = None
|
||||||
|
self._expected_fingerprint = None
|
||||||
self._port = updater_port
|
self._port = updater_port
|
||||||
self._identity = identity or ensure_master_identity()
|
self._identity = identity or ensure_master_identity()
|
||||||
self._client: Optional[httpx.AsyncClient] = None
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
@@ -70,8 +84,64 @@ class UpdaterClient:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _fetch_peer_fingerprint(self) -> str:
|
||||||
|
"""Open a throwaway TLS connection to the updater port and return the
|
||||||
|
SHA-256 hex of the leaf cert it presents. Mirrors
|
||||||
|
``AgentClient._fetch_peer_fingerprint`` exactly."""
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
ctx.load_cert_chain(
|
||||||
|
str(self._identity.cert_path), str(self._identity.key_path),
|
||||||
|
)
|
||||||
|
ctx.load_verify_locations(cafile=str(self._identity.ca_cert_path))
|
||||||
|
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
ctx.check_hostname = self._verify_hostname
|
||||||
|
sock = socket.create_connection((self._address, self._port), timeout=10.0)
|
||||||
|
try:
|
||||||
|
server_hostname = self._address if self._verify_hostname else None
|
||||||
|
with ctx.wrap_socket(sock, server_hostname=server_hostname) as ssock:
|
||||||
|
der = ssock.getpeercert(binary_form=True)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
sock.close()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
if not der:
|
||||||
|
raise FingerprintMismatchError(
|
||||||
|
f"{self._address}:{self._port}", self._expected_fingerprint or "", ""
|
||||||
|
)
|
||||||
|
return hashlib.sha256(der).hexdigest().lower()
|
||||||
|
|
||||||
|
async def _verify_pin(self) -> None:
|
||||||
|
"""Fail closed unless the updater leaf cert SHA-256 matches the pin.
|
||||||
|
|
||||||
|
Unlike ``AgentClient`` (which falls through to CA-only when no pin is
|
||||||
|
recorded), the updater channel pip-installs code as root — so a host
|
||||||
|
with NO recorded ``updater_cert_fingerprint`` is rejected outright
|
||||||
|
rather than accepted on CA validity alone. A missing pin means the
|
||||||
|
host was never enrolled with an updater identity; we refuse to drive
|
||||||
|
code into it."""
|
||||||
|
if not self._expected_fingerprint:
|
||||||
|
raise FingerprintMismatchError(
|
||||||
|
f"{self._address}:{self._port}",
|
||||||
|
"<no updater_cert_fingerprint recorded for host>",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
actual = await asyncio.to_thread(self._fetch_peer_fingerprint)
|
||||||
|
if actual != self._expected_fingerprint:
|
||||||
|
raise FingerprintMismatchError(
|
||||||
|
f"{self._address}:{self._port}",
|
||||||
|
self._expected_fingerprint,
|
||||||
|
actual,
|
||||||
|
)
|
||||||
|
|
||||||
async def __aenter__(self) -> "UpdaterClient":
|
async def __aenter__(self) -> "UpdaterClient":
|
||||||
self._client = self._build_client(_TIMEOUT_CONTROL)
|
self._client = self._build_client(_TIMEOUT_CONTROL)
|
||||||
|
try:
|
||||||
|
await self._verify_pin()
|
||||||
|
except BaseException:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
raise
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *exc: Any) -> None:
|
async def __aexit__(self, *exc: Any) -> None:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def apply_ceiling(base: float, state: "RuleState") -> float:
|
|||||||
"""Apply the operator's confidence ceiling, downward only.
|
"""Apply the operator's confidence ceiling, downward only.
|
||||||
|
|
||||||
A ``clipped`` state with ``confidence_max < 1.0`` clamps the emitted
|
A ``clipped`` state with ``confidence_max < 1.0`` clamps the emitted
|
||||||
confidence to ``min(base, base * ceiling)``. Any other state is a
|
confidence to ``min(base, ceiling)``. Any other state is a
|
||||||
no-op. The clamp is downward by construction — operator clips can
|
no-op. The clamp is downward by construction — operator clips can
|
||||||
never raise a rule's confidence above its YAML-declared base, per
|
never raise a rule's confidence above its YAML-declared base, per
|
||||||
TTP_TAGGING.md §"Confidence model".
|
TTP_TAGGING.md §"Confidence model".
|
||||||
@@ -50,7 +50,7 @@ def apply_ceiling(base: float, state: "RuleState") -> float:
|
|||||||
ceiling = state.confidence_max
|
ceiling = state.confidence_max
|
||||||
if ceiling is None or ceiling >= 1.0:
|
if ceiling is None or ceiling >= 1.0:
|
||||||
return base
|
return base
|
||||||
return min(base, base * ceiling)
|
return min(base, ceiling)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["is_active", "apply_ceiling"]
|
__all__ = ["is_active", "apply_ceiling"]
|
||||||
|
|||||||
@@ -11,6 +11,14 @@ SECRET_KEY: str = DECNET_JWT_SECRET
|
|||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = DECNET_JWT_EXP_MINUTES
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = DECNET_JWT_EXP_MINUTES
|
||||||
|
|
||||||
|
# Pinned issuer/audience/type so a token signed with DECNET_JWT_SECRET for any
|
||||||
|
# OTHER purpose (or by a future co-tenant of the secret) is not accepted by the
|
||||||
|
# dashboard verifier. Issuance stamps these; _decode_payload requires + verifies
|
||||||
|
# them. Keep these two modules in lockstep — they are a single trust contract.
|
||||||
|
JWT_ISSUER: str = "decnet"
|
||||||
|
JWT_AUDIENCE: str = "decnet-dashboard"
|
||||||
|
JWT_TYPE: str = "access"
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
@@ -45,5 +53,10 @@ def create_access_token(data: dict[str, Any], expires_delta: Optional[timedelta]
|
|||||||
|
|
||||||
_to_encode.update({"exp": _expire})
|
_to_encode.update({"exp": _expire})
|
||||||
_to_encode.update({"iat": datetime.now(timezone.utc)})
|
_to_encode.update({"iat": datetime.now(timezone.utc)})
|
||||||
|
# Pin issuer / audience / token-type so the verifier can reject tokens
|
||||||
|
# minted for any other purpose with the same shared secret.
|
||||||
|
_to_encode.setdefault("iss", JWT_ISSUER)
|
||||||
|
_to_encode.setdefault("aud", JWT_AUDIENCE)
|
||||||
|
_to_encode.setdefault("typ", JWT_TYPE)
|
||||||
_encoded_jwt: str = jwt.encode(_to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
_encoded_jwt: str = jwt.encode(_to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return _encoded_jwt
|
return _encoded_jwt
|
||||||
|
|||||||
@@ -50,7 +50,18 @@ class LoginRequest(BaseModel):
|
|||||||
|
|
||||||
class ChangePasswordRequest(BaseModel):
|
class ChangePasswordRequest(BaseModel):
|
||||||
old_password: str = PydanticField(..., max_length=72)
|
old_password: str = PydanticField(..., max_length=72)
|
||||||
new_password: str = PydanticField(..., max_length=72)
|
# min_length=12 aligns with the DECNET_ADMIN_PASSWORD >=12 policy. The
|
||||||
|
# forced first-login flow routes through /auth/change-password, so without a
|
||||||
|
# floor a seeded admin could clear must_change_password with a 1-char secret.
|
||||||
|
new_password: str = PydanticField(..., min_length=12, max_length=72)
|
||||||
|
|
||||||
|
|
||||||
|
class SSETicketResponse(BaseModel):
|
||||||
|
"""Single-use, short-lived opaque ticket the dashboard exchanges its header
|
||||||
|
JWT for, then passes to an SSE endpoint as ?ticket= (EventSource cannot set
|
||||||
|
an Authorization header). See decnet.web.dependencies SSE ticket store."""
|
||||||
|
ticket: str
|
||||||
|
expires_in: int
|
||||||
|
|
||||||
|
|
||||||
# --- Configuration Models ---
|
# --- Configuration Models ---
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import secrets
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -8,7 +9,13 @@ import jwt
|
|||||||
from fastapi import HTTPException, status, Request
|
from fastapi import HTTPException, status, Request
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
from decnet.web.auth import ALGORITHM, SECRET_KEY
|
from decnet.web.auth import (
|
||||||
|
ALGORITHM,
|
||||||
|
JWT_AUDIENCE,
|
||||||
|
JWT_ISSUER,
|
||||||
|
JWT_TYPE,
|
||||||
|
SECRET_KEY,
|
||||||
|
)
|
||||||
from decnet.web.db.repository import BaseRepository
|
from decnet.web.db.repository import BaseRepository
|
||||||
from decnet.web.db.factory import get_repository
|
from decnet.web.db.factory import get_repository
|
||||||
|
|
||||||
@@ -168,13 +175,30 @@ def _epoch(value: Any) -> float:
|
|||||||
|
|
||||||
|
|
||||||
def _decode_payload(token: str) -> dict[str, Any]:
|
def _decode_payload(token: str) -> dict[str, Any]:
|
||||||
"""Decode + signature/expiry-verify a raw JWT, or raise 401."""
|
"""Decode + signature/expiry-verify a raw JWT, or raise 401.
|
||||||
|
|
||||||
|
Beyond signature + expiry, this pins the issuer and audience and requires
|
||||||
|
the registered claims to be present, so a token minted with the same shared
|
||||||
|
secret for a different purpose (or omitting exp/iat/iss/aud) is rejected.
|
||||||
|
``uuid`` (not ``sub``) is this app's identity claim, so it is in ``require``.
|
||||||
|
``typ`` is a custom payload claim PyJWT does not validate natively, so it is
|
||||||
|
checked explicitly below.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload: dict[str, Any] = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload: dict[str, Any] = jwt.decode(
|
||||||
|
token,
|
||||||
|
SECRET_KEY,
|
||||||
|
algorithms=[ALGORITHM],
|
||||||
|
audience=JWT_AUDIENCE,
|
||||||
|
issuer=JWT_ISSUER,
|
||||||
|
options={"require": ["exp", "iat", "iss", "aud", "uuid"]},
|
||||||
|
)
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
raise _CREDENTIALS_EXCEPTION
|
raise _CREDENTIALS_EXCEPTION
|
||||||
if payload.get("uuid") is None:
|
if payload.get("uuid") is None:
|
||||||
raise _CREDENTIALS_EXCEPTION
|
raise _CREDENTIALS_EXCEPTION
|
||||||
|
if payload.get("typ") != JWT_TYPE:
|
||||||
|
raise _CREDENTIALS_EXCEPTION
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
@@ -236,17 +260,70 @@ async def get_token_claims(request: Request) -> dict[str, Any]:
|
|||||||
return _decode_payload(token)
|
return _decode_payload(token)
|
||||||
|
|
||||||
|
|
||||||
async def get_stream_user(request: Request, token: Optional[str] = None) -> str:
|
# ---------------------------------------------------------------------------
|
||||||
"""Auth dependency for SSE endpoints — accepts Bearer header OR ?token= query param.
|
# SSE stream tickets (V3.1.1)
|
||||||
EventSource does not support custom headers, so the query-string fallback is intentional here only.
|
# ---------------------------------------------------------------------------
|
||||||
|
# EventSource cannot set an Authorization header, so SSE auth historically rode
|
||||||
|
# in ?token=<JWT>, leaking the full-lifetime bearer into access/proxy logs,
|
||||||
|
# browser history, and Referer. Instead the client exchanges its header JWT for
|
||||||
|
# a single-use, short-lived OPAQUE ticket via POST /api/v1/auth/sse-ticket and
|
||||||
|
# connects with ?ticket=<opaque>. The JWT never appears in any URL.
|
||||||
|
#
|
||||||
|
# Security-boundary store — FAIL CLOSED. The map is keyed on the opaque ticket
|
||||||
|
# and holds (expiry_monotonic, bound_identity). Redemption validates presence +
|
||||||
|
# freshness, then DELETES the entry (single-use). Unknown / expired / reused
|
||||||
|
# tickets all resolve to 401.
|
||||||
|
#
|
||||||
|
# This is a MODULE-LEVEL dict: tickets live only in the process that minted
|
||||||
|
# them. A multi-process / multi-worker deployment needs a SHARED store (Redis,
|
||||||
|
# DB) so a ticket minted on worker A can be redeemed on worker B — out of scope
|
||||||
|
# here, deliberately. No background sweeper daemon (project rule: library, not
|
||||||
|
# new worker); expiry is enforced opportunistically on every redeem + mint.
|
||||||
|
_SSE_TICKET_TTL = 60.0 # seconds
|
||||||
|
_sse_tickets: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_sse_tickets() -> None:
|
||||||
|
"""Test hook: drop all outstanding stream tickets."""
|
||||||
|
_sse_tickets.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _sweep_sse_tickets(now: Optional[float] = None) -> None:
|
||||||
|
"""Opportunistic eviction of expired tickets. O(n) over a tiny map (tickets
|
||||||
|
are single-use and 60s-lived), called on every mint/redeem — no daemon."""
|
||||||
|
_now = time.monotonic() if now is None else now
|
||||||
|
expired = [t for t, (exp, _) in _sse_tickets.items() if exp <= _now]
|
||||||
|
for t in expired:
|
||||||
|
_sse_tickets.pop(t, None)
|
||||||
|
|
||||||
|
|
||||||
|
def mint_sse_ticket(user_uuid: str, role: str) -> str:
|
||||||
|
"""Mint a single-use, 60s opaque SSE ticket bound to ``user_uuid``+``role``.
|
||||||
|
|
||||||
|
Called by POST /auth/sse-ticket AFTER the header JWT has been validated, so
|
||||||
|
the bound identity is already trusted. Returns the opaque token the client
|
||||||
|
passes as ?ticket=. Sweeps expired entries on the way in.
|
||||||
"""
|
"""
|
||||||
resolved = _bearer_from_header(request) or token
|
_sweep_sse_tickets()
|
||||||
if not resolved:
|
ticket = secrets.token_urlsafe(32)
|
||||||
|
expiry = time.monotonic() + _SSE_TICKET_TTL
|
||||||
|
_sse_tickets[ticket] = (expiry, {"uuid": user_uuid, "role": role})
|
||||||
|
return ticket
|
||||||
|
|
||||||
|
|
||||||
|
def _redeem_sse_ticket(ticket: str) -> dict[str, Any]:
|
||||||
|
"""Redeem a stream ticket: validate exists + unexpired, then DELETE it
|
||||||
|
(single-use). Returns the bound ``{"uuid","role"}`` identity or raises 401.
|
||||||
|
Fail closed: unknown / expired / already-redeemed all raise."""
|
||||||
|
now = time.monotonic()
|
||||||
|
_sweep_sse_tickets(now)
|
||||||
|
entry = _sse_tickets.pop(ticket, None) # pop = single-use, even on expiry
|
||||||
|
if entry is None:
|
||||||
raise _CREDENTIALS_EXCEPTION
|
raise _CREDENTIALS_EXCEPTION
|
||||||
# Decode-only: returns the uuid. Revocation/role enforcement happens in
|
expiry, identity = entry
|
||||||
# require_stream_role (the sole production caller), which runs the full
|
if expiry <= now:
|
||||||
# _resolve_token path. Kept thin so its decode contract stays unit-testable.
|
raise _CREDENTIALS_EXCEPTION
|
||||||
return _decode_payload(resolved)["uuid"]
|
return identity
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> str:
|
async def get_current_user(request: Request) -> str:
|
||||||
@@ -298,18 +375,35 @@ def require_role(*allowed_roles: str):
|
|||||||
|
|
||||||
|
|
||||||
def require_stream_role(*allowed_roles: str):
|
def require_stream_role(*allowed_roles: str):
|
||||||
"""Like ``require_role`` but for SSE endpoints that accept a query-param token."""
|
"""Like ``require_role`` but for SSE endpoints.
|
||||||
async def _check(request: Request, token: Optional[str] = None) -> dict:
|
|
||||||
resolved = _bearer_from_header(request) or token
|
Two ingress paths:
|
||||||
if not resolved:
|
* Bearer header → full ``_resolve_token`` (revocation + cutoff enforced).
|
||||||
|
* ?ticket=<opaque> → single-use stream ticket minted by /auth/sse-ticket,
|
||||||
|
which already validated the header JWT and bound the uuid+role. The
|
||||||
|
ticket carries no jti, so the per-token denylist cannot apply here; the
|
||||||
|
60s single-use lifetime is the bounded exposure we accept for SSE.
|
||||||
|
|
||||||
|
Raw ?token=<JWT> is intentionally NOT accepted (V3.1.1)."""
|
||||||
|
async def _check(request: Request, ticket: Optional[str] = None) -> dict:
|
||||||
|
header_token = _bearer_from_header(request)
|
||||||
|
if header_token:
|
||||||
|
_user_uuid, user = await _resolve_token(header_token)
|
||||||
|
if user["role"] not in allowed_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Insufficient permissions",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
if not ticket:
|
||||||
raise _CREDENTIALS_EXCEPTION
|
raise _CREDENTIALS_EXCEPTION
|
||||||
_user_uuid, user = await _resolve_token(resolved)
|
identity = _redeem_sse_ticket(ticket)
|
||||||
if user["role"] not in allowed_roles:
|
if identity["role"] not in allowed_roles:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Insufficient permissions",
|
detail="Insufficient permissions",
|
||||||
)
|
)
|
||||||
return user
|
return identity
|
||||||
return _check
|
return _check
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter
|
|||||||
from .auth.api_login import router as login_router
|
from .auth.api_login import router as login_router
|
||||||
from .auth.api_change_pass import router as change_pass_router
|
from .auth.api_change_pass import router as change_pass_router
|
||||||
from .auth.api_logout import router as logout_router
|
from .auth.api_logout import router as logout_router
|
||||||
|
from .auth.api_sse_ticket import router as sse_ticket_router
|
||||||
from .logs.api_get_logs import router as logs_router
|
from .logs.api_get_logs import router as logs_router
|
||||||
from .logs.api_get_histogram import router as histogram_router
|
from .logs.api_get_histogram import router as histogram_router
|
||||||
from .bounty.api_get_bounties import router as bounty_router
|
from .bounty.api_get_bounties import router as bounty_router
|
||||||
@@ -75,9 +76,12 @@ from .ttp.api_export_navigator import router as ttp_navigator_router
|
|||||||
from .ttp.api_get_groups_for_technique import router as ttp_groups_for_technique_router
|
from .ttp.api_get_groups_for_technique import router as ttp_groups_for_technique_router
|
||||||
|
|
||||||
api_router = APIRouter(
|
api_router = APIRouter(
|
||||||
# Every route under /api/v1 is auth-guarded (either by an explicit
|
# Auth is enforced PER ROUTE via explicit ``require_*`` Depends (see
|
||||||
# require_* Depends or by the global auth middleware). Document 401/403
|
# decnet.web.dependencies) — there is NO global auth middleware. A route
|
||||||
# here so the OpenAPI schema reflects reality for contract tests.
|
# without a require_* dependency is unauthenticated BY DESIGN; the only such
|
||||||
|
# routes are /health (liveness) and /auth/login (credential exchange).
|
||||||
|
# The 401/403 entries below are documented here so the OpenAPI schema
|
||||||
|
# reflects reality for contract tests, not because a middleware applies them.
|
||||||
responses={
|
responses={
|
||||||
400: {"description": "Malformed request body"},
|
400: {"description": "Malformed request body"},
|
||||||
401: {"description": "Missing or invalid credentials"},
|
401: {"description": "Missing or invalid credentials"},
|
||||||
@@ -91,6 +95,7 @@ api_router = APIRouter(
|
|||||||
api_router.include_router(login_router)
|
api_router.include_router(login_router)
|
||||||
api_router.include_router(change_pass_router)
|
api_router.include_router(change_pass_router)
|
||||||
api_router.include_router(logout_router)
|
api_router.include_router(logout_router)
|
||||||
|
api_router.include_router(sse_ticket_router)
|
||||||
|
|
||||||
# Logs & Analytics
|
# Logs & Analytics
|
||||||
api_router.include_router(logs_router)
|
api_router.include_router(logs_router)
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ stream's attacker. Emits a one-shot snapshot on connect (latest
|
|||||||
observation per primitive) so the panel hydrates immediately.
|
observation per primitive) so the panel hydrates immediately.
|
||||||
|
|
||||||
Authorization mirrors :mod:`decnet.web.router.topology.api_events` —
|
Authorization mirrors :mod:`decnet.web.router.topology.api_events` —
|
||||||
JWT via the ``?token=`` query parameter (EventSource can't set
|
a single-use opaque ticket via the ``?ticket=`` query parameter
|
||||||
arbitrary headers) + ``require_stream_viewer`` role gate. The 404
|
(EventSource can't set arbitrary headers) + ``require_stream_viewer``
|
||||||
|
role gate. The 404
|
||||||
fires after auth so an existence probe can't leak an attacker UUID
|
fires after auth so an existence probe can't leak an attacker UUID
|
||||||
to an unauthenticated caller.
|
to an unauthenticated caller.
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
from decnet.telemetry import traced as _traced
|
from decnet.telemetry import traced as _traced
|
||||||
from decnet.web.auth import ahash_password, averify_password
|
from decnet.web.auth import ahash_password, averify_password
|
||||||
from decnet.web.dependencies import get_current_user_unchecked, invalidate_user_cache, repo
|
from decnet.web.dependencies import get_current_user_unchecked, invalidate_user_cache, repo
|
||||||
from decnet.web.db.models import ChangePasswordRequest, MessageResponse
|
from decnet.web.db.models import ChangePasswordRequest, MessageResponse
|
||||||
|
from decnet.web.limiter import limiter
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -19,19 +20,21 @@ router = APIRouter()
|
|||||||
responses={
|
responses={
|
||||||
400: {"description": "Bad Request (e.g. malformed JSON)"},
|
400: {"description": "Bad Request (e.g. malformed JSON)"},
|
||||||
401: {"description": "Could not validate credentials"},
|
401: {"description": "Could not validate credentials"},
|
||||||
422: {"description": "Validation error"}
|
422: {"description": "Validation error"},
|
||||||
|
429: {"description": "Too many password-change attempts — retry after the window resets"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@limiter.limit("5/minute")
|
||||||
@_traced("api.change_password")
|
@_traced("api.change_password")
|
||||||
async def change_password(request: ChangePasswordRequest, current_user: str = Depends(get_current_user_unchecked)) -> dict[str, str]:
|
async def change_password(request: Request, body: ChangePasswordRequest, current_user: str = Depends(get_current_user_unchecked)) -> dict[str, str]:
|
||||||
_user: Optional[dict[str, Any]] = await repo.get_user_by_uuid(current_user)
|
_user: Optional[dict[str, Any]] = await repo.get_user_by_uuid(current_user)
|
||||||
if not _user or not await averify_password(request.old_password, _user["password_hash"]):
|
if not _user or not await averify_password(body.old_password, _user["password_hash"]):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Incorrect old password",
|
detail="Incorrect old password",
|
||||||
)
|
)
|
||||||
|
|
||||||
_new_hash: str = await ahash_password(request.new_password)
|
_new_hash: str = await ahash_password(body.new_password)
|
||||||
await repo.update_user_password(current_user, _new_hash, must_change_password=False)
|
await repo.update_user_password(current_user, _new_hash, must_change_password=False)
|
||||||
# Changing a password revokes every existing session for this user (incl.
|
# Changing a password revokes every existing session for this user (incl.
|
||||||
# the current one): the caller's next request 401s and re-authenticates.
|
# the current one): the caller's next request 401s and re-authenticates.
|
||||||
|
|||||||
39
decnet/web/router/auth/api_sse_ticket.py
Normal file
39
decnet/web/router/auth/api_sse_ticket.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
"""Mint a single-use, short-lived SSE stream ticket (V3.1.1).
|
||||||
|
|
||||||
|
EventSource cannot send an Authorization header, so SSE auth used to ride in
|
||||||
|
``?token=<JWT>`` — leaking the full-lifetime bearer into access/proxy logs,
|
||||||
|
browser history, and Referer. This endpoint lets an already-authenticated
|
||||||
|
client (gated by the NORMAL header JWT via ``require_viewer``) exchange that
|
||||||
|
header credential for an opaque ``secrets.token_urlsafe(32)`` ticket, valid for
|
||||||
|
60s and single-use, which it then passes to the SSE endpoint as ``?ticket=``.
|
||||||
|
The JWT never appears in any URL.
|
||||||
|
|
||||||
|
The ticket store lives in-process (decnet.web.dependencies); multi-process
|
||||||
|
deployments need a shared store — out of scope, see that module's note.
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from decnet.telemetry import traced as _traced
|
||||||
|
from decnet.web.dependencies import mint_sse_ticket, require_viewer, _SSE_TICKET_TTL
|
||||||
|
from decnet.web.db.models.auth import SSETicketResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/auth/sse-ticket",
|
||||||
|
tags=["Authentication"],
|
||||||
|
response_model=SSETicketResponse,
|
||||||
|
responses={
|
||||||
|
400: {"description": "Malformed request body"},
|
||||||
|
401: {"description": "Missing or invalid credentials"},
|
||||||
|
403: {"description": "Authenticated but not authorized"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@_traced("api.sse_ticket")
|
||||||
|
async def mint_stream_ticket(user: dict = Depends(require_viewer)) -> SSETicketResponse:
|
||||||
|
"""Exchange the presented header JWT for a single-use 60s SSE ticket bound to
|
||||||
|
this user's uuid + role. Any authenticated (viewer or admin) user may mint."""
|
||||||
|
ticket = mint_sse_ticket(user["uuid"], user["role"])
|
||||||
|
return SSETicketResponse(ticket=ticket, expires_in=int(_SSE_TICKET_TTL))
|
||||||
@@ -6,8 +6,9 @@ request and forwards each matching event as a Server-Sent Event.
|
|||||||
Emits a one-shot snapshot on connect (current paginated campaign
|
Emits a one-shot snapshot on connect (current paginated campaign
|
||||||
list).
|
list).
|
||||||
|
|
||||||
Mirror of :mod:`decnet.web.router.identities.api_events`. Auth: JWT
|
Mirror of :mod:`decnet.web.router.identities.api_events`. Auth:
|
||||||
via ``?token=`` query param + ``require_stream_viewer`` role.
|
single-use opaque ticket via ``?ticket=`` query param +
|
||||||
|
``require_stream_viewer`` role.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ Server-Sent Event to the browser. Emits a one-shot snapshot on connect
|
|||||||
fetch to initialise.
|
fetch to initialise.
|
||||||
|
|
||||||
Authorization mirrors :mod:`decnet.web.router.topology.api_events` — a
|
Authorization mirrors :mod:`decnet.web.router.topology.api_events` — a
|
||||||
JWT passed via the ``?token=`` query parameter (EventSource can't set
|
single-use opaque ticket passed via the ``?ticket=`` query parameter
|
||||||
arbitrary headers) + ``require_stream_viewer`` role gate.
|
(EventSource can't set arbitrary headers) + ``require_stream_viewer``
|
||||||
|
role gate.
|
||||||
|
|
||||||
The endpoint is broadly scoped (every identity event, not per-uuid)
|
The endpoint is broadly scoped (every identity event, not per-uuid)
|
||||||
because both ``AttackerDetail`` and ``IdentityDetail`` need the same
|
because both ``AttackerDetail`` and ``IdentityDetail`` need the same
|
||||||
|
|||||||
@@ -12,12 +12,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
|
||||||
from decnet.logging import get_logger
|
from decnet.logging import get_logger
|
||||||
from decnet.swarm.client import AgentClient
|
from decnet.swarm.client import AgentClient
|
||||||
from decnet.web.db.repository import BaseRepository
|
from decnet.web.db.repository import BaseRepository
|
||||||
from decnet.web.dependencies import get_repo, require_admin
|
from decnet.web.dependencies import get_repo, require_admin
|
||||||
|
from decnet.web.limiter import limiter
|
||||||
from decnet.web.router.swarm._mtls import PeerCert, require_operator_cert
|
from decnet.web.router.swarm._mtls import PeerCert, require_operator_cert
|
||||||
|
|
||||||
log = get_logger("swarm.decommission")
|
log = get_logger("swarm.decommission")
|
||||||
@@ -32,10 +33,13 @@ router = APIRouter()
|
|||||||
401: {"description": "Missing or invalid admin JWT"},
|
401: {"description": "Missing or invalid admin JWT"},
|
||||||
403: {"description": "Authenticated user is not an admin, or operator cert missing"},
|
403: {"description": "Authenticated user is not an admin, or operator cert missing"},
|
||||||
404: {"description": "No host with this UUID is enrolled"},
|
404: {"description": "No host with this UUID is enrolled"},
|
||||||
|
429: {"description": "Too many decommission requests — retry after the window resets"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@limiter.limit("20/minute")
|
||||||
async def api_decommission_host(
|
async def api_decommission_host(
|
||||||
uuid: str,
|
uuid: str,
|
||||||
|
request: Request,
|
||||||
repo: BaseRepository = Depends(get_repo),
|
repo: BaseRepository = Depends(get_repo),
|
||||||
_admin: dict = Depends(require_admin),
|
_admin: dict = Depends(require_admin),
|
||||||
_operator: PeerCert = Depends(require_operator_cert),
|
_operator: PeerCert = Depends(require_operator_cert),
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from decnet.swarm.bundle_builder import build_tarball, render_bootstrap
|
|||||||
from decnet.web.db.models.swarm import EnrollBundleRequest, EnrollBundleResponse
|
from decnet.web.db.models.swarm import EnrollBundleRequest, EnrollBundleResponse
|
||||||
from decnet.web.db.repository import BaseRepository
|
from decnet.web.db.repository import BaseRepository
|
||||||
from decnet.web.dependencies import get_repo, require_admin
|
from decnet.web.dependencies import get_repo, require_admin
|
||||||
|
from decnet.web.limiter import limiter
|
||||||
|
|
||||||
log = get_logger("swarm_mgmt.enroll_bundle")
|
log = get_logger("swarm_mgmt.enroll_bundle")
|
||||||
|
|
||||||
@@ -117,8 +118,10 @@ async def _lookup_live(token: str) -> _Bundle:
|
|||||||
403: {"description": "Insufficient permissions"},
|
403: {"description": "Insufficient permissions"},
|
||||||
409: {"description": "A worker with this name is already enrolled"},
|
409: {"description": "A worker with this name is already enrolled"},
|
||||||
422: {"description": "Request body validation error"},
|
422: {"description": "Request body validation error"},
|
||||||
|
429: {"description": "Too many enroll-bundle requests — retry after the window resets"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@limiter.limit("10/minute")
|
||||||
async def create_enroll_bundle(
|
async def create_enroll_bundle(
|
||||||
req: EnrollBundleRequest,
|
req: EnrollBundleRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -251,6 +254,14 @@ async def get_payload(
|
|||||||
# The agent's first connect-back — its source IP is the reachable address
|
# The agent's first connect-back — its source IP is the reachable address
|
||||||
# the master will later use to probe it. Backfill the SwarmHost row here
|
# the master will later use to probe it. Backfill the SwarmHost row here
|
||||||
# so the operator sees the real address instead of an empty placeholder.
|
# so the operator sees the real address instead of an empty placeholder.
|
||||||
|
#
|
||||||
|
# PROXY TRUST WARNING: `request.client.host` is the TCP peer's IP.
|
||||||
|
# If this endpoint sits behind a TCP-terminating reverse proxy (nginx,
|
||||||
|
# HAProxy, etc.) the recorded address will be the proxy's IP, not the
|
||||||
|
# agent's. Either bind the API directly on the network reachable by
|
||||||
|
# agents, or configure the proxy to preserve the original source IP
|
||||||
|
# (e.g. PROXY Protocol on a loopback listener, *not* X-Forwarded-For
|
||||||
|
# which is trivially spoofable). See THREAT_MODEL.md §DA-08.
|
||||||
client_host = request.client.host if request.client else ""
|
client_host = request.client.host if request.client else ""
|
||||||
if client_host:
|
if client_host:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ a Server-Sent Event to the browser. Emits a one-shot snapshot on connect
|
|||||||
separate fetch to initialise the "pending" buffer.
|
separate fetch to initialise the "pending" buffer.
|
||||||
|
|
||||||
Authorization matches :mod:`decnet.web.router.stream.api_stream_events`
|
Authorization matches :mod:`decnet.web.router.stream.api_stream_events`
|
||||||
— a JWT passed via the ``?token=`` query parameter (EventSource can't
|
— a single-use opaque ticket passed via the ``?ticket=`` query
|
||||||
set arbitrary headers) + ``require_stream_viewer`` role gate. The
|
parameter (EventSource can't set arbitrary headers) +
|
||||||
|
``require_stream_viewer`` role gate. The
|
||||||
per-topology 404 is enforced after auth so existence probes can't leak
|
per-topology 404 is enforced after auth so existence probes can't leak
|
||||||
a topology id to an unauthenticated caller.
|
a topology id to an unauthenticated caller.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import secrets
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
|
||||||
from decnet.bus import topics as _topics
|
from decnet.bus import topics as _topics
|
||||||
from decnet.bus.app import get_app_bus
|
from decnet.bus.app import get_app_bus
|
||||||
@@ -22,13 +22,28 @@ from decnet.web.db.models import (
|
|||||||
)
|
)
|
||||||
from decnet.web.db.models.webhooks import _row_to_response_dict
|
from decnet.web.db.models.webhooks import _row_to_response_dict
|
||||||
from decnet.web.dependencies import repo, require_admin
|
from decnet.web.dependencies import repo, require_admin
|
||||||
|
from decnet.web.limiter import limiter
|
||||||
from decnet.webhook.enums import merge_patterns
|
from decnet.webhook.enums import merge_patterns
|
||||||
|
from decnet.webhook.ssrf import WebhookDestinationError, validate_webhook_url
|
||||||
|
|
||||||
log = get_logger("api.webhooks")
|
log = get_logger("api.webhooks")
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_url_or_422(url: str) -> None:
|
||||||
|
"""Reject a webhook URL that resolves to a forbidden destination.
|
||||||
|
|
||||||
|
Runs the same SSRF guard the delivery path enforces, but at
|
||||||
|
registration time so a bad URL is surfaced to the operator as a clear
|
||||||
|
422 instead of being silently dropped on every delivery attempt.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validate_webhook_url(url)
|
||||||
|
except WebhookDestinationError as e:
|
||||||
|
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
async def _notify_subscriptions_changed() -> None:
|
async def _notify_subscriptions_changed() -> None:
|
||||||
"""Publish `system.webhook.subscriptions_changed` on the bus.
|
"""Publish `system.webhook.subscriptions_changed` on the bus.
|
||||||
|
|
||||||
@@ -60,10 +75,14 @@ def _row_to_response(row: dict[str, Any]) -> WebhookResponse:
|
|||||||
responses={
|
responses={
|
||||||
400: {"description": "At least one of simple_events / topic_patterns required"},
|
400: {"description": "At least one of simple_events / topic_patterns required"},
|
||||||
409: {"description": "Name already in use"},
|
409: {"description": "Name already in use"},
|
||||||
|
422: {"description": "URL resolves to a forbidden (internal) destination"},
|
||||||
|
429: {"description": "Too many webhook-create requests — retry after the window resets"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@limiter.limit("20/minute")
|
||||||
@_traced("api.webhook.create")
|
@_traced("api.webhook.create")
|
||||||
async def api_create_webhook(
|
async def api_create_webhook(
|
||||||
|
request: Request,
|
||||||
req: WebhookCreateRequest,
|
req: WebhookCreateRequest,
|
||||||
admin: dict = Depends(require_admin),
|
admin: dict = Depends(require_admin),
|
||||||
) -> WebhookCreateResponse:
|
) -> WebhookCreateResponse:
|
||||||
@@ -78,6 +97,8 @@ async def api_create_webhook(
|
|||||||
if existing:
|
if existing:
|
||||||
raise HTTPException(status_code=409, detail="Webhook name already exists")
|
raise HTTPException(status_code=409, detail="Webhook name already exists")
|
||||||
|
|
||||||
|
_validate_url_or_422(str(req.url))
|
||||||
|
|
||||||
# Auto-generate a URL-safe secret if the caller didn't provide one.
|
# Auto-generate a URL-safe secret if the caller didn't provide one.
|
||||||
# 32 bytes of os-entropy is the same ballpark as a CSRF token.
|
# 32 bytes of os-entropy is the same ballpark as a CSRF token.
|
||||||
secret = req.secret or secrets.token_urlsafe(32)
|
secret = req.secret or secrets.token_urlsafe(32)
|
||||||
@@ -146,6 +167,7 @@ async def api_get_webhook(
|
|||||||
400: {"description": "Empty or invalid patch"},
|
400: {"description": "Empty or invalid patch"},
|
||||||
404: {"description": "Webhook not found"},
|
404: {"description": "Webhook not found"},
|
||||||
409: {"description": "Name already in use"},
|
409: {"description": "Name already in use"},
|
||||||
|
422: {"description": "URL resolves to a forbidden (internal) destination"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@_traced("api.webhook.update")
|
@_traced("api.webhook.update")
|
||||||
@@ -167,6 +189,7 @@ async def api_update_webhook(
|
|||||||
patch["name"] = req.name
|
patch["name"] = req.name
|
||||||
|
|
||||||
if req.url is not None:
|
if req.url is not None:
|
||||||
|
_validate_url_or_422(str(req.url))
|
||||||
patch["url"] = str(req.url)
|
patch["url"] = str(req.url)
|
||||||
|
|
||||||
if req.secret is not None:
|
if req.secret is not None:
|
||||||
|
|||||||
@@ -22,6 +22,11 @@ import httpx
|
|||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from decnet.logging import get_logger
|
from decnet.logging import get_logger
|
||||||
|
from decnet.webhook.ssrf import (
|
||||||
|
ValidatedDestination,
|
||||||
|
WebhookDestinationError,
|
||||||
|
validate_webhook_url,
|
||||||
|
)
|
||||||
|
|
||||||
log = get_logger("webhook.client")
|
log = get_logger("webhook.client")
|
||||||
|
|
||||||
@@ -121,6 +126,51 @@ def _jittered(delay: float) -> float:
|
|||||||
return delay * random.uniform(_JITTER_LOW, _JITTER_HIGH) # nosec B311
|
return delay * random.uniform(_JITTER_LOW, _JITTER_HIGH) # nosec B311
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pinned_request(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
dest: ValidatedDestination,
|
||||||
|
body: bytes,
|
||||||
|
headers: dict[str, str],
|
||||||
|
) -> httpx.Request:
|
||||||
|
"""Build a POST request pinned to a validated IP.
|
||||||
|
|
||||||
|
Defeats DNS rebinding: instead of letting httpx re-resolve the hostname
|
||||||
|
at connect time (which an attacker-controlled DNS could flip to an
|
||||||
|
internal IP after our check passed), we point the connection at one of
|
||||||
|
the IPs we already validated, while preserving the original ``Host``
|
||||||
|
header and TLS SNI so the receiver and certificate validation still see
|
||||||
|
the real hostname.
|
||||||
|
"""
|
||||||
|
pinned_ip = dest.ip_addresses[0]
|
||||||
|
# httpx brackets IPv6 hosts itself — pass the bare IP.
|
||||||
|
pinned_url = httpx.URL(url).copy_with(host=pinned_ip)
|
||||||
|
|
||||||
|
req_headers = dict(headers)
|
||||||
|
# Preserve virtual-host routing on the receiver.
|
||||||
|
req_headers.setdefault("Host", _host_header(dest.host, dest.port, dest.scheme))
|
||||||
|
|
||||||
|
# Keep TLS SNI + cert hostname validation bound to the real host, not
|
||||||
|
# the bare IP we connect to.
|
||||||
|
extensions = {"sni_hostname": dest.host} if dest.scheme == "https" else {}
|
||||||
|
|
||||||
|
return client.build_request(
|
||||||
|
"POST",
|
||||||
|
pinned_url,
|
||||||
|
content=body,
|
||||||
|
headers=req_headers,
|
||||||
|
extensions=extensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _host_header(host: str, port: int, scheme: str) -> str:
|
||||||
|
default_port = 443 if scheme == "https" else 80
|
||||||
|
host_part = f"[{host}]" if ":" in host else host
|
||||||
|
if port == default_port:
|
||||||
|
return host_part
|
||||||
|
return f"{host_part}:{port}"
|
||||||
|
|
||||||
|
|
||||||
async def deliver(
|
async def deliver(
|
||||||
sub: dict[str, Any],
|
sub: dict[str, Any],
|
||||||
event: Any,
|
event: Any,
|
||||||
@@ -148,6 +198,15 @@ async def deliver(
|
|||||||
headers = _build_headers(sub["secret"], body, topic, eid)
|
headers = _build_headers(sub["secret"], body, topic, eid)
|
||||||
url = sub["url"]
|
url = sub["url"]
|
||||||
|
|
||||||
|
# SSRF guard: resolve + validate the destination before any connect.
|
||||||
|
# Fail closed and treat a forbidden destination as terminal (no retry —
|
||||||
|
# the URL itself is the problem, not a transient network condition).
|
||||||
|
try:
|
||||||
|
dest = validate_webhook_url(url)
|
||||||
|
except WebhookDestinationError as e:
|
||||||
|
log.warning("webhook delivery blocked by SSRF guard: %s", e)
|
||||||
|
return DeliveryResult(ok=False, status_code=None, error=str(e), attempts=0)
|
||||||
|
|
||||||
owns_client = client is None
|
owns_client = client is None
|
||||||
if client is None:
|
if client is None:
|
||||||
client = httpx.AsyncClient(timeout=timeout_s)
|
client = httpx.AsyncClient(timeout=timeout_s)
|
||||||
@@ -157,7 +216,8 @@ async def deliver(
|
|||||||
try:
|
try:
|
||||||
for attempt in range(1, max_attempts + 1):
|
for attempt in range(1, max_attempts + 1):
|
||||||
try:
|
try:
|
||||||
resp = await client.post(url, content=body, headers=headers)
|
request = _build_pinned_request(client, url, dest, body, headers)
|
||||||
|
resp = await client.send(request, follow_redirects=False)
|
||||||
last_status = resp.status_code
|
last_status = resp.status_code
|
||||||
if 200 <= resp.status_code < 300:
|
if 200 <= resp.status_code < 300:
|
||||||
return DeliveryResult(
|
return DeliveryResult(
|
||||||
|
|||||||
151
decnet/webhook/ssrf.py
Normal file
151
decnet/webhook/ssrf.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
"""SSRF egress guard for outbound webhook delivery.
|
||||||
|
|
||||||
|
Admin-supplied webhook URLs are attacker-influenceable (anyone able to
|
||||||
|
write a subscription row). Without a destination check the master can be
|
||||||
|
pointed at internal services — cloud metadata (169.254.169.254), the
|
||||||
|
loopback API, RFC1918 hosts — turning the egress path into an SSRF
|
||||||
|
primitive.
|
||||||
|
|
||||||
|
This module resolves the URL host to concrete IPs and rejects any that
|
||||||
|
are private / loopback / link-local / unspecified / reserved / multicast,
|
||||||
|
and rejects non-http(s) schemes. It returns the *validated* IP set so the
|
||||||
|
caller can connect to a checked address rather than re-resolving (which a
|
||||||
|
DNS-rebinding attacker could flip between the validation and the connect).
|
||||||
|
|
||||||
|
Fail closed: the guard is fully active unless the operator explicitly opts
|
||||||
|
out via ``DECNET_WEBHOOK_ALLOW_PRIVATE=true``.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
|
_ALLOWED_SCHEMES = frozenset({"http", "https"})
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookDestinationError(ValueError):
|
||||||
|
"""Raised when a webhook URL resolves to a forbidden destination.
|
||||||
|
|
||||||
|
Subclasses ``ValueError`` so the CRUD layer can turn it into a 422 and
|
||||||
|
the delivery layer can treat it as a terminal (non-retryable) failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ValidatedDestination:
|
||||||
|
"""Result of a successful guard check.
|
||||||
|
|
||||||
|
``ip_addresses`` is the set of validated literal IPs the URL host
|
||||||
|
resolved to. Connecting to one of these (instead of re-resolving the
|
||||||
|
hostname) closes the DNS-rebinding window.
|
||||||
|
"""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
scheme: str
|
||||||
|
ip_addresses: tuple[str, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_forbidden(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||||
|
"""Block anything that is not a routable public address.
|
||||||
|
|
||||||
|
``is_global`` is the inverse of the union we care about, but we spell
|
||||||
|
out the categories so the intent (and the audit mapping) is explicit
|
||||||
|
and so we also catch reserved/multicast that ``is_private`` misses.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
ip.is_private # RFC1918 10/8, 172.16/12, 192.168/16, fc00::/7
|
||||||
|
or ip.is_loopback # 127/8, ::1
|
||||||
|
or ip.is_link_local # 169.254/16 (incl. 169.254.169.254), fe80::/10
|
||||||
|
or ip.is_unspecified # 0.0.0.0, ::
|
||||||
|
or ip.is_reserved
|
||||||
|
or ip.is_multicast
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
# IPv4-mapped IPv6 (::ffff:a.b.c.d) hides a v4 address from the checks
|
||||||
|
# above; unwrap and re-check so 127.0.0.1 can't sneak in as ::ffff:7f00:1.
|
||||||
|
mapped = getattr(ip, "ipv4_mapped", None)
|
||||||
|
if mapped is not None:
|
||||||
|
return _is_forbidden(mapped)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve(host: str, port: int) -> tuple[str, ...]:
|
||||||
|
"""Resolve *host* to the set of literal IPs it points at.
|
||||||
|
|
||||||
|
A bare IP literal short-circuits getaddrinfo. DNS failures raise
|
||||||
|
``WebhookDestinationError`` (fail closed — we never deliver to a host
|
||||||
|
we couldn't resolve and check)."""
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(host)
|
||||||
|
return (host,)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
infos = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)
|
||||||
|
except socket.gaierror as exc:
|
||||||
|
raise WebhookDestinationError(
|
||||||
|
f"webhook host {host!r} did not resolve: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
addrs = {str(info[4][0]) for info in infos}
|
||||||
|
if not addrs:
|
||||||
|
raise WebhookDestinationError(f"webhook host {host!r} resolved to nothing")
|
||||||
|
return tuple(sorted(addrs))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_webhook_url(url: str, *, allow_private: Optional[bool] = None) -> ValidatedDestination:
|
||||||
|
"""Validate *url* as a safe webhook egress destination.
|
||||||
|
|
||||||
|
Raises ``WebhookDestinationError`` on a bad scheme, missing host, a host
|
||||||
|
that won't resolve, or any resolved address that is private / loopback /
|
||||||
|
link-local / unspecified / reserved / multicast.
|
||||||
|
|
||||||
|
``allow_private`` defaults to the ``DECNET_WEBHOOK_ALLOW_PRIVATE`` env
|
||||||
|
flag (resolved lazily so tests can monkeypatch the env module). When
|
||||||
|
True the IP-category checks are skipped, but scheme + resolvability are
|
||||||
|
still enforced.
|
||||||
|
"""
|
||||||
|
if allow_private is None:
|
||||||
|
from decnet.env import DECNET_WEBHOOK_ALLOW_PRIVATE
|
||||||
|
|
||||||
|
allow_private = DECNET_WEBHOOK_ALLOW_PRIVATE
|
||||||
|
|
||||||
|
parts = urlsplit(url)
|
||||||
|
scheme = parts.scheme.lower()
|
||||||
|
if scheme not in _ALLOWED_SCHEMES:
|
||||||
|
raise WebhookDestinationError(
|
||||||
|
f"webhook URL scheme {scheme!r} is not allowed (use http/https)"
|
||||||
|
)
|
||||||
|
|
||||||
|
host = parts.hostname
|
||||||
|
if not host:
|
||||||
|
raise WebhookDestinationError("webhook URL has no host")
|
||||||
|
|
||||||
|
port = parts.port or (443 if scheme == "https" else 80)
|
||||||
|
|
||||||
|
resolved = _resolve(host, port)
|
||||||
|
|
||||||
|
if not allow_private:
|
||||||
|
for addr in resolved:
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(addr)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise WebhookDestinationError(
|
||||||
|
f"webhook host {host!r} resolved to non-IP {addr!r}"
|
||||||
|
) from exc
|
||||||
|
if _is_forbidden(ip):
|
||||||
|
raise WebhookDestinationError(
|
||||||
|
f"webhook host {host!r} resolves to forbidden address {addr} "
|
||||||
|
"(private/loopback/link-local/reserved). Set "
|
||||||
|
"DECNET_WEBHOOK_ALLOW_PRIVATE=true to permit internal targets."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ValidatedDestination(
|
||||||
|
host=host, port=port, scheme=scheme, ip_addresses=resolved
|
||||||
|
)
|
||||||
@@ -18,7 +18,14 @@ import pytest
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||||
from decnet.web.auth import ALGORITHM, SECRET_KEY, get_password_hash
|
from decnet.web.auth import (
|
||||||
|
ALGORITHM,
|
||||||
|
JWT_AUDIENCE,
|
||||||
|
JWT_ISSUER,
|
||||||
|
JWT_TYPE,
|
||||||
|
SECRET_KEY,
|
||||||
|
get_password_hash,
|
||||||
|
)
|
||||||
from decnet.web.db.models import User
|
from decnet.web.db.models import User
|
||||||
from decnet.web.dependencies import repo
|
from decnet.web.dependencies import repo
|
||||||
|
|
||||||
@@ -54,7 +61,17 @@ def _aged_token(uuid: str, *, seconds_old: int = 30) -> str:
|
|||||||
change sets to 'now', so it is deterministically revoked once bumped."""
|
change sets to 'now', so it is deterministically revoked once bumped."""
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
return jwt.encode(
|
return jwt.encode(
|
||||||
{"uuid": uuid, "jti": f"aged-{uuid}", "iat": now - seconds_old, "exp": now + 3600},
|
{
|
||||||
|
"uuid": uuid,
|
||||||
|
"jti": f"aged-{uuid}",
|
||||||
|
"iat": now - seconds_old,
|
||||||
|
"exp": now + 3600,
|
||||||
|
# The verifier now pins issuer/audience/type (V2.1.1 / V3.1.2); a
|
||||||
|
# manually-encoded token must carry them or decode rejects it.
|
||||||
|
"iss": JWT_ISSUER,
|
||||||
|
"aud": JWT_AUDIENCE,
|
||||||
|
"typ": JWT_TYPE,
|
||||||
|
},
|
||||||
SECRET_KEY, algorithm=ALGORITHM,
|
SECRET_KEY, algorithm=ALGORITHM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from hypothesis import given, strategies as st, settings
|
from hypothesis import given, strategies as st, settings
|
||||||
import httpx
|
import httpx
|
||||||
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||||
|
from decnet.web.limiter import limiter as _limiter
|
||||||
from ..conftest import _FUZZ_SETTINGS
|
from ..conftest import _FUZZ_SETTINGS
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -57,6 +58,46 @@ async def test_fuzz_change_password(client: httpx.AsyncClient, old_password: str
|
|||||||
json=_payload,
|
json=_payload,
|
||||||
headers={"Authorization": f"Bearer {_token}"}
|
headers={"Authorization": f"Bearer {_token}"}
|
||||||
)
|
)
|
||||||
assert _response.status_code in (200, 401, 422)
|
# 400: schema-guard middleware rejects bad length/shape (e.g. a
|
||||||
|
# new_password below the 12-char floor) before the handler runs.
|
||||||
|
assert _response.status_code in (200, 400, 401, 422)
|
||||||
except (UnicodeEncodeError, json.JSONDecodeError):
|
except (UnicodeEncodeError, json.JSONDecodeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Rate-limit enforcement ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_change_password_rate_limit_trips_after_5(client: httpx.AsyncClient) -> None:
|
||||||
|
"""5 change-password attempts from one IP → 6th returns 429."""
|
||||||
|
login_resp = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"username": DECNET_ADMIN_USER, "password": DECNET_ADMIN_PASSWORD},
|
||||||
|
)
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
r = await client.post(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
json={"old_password": f"wrong-{i}", "new_password": "does-not-matter-x!"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
# 401 (bad old password) or 429 if the limiter fires — either is fine
|
||||||
|
assert r.status_code in (401, 429), f"attempt {i}: got {r.status_code}"
|
||||||
|
|
||||||
|
# The 6th attempt must trip the rate limiter (limit is 5/minute).
|
||||||
|
r = await client.post(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
json={"old_password": "still-wrong", "new_password": "does-not-matter-x!"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_change_password_route_has_rate_limit_decorator() -> None:
|
||||||
|
"""Contract test: change_password handler must be wrapped by slowapi."""
|
||||||
|
from decnet.web.router.auth import api_change_pass as _mod
|
||||||
|
|
||||||
|
assert getattr(_mod.change_password, "__wrapped__", None) is not None
|
||||||
|
|||||||
111
tests/api/auth/test_sse_ticket.py
Normal file
111
tests/api/auth/test_sse_ticket.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
"""SSE stream tickets (V3.1.1) + change-password min-length (V2.1.3).
|
||||||
|
|
||||||
|
The ticket store is a security boundary: single-use, 60s, fail-closed. These
|
||||||
|
cover the mint→redeem happy path, single-use reuse rejection, expiry rejection,
|
||||||
|
the endpoint round-trip, and the V3.1.1 invariant that a raw JWT in the SSE
|
||||||
|
query string is no longer accepted.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from decnet.env import DECNET_ADMIN_USER, DECNET_ADMIN_PASSWORD
|
||||||
|
from decnet.web.auth import create_access_token
|
||||||
|
from decnet.web import dependencies as deps
|
||||||
|
|
||||||
|
|
||||||
|
# ── ticket store unit tests ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_mint_then_redeem_happy_path() -> None:
|
||||||
|
deps._reset_sse_tickets()
|
||||||
|
ticket = deps.mint_sse_ticket("user-1", "viewer")
|
||||||
|
identity = deps._redeem_sse_ticket(ticket)
|
||||||
|
assert identity == {"uuid": "user-1", "role": "viewer"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_ticket_is_single_use() -> None:
|
||||||
|
deps._reset_sse_tickets()
|
||||||
|
ticket = deps.mint_sse_ticket("user-1", "admin")
|
||||||
|
deps._redeem_sse_ticket(ticket) # first redeem consumes it
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
deps._redeem_sse_ticket(ticket)
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_ticket_rejected() -> None:
|
||||||
|
deps._reset_sse_tickets()
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
deps._redeem_sse_ticket("never-minted")
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_expired_ticket_rejected() -> None:
|
||||||
|
deps._reset_sse_tickets()
|
||||||
|
# Mint, then jam the entry's expiry into the past so redeem fails closed.
|
||||||
|
ticket = deps.mint_sse_ticket("user-1", "viewer")
|
||||||
|
exp, identity = deps._sse_tickets[ticket]
|
||||||
|
deps._sse_tickets[ticket] = (exp - deps._SSE_TICKET_TTL - 1, identity)
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
deps._redeem_sse_ticket(ticket)
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── endpoint round-trip ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sse_ticket_endpoint_requires_auth(client: httpx.AsyncClient) -> None:
|
||||||
|
resp = await client.post("/api/v1/auth/sse-ticket")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sse_ticket_endpoint_mints_and_redeems(
|
||||||
|
client: httpx.AsyncClient, auth_token: str
|
||||||
|
) -> None:
|
||||||
|
deps._reset_sse_tickets()
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/auth/sse-ticket",
|
||||||
|
headers={"Authorization": f"Bearer {auth_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
body = resp.json()
|
||||||
|
assert body["expires_in"] == 60
|
||||||
|
ticket = body["ticket"]
|
||||||
|
assert ticket and "." not in ticket # opaque, not a JWT
|
||||||
|
# The minted ticket redeems to a bound identity exactly once.
|
||||||
|
identity = deps._redeem_sse_ticket(ticket)
|
||||||
|
assert "uuid" in identity and identity["role"] in ("admin", "viewer")
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_jwt_in_sse_query_rejected() -> None:
|
||||||
|
"""V3.1.1: a raw JWT is not a valid opaque ticket — _redeem_sse_ticket rejects
|
||||||
|
any token that wasn't minted by mint_sse_ticket (unknown key → 401)."""
|
||||||
|
deps._reset_sse_tickets()
|
||||||
|
token = create_access_token({"uuid": "leaked", "jti": "x"})
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
deps._redeem_sse_ticket(token)
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── V2.1.3 change-password min length ────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_change_password_below_min_length_rejected(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
) -> None:
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"username": DECNET_ADMIN_USER, "password": DECNET_ADMIN_PASSWORD,
|
||||||
|
})
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
# 11 chars — one below the 12-char floor. The request-validation layer
|
||||||
|
# rejects the bad length before any auth/logic runs; DECNET's schema-guard
|
||||||
|
# middleware surfaces length violations as 400 (not the raw 422).
|
||||||
|
r = await client.post(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
json={"old_password": DECNET_ADMIN_PASSWORD, "new_password": "short123456"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert r.status_code in (400, 422), r.text
|
||||||
@@ -379,3 +379,31 @@ async def test_host_row_persisted_after_enroll(client, auth_token):
|
|||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["name"] == "eta"
|
assert row["name"] == "eta"
|
||||||
assert row["status"] == "enrolled"
|
assert row["status"] == "enrolled"
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Rate-limit enforcement ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_enroll_bundle_rate_limit_trips_after_10(client, auth_token):
|
||||||
|
"""10 enroll-bundle POSTs from one IP → 11th returns 429.
|
||||||
|
|
||||||
|
Each request uses a unique agent name (otherwise the 2nd hits the 409
|
||||||
|
duplicate-name guard before the rate check fires). The limiter is
|
||||||
|
10/minute for this endpoint.
|
||||||
|
"""
|
||||||
|
for i in range(10):
|
||||||
|
r = await _post(client, auth_token, agent_name=f"rl-node-{i}")
|
||||||
|
# 201 (created) or 429 if limiter fires early — accept both.
|
||||||
|
assert r.status_code in (201, 429), f"attempt {i}: got {r.status_code}"
|
||||||
|
|
||||||
|
r = await _post(client, auth_token, agent_name="rl-node-overflow")
|
||||||
|
assert r.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_enroll_bundle_route_has_rate_limit_decorator() -> None:
|
||||||
|
"""Contract test: create_enroll_bundle must be wrapped by slowapi."""
|
||||||
|
from decnet.web.router.swarm_mgmt import api_enroll_bundle as _mod
|
||||||
|
|
||||||
|
assert getattr(_mod.create_enroll_bundle, "__wrapped__", None) is not None
|
||||||
|
|||||||
@@ -9,6 +9,80 @@ import pytest
|
|||||||
PATH = "/api/v1/webhooks/"
|
PATH = "/api/v1/webhooks/"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _public_dns(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Resolve hostnames to a public IP so the registration-time SSRF guard
|
||||||
|
passes for the functional CRUD cases without touching the network.
|
||||||
|
|
||||||
|
IP-literal URLs (e.g. the loopback-rejection test) don't hit DNS, so
|
||||||
|
this stub doesn't mask them.
|
||||||
|
"""
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from decnet.webhook import ssrf
|
||||||
|
|
||||||
|
def fake_getaddrinfo(host, port, *a, **k):
|
||||||
|
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", port))]
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssrf.socket, "getaddrinfo", fake_getaddrinfo)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_rejects_loopback_url(
|
||||||
|
client: httpx.AsyncClient, auth_token: str
|
||||||
|
):
|
||||||
|
res = await client.post(
|
||||||
|
PATH,
|
||||||
|
json={
|
||||||
|
"name": "wh-ssrf",
|
||||||
|
"url": "http://127.0.0.1:8080/inbound",
|
||||||
|
"topic_patterns": ["system.>"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {auth_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422, res.text
|
||||||
|
assert "forbidden" in res.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_rejects_metadata_url(
|
||||||
|
client: httpx.AsyncClient, auth_token: str
|
||||||
|
):
|
||||||
|
res = await client.post(
|
||||||
|
PATH,
|
||||||
|
json={
|
||||||
|
"name": "wh-meta",
|
||||||
|
"url": "http://169.254.169.254/latest/meta-data/",
|
||||||
|
"topic_patterns": ["system.>"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {auth_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422, res.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_rejects_loopback_url(
|
||||||
|
client: httpx.AsyncClient, auth_token: str
|
||||||
|
):
|
||||||
|
create = await client.post(
|
||||||
|
PATH,
|
||||||
|
json={
|
||||||
|
"name": "wh-upd-ssrf",
|
||||||
|
"url": "https://good.example/x",
|
||||||
|
"topic_patterns": ["system.>"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {auth_token}"},
|
||||||
|
)
|
||||||
|
assert create.status_code == 201, create.text
|
||||||
|
uuid = create.json()["uuid"]
|
||||||
|
res = await client.patch(
|
||||||
|
f"{PATH}{uuid}",
|
||||||
|
json={"url": "http://10.0.0.1/x"},
|
||||||
|
headers={"Authorization": f"Bearer {auth_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422, res.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_requires_patterns(client: httpx.AsyncClient, auth_token: str):
|
async def test_create_requires_patterns(client: httpx.AsyncClient, auth_token: str):
|
||||||
res = await client.post(
|
res = await client.post(
|
||||||
|
|||||||
@@ -223,3 +223,62 @@ async def test_independent_dedup_per_identity(
|
|||||||
seen = {c["payload"]["identity_uuid"] for c in captured}
|
seen = {c["payload"]["identity_uuid"] for c in captured}
|
||||||
assert seen == {iuid_a, iuid_b}
|
assert seen == {iuid_a, iuid_b}
|
||||||
await bus.close()
|
await bus.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_rearms_for_sub_threshold_identity_in_candidates(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""BUG-10 regression: seen_now.add() must run AFTER the threshold guard.
|
||||||
|
|
||||||
|
If an identity is returned by the repo with < MULTI_ACTOR_MIN_PRIMITIVES
|
||||||
|
(defensive path) it must NOT be added to seen_now. That means it stays
|
||||||
|
absent from seen_now → gets removed from last_fired on the stale-rearm
|
||||||
|
sweep → re-fires when primitives climb back above threshold.
|
||||||
|
|
||||||
|
Before fix: seen_now.add() ran before the continue, so the identity
|
||||||
|
was treated as present-and-seen even though it was below threshold,
|
||||||
|
and last_fired was never cleared → no rearm.
|
||||||
|
"""
|
||||||
|
bus = FakeBus()
|
||||||
|
await bus.connect()
|
||||||
|
captured: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def cap(_b, t, p, *, event_type=""):
|
||||||
|
captured.append({"topic": t, "payload": p})
|
||||||
|
|
||||||
|
monkeypatch.setattr(_aw, "publish_safely", cap)
|
||||||
|
|
||||||
|
iuid = "test-rearm-uuid"
|
||||||
|
|
||||||
|
class _StubRepo:
|
||||||
|
def __init__(self, entries: list[dict]) -> None:
|
||||||
|
self._entries = entries
|
||||||
|
|
||||||
|
async def list_multi_actor_identities(self) -> list[dict]:
|
||||||
|
return list(self._entries)
|
||||||
|
|
||||||
|
# First tick: identity fires with 2 primitives.
|
||||||
|
repo_above = _StubRepo([
|
||||||
|
{"identity_uuid": iuid, "primitives": ["prim.a", "prim.b"]},
|
||||||
|
])
|
||||||
|
last_fired: dict[str, Any] = {}
|
||||||
|
await _aw.tick_multi_actor(bus, repo_above, last_fired) # type: ignore[arg-type]
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert iuid in last_fired
|
||||||
|
|
||||||
|
# Second tick: identity returned by repo but with only 1 primitive
|
||||||
|
# (sub-threshold defensive path). last_fired[iuid] must be cleared.
|
||||||
|
repo_below = _StubRepo([
|
||||||
|
{"identity_uuid": iuid, "primitives": ["prim.a"]},
|
||||||
|
])
|
||||||
|
await _aw.tick_multi_actor(bus, repo_below, last_fired) # type: ignore[arg-type]
|
||||||
|
assert iuid not in last_fired, (
|
||||||
|
"sub-threshold identity must be removed from last_fired so it re-arms"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Third tick: identity climbs back above threshold — must re-fire.
|
||||||
|
await _aw.tick_multi_actor(bus, repo_above, last_fired) # type: ignore[arg-type]
|
||||||
|
assert len(captured) == 2, "identity must re-fire after rearm"
|
||||||
|
|
||||||
|
await bus.close()
|
||||||
|
|||||||
@@ -235,3 +235,69 @@ def test_multiple_rotations_increment_counter(engine, now):
|
|||||||
row = session.exec(select(AttackerFingerprintState)).one()
|
row = session.exec(select(AttackerFingerprintState)).one()
|
||||||
assert row.rotation_count == 2
|
assert row.rotation_count == 2
|
||||||
assert row.last_hash == "h3"
|
assert row.last_hash == "h3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_after_commit_raising_publish_does_not_lose_row(engine, now) -> None:
|
||||||
|
"""BUG-9 regression: publish_fn is called AFTER session.commit().
|
||||||
|
|
||||||
|
A raising publish_fn must not roll back / lose the committed rotation
|
||||||
|
row. Before fix, publish was called before commit so a raise in
|
||||||
|
publish_fn left the session without a commit and the state row was lost.
|
||||||
|
"""
|
||||||
|
later = now + timedelta(hours=1)
|
||||||
|
|
||||||
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
class _OrderRecorder:
|
||||||
|
def __call__(self, event_type: str, payload: dict) -> None:
|
||||||
|
call_order.append("emit")
|
||||||
|
raise RuntimeError("downstream unavailable")
|
||||||
|
|
||||||
|
publish = _OrderRecorder()
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
# Patch session.commit to record ordering.
|
||||||
|
original_commit = session.commit
|
||||||
|
|
||||||
|
def _recording_commit() -> None:
|
||||||
|
call_order.append("commit")
|
||||||
|
original_commit()
|
||||||
|
|
||||||
|
session.commit = _recording_commit # type: ignore[method-assign]
|
||||||
|
|
||||||
|
_seed_attacker(session)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
original_commit2 = session.commit
|
||||||
|
|
||||||
|
def _recording_commit2() -> None:
|
||||||
|
call_order.append("commit")
|
||||||
|
original_commit2()
|
||||||
|
|
||||||
|
session.commit = _recording_commit2 # type: ignore[method-assign]
|
||||||
|
|
||||||
|
# first_sighting — no publish yet
|
||||||
|
record_fingerprint(
|
||||||
|
session,
|
||||||
|
attacker_ip="1.2.3.4", port=22, probe_type="hassh",
|
||||||
|
new_hash="h1", ts=now,
|
||||||
|
)
|
||||||
|
call_order.clear()
|
||||||
|
|
||||||
|
# rotation — publish_fn raises after commit
|
||||||
|
outcome = record_fingerprint(
|
||||||
|
session,
|
||||||
|
attacker_ip="1.2.3.4", port=22, probe_type="hassh",
|
||||||
|
new_hash="h2", ts=later,
|
||||||
|
publish_fn=publish,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outcome.kind == "rotated"
|
||||||
|
# commit must come before emit
|
||||||
|
assert call_order.index("commit") < call_order.index("emit")
|
||||||
|
|
||||||
|
# The rotation row must be persisted despite publish raising
|
||||||
|
with Session(engine) as session:
|
||||||
|
row = session.exec(select(AttackerFingerprintState)).one()
|
||||||
|
assert row.last_hash == "h2"
|
||||||
|
assert row.rotation_count == 1
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ must agree with the collector's ``parse_rfc5424`` so that
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
from decnet.correlation.parser import parse_line
|
from decnet.correlation.parser import parse_line
|
||||||
|
|
||||||
|
|
||||||
@@ -71,3 +73,41 @@ def test_outer_msgid_set_does_not_recurse() -> None:
|
|||||||
assert e.event_type == "auth_attempt"
|
assert e.event_type == "auth_attempt"
|
||||||
assert e.decky == "omega-decky"
|
assert e.decky == "omega-decky"
|
||||||
assert e.service == "auth-helper"
|
assert e.service == "auth-helper"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BUG-11 regression: naive datetime normalization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_NAIVE_TS_LINE = (
|
||||||
|
"<14>1 2026-05-02T06:22:48.089309 omega-decky smtp - disconnect "
|
||||||
|
"[relay@55555 src_ip=\"10.0.0.1\"]"
|
||||||
|
)
|
||||||
|
|
||||||
|
_AWARE_TS_LINE = (
|
||||||
|
"<14>1 2026-05-02T06:22:48.089309+00:00 omega-decky smtp - disconnect "
|
||||||
|
"[relay@55555 src_ip=\"10.0.0.2\"]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_naive_timestamp_normalized_to_utc() -> None:
|
||||||
|
"""BUG-11 regression: a log line with a naïve ISO timestamp (no tz offset)
|
||||||
|
must parse to a tz-aware UTC datetime so it sorts alongside aware ones
|
||||||
|
without TypeError. Before fix, fromisoformat returned a naïve datetime
|
||||||
|
which crashed min/max/sort with aware datetimes downstream."""
|
||||||
|
e = parse_line(_NAIVE_TS_LINE)
|
||||||
|
assert e is not None
|
||||||
|
assert e.timestamp.tzinfo is not None
|
||||||
|
assert e.timestamp.tzinfo == timezone.utc
|
||||||
|
|
||||||
|
|
||||||
|
def test_naive_and_aware_timestamps_sortable_together() -> None:
|
||||||
|
"""A naïve-source entry and an aware-source entry must compare
|
||||||
|
without raising TypeError."""
|
||||||
|
naive_entry = parse_line(_NAIVE_TS_LINE)
|
||||||
|
aware_entry = parse_line(_AWARE_TS_LINE)
|
||||||
|
assert naive_entry is not None
|
||||||
|
assert aware_entry is not None
|
||||||
|
# min/max would raise TypeError pre-fix
|
||||||
|
earliest = min(naive_entry.timestamp, aware_entry.timestamp)
|
||||||
|
assert earliest is not None
|
||||||
|
|||||||
@@ -276,6 +276,29 @@ async def test_one_tick_email_branch_records_orchestrator_email(
|
|||||||
assert ev.payload["mail_decky_uuid"] == mail_decky.uuid
|
assert ev.payload["mail_decky_uuid"] == mail_decky.uuid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smtp_probe_listener_get_bus_raises_no_unbound_error(
|
||||||
|
repo, monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
"""BUG-7 regression: if get_bus() raises, the finally block must not
|
||||||
|
produce an UnboundLocalError on ``bus``; the function must return
|
||||||
|
cleanly (RuntimeError is logged+swallowed by the outer except handler)."""
|
||||||
|
import asyncio
|
||||||
|
from decnet.orchestrator import worker as _w
|
||||||
|
|
||||||
|
def bad_get_bus(**_kw):
|
||||||
|
raise RuntimeError("bus factory unavailable")
|
||||||
|
|
||||||
|
monkeypatch.setattr(_w, "get_bus", bad_get_bus)
|
||||||
|
|
||||||
|
shutdown = asyncio.Event()
|
||||||
|
shutdown.set()
|
||||||
|
|
||||||
|
# Before fix: UnboundLocalError escaped from finally because ``bus``
|
||||||
|
# was never assigned. After fix: completes without any exception.
|
||||||
|
await _w._run_smtp_probe_listener(repo, shutdown)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tick_is_noop_when_no_running_deckies(repo, fake_bus, monkeypatch):
|
async def test_tick_is_noop_when_no_running_deckies(repo, fake_bus, monkeypatch):
|
||||||
called = False
|
called = False
|
||||||
|
|||||||
@@ -345,9 +345,18 @@ def test_expired_state_treated_as_disabled_by_is_active() -> None:
|
|||||||
def test_apply_ceiling_only_clamps_clipped() -> None:
|
def test_apply_ceiling_only_clamps_clipped() -> None:
|
||||||
from decnet.ttp.impl._state import apply_ceiling
|
from decnet.ttp.impl._state import apply_ceiling
|
||||||
|
|
||||||
|
# ceiling is ignored unless state is clipped
|
||||||
enabled = RuleState(state="enabled", confidence_max=0.5)
|
enabled = RuleState(state="enabled", confidence_max=0.5)
|
||||||
assert apply_ceiling(0.9, enabled) == 0.9 # ceiling ignored unless clipped
|
assert apply_ceiling(0.9, enabled) == 0.9
|
||||||
|
|
||||||
|
# clipped + base > ceiling → clamped to ceiling (not scaled)
|
||||||
clipped = RuleState(state="clipped", confidence_max=0.5)
|
clipped = RuleState(state="clipped", confidence_max=0.5)
|
||||||
assert apply_ceiling(0.9, clipped) == pytest.approx(0.45)
|
assert apply_ceiling(0.9, clipped) == pytest.approx(0.5)
|
||||||
|
|
||||||
|
# clipped + base <= ceiling → base passes through unchanged
|
||||||
|
clipped_below = RuleState(state="clipped", confidence_max=0.8)
|
||||||
|
assert apply_ceiling(0.6, clipped_below) == pytest.approx(0.6)
|
||||||
|
|
||||||
|
# clipped + no ceiling declared → base passes through
|
||||||
clipped_no_max = RuleState(state="clipped", confidence_max=None)
|
clipped_no_max = RuleState(state="clipped", confidence_max=None)
|
||||||
assert apply_ceiling(0.9, clipped_no_max) == 0.9
|
assert apply_ceiling(0.9, clipped_no_max) == 0.9
|
||||||
|
|||||||
194
tests/updater/test_updater_client_pin.py
Normal file
194
tests/updater/test_updater_client_pin.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
"""UpdaterClient SHA-256 leaf-cert pinning (master->worker updater channel).
|
||||||
|
|
||||||
|
The updater channel pip-installs code as root, so it pins the worker's
|
||||||
|
updater leaf cert against ``SwarmHost.updater_cert_fingerprint`` and fails
|
||||||
|
closed on mismatch OR a missing recorded fingerprint.
|
||||||
|
|
||||||
|
We don't need the real updater ASGI app: ``UpdaterClient.__aenter__`` runs
|
||||||
|
``_verify_pin`` which opens its own throwaway TLS connection to extract the
|
||||||
|
peer leaf cert before any RPC. A minimal threaded mTLS socket server that
|
||||||
|
simply completes the handshake is enough to exercise the pin.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from decnet.swarm import client as swarm_client
|
||||||
|
from decnet.swarm import pki
|
||||||
|
from decnet.swarm.updater_client import UpdaterClient
|
||||||
|
|
||||||
|
|
||||||
|
def _free_port() -> int:
|
||||||
|
s = socket.socket()
|
||||||
|
s.bind(("127.0.0.1", 0))
|
||||||
|
port = s.getsockname()[1]
|
||||||
|
s.close()
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
class _MiniTLSServer:
|
||||||
|
"""Threaded mTLS server that accepts a connection, completes the
|
||||||
|
handshake (presenting the worker leaf cert), then closes."""
|
||||||
|
|
||||||
|
def __init__(self, worker_dir: pathlib.Path, port: int) -> None:
|
||||||
|
self._ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
self._ctx.load_cert_chain(
|
||||||
|
str(worker_dir / "worker.crt"), str(worker_dir / "worker.key")
|
||||||
|
)
|
||||||
|
self._ctx.load_verify_locations(cafile=str(worker_dir / "ca.crt"))
|
||||||
|
self._ctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
self._sock = socket.socket()
|
||||||
|
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
self._sock.bind(("127.0.0.1", port))
|
||||||
|
self._sock.listen(8)
|
||||||
|
self._sock.settimeout(0.5)
|
||||||
|
self._stop = threading.Event()
|
||||||
|
self._thread = threading.Thread(target=self._serve, daemon=True)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _serve(self) -> None:
|
||||||
|
while not self._stop.is_set():
|
||||||
|
try:
|
||||||
|
conn, _ = self._sock.accept()
|
||||||
|
except socket.timeout:
|
||||||
|
continue
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
tls = self._ctx.wrap_socket(conn, server_side=True)
|
||||||
|
try:
|
||||||
|
tls.recv(64)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
tls.close()
|
||||||
|
except OSError:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._stop.set()
|
||||||
|
try:
|
||||||
|
self._sock.close()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def updater_env(tmp_path: pathlib.Path):
|
||||||
|
ca_dir = tmp_path / "ca"
|
||||||
|
pki.ensure_ca(ca_dir)
|
||||||
|
worker_dir = tmp_path / "updater"
|
||||||
|
pki.write_worker_bundle(
|
||||||
|
pki.issue_worker_cert(pki.load_ca(ca_dir), "updater-test", ["127.0.0.1"]),
|
||||||
|
worker_dir,
|
||||||
|
)
|
||||||
|
master_id = swarm_client.ensure_master_identity(ca_dir)
|
||||||
|
|
||||||
|
port = _free_port()
|
||||||
|
server = _MiniTLSServer(worker_dir, port)
|
||||||
|
server.start()
|
||||||
|
# Give the listener a moment.
|
||||||
|
time.sleep(0.1)
|
||||||
|
try:
|
||||||
|
yield worker_dir, port, master_id
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_accepts_matching_fingerprint(updater_env) -> None:
|
||||||
|
worker_dir, port, master_id = updater_env
|
||||||
|
expected = pki.fingerprint((worker_dir / "worker.crt").read_bytes())
|
||||||
|
host = {
|
||||||
|
"uuid": "h1",
|
||||||
|
"name": "updater-test",
|
||||||
|
"address": "127.0.0.1",
|
||||||
|
"updater_cert_fingerprint": expected,
|
||||||
|
}
|
||||||
|
async with UpdaterClient(
|
||||||
|
host=host, updater_port=port, identity=master_id
|
||||||
|
) as u:
|
||||||
|
# Entering the context already ran _verify_pin successfully.
|
||||||
|
assert u._expected_fingerprint == expected.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_rejects_mismatch(updater_env) -> None:
|
||||||
|
_worker_dir, port, master_id = updater_env
|
||||||
|
host = {
|
||||||
|
"uuid": "h1",
|
||||||
|
"name": "updater-test",
|
||||||
|
"address": "127.0.0.1",
|
||||||
|
"updater_cert_fingerprint": "0" * 64,
|
||||||
|
}
|
||||||
|
with pytest.raises(swarm_client.FingerprintMismatchError):
|
||||||
|
async with UpdaterClient(host=host, updater_port=port, identity=master_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_rejects_missing_fingerprint(updater_env) -> None:
|
||||||
|
"""Fail closed: a host with no recorded updater fingerprint is refused
|
||||||
|
(unlike AgentClient, the updater channel never falls through to CA-only)."""
|
||||||
|
_worker_dir, port, master_id = updater_env
|
||||||
|
host = {
|
||||||
|
"uuid": "h1",
|
||||||
|
"name": "updater-test",
|
||||||
|
"address": "127.0.0.1",
|
||||||
|
"updater_cert_fingerprint": None,
|
||||||
|
}
|
||||||
|
with pytest.raises(swarm_client.FingerprintMismatchError):
|
||||||
|
async with UpdaterClient(host=host, updater_port=port, identity=master_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_hostname_defaults_to_env_flag(monkeypatch) -> None:
|
||||||
|
"""The verify_hostname kwarg defaults to DECNET_VERIFY_HOSTNAME, which
|
||||||
|
now defaults to True (operators opt OUT explicitly)."""
|
||||||
|
import decnet.env as env
|
||||||
|
|
||||||
|
monkeypatch.setattr(env, "DECNET_VERIFY_HOSTNAME", True)
|
||||||
|
c_default = UpdaterClient(address="127.0.0.1", updater_port=9)
|
||||||
|
assert c_default._verify_hostname is True
|
||||||
|
|
||||||
|
monkeypatch.setattr(env, "DECNET_VERIFY_HOSTNAME", False)
|
||||||
|
c_off = UpdaterClient(address="127.0.0.1", updater_port=9)
|
||||||
|
assert c_off._verify_hostname is False
|
||||||
|
|
||||||
|
# Explicit kwarg overrides the env default.
|
||||||
|
c_explicit = UpdaterClient(
|
||||||
|
address="127.0.0.1", updater_port=9, verify_hostname=True
|
||||||
|
)
|
||||||
|
assert c_explicit._verify_hostname is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_client_constructs_with_flag(updater_env) -> None:
|
||||||
|
"""_build_client must construct a client for both flag values without
|
||||||
|
error; check_hostname is wired from self._verify_hostname (verified via
|
||||||
|
the live handshake in the pin tests above, which use verify_hostname
|
||||||
|
from the env default)."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
_worker_dir, port, master_id = updater_env
|
||||||
|
for flag in (True, False):
|
||||||
|
c = UpdaterClient(
|
||||||
|
address="127.0.0.1", updater_port=port, identity=master_id,
|
||||||
|
verify_hostname=flag,
|
||||||
|
)
|
||||||
|
built = c._build_client(httpx.Timeout(5.0))
|
||||||
|
assert isinstance(built, httpx.AsyncClient)
|
||||||
|
assert c._verify_hostname is flag
|
||||||
|
await built.aclose()
|
||||||
@@ -76,57 +76,6 @@ class TestGetCurrentUser:
|
|||||||
await get_current_user(request)
|
await get_current_user(request)
|
||||||
|
|
||||||
|
|
||||||
# ── get_stream_user ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestGetStreamUser:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_bearer_header(self):
|
|
||||||
from decnet.web.dependencies import get_stream_user
|
|
||||||
token = create_access_token({"uuid": "stream-uuid"})
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
result = await get_stream_user(request, token=None)
|
|
||||||
assert result == "stream-uuid"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_param_fallback(self):
|
|
||||||
from decnet.web.dependencies import get_stream_user
|
|
||||||
token = create_access_token({"uuid": "query-uuid"})
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {}
|
|
||||||
result = await get_stream_user(request, token=token)
|
|
||||||
assert result == "query-uuid"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_token_raises(self):
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from decnet.web.dependencies import get_stream_user
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {}
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
|
||||||
await get_stream_user(request, token=None)
|
|
||||||
assert exc_info.value.status_code == 401
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_invalid_token_raises(self):
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from decnet.web.dependencies import get_stream_user
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {}
|
|
||||||
with pytest.raises(HTTPException):
|
|
||||||
await get_stream_user(request, token="bad-token")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_missing_uuid_raises(self):
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from decnet.web.dependencies import get_stream_user
|
|
||||||
token = create_access_token({"sub": "no-uuid"})
|
|
||||||
request = MagicMock()
|
|
||||||
request.headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
with pytest.raises(HTTPException):
|
|
||||||
await get_stream_user(request, token=None)
|
|
||||||
|
|
||||||
|
|
||||||
# ── web/api.py lifespan ──────────────────────────────────────────────────────
|
# ── web/api.py lifespan ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
class TestLifespan:
|
class TestLifespan:
|
||||||
|
|||||||
@@ -30,6 +30,23 @@ def _sub(url: str = "https://webhook.example/inbound", secret: str = "s" * 32) -
|
|||||||
return {"uuid": "w1", "url": url, "secret": secret}
|
return {"uuid": "w1", "url": url, "secret": secret}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _public_dns(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Resolve every hostname to a routable public IP so the SSRF guard
|
||||||
|
passes for the HMAC/retry behavioral tests without touching the network.
|
||||||
|
|
||||||
|
SSRF-specific tests below override this with their own resolution.
|
||||||
|
"""
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from decnet.webhook import ssrf
|
||||||
|
|
||||||
|
def fake_getaddrinfo(host, port, *a, **k):
|
||||||
|
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", port))]
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssrf.socket, "getaddrinfo", fake_getaddrinfo)
|
||||||
|
|
||||||
|
|
||||||
def test_sign_matches_known_vector():
|
def test_sign_matches_known_vector():
|
||||||
body = b'{"hello":"world"}'
|
body = b'{"hello":"world"}'
|
||||||
secret = "0123456789abcdef"
|
secret = "0123456789abcdef"
|
||||||
@@ -144,3 +161,141 @@ async def test_deliver_receiver_can_verify_signature():
|
|||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
assert captured["sig"] == expected
|
assert captured["sig"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------- SSRF egress guard ----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_to(monkeypatch, ip: str) -> None:
|
||||||
|
import socket as _socket
|
||||||
|
|
||||||
|
from decnet.webhook import ssrf
|
||||||
|
|
||||||
|
def fake(host, port, *a, **k):
|
||||||
|
return [(_socket.AF_INET, _socket.SOCK_STREAM, 6, "", (ip, port))]
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssrf.socket, "getaddrinfo", fake)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url",
|
||||||
|
[
|
||||||
|
"https://127.0.0.1/inbound", # loopback literal
|
||||||
|
"https://169.254.169.254/latest/meta-data", # cloud metadata
|
||||||
|
"https://10.1.2.3/inbound", # RFC1918 literal
|
||||||
|
"https://192.168.1.5/x", # RFC1918 literal
|
||||||
|
"https://[::1]/x", # IPv6 loopback
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_blocks_forbidden_ip_literal(url):
|
||||||
|
sent = {"n": 0}
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
sent["n"] += 1
|
||||||
|
return httpx.Response(200)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
async with httpx.AsyncClient(transport=transport) as client:
|
||||||
|
result = await deliver(_sub(url=url), _EVENT, retry_schedule=[], client=client)
|
||||||
|
assert result.ok is False
|
||||||
|
assert result.attempts == 0 # never left the guard
|
||||||
|
assert sent["n"] == 0 # transport never hit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_blocks_hostname_resolving_to_private(monkeypatch):
|
||||||
|
_resolve_to(monkeypatch, "10.0.0.7")
|
||||||
|
sent = {"n": 0}
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
sent["n"] += 1
|
||||||
|
return httpx.Response(200)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
async with httpx.AsyncClient(transport=transport) as client:
|
||||||
|
result = await deliver(
|
||||||
|
_sub(url="https://rebind.evil.example/x"), _EVENT,
|
||||||
|
retry_schedule=[], client=client,
|
||||||
|
)
|
||||||
|
assert result.ok is False
|
||||||
|
assert sent["n"] == 0
|
||||||
|
assert "forbidden" in (result.error or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_blocks_non_http_scheme():
|
||||||
|
result = await deliver(
|
||||||
|
_sub(url="file:///etc/passwd"), _EVENT, retry_schedule=[],
|
||||||
|
)
|
||||||
|
assert result.ok is False
|
||||||
|
assert "scheme" in (result.error or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_public_url_passes(monkeypatch):
|
||||||
|
_resolve_to(monkeypatch, "93.184.216.34")
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
async with httpx.AsyncClient(transport=transport) as client:
|
||||||
|
result = await deliver(
|
||||||
|
_sub(url="https://good.example/inbound"), _EVENT,
|
||||||
|
retry_schedule=[], client=client,
|
||||||
|
)
|
||||||
|
assert result.ok is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_allow_private_escape_hatch(monkeypatch):
|
||||||
|
# Operator opt-in flips the guard off for internal targets.
|
||||||
|
import decnet.env as env
|
||||||
|
|
||||||
|
monkeypatch.setattr(env, "DECNET_WEBHOOK_ALLOW_PRIVATE", True)
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
async with httpx.AsyncClient(transport=transport) as client:
|
||||||
|
result = await deliver(
|
||||||
|
_sub(url="https://127.0.0.1/inbound"), _EVENT,
|
||||||
|
retry_schedule=[], client=client,
|
||||||
|
)
|
||||||
|
assert result.ok is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_does_not_follow_redirect_to_internal(monkeypatch):
|
||||||
|
"""A 302 pointing at an IMDS address must never be followed.
|
||||||
|
|
||||||
|
deliver() sets follow_redirects=False on every send() call regardless of
|
||||||
|
the injected client's config, so the response is the raw 302 and the
|
||||||
|
internal IP is never contacted.
|
||||||
|
"""
|
||||||
|
requests_seen: list[str] = []
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
requests_seen.append(str(request.url))
|
||||||
|
# First request: public host returns a redirect to the cloud metadata IP.
|
||||||
|
return httpx.Response(
|
||||||
|
302,
|
||||||
|
headers={"Location": "http://169.254.169.254/latest/meta-data/"},
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
# Deliberately build the client with follow_redirects=True to prove that
|
||||||
|
# deliver() overrides it at the send() level.
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
transport=transport, follow_redirects=True
|
||||||
|
) as client:
|
||||||
|
result = await deliver(_sub(), _EVENT, retry_schedule=[], client=client)
|
||||||
|
|
||||||
|
# Only the initial request to the public host should have been made.
|
||||||
|
assert len(requests_seen) == 1
|
||||||
|
assert "169.254.169.254" not in requests_seen[0]
|
||||||
|
# deliver() treats the 302 as a non-retryable non-2xx.
|
||||||
|
assert result.ok is False
|
||||||
|
assert result.status_code == 302
|
||||||
|
|||||||
@@ -19,6 +19,20 @@ from decnet.webhook.worker import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _public_dns(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Resolve the test webhook host to a public IP so the egress SSRF guard
|
||||||
|
passes for these integration tests without touching the network."""
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from decnet.webhook import ssrf
|
||||||
|
|
||||||
|
def fake_getaddrinfo(host, port, *a, **k):
|
||||||
|
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", port))]
|
||||||
|
|
||||||
|
monkeypatch.setattr(ssrf.socket, "getaddrinfo", fake_getaddrinfo)
|
||||||
|
|
||||||
|
|
||||||
def _sub(
|
def _sub(
|
||||||
uuid: str,
|
uuid: str,
|
||||||
name: str,
|
name: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user