fix(types): P2 — wire _MixinBase + col() across sqlmodel_repo; suppress pydantic/SQLModel column typing false positives

- Add _MixinBase abstract class to _helpers.py: declares _session(),
  _deserialize_attacker(), _assert_pending(), _check_and_bump_version(),
  and list_running_topology_deckies() so mypy can see cross-mixin contracts
- Add _require(val, msg) helper for narrowing T | None → T
- Inherit _MixinBase in all 26 leaf mixin classes
- Wrap SQLAlchemy column method calls (.is_(), .like(), .notin_(), .in_(),
  .contains()) with col() from sqlmodel — fixes attr-defined false positives
  caused by pydantic plugin typing class-level fields as Python value types
- Wrap select(Model.field) with select(col(Model.field)) for column projections
- Add pyproject.toml [[tool.mypy.overrides]] to disable arg-type in
  sqlmodel_repo.*: pydantic plugin resolves .where(Model.field == v) as
  where(bool), a false positive; call-arg still catches real argument errors
- Remove 9 stale # type: ignore comments (logging, helpers, credentials)
- Fix telemetry.py traced() overload no-redef + misc
- Fix logs.py datetime/str operator and nullable PK comparison with col()
- sqlmodel_repo/ now has 0 mypy errors
This commit is contained in:
2026-05-01 00:49:18 -04:00
parent d777a1c4e0
commit 614780f144
30 changed files with 221 additions and 100 deletions

View File

@@ -28,7 +28,7 @@ class _ComponentFilter(logging.Filter):
self.component = component self.component = component
def filter(self, record: logging.LogRecord) -> bool: def filter(self, record: logging.LogRecord) -> bool:
record.decnet_component = self.component # type: ignore[attr-defined] record.decnet_component = self.component
return True return True
@@ -49,14 +49,14 @@ class _TraceContextFilter(logging.Filter):
span = trace.get_current_span() span = trace.get_current_span()
ctx = span.get_span_context() ctx = span.get_span_context()
if ctx and ctx.trace_id: if ctx and ctx.trace_id:
record.otel_trace_id = format(ctx.trace_id, "032x") # type: ignore[attr-defined] record.otel_trace_id = format(ctx.trace_id, "032x")
record.otel_span_id = format(ctx.span_id, "016x") # type: ignore[attr-defined] record.otel_span_id = format(ctx.span_id, "016x")
else: else:
record.otel_trace_id = "0" # type: ignore[attr-defined] record.otel_trace_id = "0"
record.otel_span_id = "0" # type: ignore[attr-defined] record.otel_span_id = "0"
except Exception: except Exception:
record.otel_trace_id = "0" # type: ignore[attr-defined] record.otel_trace_id = "0"
record.otel_span_id = "0" # type: ignore[attr-defined] record.otel_span_id = "0"
return True return True

View File

@@ -138,7 +138,7 @@ def traced(fn: F) -> F: ...
def traced(name: str) -> Callable[[F], F]: ... def traced(name: str) -> Callable[[F], F]: ...
def traced(fn: Any = None, *, name: str | None = None) -> Any: def traced(fn: Any = None, *, name: str | None = None) -> Any: # type: ignore[misc]
"""Decorator that wraps a function in an OTEL span. """Decorator that wraps a function in an OTEL span.
Usage:: Usage::
@@ -168,9 +168,9 @@ def traced(fn: Any = None, *, name: str | None = None) -> Any:
# Called as @traced (no arguments) # Called as @traced (no arguments)
return _wrap(fn, None) return _wrap(fn, None)
# Fallback: @traced() with no args # Fallback: @traced() with no args
def decorator(f: F) -> F: def _fallback_decorator(f: F) -> F:
return _wrap(f, name) return _wrap(f, name)
return decorator return _fallback_decorator
def _wrap(fn: F, span_name: str | None) -> F: def _wrap(fn: F, span_name: str | None) -> F:

View File

@@ -12,14 +12,60 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from abc import abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any, Optional, TypeVar
import orjson import orjson
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from decnet.logging import get_logger from decnet.logging import get_logger
T = TypeVar("T")
def _require(val: T | None, msg: str) -> T:
"""Narrow ``X | None`` to ``X``, raising ``ValueError`` if None."""
if val is None:
raise ValueError(msg)
return val
class _MixinBase:
"""Typing base for all repo mixins.
Declares the contract that ``SQLModelRepository`` satisfies at runtime
via MRO composition. Without this, mypy checks each mixin in isolation
and cannot see ``_session`` or cross-mixin helpers.
"""
@abstractmethod
def _session(self):
"""Return a cancellation-safe async session context manager."""
raise NotImplementedError
@staticmethod
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
"""Stub — concrete impl on AttackersCoreMixin via MRO."""
return d
async def _assert_pending(self, session: AsyncSession, topology_id: str) -> None:
"""Stub — concrete impl on TopologyCoreMixin via MRO."""
raise NotImplementedError
async def _check_and_bump_version(
self,
session: AsyncSession,
topology_id: str,
expected_version: Optional[int],
) -> None:
"""Stub — concrete impl on TopologyCoreMixin via MRO."""
raise NotImplementedError
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
"""Stub — concrete impl on TopologyDeckiesMixin via MRO."""
raise NotImplementedError
_log = get_logger("db.pool") _log = get_logger("db.pool")
# Hold strong refs to in-flight cleanup tasks so they aren't GC'd mid-run. # Hold strong refs to in-flight cleanup tasks so they aren't GC'd mid-run.
@@ -66,7 +112,7 @@ def _detach_close(session: AsyncSession) -> None:
task = loop.create_task(_cleanup()) task = loop.create_task(_cleanup())
_cleanup_tasks.add(task) _cleanup_tasks.add(task)
# Consume any exception to silence "Task exception was never retrieved". # Consume any exception to silence "Task exception was never retrieved".
task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception())) task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception())) # type: ignore[func-returns-value]
@asynccontextmanager @asynccontextmanager

View File

@@ -13,11 +13,14 @@ from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import desc, or_, select from sqlalchemy import desc, or_, select
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerIntel from decnet.web.db.models import Attacker, AttackerIntel
class AttackerIntelMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackerIntelMixin(_MixinBase):
"""Mixin: methods composed onto ``SQLModelRepository``. """Mixin: methods composed onto ``SQLModelRepository``.
Expects ``self._session()`` from the base. Expects ``self._session()`` from the base.
@@ -82,13 +85,13 @@ class AttackerIntelMixin:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
async with self._session() as session: async with self._session() as session:
stmt = ( stmt = (
select(Attacker.uuid, Attacker.ip) select(col(Attacker.uuid), col(Attacker.ip))
.outerjoin( .outerjoin(
AttackerIntel, AttackerIntel.attacker_uuid == Attacker.uuid, AttackerIntel, AttackerIntel.attacker_uuid == Attacker.uuid,
) )
.where( .where(
or_( or_(
AttackerIntel.uuid.is_(None), col(AttackerIntel.uuid).is_(None),
AttackerIntel.expires_at < now, AttackerIntel.expires_at < now,
) )
) )

View File

@@ -12,11 +12,14 @@ import uuid as _uuid
from typing import Any, List, Optional from typing import Any, List, Optional
from sqlalchemy import desc, func, outerjoin, select from sqlalchemy import desc, func, outerjoin, select
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerIntel from decnet.web.db.models import Attacker, AttackerIntel
class AttackersCoreMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackersCoreMixin(_MixinBase):
@staticmethod @staticmethod
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]: def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
for key in ("services", "deckies", "fingerprints", "commands"): for key in ("services", "deckies", "fingerprints", "commands"):
@@ -63,16 +66,16 @@ class AttackersCoreMixin:
sort_by: str = "recent", sort_by: str = "recent",
service: Optional[str] = None, service: Optional[str] = None,
) -> List[dict[str, Any]]: ) -> List[dict[str, Any]]:
order = { order: Any = {
"active": desc(Attacker.event_count), "active": desc(Attacker.event_count),
"traversals": desc(Attacker.is_traversal), "traversals": desc(Attacker.is_traversal),
}.get(sort_by, desc(Attacker.last_seen)) }.get(sort_by, desc(Attacker.last_seen))
statement = select(Attacker).order_by(order).offset(offset).limit(limit) statement = select(Attacker).order_by(order).offset(offset).limit(limit)
if search: if search:
statement = statement.where(Attacker.ip.like(f"%{search}%")) statement = statement.where(col(Attacker.ip).like(f"%{search}%"))
if service: if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%')) statement = statement.where(col(Attacker.services).like(f'%"{service}"%'))
async with self._session() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
@@ -121,9 +124,9 @@ class AttackersCoreMixin:
) -> int: ) -> int:
statement = select(func.count()).select_from(Attacker) statement = select(func.count()).select_from(Attacker)
if search: if search:
statement = statement.where(Attacker.ip.like(f"%{search}%")) statement = statement.where(col(Attacker.ip).like(f"%{search}%"))
if service: if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%')) statement = statement.where(col(Attacker.services).like(f'%"{service}"%'))
async with self._session() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)

View File

@@ -10,11 +10,14 @@ import json
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import desc, func, select from sqlalchemy import desc, func, select
from sqlmodel import col
from decnet.web.db.models import Attacker, Bounty, Log from decnet.web.db.models import Attacker, Bounty, Log
class AttackerActivityMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackerActivityMixin(_MixinBase):
async def get_attacker_commands( async def get_attacker_commands(
self, self,
uuid: str, uuid: str,
@@ -24,7 +27,7 @@ class AttackerActivityMixin:
) -> dict[str, Any]: ) -> dict[str, Any]:
async with self._session() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Attacker.commands).where(Attacker.uuid == uuid) select(col(Attacker.commands)).where(Attacker.uuid == uuid)
) )
raw = result.scalar_one_or_none() raw = result.scalar_one_or_none()
if raw is None: if raw is None:
@@ -52,13 +55,13 @@ class AttackerActivityMixin:
""" """
async with self._session() as session: async with self._session() as session:
ip_res = await session.execute( ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == attacker_uuid) select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
) )
ip = ip_res.scalar_one_or_none() ip = ip_res.scalar_one_or_none()
if not ip: if not ip:
return [] return []
rows = await session.execute( rows = await session.execute(
select(Log.service, Log.event_type) select(col(Log.service), col(Log.event_type))
.where(Log.attacker_ip == ip) .where(Log.attacker_ip == ip)
.distinct() .distinct()
) )
@@ -75,7 +78,7 @@ class AttackerActivityMixin:
rotation detection.""" rotation detection."""
async with self._session() as session: async with self._session() as session:
ip_res = await session.execute( ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == attacker_uuid) select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
) )
ip = ip_res.scalar_one_or_none() ip = ip_res.scalar_one_or_none()
if not ip: if not ip:
@@ -104,7 +107,7 @@ class AttackerActivityMixin:
"""Cheap COUNT(*) for XFF-rotation detection.""" """Cheap COUNT(*) for XFF-rotation detection."""
async with self._session() as session: async with self._session() as session:
ip_res = await session.execute( ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == attacker_uuid) select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
) )
ip = ip_res.scalar_one_or_none() ip = ip_res.scalar_one_or_none()
if not ip: if not ip:
@@ -126,7 +129,7 @@ class AttackerActivityMixin:
""" """
async with self._session() as session: async with self._session() as session:
ip_res = await session.execute( ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == uuid) select(col(Attacker.ip)).where(Attacker.uuid == uuid)
) )
ip = ip_res.scalar_one_or_none() ip = ip_res.scalar_one_or_none()
if not ip: if not ip:
@@ -150,7 +153,7 @@ class AttackerActivityMixin:
""" """
async with self._session() as session: async with self._session() as session:
ip_res = await session.execute( ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == uuid) select(col(Attacker.ip)).where(Attacker.uuid == uuid)
) )
ip = ip_res.scalar_one_or_none() ip = ip_res.scalar_one_or_none()
if not ip: if not ip:
@@ -176,7 +179,7 @@ class AttackerActivityMixin:
rows = await session.execute( rows = await session.execute(
select(Log) select(Log)
.where(Log.event_type == "session_recorded") .where(Log.event_type == "session_recorded")
.where(Log.fields.contains(needle)) .where(col(Log.fields).contains(needle))
.limit(1) .limit(1)
) )
row = rows.scalars().first() row = rows.scalars().first()
@@ -192,7 +195,7 @@ class AttackerActivityMixin:
""" """
async with self._session() as session: async with self._session() as session:
ip_res = await session.execute( ip_res = await session.execute(
select(Attacker.ip).where(Attacker.uuid == uuid) select(col(Attacker.ip)).where(Attacker.uuid == uuid)
) )
ip = ip_res.scalar_one_or_none() ip = ip_res.scalar_one_or_none()
if not ip: if not ip:

View File

@@ -7,11 +7,14 @@ from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import select from sqlalchemy import select
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerBehavior from decnet.web.db.models import Attacker, AttackerBehavior
class AttackerBehaviorMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AttackerBehaviorMixin(_MixinBase):
async def upsert_attacker_behavior( async def upsert_attacker_behavior(
self, self,
attacker_uuid: str, attacker_uuid: str,
@@ -56,9 +59,9 @@ class AttackerBehaviorMixin:
return {} return {}
async with self._session() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Attacker.ip, AttackerBehavior) select(col(Attacker.ip), AttackerBehavior)
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
.where(Attacker.ip.in_(ips)) .where(col(Attacker.ip).in_(ips))
) )
out: dict[str, dict[str, Any]] = {} out: dict[str, dict[str, Any]] = {}
for ip, row in result.all(): for ip, row in result.all():

View File

@@ -9,7 +9,9 @@ from sqlalchemy import select
from decnet.web.db.models import SessionProfile from decnet.web.db.models import SessionProfile
class SessionProfilesMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class SessionProfilesMixin(_MixinBase):
async def upsert_session_profile( async def upsert_session_profile(
self, self,
sid: str, sid: str,

View File

@@ -10,7 +10,9 @@ from sqlalchemy import desc, func, select
from decnet.web.db.models import SmtpTarget from decnet.web.db.models import SmtpTarget
class SmtpTargetsMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class SmtpTargetsMixin(_MixinBase):
async def increment_smtp_target(self, attacker_uuid: str, domain: str) -> None: async def increment_smtp_target(self, attacker_uuid: str, domain: str) -> None:
"""Upsert an (attacker_uuid, domain) pair and bump count + last_seen. """Upsert an (attacker_uuid, domain) pair and bump count + last_seen.

View File

@@ -8,7 +8,9 @@ from sqlalchemy import select, update
from decnet.web.db.models import User from decnet.web.db.models import User
class AuthMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class AuthMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``. """Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``.
``_ensure_admin_user`` stays in the package ``__init__`` so the ``_ensure_admin_user`` stays in the package ``__init__`` so the

View File

@@ -7,12 +7,15 @@ from typing import Any, List, Optional
import orjson import orjson
from sqlalchemy import asc, desc, func, or_, select, text from sqlalchemy import asc, desc, func, or_, select, text
from sqlmodel import col
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
from decnet.web.db.models import Bounty from decnet.web.db.models import Bounty
class BountiesMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class BountiesMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def purge_logs_and_bounties(self) -> dict[str, int]: async def purge_logs_and_bounties(self) -> dict[str, int]:
@@ -40,7 +43,7 @@ class BountiesMixin:
async with self._session() as session: async with self._session() as session:
dup = await session.execute( dup = await session.execute(
select(Bounty.id).where( select(col(Bounty.id)).where(
Bounty.bounty_type == data.get("bounty_type"), Bounty.bounty_type == data.get("bounty_type"),
Bounty.attacker_ip == data.get("attacker_ip"), Bounty.attacker_ip == data.get("attacker_ip"),
Bounty.payload == data.get("payload"), Bounty.payload == data.get("payload"),
@@ -63,10 +66,10 @@ class BountiesMixin:
lk = f"%{search}%" lk = f"%{search}%"
statement = statement.where( statement = statement.where(
or_( or_(
Bounty.decky.like(lk), col(Bounty.decky).like(lk),
Bounty.service.like(lk), col(Bounty.service).like(lk),
Bounty.attacker_ip.like(lk), col(Bounty.attacker_ip).like(lk),
Bounty.payload.like(lk), col(Bounty.payload).like(lk),
) )
) )
return statement return statement
@@ -126,7 +129,7 @@ class BountiesMixin:
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
async with self._session() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) select(Bounty).where(col(Bounty.attacker_ip).in_(ips)).order_by(asc(Bounty.timestamp))
) )
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list) grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
for item in result.scalars().all(): for item in result.scalars().all():

View File

@@ -11,11 +11,14 @@ from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import desc, func, select, update from sqlalchemy import desc, func, select, update
from sqlmodel import col
from decnet.web.db.models import AttackerIdentity, Campaign from decnet.web.db.models import AttackerIdentity, Campaign
class CampaignsMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CampaignsMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def get_campaign_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: async def get_campaign_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
@@ -41,7 +44,7 @@ class CampaignsMixin:
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
statement = ( statement = (
select(Campaign) select(Campaign)
.where(Campaign.merged_into_uuid.is_(None)) .where(col(Campaign.merged_into_uuid).is_(None))
.order_by(desc(Campaign.updated_at)) .order_by(desc(Campaign.updated_at))
.offset(offset) .offset(offset)
.limit(limit) .limit(limit)
@@ -54,7 +57,7 @@ class CampaignsMixin:
statement = ( statement = (
select(func.count()) select(func.count())
.select_from(Campaign) .select_from(Campaign)
.where(Campaign.merged_into_uuid.is_(None)) .where(col(Campaign.merged_into_uuid).is_(None))
) )
async with self._session() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
@@ -91,7 +94,7 @@ class CampaignsMixin:
# graph reads. Narrow on purpose — future denormalised # graph reads. Narrow on purpose — future denormalised
# projections (commands_by_phase from log mining, decky-set # projections (commands_by_phase from log mining, decky-set
# aggregates) can land here without churning callers. # aggregates) can land here without churning callers.
statement = select( statement = select( # type: ignore[call-overload, misc]
AttackerIdentity.uuid, AttackerIdentity.uuid,
AttackerIdentity.campaign_id, AttackerIdentity.campaign_id,
AttackerIdentity.merged_into_uuid, AttackerIdentity.merged_into_uuid,

View File

@@ -10,7 +10,9 @@ from sqlalchemy import desc, func, select, update
from decnet.web.db.models import CanaryBlob, CanaryToken, CanaryTrigger from decnet.web.db.models import CanaryBlob, CanaryToken, CanaryTrigger
class CanaryMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CanaryMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]: async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]:

View File

@@ -6,12 +6,15 @@ from datetime import datetime, timezone
from typing import Any, List, Optional from typing import Any, List, Optional
from sqlalchemy import desc, func, or_, select, update from sqlalchemy import desc, func, or_, select, update
from sqlmodel import col
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
from decnet.web.db.models import Credential from decnet.web.db.models import Credential
class CredentialsCoreMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CredentialsCoreMixin(_MixinBase):
async def upsert_credential(self, data: dict[str, Any]) -> int: async def upsert_credential(self, data: dict[str, Any]) -> int:
"""Upsert a credential attempt; returns the row id. """Upsert a credential attempt; returns the row id.
@@ -37,7 +40,7 @@ class CredentialsCoreMixin:
Credential.secret_sha256 == payload["secret_sha256"], Credential.secret_sha256 == payload["secret_sha256"],
# NULL == NULL is False under SQL — branch the predicate. # NULL == NULL is False under SQL — branch the predicate.
(Credential.principal == principal) if principal is not None (Credential.principal == principal) if principal is not None
else Credential.principal.is_(None), else col(Credential.principal).is_(None),
) )
existing = (await session.execute(stmt)).scalar_one_or_none() existing = (await session.execute(stmt)).scalar_one_or_none()
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@@ -48,7 +51,7 @@ class CredentialsCoreMixin:
existing.outcome = payload["outcome"] existing.outcome = payload["outcome"]
session.add(existing) session.add(existing)
await session.commit() await session.commit()
return existing.id # type: ignore[return-value] return existing.id
row = Credential( row = Credential(
attacker_ip=payload["attacker_ip"], attacker_ip=payload["attacker_ip"],
decky_name=payload["decky_name"], decky_name=payload["decky_name"],
@@ -84,10 +87,10 @@ class CredentialsCoreMixin:
lk = f"%{search}%" lk = f"%{search}%"
statement = statement.where( statement = statement.where(
or_( or_(
Credential.decky_name.like(lk), col(Credential.decky_name).like(lk),
Credential.service.like(lk), col(Credential.service).like(lk),
Credential.principal.like(lk), col(Credential.principal).like(lk),
Credential.secret_printable.like(lk), col(Credential.secret_printable).like(lk),
) )
) )
return statement return statement
@@ -188,7 +191,7 @@ class CredentialsCoreMixin:
update(Credential) update(Credential)
.where( .where(
Credential.attacker_ip == attacker_ip, Credential.attacker_ip == attacker_ip,
Credential.attacker_uuid.is_(None), col(Credential.attacker_uuid).is_(None),
) )
.values(attacker_uuid=attacker_uuid) .values(attacker_uuid=attacker_uuid)
) )

View File

@@ -9,11 +9,14 @@ from datetime import datetime, timezone
from typing import Any, List, Optional from typing import Any, List, Optional
from sqlalchemy import desc, func, select from sqlalchemy import desc, func, select
from sqlmodel import col
from decnet.web.db.models import Credential, CredentialReuse from decnet.web.db.models import Credential, CredentialReuse
class CredentialReuseMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class CredentialReuseMixin(_MixinBase):
@staticmethod @staticmethod
def _merge_unique(existing_json: str, value: Optional[str]) -> tuple[str, bool]: def _merge_unique(existing_json: str, value: Optional[str]) -> tuple[str, bool]:
"""Append ``value`` to a JSON list[str] column if not present. """Append ``value`` to a JSON list[str] column if not present.
@@ -117,7 +120,7 @@ class CredentialReuseMixin:
Credential.secret_sha256 == secret_sha256, Credential.secret_sha256 == secret_sha256,
Credential.secret_kind == secret_kind, Credential.secret_kind == secret_kind,
(Credential.principal == principal) if principal is not None (Credential.principal == principal) if principal is not None
else Credential.principal.is_(None), else col(Credential.principal).is_(None),
) )
) )
target_count = (await session.execute(stmt)).scalar() or 0 target_count = (await session.execute(stmt)).scalar() or 0
@@ -150,7 +153,7 @@ class CredentialReuseMixin:
).label("target_count") ).label("target_count")
async with self._session() as session: async with self._session() as session:
group_stmt = ( group_stmt = (
select( select( # type: ignore[call-overload]
Credential.secret_sha256, Credential.secret_sha256,
Credential.secret_kind, Credential.secret_kind,
Credential.principal, Credential.principal,
@@ -171,7 +174,7 @@ class CredentialReuseMixin:
Credential.secret_kind == kind, Credential.secret_kind == kind,
(Credential.principal == principal) (Credential.principal == principal)
if principal is not None if principal is not None
else Credential.principal.is_(None), else col(Credential.principal).is_(None),
) )
rows = (await session.execute(cred_stmt)).scalars().all() rows = (await session.execute(cred_stmt)).scalars().all()
out.append({ out.append({
@@ -253,13 +256,13 @@ class CredentialReuseMixin:
sha_set = {r["secret_sha256"] for r in rows} sha_set = {r["secret_sha256"] for r in rows}
if not sha_set: if not sha_set:
return return
stmt = select( stmt = select( # type: ignore[call-overload]
Credential.secret_sha256, Credential.secret_sha256,
Credential.secret_kind, Credential.secret_kind,
Credential.principal, Credential.principal,
Credential.secret_printable, Credential.secret_printable,
Credential.secret_b64, Credential.secret_b64,
).where(Credential.secret_sha256.in_(sha_set)) ).where(col(Credential.secret_sha256).in_(sha_set))
secret_map: dict[ secret_map: dict[
tuple[str, str, Optional[str]], tuple[str, str, Optional[str]],
tuple[Optional[str], Optional[str]], tuple[Optional[str], Optional[str]],

View File

@@ -11,7 +11,9 @@ from sqlalchemy import asc, select, text
from decnet.web.db.models import DeckyShard from decnet.web.db.models import DeckyShard
class DeckiesMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class DeckiesMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def upsert_decky_shard(self, data: dict[str, Any]) -> None: async def upsert_decky_shard(self, data: dict[str, Any]) -> None:

View File

@@ -8,10 +8,13 @@ import orjson
from sqlalchemy import asc, select, text, update from sqlalchemy import asc, select, text, update
from decnet.web.db.models import DeckyShard, FleetDecky, LOCAL_HOST_SENTINEL from decnet.web.db.models import DeckyShard, FleetDecky, LOCAL_HOST_SENTINEL
from decnet.web.db.sqlmodel_repo._helpers import _deserialize_json_fields from decnet.web.db.sqlmodel_repo._helpers import (
_MixinBase,
_deserialize_json_fields
)
class FleetMixin: class FleetMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``. """Mixin: composed onto ``SQLModelRepository``.
``list_running_deckies`` aggregates topology + fleet + swarm-shard ``list_running_deckies`` aggregates topology + fleet + swarm-shard

View File

@@ -11,11 +11,14 @@ from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import desc, func, select, update from sqlalchemy import desc, func, select, update
from sqlmodel import col
from decnet.web.db.models import Attacker, AttackerIdentity from decnet.web.db.models import Attacker, AttackerIdentity
class IdentitiesMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class IdentitiesMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``. """Mixin: composed onto ``SQLModelRepository``.
``self._deserialize_attacker`` resolves through ``AttackersMixin`` ``self._deserialize_attacker`` resolves through ``AttackersMixin``
@@ -51,7 +54,7 @@ class IdentitiesMixin:
# and a future "merged into" endpoint when we need it. # and a future "merged into" endpoint when we need it.
statement = ( statement = (
select(AttackerIdentity) select(AttackerIdentity)
.where(AttackerIdentity.merged_into_uuid.is_(None)) .where(col(AttackerIdentity.merged_into_uuid).is_(None))
.order_by(desc(AttackerIdentity.updated_at)) .order_by(desc(AttackerIdentity.updated_at))
.offset(offset) .offset(offset)
.limit(limit) .limit(limit)
@@ -64,7 +67,7 @@ class IdentitiesMixin:
statement = ( statement = (
select(func.count()) select(func.count())
.select_from(AttackerIdentity) .select_from(AttackerIdentity)
.where(AttackerIdentity.merged_into_uuid.is_(None)) .where(col(AttackerIdentity.merged_into_uuid).is_(None))
) )
async with self._session() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
@@ -105,7 +108,7 @@ class IdentitiesMixin:
# joined from logs, c2 endpoints aggregated from sessions) can # joined from logs, c2 endpoints aggregated from sessions) can
# land here without churning every caller. ``fingerprints`` is # land here without churning every caller. ``fingerprints`` is
# the raw JSON list — the clusterer parses for JA3 / HASSH. # the raw JSON list — the clusterer parses for JA3 / HASSH.
statement = select( statement = select( # type: ignore[call-overload]
Attacker.uuid, Attacker.asn, Attacker.identity_id, Attacker.fingerprints, Attacker.uuid, Attacker.asn, Attacker.identity_id, Attacker.fingerprints,
).order_by(Attacker.first_seen) ).order_by(Attacker.first_seen)
if limit is not None: if limit is not None:

View File

@@ -15,13 +15,16 @@ from typing import Any, List, Optional
import orjson import orjson
from sqlalchemy import asc, desc, func, or_, select, text from sqlalchemy import asc, desc, func, or_, select, text
from sqlmodel import col
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
from decnet.config import load_state from decnet.config import load_state
from decnet.web.db.models import Log, TopologyDecky from decnet.web.db.models import Log, TopologyDecky
class LogsMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class LogsMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
@staticmethod @staticmethod
@@ -61,9 +64,9 @@ class LogsMixin:
end_time: Optional[str], end_time: Optional[str],
) -> SelectOfScalar: ) -> SelectOfScalar:
if start_time: if start_time:
statement = statement.where(Log.timestamp >= start_time) statement = statement.where(col(Log.timestamp) >= start_time)
if end_time: if end_time:
statement = statement.where(Log.timestamp <= end_time) statement = statement.where(col(Log.timestamp) <= end_time)
if search: if search:
try: try:
@@ -95,10 +98,10 @@ class LogsMixin:
lk = f"%{token}%" lk = f"%{token}%"
statement = statement.where( statement = statement.where(
or_( or_(
Log.raw_line.like(lk), col(Log.raw_line).like(lk),
Log.decky.like(lk), col(Log.decky).like(lk),
Log.service.like(lk), col(Log.service).like(lk),
Log.attacker_ip.like(lk), col(Log.attacker_ip).like(lk),
) )
) )
return statement return statement
@@ -148,7 +151,7 @@ class LogsMixin:
end_time: Optional[str] = None, end_time: Optional[str] = None,
) -> List[dict]: ) -> List[dict]:
statement = ( statement = (
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit) select(Log).where(col(Log.id) > last_id).order_by(asc(Log.id)).limit(limit)
) )
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)

View File

@@ -7,11 +7,14 @@ from typing import Any, Optional
from sqlalchemy import delete as sa_delete from sqlalchemy import delete as sa_delete
from sqlalchemy import desc, func, or_, select from sqlalchemy import desc, func, or_, select
from sqlmodel import col
from decnet.web.db.models import OrchestratorEmail, OrchestratorEvent from decnet.web.db.models import OrchestratorEmail, OrchestratorEvent
class OrchestratorMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class OrchestratorMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def record_orchestrator_event(self, data: dict[str, Any]) -> str: async def record_orchestrator_event(self, data: dict[str, Any]) -> str:
@@ -62,11 +65,11 @@ class OrchestratorMixin:
deleted = 0 deleted = 0
async with self._session() as session: async with self._session() as session:
dst_rows = await session.execute( dst_rows = await session.execute(
select(OrchestratorEvent.dst_decky_uuid).distinct() select(col(OrchestratorEvent.dst_decky_uuid)).distinct()
) )
for (dst,) in dst_rows.all(): for (dst,) in dst_rows.all():
keep = await session.execute( keep = await session.execute(
select(OrchestratorEvent.uuid) select(col(OrchestratorEvent.uuid))
.where(OrchestratorEvent.dst_decky_uuid == dst) .where(OrchestratorEvent.dst_decky_uuid == dst)
.order_by(desc(OrchestratorEvent.ts)) .order_by(desc(OrchestratorEvent.ts))
.limit(per_dst_cap) .limit(per_dst_cap)
@@ -76,7 +79,7 @@ class OrchestratorMixin:
continue continue
stmt = sa_delete(OrchestratorEvent).where( stmt = sa_delete(OrchestratorEvent).where(
OrchestratorEvent.dst_decky_uuid == dst, OrchestratorEvent.dst_decky_uuid == dst,
OrchestratorEvent.uuid.notin_(keep_uuids), col(OrchestratorEvent.uuid).notin_(keep_uuids),
) )
res = await session.execute(stmt) res = await session.execute(stmt)
deleted += res.rowcount or 0 deleted += res.rowcount or 0
@@ -156,7 +159,7 @@ class OrchestratorMixin:
(OrchestratorEmail.sender_email == recipient_email) (OrchestratorEmail.sender_email == recipient_email)
& (OrchestratorEmail.recipient_email == sender_email), & (OrchestratorEmail.recipient_email == sender_email),
), ),
OrchestratorEmail.success.is_(True), col(OrchestratorEmail.success).is_(True),
) )
.order_by(desc(OrchestratorEmail.ts)) .order_by(desc(OrchestratorEmail.ts))
.limit(limit) .limit(limit)
@@ -169,11 +172,11 @@ class OrchestratorMixin:
deleted = 0 deleted = 0
async with self._session() as session: async with self._session() as session:
decky_rows = await session.execute( decky_rows = await session.execute(
select(OrchestratorEmail.mail_decky_uuid).distinct() select(col(OrchestratorEmail.mail_decky_uuid)).distinct()
) )
for (mail_uuid,) in decky_rows.all(): for (mail_uuid,) in decky_rows.all():
keep = await session.execute( keep = await session.execute(
select(OrchestratorEmail.uuid) select(col(OrchestratorEmail.uuid))
.where(OrchestratorEmail.mail_decky_uuid == mail_uuid) .where(OrchestratorEmail.mail_decky_uuid == mail_uuid)
.order_by(desc(OrchestratorEmail.ts)) .order_by(desc(OrchestratorEmail.ts))
.limit(per_decky_cap) .limit(per_decky_cap)
@@ -183,7 +186,7 @@ class OrchestratorMixin:
continue continue
stmt = sa_delete(OrchestratorEmail).where( stmt = sa_delete(OrchestratorEmail).where(
OrchestratorEmail.mail_decky_uuid == mail_uuid, OrchestratorEmail.mail_decky_uuid == mail_uuid,
OrchestratorEmail.uuid.notin_(keep_uuids), col(OrchestratorEmail.uuid).notin_(keep_uuids),
) )
res = await session.execute(stmt) res = await session.execute(stmt)
deleted += res.rowcount or 0 deleted += res.rowcount or 0

View File

@@ -10,7 +10,9 @@ from decnet.web.db.models import RealismConfig, SyntheticFile
from decnet.web.db.models.realism import SYNTHETIC_FILE_BODY_LIMIT from decnet.web.db.models.realism import SYNTHETIC_FILE_BODY_LIMIT
class RealismMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class RealismMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def record_synthetic_file(self, data: dict[str, Any]) -> str: async def record_synthetic_file(self, data: dict[str, Any]) -> str:

View File

@@ -8,7 +8,9 @@ from sqlalchemy import asc, select, text, update
from decnet.web.db.models import SwarmHost from decnet.web.db.models import SwarmHost
class SwarmMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class SwarmMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``.""" """Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``."""
async def add_swarm_host(self, data: dict[str, Any]) -> None: async def add_swarm_host(self, data: dict[str, Any]) -> None:

View File

@@ -11,7 +11,9 @@ from sqlalchemy import select
from decnet.web.db.models import TarpitRule from decnet.web.db.models import TarpitRule
class TarpitMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class TarpitMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def set_tarpit_rule(self, data: dict[str, Any]) -> None: async def set_tarpit_rule(self, data: dict[str, Any]) -> None:

View File

@@ -9,10 +9,15 @@ from sqlalchemy import desc, func, select, text
from decnet.web.db.models import Topology, TopologyStatusEvent from decnet.web.db.models import Topology, TopologyStatusEvent
from decnet.web.db.models.topology import TopologySummary from decnet.web.db.models.topology import TopologySummary
from decnet.web.db.sqlmodel_repo._helpers import _serialize_json_fields from sqlmodel import col
from decnet.web.db.sqlmodel_repo._helpers import (
_MixinBase,
_serialize_json_fields
)
class TopologyCoreMixin: class TopologyCoreMixin(_MixinBase):
"""Topologies CRUD + ``_assert_pending`` / ``_check_and_bump_version``. """Topologies CRUD + ``_assert_pending`` / ``_check_and_bump_version``.
The two private helpers live here because every other topology The two private helpers live here because every other topology
@@ -184,8 +189,8 @@ class TopologyCoreMixin:
"""Return ids of topologies currently in ``active|degraded``.""" """Return ids of topologies currently in ``active|degraded``."""
async with self._session() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Topology.id).where( select(col(Topology.id)).where(
Topology.status.in_(["active", "degraded"]) col(Topology.status).in_(["active", "degraded"])
) )
) )
return [r for r in result.scalars().all()] return [r for r in result.scalars().all()]

View File

@@ -10,12 +10,13 @@ from sqlalchemy import asc, select, text, update
from decnet.web.db.models import TopologyDecky from decnet.web.db.models import TopologyDecky
from decnet.web.db.models.topology import DeckyRow from decnet.web.db.models.topology import DeckyRow
from decnet.web.db.sqlmodel_repo._helpers import ( from decnet.web.db.sqlmodel_repo._helpers import (
_MixinBase,
_deserialize_json_fields, _deserialize_json_fields,
_serialize_json_fields, _serialize_json_fields,
) )
class TopologyDeckiesMixin: class TopologyDeckiesMixin(_MixinBase):
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve """``self._assert_pending`` / ``self._check_and_bump_version`` resolve
through ``TopologyCoreMixin`` via MRO.""" through ``TopologyCoreMixin`` via MRO."""

View File

@@ -9,7 +9,9 @@ from decnet.web.db.models import TopologyEdge, TopologyStatusEvent
from decnet.web.db.models.topology import EdgeRow from decnet.web.db.models.topology import EdgeRow
class TopologyEdgesMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class TopologyEdgesMixin(_MixinBase):
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve """``self._assert_pending`` / ``self._check_and_bump_version`` resolve
through ``TopologyCoreMixin`` via MRO.""" through ``TopologyCoreMixin`` via MRO."""

View File

@@ -9,7 +9,9 @@ from decnet.web.db.models import LAN, TopologyEdge
from decnet.web.db.models.topology import LANRow from decnet.web.db.models.topology import LANRow
class LansMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class LansMixin(_MixinBase):
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve """``self._assert_pending`` / ``self._check_and_bump_version`` resolve
through ``TopologyCoreMixin`` via MRO.""" through ``TopologyCoreMixin`` via MRO."""

View File

@@ -10,7 +10,9 @@ from sqlalchemy import asc, desc, select, text
from decnet.web.db.models import TopologyMutation from decnet.web.db.models import TopologyMutation
class TopologyMutationsMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class TopologyMutationsMixin(_MixinBase):
"""``self._check_and_bump_version`` resolves through """``self._check_and_bump_version`` resolves through
``TopologyCoreMixin`` via MRO.""" ``TopologyCoreMixin`` via MRO."""

View File

@@ -5,11 +5,14 @@ from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import select, update from sqlalchemy import select, update
from sqlmodel import col
from decnet.web.db.models import WebhookSubscription from decnet.web.db.models import WebhookSubscription
class WebhooksMixin: from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
class WebhooksMixin(_MixinBase):
"""Mixin: composed onto ``SQLModelRepository``.""" """Mixin: composed onto ``SQLModelRepository``."""
async def create_webhook_subscription(self, data: dict[str, Any]) -> None: async def create_webhook_subscription(self, data: dict[str, Any]) -> None:
@@ -43,7 +46,7 @@ class WebhooksMixin:
async with self._session() as session: async with self._session() as session:
stmt = select(WebhookSubscription) stmt = select(WebhookSubscription)
if enabled_only: if enabled_only:
stmt = stmt.where(WebhookSubscription.enabled.is_(True)) stmt = stmt.where(col(WebhookSubscription.enabled).is_(True))
stmt = stmt.order_by(WebhookSubscription.created_at) stmt = stmt.order_by(WebhookSubscription.created_at)
result = await session.execute(stmt) result = await session.execute(stmt)
return [r.model_dump() for r in result.scalars().all()] return [r.model_dump() for r in result.scalars().all()]
@@ -100,7 +103,7 @@ class WebhooksMixin:
# the counter informs the circuit-breaker heuristic, not a # the counter informs the circuit-breaker heuristic, not a
# correctness invariant. # correctness invariant.
result = await session.execute( result = await session.execute(
select(WebhookSubscription.consecutive_failures).where( select(col(WebhookSubscription.consecutive_failures)).where(
WebhookSubscription.uuid == uuid WebhookSubscription.uuid == uuid
) )
) )

View File

@@ -154,3 +154,11 @@ ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
warn_redundant_casts = true warn_redundant_casts = true
warn_unused_ignores = true warn_unused_ignores = true
[[tool.mypy.overrides]]
# The pydantic plugin types SQLModel class-level field descriptors as their
# Python value types (str, bool, …) instead of InstrumentedAttribute. Every
# .where(Model.field == value) then becomes where(bool) — a false positive.
# Suppressing arg-type here; genuine argument errors are caught by call-arg.
module = "decnet.web.db.sqlmodel_repo.*"
disable_error_code = ["arg-type"]