chore(types): enable warn_return_any and cast all no-any-return sites
Turn on mypy warn_return_any (pyproject) and resolve the 84 resulting [no-any-return] errors across 43 files with typing.cast() at the return sites — runtime no-ops that make the declared return type explicit where a dependency (SQLAlchemy scalar/first/one, httpx .json(), subprocess, docker SDK) hands back Any. No behavior change: no DTO/table field types altered, no validation/coercion calls added, every cast reflects the true runtime type. Locks in return-type strictness so the class of bug where a function silently widens to Any can't regress. mypy decnet/ clean; adversarially verified behavior-preserving (84 casts 1:1 with prior returns). Bump tornado 6.5.5 -> 6.5.7 (CVE-2026-49854, transitive via snakeviz).
This commit is contained in:
@@ -22,7 +22,7 @@ import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator
|
||||
from typing import Any, AsyncIterator, cast
|
||||
|
||||
EVENT_SCHEMA_VERSION = 1
|
||||
|
||||
@@ -203,4 +203,4 @@ async def _next_or_stop(queue: "asyncio.Queue[Any]") -> Event:
|
||||
item = await queue.get()
|
||||
if item is _CLOSE_SENTINEL:
|
||||
raise StopAsyncIteration
|
||||
return item
|
||||
return cast(Event, item)
|
||||
|
||||
@@ -17,7 +17,7 @@ env-driven dispatch, optional telemetry wrapping). Callers MUST use
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from decnet.bus.base import BaseBus
|
||||
|
||||
@@ -81,6 +81,6 @@ def _maybe_wrap_telemetry(bus: BaseBus) -> BaseBus:
|
||||
except ImportError:
|
||||
return bus
|
||||
try:
|
||||
return wrap_repository(bus)
|
||||
return cast(BaseBus, wrap_repository(bus))
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return bus
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from decnet.bus.base import (
|
||||
BaseBus,
|
||||
@@ -51,7 +51,7 @@ class _FakeSubscription(Subscription):
|
||||
item = await self._queue.get()
|
||||
if item is _CLOSE_SENTINEL:
|
||||
raise StopAsyncIteration
|
||||
return item
|
||||
return cast(Event, item)
|
||||
|
||||
async def _aclose(self) -> None:
|
||||
self._bus._unregister(self)
|
||||
|
||||
@@ -26,7 +26,7 @@ import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from decnet.bus import protocol
|
||||
from decnet.bus.base import (
|
||||
@@ -61,7 +61,7 @@ class _UnixSubscription(Subscription):
|
||||
item = await self._queue.get()
|
||||
if item is _CLOSE_SENTINEL:
|
||||
raise StopAsyncIteration
|
||||
return item
|
||||
return cast(Event, item)
|
||||
|
||||
async def _aclose(self) -> None:
|
||||
await self._bus._unregister(self)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
@@ -96,9 +96,9 @@ def _list() -> None:
|
||||
"""List all topologies."""
|
||||
_require_master_mode("topology list")
|
||||
|
||||
async def _go() -> list[dict]:
|
||||
async def _go() -> list[dict[Any, Any]]:
|
||||
repo = await _repo()
|
||||
return await repo.list_topologies()
|
||||
return cast(list[dict[Any, Any]], await repo.list_topologies())
|
||||
|
||||
rows = asyncio.run(_go())
|
||||
if not rows:
|
||||
@@ -140,7 +140,7 @@ def _show(topology_id: str = typer.Argument(..., help="Topology id")) -> None:
|
||||
|
||||
def _decky_name(d: dict) -> str:
|
||||
cfg = d.get("decky_config") or {}
|
||||
return cfg.get("name") or d.get("name") or d["uuid"]
|
||||
return cast(str, cfg.get("name") or d.get("name") or d["uuid"])
|
||||
|
||||
deckies_by_name = {_decky_name(d): d for d in hydrated["deckies"]}
|
||||
edges_by_lan: dict[str, list[dict]] = {}
|
||||
@@ -296,9 +296,9 @@ def _mutate(
|
||||
|
||||
async def _go() -> str:
|
||||
repo = await _repo()
|
||||
return await repo.enqueue_topology_mutation(
|
||||
return cast(str, await repo.enqueue_topology_mutation(
|
||||
topology_id, op, payload, expected_version=expected_version,
|
||||
)
|
||||
))
|
||||
|
||||
mid = asyncio.run(_go())
|
||||
_console.print(
|
||||
@@ -319,9 +319,9 @@ def _mutations(
|
||||
"""List queued/applied mutations for a topology."""
|
||||
_require_master_mode("topology mutations")
|
||||
|
||||
async def _go() -> list[dict]:
|
||||
async def _go() -> list[dict[Any, Any]]:
|
||||
repo = await _repo()
|
||||
return await repo.list_topology_mutations(topology_id, state=state)
|
||||
return cast(list[dict[Any, Any]], await repo.list_topology_mutations(topology_id, state=state))
|
||||
|
||||
rows = asyncio.run(_go())
|
||||
if not rows:
|
||||
|
||||
@@ -12,7 +12,7 @@ import signal
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
@@ -91,7 +91,7 @@ def _is_running(match_fn) -> int | None:
|
||||
try:
|
||||
cmd = proc.info["cmdline"]
|
||||
if cmd and match_fn(cmd):
|
||||
return proc.info["pid"]
|
||||
return cast(int, proc.info["pid"])
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
return None
|
||||
|
||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
from typing import Any, Sequence, cast
|
||||
|
||||
from decnet.correlation.attribution import _thresholds as _T
|
||||
|
||||
@@ -250,7 +250,7 @@ def _coef_of_variation(values: Sequence[float], mean: float) -> float:
|
||||
stdev = variance ** 0.5
|
||||
if mean == 0:
|
||||
return 0.0 if stdev == 0 else 1e9
|
||||
return stdev / abs(mean)
|
||||
return cast(float, stdev / abs(mean))
|
||||
|
||||
|
||||
def _safe_float(value: Any) -> float:
|
||||
|
||||
@@ -9,6 +9,7 @@ import shutil
|
||||
import subprocess # nosec B404
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
import docker
|
||||
@@ -853,7 +854,7 @@ async def _resolve_swarm_host(repo, host_uuid: str) -> dict:
|
||||
raise ValueError(
|
||||
f"topology pinned to unknown swarm host {host_uuid!r}"
|
||||
)
|
||||
return host
|
||||
return cast(dict, host)
|
||||
|
||||
|
||||
async def _deploy_on_agent(repo, topology_id: str, hydrated: dict) -> None:
|
||||
|
||||
@@ -24,7 +24,7 @@ from __future__ import annotations
|
||||
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import anyio
|
||||
|
||||
@@ -217,7 +217,7 @@ async def _topology_decky(
|
||||
cfg = d.get("decky_config") or {}
|
||||
name = cfg.get("name") or d.get("name")
|
||||
if name == decky_name:
|
||||
return d
|
||||
return cast(dict[str, Any], d)
|
||||
raise ServiceNotFoundError(
|
||||
f"decky {decky_name!r} is not in topology {topology_id!r}"
|
||||
)
|
||||
@@ -343,7 +343,7 @@ def _fleet_state_or_raise() -> tuple[Any, Path]:
|
||||
raise ServiceMutationError(
|
||||
"no fleet state on disk — run `decnet up` first"
|
||||
)
|
||||
return state
|
||||
return cast(tuple[Any, Path], state)
|
||||
|
||||
|
||||
def _fleet_find_decky(config: Any, decky_name: str) -> Any:
|
||||
|
||||
@@ -25,7 +25,7 @@ Design notes
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Any, Awaitable, Callable, Optional, cast
|
||||
|
||||
from decnet.logging import get_logger
|
||||
from decnet.topology.allocator import IPAllocator, reserved_subnets, SubnetAllocator
|
||||
@@ -301,7 +301,7 @@ async def _live_topology_or_none(
|
||||
topology_id,
|
||||
)
|
||||
return None
|
||||
return topology
|
||||
return cast(dict[str, Any], topology)
|
||||
|
||||
|
||||
async def _rerender_compose(repo: Any, topology_id: str) -> None:
|
||||
|
||||
@@ -12,6 +12,7 @@ Handles:
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network
|
||||
from typing import cast
|
||||
|
||||
import docker
|
||||
|
||||
@@ -38,7 +39,7 @@ def detect_interface() -> str:
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
if "dev" in parts:
|
||||
return parts[parts.index("dev") + 1]
|
||||
return cast(str, parts[parts.index("dev") + 1])
|
||||
raise RuntimeError("Could not auto-detect network interface. Use --interface.")
|
||||
|
||||
|
||||
@@ -79,7 +80,7 @@ def get_host_ip(interface: str) -> str:
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("inet ") and not line.startswith("inet6"):
|
||||
return line.split()[1].split("/")[0]
|
||||
return cast(str, line.split()[1].split("/")[0])
|
||||
raise RuntimeError(f"Could not determine host IP for interface {interface}.")
|
||||
|
||||
|
||||
@@ -297,7 +298,7 @@ def create_bridge_network(
|
||||
pools = (net.attrs.get("IPAM") or {}).get("Config") or []
|
||||
cur = pools[0] if pools else {}
|
||||
if net.attrs.get("Driver") == "bridge" and cur.get("Subnet") == subnet:
|
||||
return net.id
|
||||
return cast(str, net.id)
|
||||
for cid in (net.attrs.get("Containers") or {}):
|
||||
try:
|
||||
net.disconnect(cid, force=True)
|
||||
@@ -332,7 +333,7 @@ def create_bridge_network(
|
||||
pool_configs=[docker.types.IPAMPool(subnet=subnet)],
|
||||
),
|
||||
)
|
||||
return net.id
|
||||
return cast(str, net.id)
|
||||
|
||||
|
||||
def remove_bridge_network(client: docker.DockerClient, name: str) -> None:
|
||||
@@ -480,7 +481,7 @@ def get_container_pid(container_name: str) -> int:
|
||||
pid = container.attrs["State"]["Pid"]
|
||||
if not pid:
|
||||
raise LookupError(f"container {container_name!r} is not running (PID=0)")
|
||||
return pid
|
||||
return cast(int, pid)
|
||||
|
||||
|
||||
def get_container_veth(container_name: str) -> str:
|
||||
@@ -507,7 +508,7 @@ def get_container_veth(container_name: str) -> str:
|
||||
if line.startswith(f"{peer_index}:"):
|
||||
# Format: "42: veth3a4b5c@if41: <BROADCAST,...>"
|
||||
iface = line.split(":")[1].strip().split("@")[0]
|
||||
return iface
|
||||
return cast(str, iface)
|
||||
raise LookupError(
|
||||
f"no host veth found for container {container_name!r} (peer ifindex {peer_index})"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ import json
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Sequence
|
||||
from typing import Any, Optional, Sequence, cast
|
||||
|
||||
from decnet.realism import personas_pool
|
||||
from decnet.realism.personas import EmailPersona, parse_personas
|
||||
@@ -256,7 +256,7 @@ def _persona_by_name(
|
||||
for decky in enriched:
|
||||
for persona in decky.get("_realism_personas") or []:
|
||||
if persona.name == name:
|
||||
return persona
|
||||
return cast(EmailPersona, persona)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
@@ -42,7 +42,7 @@ def _iso_utc(dt: Any) -> str:
|
||||
and ``not_valid_before_utc`` (timezone-aware) — prefer the latter
|
||||
when available so we always emit explicit-Z ISO strings.
|
||||
"""
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
return cast(str, dt.strftime("%Y-%m-%dT%H:%M:%SZ"))
|
||||
|
||||
|
||||
def _extract_sans(cert: x509.Certificate) -> list[str]:
|
||||
|
||||
@@ -19,7 +19,7 @@ graph cycle-free and the env contract auditable in one place.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from decnet.realism.llm.base import LLMBackend
|
||||
|
||||
@@ -45,7 +45,7 @@ def get_llm(*, model: str | None = None, **kwargs: Any) -> LLMBackend:
|
||||
from decnet.realism.llm.config import get_cached_backend
|
||||
cached = get_cached_backend()
|
||||
if cached is not None:
|
||||
return cached
|
||||
return cast(LLMBackend, cached)
|
||||
|
||||
backend_key = os.environ.get("DECNET_REALISM_LLM", "ollama").lower()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
import sqlite3
|
||||
import urllib.request
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from decnet.rpki import cache as _cache
|
||||
from decnet.rpki.base import RpkiResult, RpkiStatus, Validator
|
||||
@@ -68,13 +68,13 @@ class RipeStatValidator(Validator):
|
||||
)
|
||||
raw = data.get("data", {}).get("status", "unknown")
|
||||
if raw in ("valid", "invalid", "not-found"):
|
||||
return raw
|
||||
return cast(RpkiStatus, raw)
|
||||
return "unknown"
|
||||
|
||||
def _fetch(self, url: str) -> dict:
|
||||
def _fetch(self, url: str) -> dict[Any, Any]:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": _UA})
|
||||
with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp: # nosec B310 — HTTPS RIPE STAT base URL only; IP/ASN components are validated upstream
|
||||
return json.loads(resp.read())
|
||||
return cast(dict[Any, Any], json.loads(resp.read()))
|
||||
|
||||
def _store(
|
||||
self, ip: str, asn: int, status: str, prefix: Optional[str]
|
||||
|
||||
@@ -14,7 +14,7 @@ import hashlib
|
||||
import struct
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from decnet.logging import get_logger
|
||||
from decnet.prober.tcpfp import _extract_options_order
|
||||
@@ -1058,18 +1058,18 @@ class SnifferEngine:
|
||||
|
||||
def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str:
|
||||
if event_type == "tls_client_hello":
|
||||
return fields.get("ja3", "") + "|" + fields.get("ja4", "")
|
||||
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja4", ""))
|
||||
if event_type == "tls_session":
|
||||
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
|
||||
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
|
||||
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
|
||||
if event_type == "tls_certificate":
|
||||
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "")
|
||||
return cast(str, fields.get("subject_cn", "") + "|" + fields.get("issuer", ""))
|
||||
if event_type == "tcp_syn_fingerprint":
|
||||
# Dedupe per (OS signature, options layout, sequence-pattern
|
||||
# classification). Including ipid_class/isn_class lets each
|
||||
# transition (unknown → random/incremental/zero/constant) emit
|
||||
# exactly one fresh event as samples accumulate.
|
||||
return (
|
||||
return cast(str,
|
||||
fields.get("os_guess", "")
|
||||
+ "|" + fields.get("options_sig", "")
|
||||
+ "|" + fields.get("ipid_class", "")
|
||||
@@ -1080,14 +1080,14 @@ class SnifferEngine:
|
||||
# excluded so a port scanner rotating source ports only produces
|
||||
# one timing event per dedup window. Behavior cadence doesn't
|
||||
# need per-ephemeral-port fidelity.
|
||||
return fields.get("dst_ip", "") + "|" + fields.get("dst_port", "")
|
||||
return cast(str, fields.get("dst_ip", "") + "|" + fields.get("dst_port", ""))
|
||||
if event_type == "quic_client_hello":
|
||||
return fields.get("src_ip", "") + "|" + fields.get("ja4_quic", "")
|
||||
return cast(str, fields.get("src_ip", "") + "|" + fields.get("ja4_quic", ""))
|
||||
if event_type == "http_request_fingerprint":
|
||||
return fields.get("src_ip", "") + "|" + fields.get("ja4h", "")
|
||||
return cast(str, fields.get("src_ip", "") + "|" + fields.get("ja4h", ""))
|
||||
if event_type in ("http2_settings", "http3_settings"):
|
||||
return fields.get("src_ip", "") + "|" + str(fields.get("settings_hash", ""))
|
||||
return fields.get("mechanisms", fields.get("resumption", ""))
|
||||
return cast(str, fields.get("src_ip", "") + "|" + str(fields.get("settings_hash", "")))
|
||||
return cast(str, fields.get("mechanisms", fields.get("resumption", "")))
|
||||
|
||||
def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool:
|
||||
if self._dedup_ttl <= 0:
|
||||
|
||||
@@ -23,7 +23,7 @@ import pathlib
|
||||
import socket
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -229,12 +229,12 @@ class AgentClient:
|
||||
async def health(self) -> dict[str, Any]:
|
||||
resp = await self._require_client().get("/health")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def status(self) -> dict[str, Any]:
|
||||
resp = await self._require_client().get("/status")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def deploy(
|
||||
self,
|
||||
@@ -254,7 +254,7 @@ class AgentClient:
|
||||
# need for the long deploy timeout here.
|
||||
resp = await self._require_client().post("/deploy", json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def mutate(
|
||||
self,
|
||||
@@ -271,20 +271,20 @@ class AgentClient:
|
||||
# Worker /mutate is async (202): control-timeout is right.
|
||||
resp = await self._require_client().post("/mutate", json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]:
|
||||
resp = await self._require_client().post(
|
||||
"/teardown", json={"decky_id": decky_id}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def self_destruct(self) -> dict[str, Any]:
|
||||
"""Trigger the worker to stop services and wipe its install."""
|
||||
resp = await self._require_client().post("/self-destruct")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
# ------------------------------------------------------------ topology
|
||||
|
||||
@@ -309,7 +309,7 @@ class AgentClient:
|
||||
finally:
|
||||
self._require_client().timeout = old
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def teardown_topology(self, topology_id: str) -> dict[str, Any]:
|
||||
"""Ask the agent to dismantle the named topology."""
|
||||
@@ -323,13 +323,13 @@ class AgentClient:
|
||||
finally:
|
||||
self._require_client().timeout = old
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
async def get_topology_state(self) -> dict[str, Any]:
|
||||
"""Snapshot of the agent's applied topology + live docker state."""
|
||||
resp = await self._require_client().get("/topology/state")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
return cast(dict[str, Any], resp.json())
|
||||
|
||||
# -------------------------------------------------------------- diagnostics
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import asyncio
|
||||
import hashlib
|
||||
import socket
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -159,12 +159,12 @@ class UpdaterClient:
|
||||
async def health(self) -> dict[str, Any]:
|
||||
r = await self._require().get("/health")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
return cast(dict[str, Any], r.json())
|
||||
|
||||
async def releases(self) -> dict[str, Any]:
|
||||
r = await self._require().get("/releases")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
return cast(dict[str, Any], r.json())
|
||||
|
||||
async def update(self, tarball: bytes, sha: str = "") -> httpx.Response:
|
||||
"""POST /update. Returns the Response so the caller can distinguish
|
||||
|
||||
@@ -147,7 +147,7 @@ _MONGO_SET_NAME = os.environ.get("MONGO_REPL_SET", "") # empty = standalone
|
||||
|
||||
def _new_objectid() -> bytes:
|
||||
"""12-byte BSON ObjectId — fresh per call."""
|
||||
return _seed.fresh_bytes(12)
|
||||
return cast(bytes, _seed.fresh_bytes(12))
|
||||
|
||||
# Minimal BSON helpers
|
||||
def _bson_str(key: str, val: str) -> bytes:
|
||||
|
||||
@@ -205,7 +205,7 @@ def _seed_dict_to_rfc822(entry: dict) -> str | None:
|
||||
date = str(entry.get("date") or "")
|
||||
body = entry["body"]
|
||||
if "\r\n\r\n" in body or "\n\n" in body:
|
||||
return body # already a full RFC 822 message
|
||||
return cast(str, body) # already a full RFC 822 message
|
||||
return (
|
||||
f"Date: {date}\r\n"
|
||||
f"From: {from_name} <{from_addr}>\r\n"
|
||||
|
||||
@@ -22,6 +22,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
from typing import cast
|
||||
|
||||
import instance_seed
|
||||
from ntlmssp import find_ntlmssp, parse_type3
|
||||
@@ -184,7 +185,7 @@ def _negotiate_response(message_id: int) -> bytes:
|
||||
+ struct.pack("<H", 0) # SecurityBufferLength
|
||||
+ struct.pack("<I", 0) # Reserved2
|
||||
)
|
||||
return _smb2_header(SMB2_NEGOTIATE, STATUS_SUCCESS, message_id) + body
|
||||
return cast(bytes, _smb2_header(SMB2_NEGOTIATE, STATUS_SUCCESS, message_id) + body)
|
||||
|
||||
|
||||
def _session_setup_response(message_id: int, session_id: int, sec_blob: bytes, status: int) -> bytes:
|
||||
|
||||
@@ -28,7 +28,7 @@ import hashlib
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from scapy.layers.inet import IP, TCP
|
||||
from scapy.sendrecv import sniff
|
||||
@@ -841,14 +841,14 @@ _dedup_last_cleanup: float = 0.0
|
||||
def _dedup_key_for(event_type: str, fields: dict[str, Any]) -> str:
|
||||
"""Build a dedup fingerprint from the most significant fields."""
|
||||
if event_type == "tls_client_hello":
|
||||
return fields.get("ja3", "") + "|" + fields.get("ja4", "")
|
||||
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja4", ""))
|
||||
if event_type == "tls_session":
|
||||
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
|
||||
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
|
||||
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
|
||||
if event_type == "tls_certificate":
|
||||
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "")
|
||||
return cast(str, fields.get("subject_cn", "") + "|" + fields.get("issuer", ""))
|
||||
# tls_resumption or unknown — dedup on mechanisms
|
||||
return fields.get("mechanisms", fields.get("resumption", ""))
|
||||
return cast(str, fields.get("mechanisms", fields.get("resumption", "")))
|
||||
|
||||
|
||||
def _is_duplicate(event_type: str, fields: dict[str, Any]) -> bool:
|
||||
|
||||
@@ -39,7 +39,7 @@ from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import TYPE_CHECKING, Final, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mitreattack.stix20 import MitreAttackData
|
||||
@@ -310,7 +310,7 @@ def subtechnique_parent_name(technique_id: str) -> str | None:
|
||||
)
|
||||
if not parents:
|
||||
return None
|
||||
return parents[0]["object"].name
|
||||
return cast(str, parents[0]["object"].name)
|
||||
|
||||
|
||||
def is_subtechnique(technique_id: str) -> bool:
|
||||
@@ -335,7 +335,7 @@ def tactic_id_for_short_name(short_name: str) -> str | None:
|
||||
return None
|
||||
for ref in obj.get("external_references", []):
|
||||
if ref.get("source_name") == "mitre-attack":
|
||||
return ref.get("external_id")
|
||||
return cast(str | None, ref.get("external_id"))
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ build_fleet_misp_collection → dict ({"response": [event, ...]})
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from misp_stix_converter import ExternalSTIX2toMISPParser
|
||||
|
||||
@@ -35,7 +35,7 @@ def _parse_bundle(bundle: Any) -> dict[str, Any]:
|
||||
event = parser.misp_events
|
||||
if event is None:
|
||||
return {}
|
||||
return json.loads(event.to_json())
|
||||
return cast(dict[str, Any], json.loads(event.to_json()))
|
||||
|
||||
|
||||
def build_attacker_misp_event(
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
from typing import Any, AsyncGenerator, Optional, cast
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@@ -306,7 +306,7 @@ class _ContentTypeMiddleware(BaseHTTPMiddleware):
|
||||
status_code=415,
|
||||
media_type="text/plain",
|
||||
)
|
||||
return await call_next(request)
|
||||
return cast(StarletteResponse, await call_next(request))
|
||||
|
||||
|
||||
app.add_middleware(_ContentTypeMiddleware)
|
||||
|
||||
@@ -6,7 +6,7 @@ Repository factory — selects a :class:`BaseRepository` implementation based on
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from decnet.web.db.repository import BaseRepository
|
||||
|
||||
@@ -32,4 +32,4 @@ def get_repository(**kwargs: Any) -> BaseRepository:
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
|
||||
from decnet.telemetry import wrap_repository
|
||||
return wrap_repository(repo)
|
||||
return cast(BaseRepository, wrap_repository(repo))
|
||||
|
||||
@@ -18,7 +18,7 @@ import os
|
||||
|
||||
import orjson
|
||||
import uuid
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
@@ -183,7 +183,7 @@ class SQLModelRepository(
|
||||
result = await session.execute(statement)
|
||||
state = result.scalar_one_or_none()
|
||||
if state:
|
||||
return json.loads(state.value)
|
||||
return cast(dict[str, Any], json.loads(state.value))
|
||||
return None
|
||||
|
||||
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
||||
|
||||
@@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, or_, select
|
||||
from sqlmodel import col
|
||||
@@ -44,7 +44,7 @@ class AttackerIntelMixin(_MixinBase):
|
||||
row_uuid = _uuid.uuid4().hex
|
||||
session.add(AttackerIntel(uuid=row_uuid, **data))
|
||||
await session.commit()
|
||||
return row_uuid
|
||||
return cast(str, row_uuid)
|
||||
|
||||
async def get_attacker_intel_row_by_uuid(
|
||||
self,
|
||||
@@ -54,7 +54,7 @@ class AttackerIntelMixin(_MixinBase):
|
||||
result = await session.execute(
|
||||
select(AttackerIntel).where(AttackerIntel.attacker_uuid == uuid)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return cast(Optional[AttackerIntel], result.scalar_one_or_none())
|
||||
|
||||
async def get_attacker_intel_by_uuid(
|
||||
self,
|
||||
@@ -67,7 +67,7 @@ class AttackerIntelMixin(_MixinBase):
|
||||
row = result.scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
|
||||
async def get_unenriched_attackers(
|
||||
self, limit: int = 100,
|
||||
|
||||
@@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid as _uuid
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, outerjoin, select
|
||||
from sqlmodel import col
|
||||
@@ -47,7 +47,7 @@ class AttackersCoreMixin(_MixinBase):
|
||||
data = {**data, "uuid": row_uuid}
|
||||
session.add(Attacker(**data))
|
||||
await session.commit()
|
||||
return row_uuid
|
||||
return cast(str, row_uuid)
|
||||
|
||||
async def get_attacker_uuid_by_ip(self, ip: str) -> Optional[str]:
|
||||
"""Return the :class:`Attacker` UUID for *ip*, or ``None``.
|
||||
@@ -61,7 +61,7 @@ class AttackersCoreMixin(_MixinBase):
|
||||
result = await session.execute(
|
||||
select(col(Attacker.uuid)).where(Attacker.ip == ip)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return cast(Optional[str], result.scalar_one_or_none())
|
||||
|
||||
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||
async with self._session() as session:
|
||||
|
||||
@@ -21,7 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlmodel import col
|
||||
@@ -66,7 +66,7 @@ class AttributionMixin(_MixinBase):
|
||||
if attacker_row is None:
|
||||
return None
|
||||
if attacker_row.identity_id:
|
||||
return attacker_row.identity_id
|
||||
return cast(str, attacker_row.identity_id)
|
||||
new_uuid = _uuid.uuid4().hex
|
||||
now = datetime.now(timezone.utc)
|
||||
session.add(
|
||||
|
||||
@@ -9,7 +9,7 @@ the reads.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, select, update
|
||||
from sqlmodel import col
|
||||
@@ -36,9 +36,9 @@ class CampaignsMixin(_MixinBase):
|
||||
if campaign is None:
|
||||
return None
|
||||
if campaign.merged_into_uuid is None:
|
||||
return campaign.model_dump(mode="json")
|
||||
return cast(dict[str, Any], campaign.model_dump(mode="json"))
|
||||
current_uuid = campaign.merged_into_uuid
|
||||
return campaign.model_dump(mode="json")
|
||||
return cast(dict[str, Any], campaign.model_dump(mode="json"))
|
||||
|
||||
async def list_campaigns(
|
||||
self, limit: int = 50, offset: int = 0,
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, select, update
|
||||
|
||||
@@ -26,7 +26,7 @@ class CanaryMixin(_MixinBase):
|
||||
)
|
||||
row = existing.scalar_one_or_none()
|
||||
if row:
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
row = CanaryBlob(**data)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
@@ -155,7 +155,7 @@ class CanaryMixin(_MixinBase):
|
||||
.values(state=state, last_error=last_error)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
return cast(bool, result.rowcount > 0)
|
||||
|
||||
async def record_canary_trigger(self, data: dict[str, Any]) -> str:
|
||||
# Persist the trigger row + bump the token's counters in the same
|
||||
@@ -204,4 +204,4 @@ class CanaryMixin(_MixinBase):
|
||||
.values(attacker_id=attacker_id)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
return cast(bool, result.rowcount > 0)
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, or_, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -67,7 +67,7 @@ class CredentialsCoreMixin(_MixinBase):
|
||||
existing.outcome = payload["outcome"]
|
||||
session.add(existing)
|
||||
await session.commit()
|
||||
return existing.id
|
||||
return cast(int, existing.id)
|
||||
row = Credential(
|
||||
attacker_ip=payload["attacker_ip"],
|
||||
decky_name=payload["decky_name"],
|
||||
@@ -103,7 +103,7 @@ class CredentialsCoreMixin(_MixinBase):
|
||||
existing2.outcome = payload["outcome"]
|
||||
session2.add(existing2)
|
||||
await session2.commit()
|
||||
return existing2.id
|
||||
return cast(int, existing2.id)
|
||||
await session.refresh(row)
|
||||
return row.id # type: ignore[return-value]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlmodel import col
|
||||
@@ -137,7 +137,7 @@ class CredentialReuseMixin(_MixinBase):
|
||||
d = existing.model_dump(mode="json")
|
||||
d["inserted"] = False
|
||||
d["changed"] = changed
|
||||
return d
|
||||
return cast(dict[str, Any], d)
|
||||
|
||||
async def find_credential_reuse_candidates(
|
||||
self, min_targets: int = 2
|
||||
@@ -236,7 +236,7 @@ class CredentialReuseMixin(_MixinBase):
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
d[key] = []
|
||||
await self._enrich_with_secret(session, [d])
|
||||
return d
|
||||
return cast(dict[str, Any], d)
|
||||
|
||||
@staticmethod
|
||||
async def _enrich_with_secret(
|
||||
|
||||
@@ -9,7 +9,7 @@ the reads.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, select, update
|
||||
from sqlmodel import col
|
||||
@@ -42,10 +42,10 @@ class IdentitiesMixin(_MixinBase):
|
||||
if identity is None:
|
||||
return None
|
||||
if identity.merged_into_uuid is None:
|
||||
return identity.model_dump(mode="json")
|
||||
return cast(dict[str, Any], identity.model_dump(mode="json"))
|
||||
current_uuid = identity.merged_into_uuid
|
||||
# Hit the hop cap — surface what we have rather than recurse.
|
||||
return identity.model_dump(mode="json")
|
||||
return cast(dict[str, Any], identity.model_dump(mode="json"))
|
||||
|
||||
async def list_identities(
|
||||
self, limit: int = 50, offset: int = 0,
|
||||
|
||||
@@ -21,7 +21,7 @@ not validate values — that happens at construction time by the BEHAVE
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlmodel import col
|
||||
@@ -85,7 +85,7 @@ class ObservationsMixin(_MixinBase):
|
||||
row_data = {**data, "id": row_id}
|
||||
session.add(ObservationRow(**row_data))
|
||||
await session.commit()
|
||||
return row_id
|
||||
return cast(str, row_id)
|
||||
|
||||
async def latest_observation_per_primitive(
|
||||
self, attacker_uuid: str,
|
||||
@@ -178,7 +178,7 @@ class ObservationsMixin(_MixinBase):
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not row:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
|
||||
async def observations_for_identity_primitive(
|
||||
self, identity_uuid: str, primitive: str,
|
||||
|
||||
@@ -13,7 +13,7 @@ caller is upgrading ``False/None`` to ``True``.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -106,4 +106,4 @@ class ObservedAttachmentsMixin(_MixinBase):
|
||||
row.mal_hash_match_at = now
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
return row.uuid
|
||||
return cast(str, row.uuid)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import desc, func, select, update
|
||||
|
||||
@@ -92,7 +92,7 @@ class RealismMixin(_MixinBase):
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
|
||||
async def get_realism_config(
|
||||
self, key: str,
|
||||
@@ -103,7 +103,7 @@ class RealismMixin(_MixinBase):
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
|
||||
async def set_realism_config(
|
||||
self, key: str, value: str,
|
||||
@@ -157,4 +157,4 @@ class RealismMixin(_MixinBase):
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -47,7 +47,7 @@ class TarpitMixin(_MixinBase):
|
||||
return None
|
||||
d = row.model_dump(mode="json")
|
||||
d["ports"] = json.loads(d["ports"])
|
||||
return d
|
||||
return cast(dict[str, Any], d)
|
||||
|
||||
async def delete_tarpit_rule(self, decky_name: str) -> bool:
|
||||
async with self._session() as session:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import asc, desc, select, text
|
||||
@@ -102,7 +102,7 @@ class TopologyMutationsMixin(_MixinBase):
|
||||
_ = now
|
||||
if row is None:
|
||||
return None
|
||||
return row.model_dump(mode="json")
|
||||
return cast(dict[str, Any], row.model_dump(mode="json"))
|
||||
|
||||
async def mark_mutation_applied(self, mutation_id: str) -> None:
|
||||
async with self._session() as session:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlmodel import col
|
||||
@@ -65,7 +65,7 @@ class WebhooksMixin(_MixinBase):
|
||||
.values(**patch)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
return cast(bool, result.rowcount > 0)
|
||||
|
||||
async def delete_webhook_subscription(self, uuid: str) -> bool:
|
||||
async with self._session() as session:
|
||||
|
||||
Reference in New Issue
Block a user