chore(types): enable warn_return_any and cast all no-any-return sites

Turn on mypy warn_return_any (pyproject) and resolve the 84 resulting
[no-any-return] errors across 43 files with typing.cast() at the return
sites — runtime no-ops that make the declared return type explicit where a
dependency (SQLAlchemy scalar/first/one, httpx .json(), subprocess, docker
SDK) hands back Any. No behavior change: no DTO/table field types altered, no
validation/coercion calls added, every cast reflects the true runtime type.

Locks in return-type strictness so the class of bug where a function silently
widens to Any can't regress. mypy decnet/ clean; adversarially verified
behavior-preserving (84 casts 1:1 with prior returns).

Bump tornado 6.5.5 -> 6.5.7 (CVE-2026-49854, transitive via snakeviz).
This commit is contained in:
2026-06-12 18:21:22 -04:00
parent 337520c7ad
commit 721122a7ef
42 changed files with 128 additions and 124 deletions

View File

@@ -22,7 +22,7 @@ import asyncio
import time import time
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, AsyncIterator from typing import Any, AsyncIterator, cast
EVENT_SCHEMA_VERSION = 1 EVENT_SCHEMA_VERSION = 1
@@ -203,4 +203,4 @@ async def _next_or_stop(queue: "asyncio.Queue[Any]") -> Event:
item = await queue.get() item = await queue.get()
if item is _CLOSE_SENTINEL: if item is _CLOSE_SENTINEL:
raise StopAsyncIteration raise StopAsyncIteration
return item return cast(Event, item)

View File

@@ -17,7 +17,7 @@ env-driven dispatch, optional telemetry wrapping). Callers MUST use
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any from typing import Any, cast
from decnet.bus.base import BaseBus from decnet.bus.base import BaseBus
@@ -81,6 +81,6 @@ def _maybe_wrap_telemetry(bus: BaseBus) -> BaseBus:
except ImportError: except ImportError:
return bus return bus
try: try:
return wrap_repository(bus) return cast(BaseBus, wrap_repository(bus))
except Exception: # pragma: no cover - defensive except Exception: # pragma: no cover - defensive
return bus return bus

View File

@@ -14,7 +14,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Any from typing import Any, cast
from decnet.bus.base import ( from decnet.bus.base import (
BaseBus, BaseBus,
@@ -51,7 +51,7 @@ class _FakeSubscription(Subscription):
item = await self._queue.get() item = await self._queue.get()
if item is _CLOSE_SENTINEL: if item is _CLOSE_SENTINEL:
raise StopAsyncIteration raise StopAsyncIteration
return item return cast(Event, item)
async def _aclose(self) -> None: async def _aclose(self) -> None:
self._bus._unregister(self) self._bus._unregister(self)

View File

@@ -26,7 +26,7 @@ import asyncio
import contextlib import contextlib
import os import os
import pathlib import pathlib
from typing import Any from typing import Any, cast
from decnet.bus import protocol from decnet.bus import protocol
from decnet.bus.base import ( from decnet.bus.base import (
@@ -61,7 +61,7 @@ class _UnixSubscription(Subscription):
item = await self._queue.get() item = await self._queue.get()
if item is _CLOSE_SENTINEL: if item is _CLOSE_SENTINEL:
raise StopAsyncIteration raise StopAsyncIteration
return item return cast(Event, item)
async def _aclose(self) -> None: async def _aclose(self) -> None:
await self._bus._unregister(self) await self._bus._unregister(self)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Optional from typing import Any, Optional, cast
import typer import typer
from rich.console import Console from rich.console import Console
@@ -96,9 +96,9 @@ def _list() -> None:
"""List all topologies.""" """List all topologies."""
_require_master_mode("topology list") _require_master_mode("topology list")
async def _go() -> list[dict]: async def _go() -> list[dict[Any, Any]]:
repo = await _repo() repo = await _repo()
return await repo.list_topologies() return cast(list[dict[Any, Any]], await repo.list_topologies())
rows = asyncio.run(_go()) rows = asyncio.run(_go())
if not rows: if not rows:
@@ -140,7 +140,7 @@ def _show(topology_id: str = typer.Argument(..., help="Topology id")) -> None:
def _decky_name(d: dict) -> str: def _decky_name(d: dict) -> str:
cfg = d.get("decky_config") or {} cfg = d.get("decky_config") or {}
return cfg.get("name") or d.get("name") or d["uuid"] return cast(str, cfg.get("name") or d.get("name") or d["uuid"])
deckies_by_name = {_decky_name(d): d for d in hydrated["deckies"]} deckies_by_name = {_decky_name(d): d for d in hydrated["deckies"]}
edges_by_lan: dict[str, list[dict]] = {} edges_by_lan: dict[str, list[dict]] = {}
@@ -296,9 +296,9 @@ def _mutate(
async def _go() -> str: async def _go() -> str:
repo = await _repo() repo = await _repo()
return await repo.enqueue_topology_mutation( return cast(str, await repo.enqueue_topology_mutation(
topology_id, op, payload, expected_version=expected_version, topology_id, op, payload, expected_version=expected_version,
) ))
mid = asyncio.run(_go()) mid = asyncio.run(_go())
_console.print( _console.print(
@@ -319,9 +319,9 @@ def _mutations(
"""List queued/applied mutations for a topology.""" """List queued/applied mutations for a topology."""
_require_master_mode("topology mutations") _require_master_mode("topology mutations")
async def _go() -> list[dict]: async def _go() -> list[dict[Any, Any]]:
repo = await _repo() repo = await _repo()
return await repo.list_topology_mutations(topology_id, state=state) return cast(list[dict[Any, Any]], await repo.list_topology_mutations(topology_id, state=state))
rows = asyncio.run(_go()) rows = asyncio.run(_go())
if not rows: if not rows:

View File

@@ -12,7 +12,7 @@ import signal
import subprocess # nosec B404 import subprocess # nosec B404
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, cast
import typer import typer
from rich.console import Console from rich.console import Console
@@ -91,7 +91,7 @@ def _is_running(match_fn) -> int | None:
try: try:
cmd = proc.info["cmdline"] cmd = proc.info["cmdline"]
if cmd and match_fn(cmd): if cmd and match_fn(cmd):
return proc.info["pid"] return cast(int, proc.info["pid"])
except (psutil.NoSuchProcess, psutil.AccessDenied): except (psutil.NoSuchProcess, psutil.AccessDenied):
continue continue
return None return None

View File

@@ -25,7 +25,7 @@ from __future__ import annotations
from collections import Counter from collections import Counter
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Sequence from typing import Any, Sequence, cast
from decnet.correlation.attribution import _thresholds as _T from decnet.correlation.attribution import _thresholds as _T
@@ -250,7 +250,7 @@ def _coef_of_variation(values: Sequence[float], mean: float) -> float:
stdev = variance ** 0.5 stdev = variance ** 0.5
if mean == 0: if mean == 0:
return 0.0 if stdev == 0 else 1e9 return 0.0 if stdev == 0 else 1e9
return stdev / abs(mean) return cast(float, stdev / abs(mean))
def _safe_float(value: Any) -> float: def _safe_float(value: Any) -> float:

View File

@@ -9,6 +9,7 @@ import shutil
import subprocess # nosec B404 import subprocess # nosec B404
import time import time
from pathlib import Path from pathlib import Path
from typing import cast
import anyio import anyio
import docker import docker
@@ -853,7 +854,7 @@ async def _resolve_swarm_host(repo, host_uuid: str) -> dict:
raise ValueError( raise ValueError(
f"topology pinned to unknown swarm host {host_uuid!r}" f"topology pinned to unknown swarm host {host_uuid!r}"
) )
return host return cast(dict, host)
async def _deploy_on_agent(repo, topology_id: str, hydrated: dict) -> None: async def _deploy_on_agent(repo, topology_id: str, hydrated: dict) -> None:

View File

@@ -24,7 +24,7 @@ from __future__ import annotations
import subprocess # nosec B404 import subprocess # nosec B404
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Any, Literal, Optional, cast
import anyio import anyio
@@ -217,7 +217,7 @@ async def _topology_decky(
cfg = d.get("decky_config") or {} cfg = d.get("decky_config") or {}
name = cfg.get("name") or d.get("name") name = cfg.get("name") or d.get("name")
if name == decky_name: if name == decky_name:
return d return cast(dict[str, Any], d)
raise ServiceNotFoundError( raise ServiceNotFoundError(
f"decky {decky_name!r} is not in topology {topology_id!r}" f"decky {decky_name!r} is not in topology {topology_id!r}"
) )
@@ -343,7 +343,7 @@ def _fleet_state_or_raise() -> tuple[Any, Path]:
raise ServiceMutationError( raise ServiceMutationError(
"no fleet state on disk — run `decnet up` first" "no fleet state on disk — run `decnet up` first"
) )
return state return cast(tuple[Any, Path], state)
def _fleet_find_decky(config: Any, decky_name: str) -> Any: def _fleet_find_decky(config: Any, decky_name: str) -> Any:

View File

@@ -25,7 +25,7 @@ Design notes
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Awaitable, Callable, Optional from typing import Any, Awaitable, Callable, Optional, cast
from decnet.logging import get_logger from decnet.logging import get_logger
from decnet.topology.allocator import IPAllocator, reserved_subnets, SubnetAllocator from decnet.topology.allocator import IPAllocator, reserved_subnets, SubnetAllocator
@@ -301,7 +301,7 @@ async def _live_topology_or_none(
topology_id, topology_id,
) )
return None return None
return topology return cast(dict[str, Any], topology)
async def _rerender_compose(repo: Any, topology_id: str) -> None: async def _rerender_compose(repo: Any, topology_id: str) -> None:

View File

@@ -12,6 +12,7 @@ Handles:
import os import os
import subprocess # nosec B404 import subprocess # nosec B404
from ipaddress import IPv4Address, IPv4Interface, IPv4Network from ipaddress import IPv4Address, IPv4Interface, IPv4Network
from typing import cast
import docker import docker
@@ -38,7 +39,7 @@ def detect_interface() -> str:
for line in result.stdout.splitlines(): for line in result.stdout.splitlines():
parts = line.split() parts = line.split()
if "dev" in parts: if "dev" in parts:
return parts[parts.index("dev") + 1] return cast(str, parts[parts.index("dev") + 1])
raise RuntimeError("Could not auto-detect network interface. Use --interface.") raise RuntimeError("Could not auto-detect network interface. Use --interface.")
@@ -79,7 +80,7 @@ def get_host_ip(interface: str) -> str:
for line in result.stdout.splitlines(): for line in result.stdout.splitlines():
line = line.strip() line = line.strip()
if line.startswith("inet ") and not line.startswith("inet6"): if line.startswith("inet ") and not line.startswith("inet6"):
return line.split()[1].split("/")[0] return cast(str, line.split()[1].split("/")[0])
raise RuntimeError(f"Could not determine host IP for interface {interface}.") raise RuntimeError(f"Could not determine host IP for interface {interface}.")
@@ -297,7 +298,7 @@ def create_bridge_network(
pools = (net.attrs.get("IPAM") or {}).get("Config") or [] pools = (net.attrs.get("IPAM") or {}).get("Config") or []
cur = pools[0] if pools else {} cur = pools[0] if pools else {}
if net.attrs.get("Driver") == "bridge" and cur.get("Subnet") == subnet: if net.attrs.get("Driver") == "bridge" and cur.get("Subnet") == subnet:
return net.id return cast(str, net.id)
for cid in (net.attrs.get("Containers") or {}): for cid in (net.attrs.get("Containers") or {}):
try: try:
net.disconnect(cid, force=True) net.disconnect(cid, force=True)
@@ -332,7 +333,7 @@ def create_bridge_network(
pool_configs=[docker.types.IPAMPool(subnet=subnet)], pool_configs=[docker.types.IPAMPool(subnet=subnet)],
), ),
) )
return net.id return cast(str, net.id)
def remove_bridge_network(client: docker.DockerClient, name: str) -> None: def remove_bridge_network(client: docker.DockerClient, name: str) -> None:
@@ -480,7 +481,7 @@ def get_container_pid(container_name: str) -> int:
pid = container.attrs["State"]["Pid"] pid = container.attrs["State"]["Pid"]
if not pid: if not pid:
raise LookupError(f"container {container_name!r} is not running (PID=0)") raise LookupError(f"container {container_name!r} is not running (PID=0)")
return pid return cast(int, pid)
def get_container_veth(container_name: str) -> str: def get_container_veth(container_name: str) -> str:
@@ -507,7 +508,7 @@ def get_container_veth(container_name: str) -> str:
if line.startswith(f"{peer_index}:"): if line.startswith(f"{peer_index}:"):
# Format: "42: veth3a4b5c@if41: <BROADCAST,...>" # Format: "42: veth3a4b5c@if41: <BROADCAST,...>"
iface = line.split(":")[1].strip().split("@")[0] iface = line.split(":")[1].strip().split("@")[0]
return iface return cast(str, iface)
raise LookupError( raise LookupError(
f"no host veth found for container {container_name!r} (peer ifindex {peer_index})" f"no host veth found for container {container_name!r} (peer ifindex {peer_index})"
) )

View File

@@ -20,7 +20,7 @@ import json
import secrets import secrets
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional, Sequence from typing import Any, Optional, Sequence, cast
from decnet.realism import personas_pool from decnet.realism import personas_pool
from decnet.realism.personas import EmailPersona, parse_personas from decnet.realism.personas import EmailPersona, parse_personas
@@ -256,7 +256,7 @@ def _persona_by_name(
for decky in enriched: for decky in enriched:
for persona in decky.get("_realism_personas") or []: for persona in decky.get("_realism_personas") or []:
if persona.name == name: if persona.name == name:
return persona return cast(EmailPersona, persona)
return None return None

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import hashlib import hashlib
import socket import socket
import ssl import ssl
from typing import Any from typing import Any, cast
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@@ -42,7 +42,7 @@ def _iso_utc(dt: Any) -> str:
and ``not_valid_before_utc`` (timezone-aware) — prefer the latter and ``not_valid_before_utc`` (timezone-aware) — prefer the latter
when available so we always emit explicit-Z ISO strings. when available so we always emit explicit-Z ISO strings.
""" """
return dt.strftime("%Y-%m-%dT%H:%M:%SZ") return cast(str, dt.strftime("%Y-%m-%dT%H:%M:%SZ"))
def _extract_sans(cert: x509.Certificate) -> list[str]: def _extract_sans(cert: x509.Certificate) -> list[str]:

View File

@@ -19,7 +19,7 @@ graph cycle-free and the env contract auditable in one place.
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any from typing import Any, cast
from decnet.realism.llm.base import LLMBackend from decnet.realism.llm.base import LLMBackend
@@ -45,7 +45,7 @@ def get_llm(*, model: str | None = None, **kwargs: Any) -> LLMBackend:
from decnet.realism.llm.config import get_cached_backend from decnet.realism.llm.config import get_cached_backend
cached = get_cached_backend() cached = get_cached_backend()
if cached is not None: if cached is not None:
return cached return cast(LLMBackend, cached)
backend_key = os.environ.get("DECNET_REALISM_LLM", "ollama").lower() backend_key = os.environ.get("DECNET_REALISM_LLM", "ollama").lower()

View File

@@ -17,7 +17,7 @@ import logging
import sqlite3 import sqlite3
import urllib.request import urllib.request
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Any, Optional, cast
from decnet.rpki import cache as _cache from decnet.rpki import cache as _cache
from decnet.rpki.base import RpkiResult, RpkiStatus, Validator from decnet.rpki.base import RpkiResult, RpkiStatus, Validator
@@ -68,13 +68,13 @@ class RipeStatValidator(Validator):
) )
raw = data.get("data", {}).get("status", "unknown") raw = data.get("data", {}).get("status", "unknown")
if raw in ("valid", "invalid", "not-found"): if raw in ("valid", "invalid", "not-found"):
return raw return cast(RpkiStatus, raw)
return "unknown" return "unknown"
def _fetch(self, url: str) -> dict: def _fetch(self, url: str) -> dict[Any, Any]:
req = urllib.request.Request(url, headers={"User-Agent": _UA}) req = urllib.request.Request(url, headers={"User-Agent": _UA})
with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp: # nosec B310 — HTTPS RIPE STAT base URL only; IP/ASN components are validated upstream with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp: # nosec B310 — HTTPS RIPE STAT base URL only; IP/ASN components are validated upstream
return json.loads(resp.read()) return cast(dict[Any, Any], json.loads(resp.read()))
def _store( def _store(
self, ip: str, asn: int, status: str, prefix: Optional[str] self, ip: str, asn: int, status: str, prefix: Optional[str]

View File

@@ -14,7 +14,7 @@ import hashlib
import struct import struct
import time import time
from collections import deque from collections import deque
from typing import Any, Callable from typing import Any, Callable, cast
from decnet.logging import get_logger from decnet.logging import get_logger
from decnet.prober.tcpfp import _extract_options_order from decnet.prober.tcpfp import _extract_options_order
@@ -1058,18 +1058,18 @@ class SnifferEngine:
def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str: def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str:
if event_type == "tls_client_hello": if event_type == "tls_client_hello":
return fields.get("ja3", "") + "|" + fields.get("ja4", "") return cast(str, fields.get("ja3", "") + "|" + fields.get("ja4", ""))
if event_type == "tls_session": if event_type == "tls_session":
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") + return cast(str, fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", "")) "|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
if event_type == "tls_certificate": if event_type == "tls_certificate":
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "") return cast(str, fields.get("subject_cn", "") + "|" + fields.get("issuer", ""))
if event_type == "tcp_syn_fingerprint": if event_type == "tcp_syn_fingerprint":
# Dedupe per (OS signature, options layout, sequence-pattern # Dedupe per (OS signature, options layout, sequence-pattern
# classification). Including ipid_class/isn_class lets each # classification). Including ipid_class/isn_class lets each
# transition (unknown → random/incremental/zero/constant) emit # transition (unknown → random/incremental/zero/constant) emit
# exactly one fresh event as samples accumulate. # exactly one fresh event as samples accumulate.
return ( return cast(str,
fields.get("os_guess", "") fields.get("os_guess", "")
+ "|" + fields.get("options_sig", "") + "|" + fields.get("options_sig", "")
+ "|" + fields.get("ipid_class", "") + "|" + fields.get("ipid_class", "")
@@ -1080,14 +1080,14 @@ class SnifferEngine:
# excluded so a port scanner rotating source ports only produces # excluded so a port scanner rotating source ports only produces
# one timing event per dedup window. Behavior cadence doesn't # one timing event per dedup window. Behavior cadence doesn't
# need per-ephemeral-port fidelity. # need per-ephemeral-port fidelity.
return fields.get("dst_ip", "") + "|" + fields.get("dst_port", "") return cast(str, fields.get("dst_ip", "") + "|" + fields.get("dst_port", ""))
if event_type == "quic_client_hello": if event_type == "quic_client_hello":
return fields.get("src_ip", "") + "|" + fields.get("ja4_quic", "") return cast(str, fields.get("src_ip", "") + "|" + fields.get("ja4_quic", ""))
if event_type == "http_request_fingerprint": if event_type == "http_request_fingerprint":
return fields.get("src_ip", "") + "|" + fields.get("ja4h", "") return cast(str, fields.get("src_ip", "") + "|" + fields.get("ja4h", ""))
if event_type in ("http2_settings", "http3_settings"): if event_type in ("http2_settings", "http3_settings"):
return fields.get("src_ip", "") + "|" + str(fields.get("settings_hash", "")) return cast(str, fields.get("src_ip", "") + "|" + str(fields.get("settings_hash", "")))
return fields.get("mechanisms", fields.get("resumption", "")) return cast(str, fields.get("mechanisms", fields.get("resumption", "")))
def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool: def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool:
if self._dedup_ttl <= 0: if self._dedup_ttl <= 0:

View File

@@ -23,7 +23,7 @@ import pathlib
import socket import socket
import ssl import ssl
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional, cast
import httpx import httpx
@@ -229,12 +229,12 @@ class AgentClient:
async def health(self) -> dict[str, Any]: async def health(self) -> dict[str, Any]:
resp = await self._require_client().get("/health") resp = await self._require_client().get("/health")
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def status(self) -> dict[str, Any]: async def status(self) -> dict[str, Any]:
resp = await self._require_client().get("/status") resp = await self._require_client().get("/status")
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def deploy( async def deploy(
self, self,
@@ -254,7 +254,7 @@ class AgentClient:
# need for the long deploy timeout here. # need for the long deploy timeout here.
resp = await self._require_client().post("/deploy", json=body) resp = await self._require_client().post("/deploy", json=body)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def mutate( async def mutate(
self, self,
@@ -271,20 +271,20 @@ class AgentClient:
# Worker /mutate is async (202): control-timeout is right. # Worker /mutate is async (202): control-timeout is right.
resp = await self._require_client().post("/mutate", json=body) resp = await self._require_client().post("/mutate", json=body)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]: async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]:
resp = await self._require_client().post( resp = await self._require_client().post(
"/teardown", json={"decky_id": decky_id} "/teardown", json={"decky_id": decky_id}
) )
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def self_destruct(self) -> dict[str, Any]: async def self_destruct(self) -> dict[str, Any]:
"""Trigger the worker to stop services and wipe its install.""" """Trigger the worker to stop services and wipe its install."""
resp = await self._require_client().post("/self-destruct") resp = await self._require_client().post("/self-destruct")
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
# ------------------------------------------------------------ topology # ------------------------------------------------------------ topology
@@ -309,7 +309,7 @@ class AgentClient:
finally: finally:
self._require_client().timeout = old self._require_client().timeout = old
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def teardown_topology(self, topology_id: str) -> dict[str, Any]: async def teardown_topology(self, topology_id: str) -> dict[str, Any]:
"""Ask the agent to dismantle the named topology.""" """Ask the agent to dismantle the named topology."""
@@ -323,13 +323,13 @@ class AgentClient:
finally: finally:
self._require_client().timeout = old self._require_client().timeout = old
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
async def get_topology_state(self) -> dict[str, Any]: async def get_topology_state(self) -> dict[str, Any]:
"""Snapshot of the agent's applied topology + live docker state.""" """Snapshot of the agent's applied topology + live docker state."""
resp = await self._require_client().get("/topology/state") resp = await self._require_client().get("/topology/state")
resp.raise_for_status() resp.raise_for_status()
return resp.json() return cast(dict[str, Any], resp.json())
# -------------------------------------------------------------- diagnostics # -------------------------------------------------------------- diagnostics

View File

@@ -17,7 +17,7 @@ import asyncio
import hashlib import hashlib
import socket import socket
import ssl import ssl
from typing import Any, Optional from typing import Any, Optional, cast
import httpx import httpx
@@ -159,12 +159,12 @@ class UpdaterClient:
async def health(self) -> dict[str, Any]: async def health(self) -> dict[str, Any]:
r = await self._require().get("/health") r = await self._require().get("/health")
r.raise_for_status() r.raise_for_status()
return r.json() return cast(dict[str, Any], r.json())
async def releases(self) -> dict[str, Any]: async def releases(self) -> dict[str, Any]:
r = await self._require().get("/releases") r = await self._require().get("/releases")
r.raise_for_status() r.raise_for_status()
return r.json() return cast(dict[str, Any], r.json())
async def update(self, tarball: bytes, sha: str = "") -> httpx.Response: async def update(self, tarball: bytes, sha: str = "") -> httpx.Response:
"""POST /update. Returns the Response so the caller can distinguish """POST /update. Returns the Response so the caller can distinguish

View File

@@ -147,7 +147,7 @@ _MONGO_SET_NAME = os.environ.get("MONGO_REPL_SET", "") # empty = standalone
def _new_objectid() -> bytes: def _new_objectid() -> bytes:
"""12-byte BSON ObjectId — fresh per call.""" """12-byte BSON ObjectId — fresh per call."""
return _seed.fresh_bytes(12) return cast(bytes, _seed.fresh_bytes(12))
# Minimal BSON helpers # Minimal BSON helpers
def _bson_str(key: str, val: str) -> bytes: def _bson_str(key: str, val: str) -> bytes:

View File

@@ -205,7 +205,7 @@ def _seed_dict_to_rfc822(entry: dict) -> str | None:
date = str(entry.get("date") or "") date = str(entry.get("date") or "")
body = entry["body"] body = entry["body"]
if "\r\n\r\n" in body or "\n\n" in body: if "\r\n\r\n" in body or "\n\n" in body:
return body # already a full RFC 822 message return cast(str, body) # already a full RFC 822 message
return ( return (
f"Date: {date}\r\n" f"Date: {date}\r\n"
f"From: {from_name} <{from_addr}>\r\n" f"From: {from_name} <{from_addr}>\r\n"

View File

@@ -22,6 +22,7 @@ from __future__ import annotations
import asyncio import asyncio
import os import os
import struct import struct
from typing import cast
import instance_seed import instance_seed
from ntlmssp import find_ntlmssp, parse_type3 from ntlmssp import find_ntlmssp, parse_type3
@@ -184,7 +185,7 @@ def _negotiate_response(message_id: int) -> bytes:
+ struct.pack("<H", 0) # SecurityBufferLength + struct.pack("<H", 0) # SecurityBufferLength
+ struct.pack("<I", 0) # Reserved2 + struct.pack("<I", 0) # Reserved2
) )
return _smb2_header(SMB2_NEGOTIATE, STATUS_SUCCESS, message_id) + body return cast(bytes, _smb2_header(SMB2_NEGOTIATE, STATUS_SUCCESS, message_id) + body)
def _session_setup_response(message_id: int, session_id: int, sec_blob: bytes, status: int) -> bytes: def _session_setup_response(message_id: int, session_id: int, sec_blob: bytes, status: int) -> bytes:

View File

@@ -28,7 +28,7 @@ import hashlib
import os import os
import struct import struct
import time import time
from typing import Any from typing import Any, cast
from scapy.layers.inet import IP, TCP from scapy.layers.inet import IP, TCP
from scapy.sendrecv import sniff from scapy.sendrecv import sniff
@@ -841,14 +841,14 @@ _dedup_last_cleanup: float = 0.0
def _dedup_key_for(event_type: str, fields: dict[str, Any]) -> str: def _dedup_key_for(event_type: str, fields: dict[str, Any]) -> str:
"""Build a dedup fingerprint from the most significant fields.""" """Build a dedup fingerprint from the most significant fields."""
if event_type == "tls_client_hello": if event_type == "tls_client_hello":
return fields.get("ja3", "") + "|" + fields.get("ja4", "") return cast(str, fields.get("ja3", "") + "|" + fields.get("ja4", ""))
if event_type == "tls_session": if event_type == "tls_session":
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") + return cast(str, fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", "")) "|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
if event_type == "tls_certificate": if event_type == "tls_certificate":
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "") return cast(str, fields.get("subject_cn", "") + "|" + fields.get("issuer", ""))
# tls_resumption or unknown — dedup on mechanisms # tls_resumption or unknown — dedup on mechanisms
return fields.get("mechanisms", fields.get("resumption", "")) return cast(str, fields.get("mechanisms", fields.get("resumption", "")))
def _is_duplicate(event_type: str, fields: dict[str, Any]) -> bool: def _is_duplicate(event_type: str, fields: dict[str, Any]) -> bool:

View File

@@ -39,7 +39,7 @@ from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Final from typing import TYPE_CHECKING, Final, cast
if TYPE_CHECKING: if TYPE_CHECKING:
from mitreattack.stix20 import MitreAttackData from mitreattack.stix20 import MitreAttackData
@@ -310,7 +310,7 @@ def subtechnique_parent_name(technique_id: str) -> str | None:
) )
if not parents: if not parents:
return None return None
return parents[0]["object"].name return cast(str, parents[0]["object"].name)
def is_subtechnique(technique_id: str) -> bool: def is_subtechnique(technique_id: str) -> bool:
@@ -335,7 +335,7 @@ def tactic_id_for_short_name(short_name: str) -> str | None:
return None return None
for ref in obj.get("external_references", []): for ref in obj.get("external_references", []):
if ref.get("source_name") == "mitre-attack": if ref.get("source_name") == "mitre-attack":
return ref.get("external_id") return cast(str | None, ref.get("external_id"))
return None return None

View File

@@ -16,7 +16,7 @@ build_fleet_misp_collection → dict ({"response": [event, ...]})
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any from typing import Any, cast
from misp_stix_converter import ExternalSTIX2toMISPParser from misp_stix_converter import ExternalSTIX2toMISPParser
@@ -35,7 +35,7 @@ def _parse_bundle(bundle: Any) -> dict[str, Any]:
event = parser.misp_events event = parser.misp_events
if event is None: if event is None:
return {} return {}
return json.loads(event.to_json()) return cast(dict[str, Any], json.loads(event.to_json()))
def build_attacker_misp_event( def build_attacker_misp_event(

View File

@@ -8,7 +8,7 @@ import os
import traceback import traceback
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Optional from typing import Any, AsyncGenerator, Optional, cast
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@@ -306,7 +306,7 @@ class _ContentTypeMiddleware(BaseHTTPMiddleware):
status_code=415, status_code=415,
media_type="text/plain", media_type="text/plain",
) )
return await call_next(request) return cast(StarletteResponse, await call_next(request))
app.add_middleware(_ContentTypeMiddleware) app.add_middleware(_ContentTypeMiddleware)

View File

@@ -6,7 +6,7 @@ Repository factory — selects a :class:`BaseRepository` implementation based on
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any from typing import Any, cast
from decnet.web.db.repository import BaseRepository from decnet.web.db.repository import BaseRepository
@@ -32,4 +32,4 @@ def get_repository(**kwargs: Any) -> BaseRepository:
raise ValueError(f"Unsupported database type: {db_type}") raise ValueError(f"Unsupported database type: {db_type}")
from decnet.telemetry import wrap_repository from decnet.telemetry import wrap_repository
return wrap_repository(repo) return cast(BaseRepository, wrap_repository(repo))

View File

@@ -18,7 +18,7 @@ import os
import orjson import orjson
import uuid import uuid
from typing import Any, Optional, List from typing import Any, Optional, List, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
@@ -183,7 +183,7 @@ class SQLModelRepository(
result = await session.execute(statement) result = await session.execute(statement)
state = result.scalar_one_or_none() state = result.scalar_one_or_none()
if state: if state:
return json.loads(state.value) return cast(dict[str, Any], json.loads(state.value))
return None return None
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401

View File

@@ -10,7 +10,7 @@ from __future__ import annotations
import uuid as _uuid import uuid as _uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import desc, or_, select from sqlalchemy import desc, or_, select
from sqlmodel import col from sqlmodel import col
@@ -44,7 +44,7 @@ class AttackerIntelMixin(_MixinBase):
row_uuid = _uuid.uuid4().hex row_uuid = _uuid.uuid4().hex
session.add(AttackerIntel(uuid=row_uuid, **data)) session.add(AttackerIntel(uuid=row_uuid, **data))
await session.commit() await session.commit()
return row_uuid return cast(str, row_uuid)
async def get_attacker_intel_row_by_uuid( async def get_attacker_intel_row_by_uuid(
self, self,
@@ -54,7 +54,7 @@ class AttackerIntelMixin(_MixinBase):
result = await session.execute( result = await session.execute(
select(AttackerIntel).where(AttackerIntel.attacker_uuid == uuid) select(AttackerIntel).where(AttackerIntel.attacker_uuid == uuid)
) )
return result.scalar_one_or_none() return cast(Optional[AttackerIntel], result.scalar_one_or_none())
async def get_attacker_intel_by_uuid( async def get_attacker_intel_by_uuid(
self, self,
@@ -67,7 +67,7 @@ class AttackerIntelMixin(_MixinBase):
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if not row: if not row:
return None return None
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))
async def get_unenriched_attackers( async def get_unenriched_attackers(
self, limit: int = 100, self, limit: int = 100,

View File

@@ -10,7 +10,7 @@ from __future__ import annotations
import json import json
import uuid as _uuid import uuid as _uuid
from typing import Any, List, Optional from typing import Any, List, Optional, cast
from sqlalchemy import desc, func, outerjoin, select from sqlalchemy import desc, func, outerjoin, select
from sqlmodel import col from sqlmodel import col
@@ -47,7 +47,7 @@ class AttackersCoreMixin(_MixinBase):
data = {**data, "uuid": row_uuid} data = {**data, "uuid": row_uuid}
session.add(Attacker(**data)) session.add(Attacker(**data))
await session.commit() await session.commit()
return row_uuid return cast(str, row_uuid)
async def get_attacker_uuid_by_ip(self, ip: str) -> Optional[str]: async def get_attacker_uuid_by_ip(self, ip: str) -> Optional[str]:
"""Return the :class:`Attacker` UUID for *ip*, or ``None``. """Return the :class:`Attacker` UUID for *ip*, or ``None``.
@@ -61,7 +61,7 @@ class AttackersCoreMixin(_MixinBase):
result = await session.execute( result = await session.execute(
select(col(Attacker.uuid)).where(Attacker.ip == ip) select(col(Attacker.uuid)).where(Attacker.ip == ip)
) )
return result.scalar_one_or_none() return cast(Optional[str], result.scalar_one_or_none())
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
async with self._session() as session: async with self._session() as session:

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import uuid as _uuid import uuid as _uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlmodel import col from sqlmodel import col
@@ -66,7 +66,7 @@ class AttributionMixin(_MixinBase):
if attacker_row is None: if attacker_row is None:
return None return None
if attacker_row.identity_id: if attacker_row.identity_id:
return attacker_row.identity_id return cast(str, attacker_row.identity_id)
new_uuid = _uuid.uuid4().hex new_uuid = _uuid.uuid4().hex
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
session.add( session.add(

View File

@@ -9,7 +9,7 @@ the reads.
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update from sqlalchemy import desc, func, select, update
from sqlmodel import col from sqlmodel import col
@@ -36,9 +36,9 @@ class CampaignsMixin(_MixinBase):
if campaign is None: if campaign is None:
return None return None
if campaign.merged_into_uuid is None: if campaign.merged_into_uuid is None:
return campaign.model_dump(mode="json") return cast(dict[str, Any], campaign.model_dump(mode="json"))
current_uuid = campaign.merged_into_uuid current_uuid = campaign.merged_into_uuid
return campaign.model_dump(mode="json") return cast(dict[str, Any], campaign.model_dump(mode="json"))
async def list_campaigns( async def list_campaigns(
self, limit: int = 50, offset: int = 0, self, limit: int = 50, offset: int = 0,

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update from sqlalchemy import desc, func, select, update
@@ -26,7 +26,7 @@ class CanaryMixin(_MixinBase):
) )
row = existing.scalar_one_or_none() row = existing.scalar_one_or_none()
if row: if row:
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))
row = CanaryBlob(**data) row = CanaryBlob(**data)
session.add(row) session.add(row)
await session.commit() await session.commit()
@@ -155,7 +155,7 @@ class CanaryMixin(_MixinBase):
.values(state=state, last_error=last_error) .values(state=state, last_error=last_error)
) )
await session.commit() await session.commit()
return result.rowcount > 0 return cast(bool, result.rowcount > 0)
async def record_canary_trigger(self, data: dict[str, Any]) -> str: async def record_canary_trigger(self, data: dict[str, Any]) -> str:
# Persist the trigger row + bump the token's counters in the same # Persist the trigger row + bump the token's counters in the same
@@ -204,4 +204,4 @@ class CanaryMixin(_MixinBase):
.values(attacker_id=attacker_id) .values(attacker_id=attacker_id)
) )
await session.commit() await session.commit()
return result.rowcount > 0 return cast(bool, result.rowcount > 0)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, List, Optional from typing import Any, List, Optional, cast
from sqlalchemy import desc, func, or_, select, update from sqlalchemy import desc, func, or_, select, update
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@@ -67,7 +67,7 @@ class CredentialsCoreMixin(_MixinBase):
existing.outcome = payload["outcome"] existing.outcome = payload["outcome"]
session.add(existing) session.add(existing)
await session.commit() await session.commit()
return existing.id return cast(int, existing.id)
row = Credential( row = Credential(
attacker_ip=payload["attacker_ip"], attacker_ip=payload["attacker_ip"],
decky_name=payload["decky_name"], decky_name=payload["decky_name"],
@@ -103,7 +103,7 @@ class CredentialsCoreMixin(_MixinBase):
existing2.outcome = payload["outcome"] existing2.outcome = payload["outcome"]
session2.add(existing2) session2.add(existing2)
await session2.commit() await session2.commit()
return existing2.id return cast(int, existing2.id)
await session.refresh(row) await session.refresh(row)
return row.id # type: ignore[return-value] return row.id # type: ignore[return-value]

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
import json import json
import uuid as _uuid import uuid as _uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, List, Optional from typing import Any, List, Optional, cast
from sqlalchemy import desc, func, select from sqlalchemy import desc, func, select
from sqlmodel import col from sqlmodel import col
@@ -137,7 +137,7 @@ class CredentialReuseMixin(_MixinBase):
d = existing.model_dump(mode="json") d = existing.model_dump(mode="json")
d["inserted"] = False d["inserted"] = False
d["changed"] = changed d["changed"] = changed
return d return cast(dict[str, Any], d)
async def find_credential_reuse_candidates( async def find_credential_reuse_candidates(
self, min_targets: int = 2 self, min_targets: int = 2
@@ -236,7 +236,7 @@ class CredentialReuseMixin(_MixinBase):
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
d[key] = [] d[key] = []
await self._enrich_with_secret(session, [d]) await self._enrich_with_secret(session, [d])
return d return cast(dict[str, Any], d)
@staticmethod @staticmethod
async def _enrich_with_secret( async def _enrich_with_secret(

View File

@@ -9,7 +9,7 @@ the reads.
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update from sqlalchemy import desc, func, select, update
from sqlmodel import col from sqlmodel import col
@@ -42,10 +42,10 @@ class IdentitiesMixin(_MixinBase):
if identity is None: if identity is None:
return None return None
if identity.merged_into_uuid is None: if identity.merged_into_uuid is None:
return identity.model_dump(mode="json") return cast(dict[str, Any], identity.model_dump(mode="json"))
current_uuid = identity.merged_into_uuid current_uuid = identity.merged_into_uuid
# Hit the hop cap — surface what we have rather than recurse. # Hit the hop cap — surface what we have rather than recurse.
return identity.model_dump(mode="json") return cast(dict[str, Any], identity.model_dump(mode="json"))
async def list_identities( async def list_identities(
self, limit: int = 50, offset: int = 0, self, limit: int = 50, offset: int = 0,

View File

@@ -21,7 +21,7 @@ not validate values — that happens at construction time by the BEHAVE
from __future__ import annotations from __future__ import annotations
import uuid as _uuid import uuid as _uuid
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import desc, func, select from sqlalchemy import desc, func, select
from sqlmodel import col from sqlmodel import col
@@ -85,7 +85,7 @@ class ObservationsMixin(_MixinBase):
row_data = {**data, "id": row_id} row_data = {**data, "id": row_id}
session.add(ObservationRow(**row_data)) session.add(ObservationRow(**row_data))
await session.commit() await session.commit()
return row_id return cast(str, row_id)
async def latest_observation_per_primitive( async def latest_observation_per_primitive(
self, attacker_uuid: str, self, attacker_uuid: str,
@@ -178,7 +178,7 @@ class ObservationsMixin(_MixinBase):
row = (await session.execute(stmt)).scalar_one_or_none() row = (await session.execute(stmt)).scalar_one_or_none()
if not row: if not row:
return None return None
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))
async def observations_for_identity_primitive( async def observations_for_identity_primitive(
self, identity_uuid: str, primitive: str, self, identity_uuid: str, primitive: str,

View File

@@ -13,7 +13,7 @@ caller is upgrading ``False/None`` to ``True``.
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Optional, cast
from sqlalchemy import select from sqlalchemy import select
@@ -106,4 +106,4 @@ class ObservedAttachmentsMixin(_MixinBase):
row.mal_hash_match_at = now row.mal_hash_match_at = now
session.add(row) session.add(row)
await session.commit() await session.commit()
return row.uuid return cast(str, row.uuid)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update from sqlalchemy import desc, func, select, update
@@ -92,7 +92,7 @@ class RealismMixin(_MixinBase):
row = result.scalars().first() row = result.scalars().first()
if row is None: if row is None:
return None return None
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))
async def get_realism_config( async def get_realism_config(
self, key: str, self, key: str,
@@ -103,7 +103,7 @@ class RealismMixin(_MixinBase):
row = result.scalars().first() row = result.scalars().first()
if row is None: if row is None:
return None return None
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))
async def set_realism_config( async def set_realism_config(
self, key: str, value: str, self, key: str, value: str,
@@ -157,4 +157,4 @@ class RealismMixin(_MixinBase):
row = result.scalars().first() row = result.scalars().first()
if row is None: if row is None:
return None return None
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json import json
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import select from sqlalchemy import select
@@ -47,7 +47,7 @@ class TarpitMixin(_MixinBase):
return None return None
d = row.model_dump(mode="json") d = row.model_dump(mode="json")
d["ports"] = json.loads(d["ports"]) d["ports"] = json.loads(d["ports"])
return d return cast(dict[str, Any], d)
async def delete_tarpit_rule(self, decky_name: str) -> bool: async def delete_tarpit_rule(self, decky_name: str) -> bool:
async with self._session() as session: async with self._session() as session:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
import orjson import orjson
from sqlalchemy import asc, desc, select, text from sqlalchemy import asc, desc, select, text
@@ -102,7 +102,7 @@ class TopologyMutationsMixin(_MixinBase):
_ = now _ = now
if row is None: if row is None:
return None return None
return row.model_dump(mode="json") return cast(dict[str, Any], row.model_dump(mode="json"))
async def mark_mutation_applied(self, mutation_id: str) -> None: async def mark_mutation_applied(self, mutation_id: str) -> None:
async with self._session() as session: async with self._session() as session:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional, cast
from sqlalchemy import select, update from sqlalchemy import select, update
from sqlmodel import col from sqlmodel import col
@@ -65,7 +65,7 @@ class WebhooksMixin(_MixinBase):
.values(**patch) .values(**patch)
) )
await session.commit() await session.commit()
return result.rowcount > 0 return cast(bool, result.rowcount > 0)
async def delete_webhook_subscription(self, uuid: str) -> bool: async def delete_webhook_subscription(self, uuid: str) -> bool:
async with self._session() as session: async with self._session() as session:

View File

@@ -190,6 +190,7 @@ ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
warn_redundant_casts = true warn_redundant_casts = true
warn_unused_ignores = true warn_unused_ignores = true
warn_return_any = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
# The pydantic plugin types SQLModel class-level field descriptors as their # The pydantic plugin types SQLModel class-level field descriptors as their