fix(types): P2 — wire _MixinBase + col() across sqlmodel_repo; suppress pydantic/SQLModel column typing false positives

- Add _MixinBase abstract class to _helpers.py: declares _session(),
  _deserialize_attacker(), _assert_pending(), _check_and_bump_version(),
  and list_running_topology_deckies() so mypy can see cross-mixin contracts
- Add _require(val, msg) helper for narrowing T | None → T
- Inherit _MixinBase in all 26 leaf mixin classes
- Wrap SQLAlchemy column method calls (.is_(), .like(), .notin_(), .in_(),
  .contains()) with col() from sqlmodel — fixes attr-defined false positives
  caused by pydantic plugin typing class-level fields as Python value types
- Wrap select(Model.field) with select(col(Model.field)) for column projections
- Add pyproject.toml [[tool.mypy.overrides]] to disable arg-type in
  sqlmodel_repo.*: pydantic plugin resolves .where(Model.field == v) as
  where(bool), a false positive; call-arg still catches real argument errors
- Remove 9 stale # type: ignore comments (logging, helpers, credentials)
- Fix telemetry.py traced() overload no-redef + misc
- Fix logs.py datetime/str operator and nullable PK comparison with col()
- sqlmodel_repo/ now has 0 mypy errors
This commit is contained in:
2026-05-01 00:49:18 -04:00
parent d777a1c4e0
commit 614780f144
30 changed files with 221 additions and 100 deletions

View File

@@ -28,7 +28,7 @@ class _ComponentFilter(logging.Filter):
self.component = component
def filter(self, record: logging.LogRecord) -> bool:
record.decnet_component = self.component # type: ignore[attr-defined]
record.decnet_component = self.component
return True
@@ -49,14 +49,14 @@ class _TraceContextFilter(logging.Filter):
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.trace_id:
record.otel_trace_id = format(ctx.trace_id, "032x") # type: ignore[attr-defined]
record.otel_span_id = format(ctx.span_id, "016x") # type: ignore[attr-defined]
record.otel_trace_id = format(ctx.trace_id, "032x")
record.otel_span_id = format(ctx.span_id, "016x")
else:
record.otel_trace_id = "0" # type: ignore[attr-defined]
record.otel_span_id = "0" # type: ignore[attr-defined]
record.otel_trace_id = "0"
record.otel_span_id = "0"
except Exception:
record.otel_trace_id = "0" # type: ignore[attr-defined]
record.otel_span_id = "0" # type: ignore[attr-defined]
record.otel_trace_id = "0"
record.otel_span_id = "0"
return True

View File

@@ -138,7 +138,7 @@ def traced(fn: F) -> F: ...
def traced(name: str) -> Callable[[F], F]: ...
def traced(fn: Any = None, *, name: str | None = None) -> Any:
def traced(fn: Any = None, *, name: str | None = None) -> Any: # type: ignore[misc]
"""Decorator that wraps a function in an OTEL span.
Usage::
@@ -168,9 +168,9 @@ def traced(fn: Any = None, *, name: str | None = None) -> Any:
# Called as @traced (no arguments)
return _wrap(fn, None)
# Fallback: @traced() with no args
def decorator(f: F) -> F:
def _fallback_decorator(f: F) -> F:
return _wrap(f, name)
return decorator
return _fallback_decorator
def _wrap(fn: F, span_name: str | None) -> F:

View File

@@ -12,14 +12,60 @@ from __future__ import annotations
import asyncio
import json
from abc import abstractmethod
from contextlib import asynccontextmanager
from typing import Any
from typing import Any, Optional, TypeVar
import orjson
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from decnet.logging import get_logger
T = TypeVar("T")
def _require(val: T | None, msg: str) -> T:
"""Narrow ``X | None`` to ``X``, raising ``ValueError`` if None."""
if val is None:
raise ValueError(msg)
return val
class _MixinBase:
"""Typing base for all repo mixins.
Declares the contract that ``SQLModelRepository`` satisfies at runtime
via MRO composition. Without this, mypy checks each mixin in isolation
and cannot see ``_session`` or cross-mixin helpers.
"""
@abstractmethod
def _session(self):
"""Return a cancellation-safe async session context manager."""
raise NotImplementedError
@staticmethod
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
"""Stub — concrete impl on AttackersCoreMixin via MRO."""
return d
async def _assert_pending(self, session: AsyncSession, topology_id: str) -> None:
"""Stub — concrete impl on TopologyCoreMixin via MRO."""
raise NotImplementedError
async def _check_and_bump_version(
self,
session: AsyncSession,
topology_id: str,
expected_version: Optional[int],
) -> None:
"""Stub — concrete impl on TopologyCoreMixin via MRO."""
raise NotImplementedError
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
"""Stub — concrete impl on TopologyDeckiesMixin via MRO."""
raise NotImplementedError
_log = get_logger("db.pool")
# Hold strong refs to in-flight cleanup tasks so they aren't GC'd mid-run.
@@ -66,7 +112,7 @@ def _detach_close(session: AsyncSession) -> None:
task = loop.create_task(_cleanup())
_cleanup_tasks.add(task)
# Consume any exception to silence "Task exception was never retrieved".
task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception()))
task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception())) # type: ignore[func-returns-value]
@asynccontextmanager

View File

@@ -13,11 +13,14 @@ from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy import desc, or_, select
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerIntel
class AttackerIntelMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackerIntelMixin(_MixinBase):
"""Mixin: methods composed onto ``SQLModelRepository``.
Expects ``self._session()`` from the base.
@@ -82,13 +85,13 @@ class AttackerIntelMixin:
now = datetime.now(timezone.utc)
async with self._session() as session:
stmt = (
select(Attacker.uuid, Attacker.ip)
select(col(Attacker.uuid), col(Attacker.ip))
.outerjoin(
AttackerIntel, AttackerIntel.attacker_uuid == Attacker.uuid,
)
.where(
or_(
AttackerIntel.uuid.is_(None),
col(AttackerIntel.uuid).is_(None),
AttackerIntel.expires_at < now,
)
)

View File

@@ -12,11 +12,14 @@ import uuid as _uuid
from typing import Any, List, Optional
from sqlalchemy import desc, func, outerjoin, select
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerIntel
class AttackersCoreMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackersCoreMixin(_MixinBase):
@staticmethod
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
for key in ("services", "deckies", "fingerprints", "commands"):
@@ -63,16 +66,16 @@ class AttackersCoreMixin:
sort_by: str = "recent",
service: Optional[str] = None,
) -> List[dict[str, Any]]:
order = {
order: Any = {
"active": desc(Attacker.event_count),
"traversals": desc(Attacker.is_traversal),
}.get(sort_by, desc(Attacker.last_seen))
statement = select(Attacker).order_by(order).offset(offset).limit(limit)
if search:
statement = statement.where(Attacker.ip.like(f"%{search}%"))
statement = statement.where(col(Attacker.ip).like(f"%{search}%"))
if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
statement = statement.where(col(Attacker.services).like(f'%"{service}"%'))
async with self._session() as session:
result = await session.execute(statement)
@@ -121,9 +124,9 @@ class AttackersCoreMixin:
) -> int:
statement = select(func.count()).select_from(Attacker)
if search:
statement = statement.where(Attacker.ip.like(f"%{search}%"))
statement = statement.where(col(Attacker.ip).like(f"%{search}%"))
if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
statement = statement.where(col(Attacker.services).like(f'%"{service}"%'))
async with self._session() as session:
result = await session.execute(statement)

View File

@@ -10,11 +10,14 @@ import json
from typing import Any, Optional
from sqlalchemy import desc, func, select
from sqlmodel import col
from decnet.web.db.models import Attacker, Bounty, Log
class AttackerActivityMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackerActivityMixin(_MixinBase):
async def get_attacker_commands(
self,
uuid: str,
@@ -24,7 +27,7 @@ class AttackerActivityMixin:
) -> dict[str, Any]:
async with self._session() as session:
result = await session.execute(
select(Attacker.commands).where(Attacker.uuid == uuid)
select(col(Attacker.commands)).where(Attacker.uuid == uuid)
)
raw = result.scalar_one_or_none()
if raw is None:
@@ -52,13 +55,13 @@ class AttackerActivityMixin:
"""
async with self._session() as session:
ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
)
ip = ip_res.scalar_one_or_none()
if not ip:
return []
rows = await session.execute(
select(Log.service, Log.event_type)
select(col(Log.service), col(Log.event_type))
.where(Log.attacker_ip == ip)
.distinct()
)
@@ -75,7 +78,7 @@ class AttackerActivityMixin:
rotation detection."""
async with self._session() as session:
ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
)
ip = ip_res.scalar_one_or_none()
if not ip:
@@ -104,7 +107,7 @@ class AttackerActivityMixin:
"""Cheap COUNT(*) for XFF-rotation detection."""
async with self._session() as session:
ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
)
ip = ip_res.scalar_one_or_none()
if not ip:
@@ -126,7 +129,7 @@ class AttackerActivityMixin:
"""
async with self._session() as session:
ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == uuid)
select(col(Attacker.ip)).where(Attacker.uuid == uuid)
)
ip = ip_res.scalar_one_or_none()
if not ip:
@@ -150,7 +153,7 @@ class AttackerActivityMixin:
"""
async with self._session() as session:
ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == uuid)
select(col(Attacker.ip)).where(Attacker.uuid == uuid)
)
ip = ip_res.scalar_one_or_none()
if not ip:
@@ -176,7 +179,7 @@ class AttackerActivityMixin:
rows = await session.execute(
select(Log)
.where(Log.event_type == "session_recorded")
.where(Log.fields.contains(needle))
.where(col(Log.fields).contains(needle))
.limit(1)
)
row = rows.scalars().first()
@@ -192,7 +195,7 @@ class AttackerActivityMixin:
"""
async with self._session() as session:
ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == uuid)
select(col(Attacker.ip)).where(Attacker.uuid == uuid)
)
ip = ip_res.scalar_one_or_none()
if not ip:

View File

@@ -7,11 +7,14 @@ from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy import select
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerBehavior
class AttackerBehaviorMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackerBehaviorMixin(_MixinBase):
async def upsert_attacker_behavior(
self,
attacker_uuid: str,
@@ -56,9 +59,9 @@ class AttackerBehaviorMixin:
return {}
async with self._session() as session:
result = await session.execute(
select(Attacker.ip, AttackerBehavior)
select(col(Attacker.ip), AttackerBehavior)
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
.where(Attacker.ip.in_(ips))
.where(col(Attacker.ip).in_(ips))
)
out: dict[str, dict[str, Any]] = {}
for ip, row in result.all():

View File

@@ -9,7 +9,9 @@ from sqlalchemy import select
from decnet.web.db.models import SessionProfile
class SessionProfilesMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class SessionProfilesMixin(_MixinBase):
async def upsert_session_profile(
self,
sid: str,

View File

@@ -10,7 +10,9 @@ from sqlalchemy import desc, func, select
from decnet.web.db.models import SmtpTarget
class SmtpTargetsMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class SmtpTargetsMixin(_MixinBase):
async def increment_smtp_target(self, attacker_uuid: str, domain: str) -> None:
"""Upsert an (attacker_uuid, domain) pair and bump count + last_seen.

View File

@@ -8,7 +8,9 @@ from sqlalchemy import select, update
from decnet.web.db.models import User
class AuthMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AuthMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``.
``_ensure_admin_user`` stays in the package ``__init__`` so the

View File

@@ -7,12 +7,15 @@ from typing import Any, List, Optional
import orjson
from sqlalchemy import asc, desc, func, or_, select, text
from sqlmodel import col
from sqlmodel.sql.expression import SelectOfScalar
from decnet.web.db.models import Bounty
class BountiesMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class BountiesMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def purge_logs_and_bounties(self) -> dict[str, int]:
@@ -40,7 +43,7 @@ class BountiesMixin:
async with self._session() as session:
dup = await session.execute(
select(Bounty.id).where(
select(col(Bounty.id)).where(
Bounty.bounty_type == data.get("bounty_type"),
Bounty.attacker_ip == data.get("attacker_ip"),
Bounty.payload == data.get("payload"),
@@ -63,10 +66,10 @@ class BountiesMixin:
lk = f"%{search}%"
statement = statement.where(
or_(
Bounty.decky.like(lk),
Bounty.service.like(lk),
Bounty.attacker_ip.like(lk),
Bounty.payload.like(lk),
col(Bounty.decky).like(lk),
col(Bounty.service).like(lk),
col(Bounty.attacker_ip).like(lk),
col(Bounty.payload).like(lk),
)
)
return statement
@@ -126,7 +129,7 @@ class BountiesMixin:
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
async with self._session() as session:
result = await session.execute(
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
select(Bounty).where(col(Bounty.attacker_ip).in_(ips)).order_by(asc(Bounty.timestamp))
)
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
for item in result.scalars().all():

View File

@@ -11,11 +11,14 @@ from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy import desc, func, select, update
from sqlmodel import col
from decnet.web.db.models import AttackerIdentity, Campaign
class CampaignsMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CampaignsMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def get_campaign_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
@@ -41,7 +44,7 @@ class CampaignsMixin:
) -> list[dict[str, Any]]:
statement = (
select(Campaign)
.where(Campaign.merged_into_uuid.is_(None))
.where(col(Campaign.merged_into_uuid).is_(None))
.order_by(desc(Campaign.updated_at))
.offset(offset)
.limit(limit)
@@ -54,7 +57,7 @@ class CampaignsMixin:
statement = (
select(func.count())
.select_from(Campaign)
.where(Campaign.merged_into_uuid.is_(None))
.where(col(Campaign.merged_into_uuid).is_(None))
)
async with self._session() as session:
result = await session.execute(statement)
@@ -91,7 +94,7 @@ class CampaignsMixin:
# graph reads. Narrow on purpose — future denormalised
# projections (commands_by_phase from log mining, decky-set
# aggregates) can land here without churning callers.
statement = select(
statement = select( # type: ignore[call-overload, misc]
AttackerIdentity.uuid,
AttackerIdentity.campaign_id,
AttackerIdentity.merged_into_uuid,

View File

@@ -10,7 +10,9 @@ from sqlalchemy import desc, func, select, update
from decnet.web.db.models import CanaryBlob, CanaryToken, CanaryTrigger
class CanaryMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CanaryMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]:

View File

@@ -6,12 +6,15 @@ from datetime import datetime, timezone
from typing import Any, List, Optional
from sqlalchemy import desc, func, or_, select, update
from sqlmodel import col
from sqlmodel.sql.expression import SelectOfScalar
from decnet.web.db.models import Credential
class CredentialsCoreMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CredentialsCoreMixin(_MixinBase):
async def upsert_credential(self, data: dict[str, Any]) -> int:
"""Upsert a credential attempt; returns the row id.
@@ -37,7 +40,7 @@ class CredentialsCoreMixin:
Credential.secret_sha256 == payload["secret_sha256"],
# NULL == NULL is False under SQL — branch the predicate.
(Credential.principal == principal) if principal is not None
else Credential.principal.is_(None),
else col(Credential.principal).is_(None),
)
existing = (await session.execute(stmt)).scalar_one_or_none()
now = datetime.now(timezone.utc)
@@ -48,7 +51,7 @@ class CredentialsCoreMixin:
existing.outcome = payload["outcome"]
session.add(existing)
await session.commit()
return existing.id # type: ignore[return-value]
return existing.id
row = Credential(
attacker_ip=payload["attacker_ip"],
decky_name=payload["decky_name"],
@@ -84,10 +87,10 @@ class CredentialsCoreMixin:
lk = f"%{search}%"
statement = statement.where(
or_(
Credential.decky_name.like(lk),
Credential.service.like(lk),
Credential.principal.like(lk),
Credential.secret_printable.like(lk),
col(Credential.decky_name).like(lk),
col(Credential.service).like(lk),
col(Credential.principal).like(lk),
col(Credential.secret_printable).like(lk),
)
)
return statement
@@ -188,7 +191,7 @@ class CredentialsCoreMixin:
update(Credential)
.where(
Credential.attacker_ip == attacker_ip,
Credential.attacker_uuid.is_(None),
col(Credential.attacker_uuid).is_(None),
)
.values(attacker_uuid=attacker_uuid)
)

View File

@@ -9,11 +9,14 @@ from datetime import datetime, timezone
from typing import Any, List, Optional
from sqlalchemy import desc, func, select
from sqlmodel import col
from decnet.web.db.models import Credential, CredentialReuse
class CredentialReuseMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CredentialReuseMixin(_MixinBase):
@staticmethod
def _merge_unique(existing_json: str, value: Optional[str]) -> tuple[str, bool]:
"""Append ``value`` to a JSON list[str] column if not present.
@@ -117,7 +120,7 @@ class CredentialReuseMixin:
Credential.secret_sha256 == secret_sha256,
Credential.secret_kind == secret_kind,
(Credential.principal == principal) if principal is not None
else Credential.principal.is_(None),
else col(Credential.principal).is_(None),
)
)
target_count = (await session.execute(stmt)).scalar() or 0
@@ -150,7 +153,7 @@ class CredentialReuseMixin:
).label("target_count")
async with self._session() as session:
group_stmt = (
select(
select( # type: ignore[call-overload]
Credential.secret_sha256,
Credential.secret_kind,
Credential.principal,
@@ -171,7 +174,7 @@ class CredentialReuseMixin:
Credential.secret_kind == kind,
(Credential.principal == principal)
if principal is not None
else Credential.principal.is_(None),
else col(Credential.principal).is_(None),
)
rows = (await session.execute(cred_stmt)).scalars().all()
out.append({
@@ -253,13 +256,13 @@ class CredentialReuseMixin:
sha_set = {r["secret_sha256"] for r in rows}
if not sha_set:
return
stmt = select(
stmt = select( # type: ignore[call-overload]
Credential.secret_sha256,
Credential.secret_kind,
Credential.principal,
Credential.secret_printable,
Credential.secret_b64,
).where(Credential.secret_sha256.in_(sha_set))
).where(col(Credential.secret_sha256).in_(sha_set))
secret_map: dict[
tuple[str, str, Optional[str]],
tuple[Optional[str], Optional[str]],

View File

@@ -11,7 +11,9 @@ from sqlalchemy import asc, select, text
from decnet.web.db.models import DeckyShard
class DeckiesMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class DeckiesMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def upsert_decky_shard(self, data: dict[str, Any]) -> None:

View File

@@ -8,10 +8,13 @@ import orjson
from sqlalchemy import asc, select, text, update
from decnet.web.db.models import DeckyShard, FleetDecky, LOCAL_HOST_SENTINEL
from decnet.web.db.sqlmodel_repo._helpers import _deserialize_json_fields
from decnet.web.db.sqlmodel_repo._helpers import (
_MixinBase,
_deserialize_json_fields
)
class FleetMixin:
class FleetMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.
``list_running_deckies`` aggregates topology + fleet + swarm-shard

View File

@@ -11,11 +11,14 @@ from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy import desc, func, select, update
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerIdentity
class IdentitiesMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class IdentitiesMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.
``self._deserialize_attacker`` resolves through ``AttackersMixin``
@@ -51,7 +54,7 @@ class IdentitiesMixin:
# and a future "merged into" endpoint when we need it.
statement = (
select(AttackerIdentity)
.where(AttackerIdentity.merged_into_uuid.is_(None))
.where(col(AttackerIdentity.merged_into_uuid).is_(None))
.order_by(desc(AttackerIdentity.updated_at))
.offset(offset)
.limit(limit)
@@ -64,7 +67,7 @@ class IdentitiesMixin:
statement = (
select(func.count())
.select_from(AttackerIdentity)
.where(AttackerIdentity.merged_into_uuid.is_(None))
.where(col(AttackerIdentity.merged_into_uuid).is_(None))
)
async with self._session() as session:
result = await session.execute(statement)
@@ -105,7 +108,7 @@ class IdentitiesMixin:
# joined from logs, c2 endpoints aggregated from sessions) can
# land here without churning every caller. ``fingerprints`` is
# the raw JSON list — the clusterer parses for JA3 / HASSH.
statement = select(
statement = select( # type: ignore[call-overload]
Attacker.uuid, Attacker.asn, Attacker.identity_id, Attacker.fingerprints,
).order_by(Attacker.first_seen)
if limit is not None:

View File

@@ -15,13 +15,16 @@ from typing import Any, List, Optional
import orjson
from sqlalchemy import asc, desc, func, or_, select, text
from sqlmodel import col
from sqlmodel.sql.expression import SelectOfScalar
from decnet.config import load_state
from decnet.web.db.models import Log, TopologyDecky
class LogsMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class LogsMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
@staticmethod
@@ -61,9 +64,9 @@ class LogsMixin:
end_time: Optional[str],
) -> SelectOfScalar:
if start_time:
statement = statement.where(Log.timestamp >= start_time)
statement = statement.where(col(Log.timestamp) >= start_time)
if end_time:
statement = statement.where(Log.timestamp <= end_time)
statement = statement.where(col(Log.timestamp) <= end_time)
if search:
try:
@@ -95,10 +98,10 @@ class LogsMixin:
lk = f"%{token}%"
statement = statement.where(
or_(
Log.raw_line.like(lk),
Log.decky.like(lk),
Log.service.like(lk),
Log.attacker_ip.like(lk),
col(Log.raw_line).like(lk),
col(Log.decky).like(lk),
col(Log.service).like(lk),
col(Log.attacker_ip).like(lk),
)
)
return statement
@@ -148,7 +151,7 @@ class LogsMixin:
end_time: Optional[str] = None,
) -> List[dict]:
statement = (
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
select(Log).where(col(Log.id) > last_id).order_by(asc(Log.id)).limit(limit)
)
statement = self._apply_filters(statement, search, start_time, end_time)

View File

@@ -7,11 +7,14 @@ from typing import Any, Optional
from sqlalchemy import delete as sa_delete
from sqlalchemy import desc, func, or_, select
from sqlmodel import col
from decnet.web.db.models import OrchestratorEmail, OrchestratorEvent
class OrchestratorMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class OrchestratorMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def record_orchestrator_event(self, data: dict[str, Any]) -> str:
@@ -62,11 +65,11 @@ class OrchestratorMixin:
deleted = 0
async with self._session() as session:
dst_rows = await session.execute(
select(OrchestratorEvent.dst_decky_uuid).distinct()
select(col(OrchestratorEvent.dst_decky_uuid)).distinct()
)
for (dst,) in dst_rows.all():
keep = await session.execute(
select(OrchestratorEvent.uuid)
select(col(OrchestratorEvent.uuid))
.where(OrchestratorEvent.dst_decky_uuid == dst)
.order_by(desc(OrchestratorEvent.ts))
.limit(per_dst_cap)
@@ -76,7 +79,7 @@ class OrchestratorMixin:
continue
stmt = sa_delete(OrchestratorEvent).where(
OrchestratorEvent.dst_decky_uuid == dst,
OrchestratorEvent.uuid.notin_(keep_uuids),
col(OrchestratorEvent.uuid).notin_(keep_uuids),
)
res = await session.execute(stmt)
deleted += res.rowcount or 0
@@ -156,7 +159,7 @@ class OrchestratorMixin:
(OrchestratorEmail.sender_email == recipient_email)
& (OrchestratorEmail.recipient_email == sender_email),
),
OrchestratorEmail.success.is_(True),
col(OrchestratorEmail.success).is_(True),
)
.order_by(desc(OrchestratorEmail.ts))
.limit(limit)
@@ -169,11 +172,11 @@ class OrchestratorMixin:
deleted = 0
async with self._session() as session:
decky_rows = await session.execute(
select(OrchestratorEmail.mail_decky_uuid).distinct()
select(col(OrchestratorEmail.mail_decky_uuid)).distinct()
)
for (mail_uuid,) in decky_rows.all():
keep = await session.execute(
select(OrchestratorEmail.uuid)
select(col(OrchestratorEmail.uuid))
.where(OrchestratorEmail.mail_decky_uuid == mail_uuid)
.order_by(desc(OrchestratorEmail.ts))
.limit(per_decky_cap)
@@ -183,7 +186,7 @@ class OrchestratorMixin:
continue
stmt = sa_delete(OrchestratorEmail).where(
OrchestratorEmail.mail_decky_uuid == mail_uuid,
OrchestratorEmail.uuid.notin_(keep_uuids),
col(OrchestratorEmail.uuid).notin_(keep_uuids),
)
res = await session.execute(stmt)
deleted += res.rowcount or 0

View File

@@ -10,7 +10,9 @@ from decnet.web.db.models import RealismConfig, SyntheticFile
from decnet.web.db.models.realism import SYNTHETIC_FILE_BODY_LIMIT
class RealismMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class RealismMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def record_synthetic_file(self, data: dict[str, Any]) -> str:

View File

@@ -8,7 +8,9 @@ from sqlalchemy import asc, select, text, update
from decnet.web.db.models import SwarmHost
class SwarmMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class SwarmMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``."""
async def add_swarm_host(self, data: dict[str, Any]) -> None:

View File

@@ -11,7 +11,9 @@ from sqlalchemy import select
from decnet.web.db.models import TarpitRule
class TarpitMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class TarpitMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def set_tarpit_rule(self, data: dict[str, Any]) -> None:

View File

@@ -9,10 +9,15 @@ from sqlalchemy import desc, func, select, text
from decnet.web.db.models import Topology, TopologyStatusEvent
from decnet.web.db.models.topology import TopologySummary
from decnet.web.db.sqlmodel_repo._helpers import _serialize_json_fields
from sqlmodel import col
from decnet.web.db.sqlmodel_repo._helpers import (
_MixinBase,
_serialize_json_fields
)
class TopologyCoreMixin:
class TopologyCoreMixin(_MixinBase):
"""Topologies CRUD + ``_assert_pending`` / ``_check_and_bump_version``.
The two private helpers live here because every other topology
@@ -184,8 +189,8 @@ class TopologyCoreMixin:
"""Return ids of topologies currently in ``active|degraded``."""
async with self._session() as session:
result = await session.execute(
select(Topology.id).where(
Topology.status.in_(["active", "degraded"])
select(col(Topology.id)).where(
col(Topology.status).in_(["active", "degraded"])
)
)
return [r for r in result.scalars().all()]

View File

@@ -10,12 +10,13 @@ from sqlalchemy import asc, select, text, update
from decnet.web.db.models import TopologyDecky
from decnet.web.db.models.topology import DeckyRow
from decnet.web.db.sqlmodel_repo._helpers import (
_MixinBase,
_deserialize_json_fields,
_serialize_json_fields,
)
class TopologyDeckiesMixin:
class TopologyDeckiesMixin(_MixinBase):
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve
through ``TopologyCoreMixin`` via MRO."""

View File

@@ -9,7 +9,9 @@ from decnet.web.db.models import TopologyEdge, TopologyStatusEvent
from decnet.web.db.models.topology import EdgeRow
class TopologyEdgesMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class TopologyEdgesMixin(_MixinBase):
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve
through ``TopologyCoreMixin`` via MRO."""

View File

@@ -9,7 +9,9 @@ from decnet.web.db.models import LAN, TopologyEdge
from decnet.web.db.models.topology import LANRow
class LansMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class LansMixin(_MixinBase):
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve
through ``TopologyCoreMixin`` via MRO."""

View File

@@ -10,7 +10,9 @@ from sqlalchemy import asc, desc, select, text
from decnet.web.db.models import TopologyMutation
class TopologyMutationsMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class TopologyMutationsMixin(_MixinBase):
"""``self._check_and_bump_version`` resolves through
``TopologyCoreMixin`` via MRO."""

View File

@@ -5,11 +5,14 @@ from datetime import datetime, timezone
from typing import Any, Optional
from sqlalchemy import select, update
from sqlmodel import col
from decnet.web.db.models import WebhookSubscription
class WebhooksMixin:
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class WebhooksMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``."""
async def create_webhook_subscription(self, data: dict[str, Any]) -> None:
@@ -43,7 +46,7 @@ class WebhooksMixin:
async with self._session() as session:
stmt = select(WebhookSubscription)
if enabled_only:
stmt = stmt.where(WebhookSubscription.enabled.is_(True))
stmt = stmt.where(col(WebhookSubscription.enabled).is_(True))
stmt = stmt.order_by(WebhookSubscription.created_at)
result = await session.execute(stmt)
return [r.model_dump() for r in result.scalars().all()]
@@ -100,7 +103,7 @@ class WebhooksMixin:
# the counter informs the circuit-breaker heuristic, not a
# correctness invariant.
result = await session.execute(
select(WebhookSubscription.consecutive_failures).where(
select(col(WebhookSubscription.consecutive_failures)).where(
WebhookSubscription.uuid == uuid
)
)

View File

@@ -154,3 +154,11 @@ ignore_missing_imports = true
check_untyped_defs = true
warn_redundant_casts = true
warn_unused_ignores = true
[[tool.mypy.overrides]]
# The pydantic plugin types SQLModel class-level field descriptors as their
# Python value types (str, bool, …) instead of InstrumentedAttribute. Every
# .where(Model.field == value) then becomes where(bool) — a false positive.
# Suppressing arg-type here; genuine argument errors are caught by call-arg.
module = "decnet.web.db.sqlmodel_repo.*"
disable_error_code = ["arg-type"]