chore(types): enable warn_return_any and cast all no-any-return sites

Turn on mypy warn_return_any (pyproject) and resolve the 84 resulting
[no-any-return] errors across 43 files with typing.cast() at the return
sites — runtime no-ops that make the declared return type explicit where a
dependency (SQLAlchemy scalar/first/one, httpx .json(), subprocess, docker
SDK) hands back Any. No behavior change: no DTO/table field types altered, no
validation/coercion calls added, every cast reflects the true runtime type.

Locks in return-type strictness so the class of bug where a function silently
widens to Any can't regress. mypy decnet/ clean; adversarially verified
behavior-preserving (84 casts 1:1 with prior returns).

Bump tornado 6.5.5 -> 6.5.7 (CVE-2026-49854, transitive via snakeviz).
This commit is contained in:
2026-06-12 18:21:22 -04:00
parent 337520c7ad
commit 721122a7ef
42 changed files with 128 additions and 124 deletions

View File

@@ -22,7 +22,7 @@ import asyncio
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, AsyncIterator
from typing import Any, AsyncIterator, cast
EVENT_SCHEMA_VERSION = 1
@@ -203,4 +203,4 @@ async def _next_or_stop(queue: "asyncio.Queue[Any]") -> Event:
item = await queue.get()
if item is _CLOSE_SENTINEL:
raise StopAsyncIteration
return item
return cast(Event, item)

View File

@@ -17,7 +17,7 @@ env-driven dispatch, optional telemetry wrapping). Callers MUST use
from __future__ import annotations
import os
from typing import Any
from typing import Any, cast
from decnet.bus.base import BaseBus
@@ -81,6 +81,6 @@ def _maybe_wrap_telemetry(bus: BaseBus) -> BaseBus:
except ImportError:
return bus
try:
return wrap_repository(bus)
return cast(BaseBus, wrap_repository(bus))
except Exception: # pragma: no cover - defensive
return bus

View File

@@ -14,7 +14,7 @@
from __future__ import annotations
import asyncio
from typing import Any
from typing import Any, cast
from decnet.bus.base import (
BaseBus,
@@ -51,7 +51,7 @@ class _FakeSubscription(Subscription):
item = await self._queue.get()
if item is _CLOSE_SENTINEL:
raise StopAsyncIteration
return item
return cast(Event, item)
async def _aclose(self) -> None:
self._bus._unregister(self)

View File

@@ -26,7 +26,7 @@ import asyncio
import contextlib
import os
import pathlib
from typing import Any
from typing import Any, cast
from decnet.bus import protocol
from decnet.bus.base import (
@@ -61,7 +61,7 @@ class _UnixSubscription(Subscription):
item = await self._queue.get()
if item is _CLOSE_SENTINEL:
raise StopAsyncIteration
return item
return cast(Event, item)
async def _aclose(self) -> None:
await self._bus._unregister(self)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import asyncio
from typing import Optional
from typing import Any, Optional, cast
import typer
from rich.console import Console
@@ -96,9 +96,9 @@ def _list() -> None:
"""List all topologies."""
_require_master_mode("topology list")
async def _go() -> list[dict]:
async def _go() -> list[dict[Any, Any]]:
repo = await _repo()
return await repo.list_topologies()
return cast(list[dict[Any, Any]], await repo.list_topologies())
rows = asyncio.run(_go())
if not rows:
@@ -140,7 +140,7 @@ def _show(topology_id: str = typer.Argument(..., help="Topology id")) -> None:
def _decky_name(d: dict) -> str:
cfg = d.get("decky_config") or {}
return cfg.get("name") or d.get("name") or d["uuid"]
return cast(str, cfg.get("name") or d.get("name") or d["uuid"])
deckies_by_name = {_decky_name(d): d for d in hydrated["deckies"]}
edges_by_lan: dict[str, list[dict]] = {}
@@ -296,9 +296,9 @@ def _mutate(
async def _go() -> str:
repo = await _repo()
return await repo.enqueue_topology_mutation(
return cast(str, await repo.enqueue_topology_mutation(
topology_id, op, payload, expected_version=expected_version,
)
))
mid = asyncio.run(_go())
_console.print(
@@ -319,9 +319,9 @@ def _mutations(
"""List queued/applied mutations for a topology."""
_require_master_mode("topology mutations")
async def _go() -> list[dict]:
async def _go() -> list[dict[Any, Any]]:
repo = await _repo()
return await repo.list_topology_mutations(topology_id, state=state)
return cast(list[dict[Any, Any]], await repo.list_topology_mutations(topology_id, state=state))
rows = asyncio.run(_go())
if not rows:

View File

@@ -12,7 +12,7 @@ import signal
import subprocess # nosec B404
import sys
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, cast
import typer
from rich.console import Console
@@ -91,7 +91,7 @@ def _is_running(match_fn) -> int | None:
try:
cmd = proc.info["cmdline"]
if cmd and match_fn(cmd):
return proc.info["pid"]
return cast(int, proc.info["pid"])
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
return None

View File

@@ -25,7 +25,7 @@ from __future__ import annotations
from collections import Counter
from dataclasses import dataclass
from typing import Any, Sequence
from typing import Any, Sequence, cast
from decnet.correlation.attribution import _thresholds as _T
@@ -250,7 +250,7 @@ def _coef_of_variation(values: Sequence[float], mean: float) -> float:
stdev = variance ** 0.5
if mean == 0:
return 0.0 if stdev == 0 else 1e9
return stdev / abs(mean)
return cast(float, stdev / abs(mean))
def _safe_float(value: Any) -> float:

View File

@@ -9,6 +9,7 @@ import shutil
import subprocess # nosec B404
import time
from pathlib import Path
from typing import cast
import anyio
import docker
@@ -853,7 +854,7 @@ async def _resolve_swarm_host(repo, host_uuid: str) -> dict:
raise ValueError(
f"topology pinned to unknown swarm host {host_uuid!r}"
)
return host
return cast(dict, host)
async def _deploy_on_agent(repo, topology_id: str, hydrated: dict) -> None:

View File

@@ -24,7 +24,7 @@ from __future__ import annotations
import subprocess # nosec B404
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, cast
import anyio
@@ -217,7 +217,7 @@ async def _topology_decky(
cfg = d.get("decky_config") or {}
name = cfg.get("name") or d.get("name")
if name == decky_name:
return d
return cast(dict[str, Any], d)
raise ServiceNotFoundError(
f"decky {decky_name!r} is not in topology {topology_id!r}"
)
@@ -343,7 +343,7 @@ def _fleet_state_or_raise() -> tuple[Any, Path]:
raise ServiceMutationError(
"no fleet state on disk — run `decnet up` first"
)
return state
return cast(tuple[Any, Path], state)
def _fleet_find_decky(config: Any, decky_name: str) -> Any:

View File

@@ -25,7 +25,7 @@ Design notes
from __future__ import annotations
import json
from typing import Any, Awaitable, Callable, Optional
from typing import Any, Awaitable, Callable, Optional, cast
from decnet.logging import get_logger
from decnet.topology.allocator import IPAllocator, reserved_subnets, SubnetAllocator
@@ -301,7 +301,7 @@ async def _live_topology_or_none(
topology_id,
)
return None
return topology
return cast(dict[str, Any], topology)
async def _rerender_compose(repo: Any, topology_id: str) -> None:

View File

@@ -12,6 +12,7 @@ Handles:
import os
import subprocess # nosec B404
from ipaddress import IPv4Address, IPv4Interface, IPv4Network
from typing import cast
import docker
@@ -38,7 +39,7 @@ def detect_interface() -> str:
for line in result.stdout.splitlines():
parts = line.split()
if "dev" in parts:
return parts[parts.index("dev") + 1]
return cast(str, parts[parts.index("dev") + 1])
raise RuntimeError("Could not auto-detect network interface. Use --interface.")
@@ -79,7 +80,7 @@ def get_host_ip(interface: str) -> str:
for line in result.stdout.splitlines():
line = line.strip()
if line.startswith("inet ") and not line.startswith("inet6"):
return line.split()[1].split("/")[0]
return cast(str, line.split()[1].split("/")[0])
raise RuntimeError(f"Could not determine host IP for interface {interface}.")
@@ -297,7 +298,7 @@ def create_bridge_network(
pools = (net.attrs.get("IPAM") or {}).get("Config") or []
cur = pools[0] if pools else {}
if net.attrs.get("Driver") == "bridge" and cur.get("Subnet") == subnet:
return net.id
return cast(str, net.id)
for cid in (net.attrs.get("Containers") or {}):
try:
net.disconnect(cid, force=True)
@@ -332,7 +333,7 @@ def create_bridge_network(
pool_configs=[docker.types.IPAMPool(subnet=subnet)],
),
)
return net.id
return cast(str, net.id)
def remove_bridge_network(client: docker.DockerClient, name: str) -> None:
@@ -480,7 +481,7 @@ def get_container_pid(container_name: str) -> int:
pid = container.attrs["State"]["Pid"]
if not pid:
raise LookupError(f"container {container_name!r} is not running (PID=0)")
return pid
return cast(int, pid)
def get_container_veth(container_name: str) -> str:
@@ -507,7 +508,7 @@ def get_container_veth(container_name: str) -> str:
if line.startswith(f"{peer_index}:"):
# Format: "42: veth3a4b5c@if41: <BROADCAST,...>"
iface = line.split(":")[1].strip().split("@")[0]
return iface
return cast(str, iface)
raise LookupError(
f"no host veth found for container {container_name!r} (peer ifindex {peer_index})"
)

View File

@@ -20,7 +20,7 @@ import json
import secrets
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, cast
from decnet.realism import personas_pool
from decnet.realism.personas import EmailPersona, parse_personas
@@ -256,7 +256,7 @@ def _persona_by_name(
for decky in enriched:
for persona in decky.get("_realism_personas") or []:
if persona.name == name:
return persona
return cast(EmailPersona, persona)
return None

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import hashlib
import socket
import ssl
from typing import Any
from typing import Any, cast
from cryptography import x509
from cryptography.hazmat.backends import default_backend
@@ -42,7 +42,7 @@ def _iso_utc(dt: Any) -> str:
and ``not_valid_before_utc`` (timezone-aware) — prefer the latter
when available so we always emit explicit-Z ISO strings.
"""
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
return cast(str, dt.strftime("%Y-%m-%dT%H:%M:%SZ"))
def _extract_sans(cert: x509.Certificate) -> list[str]:

View File

@@ -19,7 +19,7 @@ graph cycle-free and the env contract auditable in one place.
from __future__ import annotations
import os
from typing import Any
from typing import Any, cast
from decnet.realism.llm.base import LLMBackend
@@ -45,7 +45,7 @@ def get_llm(*, model: str | None = None, **kwargs: Any) -> LLMBackend:
from decnet.realism.llm.config import get_cached_backend
cached = get_cached_backend()
if cached is not None:
return cached
return cast(LLMBackend, cached)
backend_key = os.environ.get("DECNET_REALISM_LLM", "ollama").lower()

View File

@@ -17,7 +17,7 @@ import logging
import sqlite3
import urllib.request
from datetime import datetime, timezone
from typing import Optional
from typing import Any, Optional, cast
from decnet.rpki import cache as _cache
from decnet.rpki.base import RpkiResult, RpkiStatus, Validator
@@ -68,13 +68,13 @@ class RipeStatValidator(Validator):
)
raw = data.get("data", {}).get("status", "unknown")
if raw in ("valid", "invalid", "not-found"):
return raw
return cast(RpkiStatus, raw)
return "unknown"
def _fetch(self, url: str) -> dict:
def _fetch(self, url: str) -> dict[Any, Any]:
req = urllib.request.Request(url, headers={"User-Agent": _UA})
with urllib.request.urlopen(req, timeout=_TIMEOUT_S) as resp: # nosec B310 — HTTPS RIPE STAT base URL only; IP/ASN components are validated upstream
return json.loads(resp.read())
return cast(dict[Any, Any], json.loads(resp.read()))
def _store(
self, ip: str, asn: int, status: str, prefix: Optional[str]

View File

@@ -14,7 +14,7 @@ import hashlib
import struct
import time
from collections import deque
from typing import Any, Callable
from typing import Any, Callable, cast
from decnet.logging import get_logger
from decnet.prober.tcpfp import _extract_options_order
@@ -1058,18 +1058,18 @@ class SnifferEngine:
def _dedup_key_for(self, event_type: str, fields: dict[str, Any]) -> str:
if event_type == "tls_client_hello":
return fields.get("ja3", "") + "|" + fields.get("ja4", "")
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja4", ""))
if event_type == "tls_session":
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
if event_type == "tls_certificate":
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "")
return cast(str, fields.get("subject_cn", "") + "|" + fields.get("issuer", ""))
if event_type == "tcp_syn_fingerprint":
# Dedupe per (OS signature, options layout, sequence-pattern
# classification). Including ipid_class/isn_class lets each
# transition (unknown → random/incremental/zero/constant) emit
# exactly one fresh event as samples accumulate.
return (
return cast(str,
fields.get("os_guess", "")
+ "|" + fields.get("options_sig", "")
+ "|" + fields.get("ipid_class", "")
@@ -1080,14 +1080,14 @@ class SnifferEngine:
# excluded so a port scanner rotating source ports only produces
# one timing event per dedup window. Behavior cadence doesn't
# need per-ephemeral-port fidelity.
return fields.get("dst_ip", "") + "|" + fields.get("dst_port", "")
return cast(str, fields.get("dst_ip", "") + "|" + fields.get("dst_port", ""))
if event_type == "quic_client_hello":
return fields.get("src_ip", "") + "|" + fields.get("ja4_quic", "")
return cast(str, fields.get("src_ip", "") + "|" + fields.get("ja4_quic", ""))
if event_type == "http_request_fingerprint":
return fields.get("src_ip", "") + "|" + fields.get("ja4h", "")
return cast(str, fields.get("src_ip", "") + "|" + fields.get("ja4h", ""))
if event_type in ("http2_settings", "http3_settings"):
return fields.get("src_ip", "") + "|" + str(fields.get("settings_hash", ""))
return fields.get("mechanisms", fields.get("resumption", ""))
return cast(str, fields.get("src_ip", "") + "|" + str(fields.get("settings_hash", "")))
return cast(str, fields.get("mechanisms", fields.get("resumption", "")))
def _is_duplicate(self, event_type: str, fields: dict[str, Any]) -> bool:
if self._dedup_ttl <= 0:

View File

@@ -23,7 +23,7 @@ import pathlib
import socket
import ssl
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, cast
import httpx
@@ -229,12 +229,12 @@ class AgentClient:
async def health(self) -> dict[str, Any]:
resp = await self._require_client().get("/health")
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def status(self) -> dict[str, Any]:
resp = await self._require_client().get("/status")
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def deploy(
self,
@@ -254,7 +254,7 @@ class AgentClient:
# need for the long deploy timeout here.
resp = await self._require_client().post("/deploy", json=body)
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def mutate(
self,
@@ -271,20 +271,20 @@ class AgentClient:
# Worker /mutate is async (202): control-timeout is right.
resp = await self._require_client().post("/mutate", json=body)
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def teardown(self, decky_id: Optional[str] = None) -> dict[str, Any]:
resp = await self._require_client().post(
"/teardown", json={"decky_id": decky_id}
)
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def self_destruct(self) -> dict[str, Any]:
"""Trigger the worker to stop services and wipe its install."""
resp = await self._require_client().post("/self-destruct")
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
# ------------------------------------------------------------ topology
@@ -309,7 +309,7 @@ class AgentClient:
finally:
self._require_client().timeout = old
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def teardown_topology(self, topology_id: str) -> dict[str, Any]:
"""Ask the agent to dismantle the named topology."""
@@ -323,13 +323,13 @@ class AgentClient:
finally:
self._require_client().timeout = old
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
async def get_topology_state(self) -> dict[str, Any]:
"""Snapshot of the agent's applied topology + live docker state."""
resp = await self._require_client().get("/topology/state")
resp.raise_for_status()
return resp.json()
return cast(dict[str, Any], resp.json())
# -------------------------------------------------------------- diagnostics

View File

@@ -17,7 +17,7 @@ import asyncio
import hashlib
import socket
import ssl
from typing import Any, Optional
from typing import Any, Optional, cast
import httpx
@@ -159,12 +159,12 @@ class UpdaterClient:
async def health(self) -> dict[str, Any]:
r = await self._require().get("/health")
r.raise_for_status()
return r.json()
return cast(dict[str, Any], r.json())
async def releases(self) -> dict[str, Any]:
r = await self._require().get("/releases")
r.raise_for_status()
return r.json()
return cast(dict[str, Any], r.json())
async def update(self, tarball: bytes, sha: str = "") -> httpx.Response:
"""POST /update. Returns the Response so the caller can distinguish

View File

@@ -147,7 +147,7 @@ _MONGO_SET_NAME = os.environ.get("MONGO_REPL_SET", "") # empty = standalone
def _new_objectid() -> bytes:
"""12-byte BSON ObjectId — fresh per call."""
return _seed.fresh_bytes(12)
return cast(bytes, _seed.fresh_bytes(12))
# Minimal BSON helpers
def _bson_str(key: str, val: str) -> bytes:

View File

@@ -205,7 +205,7 @@ def _seed_dict_to_rfc822(entry: dict) -> str | None:
date = str(entry.get("date") or "")
body = entry["body"]
if "\r\n\r\n" in body or "\n\n" in body:
return body # already a full RFC 822 message
return cast(str, body) # already a full RFC 822 message
return (
f"Date: {date}\r\n"
f"From: {from_name} <{from_addr}>\r\n"

View File

@@ -22,6 +22,7 @@ from __future__ import annotations
import asyncio
import os
import struct
from typing import cast
import instance_seed
from ntlmssp import find_ntlmssp, parse_type3
@@ -184,7 +185,7 @@ def _negotiate_response(message_id: int) -> bytes:
+ struct.pack("<H", 0) # SecurityBufferLength
+ struct.pack("<I", 0) # Reserved2
)
return _smb2_header(SMB2_NEGOTIATE, STATUS_SUCCESS, message_id) + body
return cast(bytes, _smb2_header(SMB2_NEGOTIATE, STATUS_SUCCESS, message_id) + body)
def _session_setup_response(message_id: int, session_id: int, sec_blob: bytes, status: int) -> bytes:

View File

@@ -28,7 +28,7 @@ import hashlib
import os
import struct
import time
from typing import Any
from typing import Any, cast
from scapy.layers.inet import IP, TCP
from scapy.sendrecv import sniff
@@ -841,14 +841,14 @@ _dedup_last_cleanup: float = 0.0
def _dedup_key_for(event_type: str, fields: dict[str, Any]) -> str:
"""Build a dedup fingerprint from the most significant fields."""
if event_type == "tls_client_hello":
return fields.get("ja3", "") + "|" + fields.get("ja4", "")
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja4", ""))
if event_type == "tls_session":
return (fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
return cast(str, fields.get("ja3", "") + "|" + fields.get("ja3s", "") +
"|" + fields.get("ja4", "") + "|" + fields.get("ja4s", ""))
if event_type == "tls_certificate":
return fields.get("subject_cn", "") + "|" + fields.get("issuer", "")
return cast(str, fields.get("subject_cn", "") + "|" + fields.get("issuer", ""))
# tls_resumption or unknown — dedup on mechanisms
return fields.get("mechanisms", fields.get("resumption", ""))
return cast(str, fields.get("mechanisms", fields.get("resumption", "")))
def _is_duplicate(event_type: str, fields: dict[str, Any]) -> bool:

View File

@@ -39,7 +39,7 @@ from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Final
from typing import TYPE_CHECKING, Final, cast
if TYPE_CHECKING:
from mitreattack.stix20 import MitreAttackData
@@ -310,7 +310,7 @@ def subtechnique_parent_name(technique_id: str) -> str | None:
)
if not parents:
return None
return parents[0]["object"].name
return cast(str, parents[0]["object"].name)
def is_subtechnique(technique_id: str) -> bool:
@@ -335,7 +335,7 @@ def tactic_id_for_short_name(short_name: str) -> str | None:
return None
for ref in obj.get("external_references", []):
if ref.get("source_name") == "mitre-attack":
return ref.get("external_id")
return cast(str | None, ref.get("external_id"))
return None

View File

@@ -16,7 +16,7 @@ build_fleet_misp_collection → dict ({"response": [event, ...]})
from __future__ import annotations
import json
from typing import Any
from typing import Any, cast
from misp_stix_converter import ExternalSTIX2toMISPParser
@@ -35,7 +35,7 @@ def _parse_bundle(bundle: Any) -> dict[str, Any]:
event = parser.misp_events
if event is None:
return {}
return json.loads(event.to_json())
return cast(dict[str, Any], json.loads(event.to_json()))
def build_attacker_misp_event(

View File

@@ -8,7 +8,7 @@ import os
import traceback
import uuid
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Optional, cast
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
@@ -306,7 +306,7 @@ class _ContentTypeMiddleware(BaseHTTPMiddleware):
status_code=415,
media_type="text/plain",
)
return await call_next(request)
return cast(StarletteResponse, await call_next(request))
app.add_middleware(_ContentTypeMiddleware)

View File

@@ -6,7 +6,7 @@ Repository factory — selects a :class:`BaseRepository` implementation based on
from __future__ import annotations
import os
from typing import Any
from typing import Any, cast
from decnet.web.db.repository import BaseRepository
@@ -32,4 +32,4 @@ def get_repository(**kwargs: Any) -> BaseRepository:
raise ValueError(f"Unsupported database type: {db_type}")
from decnet.telemetry import wrap_repository
return wrap_repository(repo)
return cast(BaseRepository, wrap_repository(repo))

View File

@@ -18,7 +18,7 @@ import os
import orjson
import uuid
from typing import Any, Optional, List
from typing import Any, Optional, List, cast
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
@@ -183,7 +183,7 @@ class SQLModelRepository(
result = await session.execute(statement)
state = result.scalar_one_or_none()
if state:
return json.loads(state.value)
return cast(dict[str, Any], json.loads(state.value))
return None
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401

View File

@@ -10,7 +10,7 @@ from __future__ import annotations
import uuid as _uuid
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import desc, or_, select
from sqlmodel import col
@@ -44,7 +44,7 @@ class AttackerIntelMixin(_MixinBase):
row_uuid = _uuid.uuid4().hex
session.add(AttackerIntel(uuid=row_uuid, **data))
await session.commit()
return row_uuid
return cast(str, row_uuid)
async def get_attacker_intel_row_by_uuid(
self,
@@ -54,7 +54,7 @@ class AttackerIntelMixin(_MixinBase):
result = await session.execute(
select(AttackerIntel).where(AttackerIntel.attacker_uuid == uuid)
)
return result.scalar_one_or_none()
return cast(Optional[AttackerIntel], result.scalar_one_or_none())
async def get_attacker_intel_by_uuid(
self,
@@ -67,7 +67,7 @@ class AttackerIntelMixin(_MixinBase):
row = result.scalar_one_or_none()
if not row:
return None
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))
async def get_unenriched_attackers(
self, limit: int = 100,

View File

@@ -10,7 +10,7 @@ from __future__ import annotations
import json
import uuid as _uuid
from typing import Any, List, Optional
from typing import Any, List, Optional, cast
from sqlalchemy import desc, func, outerjoin, select
from sqlmodel import col
@@ -47,7 +47,7 @@ class AttackersCoreMixin(_MixinBase):
data = {**data, "uuid": row_uuid}
session.add(Attacker(**data))
await session.commit()
return row_uuid
return cast(str, row_uuid)
async def get_attacker_uuid_by_ip(self, ip: str) -> Optional[str]:
"""Return the :class:`Attacker` UUID for *ip*, or ``None``.
@@ -61,7 +61,7 @@ class AttackersCoreMixin(_MixinBase):
result = await session.execute(
select(col(Attacker.uuid)).where(Attacker.ip == ip)
)
return result.scalar_one_or_none()
return cast(Optional[str], result.scalar_one_or_none())
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
async with self._session() as session:

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import uuid as _uuid
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import func, select
from sqlmodel import col
@@ -66,7 +66,7 @@ class AttributionMixin(_MixinBase):
if attacker_row is None:
return None
if attacker_row.identity_id:
return attacker_row.identity_id
return cast(str, attacker_row.identity_id)
new_uuid = _uuid.uuid4().hex
now = datetime.now(timezone.utc)
session.add(

View File

@@ -9,7 +9,7 @@ the reads.
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update
from sqlmodel import col
@@ -36,9 +36,9 @@ class CampaignsMixin(_MixinBase):
if campaign is None:
return None
if campaign.merged_into_uuid is None:
return campaign.model_dump(mode="json")
return cast(dict[str, Any], campaign.model_dump(mode="json"))
current_uuid = campaign.merged_into_uuid
return campaign.model_dump(mode="json")
return cast(dict[str, Any], campaign.model_dump(mode="json"))
async def list_campaigns(
self, limit: int = 50, offset: int = 0,

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update
@@ -26,7 +26,7 @@ class CanaryMixin(_MixinBase):
)
row = existing.scalar_one_or_none()
if row:
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))
row = CanaryBlob(**data)
session.add(row)
await session.commit()
@@ -155,7 +155,7 @@ class CanaryMixin(_MixinBase):
.values(state=state, last_error=last_error)
)
await session.commit()
return result.rowcount > 0
return cast(bool, result.rowcount > 0)
async def record_canary_trigger(self, data: dict[str, Any]) -> str:
# Persist the trigger row + bump the token's counters in the same
@@ -204,4 +204,4 @@ class CanaryMixin(_MixinBase):
.values(attacker_id=attacker_id)
)
await session.commit()
return result.rowcount > 0
return cast(bool, result.rowcount > 0)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any, List, Optional
from typing import Any, List, Optional, cast
from sqlalchemy import desc, func, or_, select, update
from sqlalchemy.exc import IntegrityError
@@ -67,7 +67,7 @@ class CredentialsCoreMixin(_MixinBase):
existing.outcome = payload["outcome"]
session.add(existing)
await session.commit()
return existing.id
return cast(int, existing.id)
row = Credential(
attacker_ip=payload["attacker_ip"],
decky_name=payload["decky_name"],
@@ -103,7 +103,7 @@ class CredentialsCoreMixin(_MixinBase):
existing2.outcome = payload["outcome"]
session2.add(existing2)
await session2.commit()
return existing2.id
return cast(int, existing2.id)
await session.refresh(row)
return row.id # type: ignore[return-value]

View File

@@ -7,7 +7,7 @@ from __future__ import annotations
import json
import uuid as _uuid
from datetime import datetime, timezone
from typing import Any, List, Optional
from typing import Any, List, Optional, cast
from sqlalchemy import desc, func, select
from sqlmodel import col
@@ -137,7 +137,7 @@ class CredentialReuseMixin(_MixinBase):
d = existing.model_dump(mode="json")
d["inserted"] = False
d["changed"] = changed
return d
return cast(dict[str, Any], d)
async def find_credential_reuse_candidates(
self, min_targets: int = 2
@@ -236,7 +236,7 @@ class CredentialReuseMixin(_MixinBase):
except (json.JSONDecodeError, TypeError):
d[key] = []
await self._enrich_with_secret(session, [d])
return d
return cast(dict[str, Any], d)
@staticmethod
async def _enrich_with_secret(

View File

@@ -9,7 +9,7 @@ the reads.
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update
from sqlmodel import col
@@ -42,10 +42,10 @@ class IdentitiesMixin(_MixinBase):
if identity is None:
return None
if identity.merged_into_uuid is None:
return identity.model_dump(mode="json")
return cast(dict[str, Any], identity.model_dump(mode="json"))
current_uuid = identity.merged_into_uuid
# Hit the hop cap — surface what we have rather than recurse.
return identity.model_dump(mode="json")
return cast(dict[str, Any], identity.model_dump(mode="json"))
async def list_identities(
self, limit: int = 50, offset: int = 0,

View File

@@ -21,7 +21,7 @@ not validate values — that happens at construction time by the BEHAVE
from __future__ import annotations
import uuid as _uuid
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import desc, func, select
from sqlmodel import col
@@ -85,7 +85,7 @@ class ObservationsMixin(_MixinBase):
row_data = {**data, "id": row_id}
session.add(ObservationRow(**row_data))
await session.commit()
return row_id
return cast(str, row_id)
async def latest_observation_per_primitive(
self, attacker_uuid: str,
@@ -178,7 +178,7 @@ class ObservationsMixin(_MixinBase):
row = (await session.execute(stmt)).scalar_one_or_none()
if not row:
return None
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))
async def observations_for_identity_primitive(
self, identity_uuid: str, primitive: str,

View File

@@ -13,7 +13,7 @@ caller is upgrading ``False/None`` to ``True``.
from __future__ import annotations
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, cast
from sqlalchemy import select
@@ -106,4 +106,4 @@ class ObservedAttachmentsMixin(_MixinBase):
row.mal_hash_match_at = now
session.add(row)
await session.commit()
return row.uuid
return cast(str, row.uuid)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import desc, func, select, update
@@ -92,7 +92,7 @@ class RealismMixin(_MixinBase):
row = result.scalars().first()
if row is None:
return None
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))
async def get_realism_config(
self, key: str,
@@ -103,7 +103,7 @@ class RealismMixin(_MixinBase):
row = result.scalars().first()
if row is None:
return None
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))
async def set_realism_config(
self, key: str, value: str,
@@ -157,4 +157,4 @@ class RealismMixin(_MixinBase):
row = result.scalars().first()
if row is None:
return None
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import select
@@ -47,7 +47,7 @@ class TarpitMixin(_MixinBase):
return None
d = row.model_dump(mode="json")
d["ports"] = json.loads(d["ports"])
return d
return cast(dict[str, Any], d)
async def delete_tarpit_rule(self, decky_name: str) -> bool:
async with self._session() as session:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
import orjson
from sqlalchemy import asc, desc, select, text
@@ -102,7 +102,7 @@ class TopologyMutationsMixin(_MixinBase):
_ = now
if row is None:
return None
return row.model_dump(mode="json")
return cast(dict[str, Any], row.model_dump(mode="json"))
async def mark_mutation_applied(self, mutation_id: str) -> None:
async with self._session() as session:

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import select, update
from sqlmodel import col
@@ -65,7 +65,7 @@ class WebhooksMixin(_MixinBase):
.values(**patch)
)
await session.commit()
return result.rowcount > 0
return cast(bool, result.rowcount > 0)
async def delete_webhook_subscription(self, uuid: str) -> bool:
async with self._session() as session: