fix(types): P2 — wire _MixinBase + col() across sqlmodel_repo; suppress pydantic/SQLModel column typing false positives
- Add _MixinBase abstract class to _helpers.py: declares _session(), _deserialize_attacker(), _assert_pending(), _check_and_bump_version(), and list_running_topology_deckies() so mypy can see cross-mixin contracts - Add _require(val, msg) helper for narrowing T | None → T - Inherit _MixinBase in all 26 leaf mixin classes - Wrap SQLAlchemy column method calls (.is_(), .like(), .notin_(), .in_(), .contains()) with col() from sqlmodel — fixes attr-defined false positives caused by pydantic plugin typing class-level fields as Python value types - Wrap select(Model.field) with select(col(Model.field)) for column projections - Add pyproject.toml [[tool.mypy.overrides]] to disable arg-type in sqlmodel_repo.*: pydantic plugin resolves .where(Model.field == v) as where(bool), a false positive; call-arg still catches real argument errors - Remove 9 stale # type: ignore comments (logging, helpers, credentials) - Fix telemetry.py traced() overload no-redef + misc - Fix logs.py datetime/str operator and nullable PK comparison with col() - sqlmodel_repo/ now has 0 mypy errors
This commit is contained in:
@@ -28,7 +28,7 @@ class _ComponentFilter(logging.Filter):
|
||||
self.component = component
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.decnet_component = self.component # type: ignore[attr-defined]
|
||||
record.decnet_component = self.component
|
||||
return True
|
||||
|
||||
|
||||
@@ -49,14 +49,14 @@ class _TraceContextFilter(logging.Filter):
|
||||
span = trace.get_current_span()
|
||||
ctx = span.get_span_context()
|
||||
if ctx and ctx.trace_id:
|
||||
record.otel_trace_id = format(ctx.trace_id, "032x") # type: ignore[attr-defined]
|
||||
record.otel_span_id = format(ctx.span_id, "016x") # type: ignore[attr-defined]
|
||||
record.otel_trace_id = format(ctx.trace_id, "032x")
|
||||
record.otel_span_id = format(ctx.span_id, "016x")
|
||||
else:
|
||||
record.otel_trace_id = "0" # type: ignore[attr-defined]
|
||||
record.otel_span_id = "0" # type: ignore[attr-defined]
|
||||
record.otel_trace_id = "0"
|
||||
record.otel_span_id = "0"
|
||||
except Exception:
|
||||
record.otel_trace_id = "0" # type: ignore[attr-defined]
|
||||
record.otel_span_id = "0" # type: ignore[attr-defined]
|
||||
record.otel_trace_id = "0"
|
||||
record.otel_span_id = "0"
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ def traced(fn: F) -> F: ...
|
||||
def traced(name: str) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
def traced(fn: Any = None, *, name: str | None = None) -> Any:
|
||||
def traced(fn: Any = None, *, name: str | None = None) -> Any: # type: ignore[misc]
|
||||
"""Decorator that wraps a function in an OTEL span.
|
||||
|
||||
Usage::
|
||||
@@ -168,9 +168,9 @@ def traced(fn: Any = None, *, name: str | None = None) -> Any:
|
||||
# Called as @traced (no arguments)
|
||||
return _wrap(fn, None)
|
||||
# Fallback: @traced() with no args
|
||||
def decorator(f: F) -> F:
|
||||
def _fallback_decorator(f: F) -> F:
|
||||
return _wrap(f, name)
|
||||
return decorator
|
||||
return _fallback_decorator
|
||||
|
||||
|
||||
def _wrap(fn: F, span_name: str | None) -> F:
|
||||
|
||||
@@ -12,14 +12,60 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
import orjson
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from decnet.logging import get_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _require(val: T | None, msg: str) -> T:
|
||||
"""Narrow ``X | None`` to ``X``, raising ``ValueError`` if None."""
|
||||
if val is None:
|
||||
raise ValueError(msg)
|
||||
return val
|
||||
|
||||
|
||||
class _MixinBase:
|
||||
"""Typing base for all repo mixins.
|
||||
|
||||
Declares the contract that ``SQLModelRepository`` satisfies at runtime
|
||||
via MRO composition. Without this, mypy checks each mixin in isolation
|
||||
and cannot see ``_session`` or cross-mixin helpers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _session(self):
|
||||
"""Return a cancellation-safe async session context manager."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Stub — concrete impl on AttackersCoreMixin via MRO."""
|
||||
return d
|
||||
|
||||
async def _assert_pending(self, session: AsyncSession, topology_id: str) -> None:
|
||||
"""Stub — concrete impl on TopologyCoreMixin via MRO."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _check_and_bump_version(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
topology_id: str,
|
||||
expected_version: Optional[int],
|
||||
) -> None:
|
||||
"""Stub — concrete impl on TopologyCoreMixin via MRO."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_running_topology_deckies(self) -> list[dict[str, Any]]:
|
||||
"""Stub — concrete impl on TopologyDeckiesMixin via MRO."""
|
||||
raise NotImplementedError
|
||||
|
||||
_log = get_logger("db.pool")
|
||||
|
||||
# Hold strong refs to in-flight cleanup tasks so they aren't GC'd mid-run.
|
||||
@@ -66,7 +112,7 @@ def _detach_close(session: AsyncSession) -> None:
|
||||
task = loop.create_task(_cleanup())
|
||||
_cleanup_tasks.add(task)
|
||||
# Consume any exception to silence "Task exception was never retrieved".
|
||||
task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception()))
|
||||
task.add_done_callback(lambda t: (_cleanup_tasks.discard(t), t.exception())) # type: ignore[func-returns-value]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -13,11 +13,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import desc, or_, select
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import Attacker, AttackerIntel
|
||||
|
||||
|
||||
class AttackerIntelMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class AttackerIntelMixin(_MixinBase):
|
||||
"""Mixin: methods composed onto ``SQLModelRepository``.
|
||||
|
||||
Expects ``self._session()`` from the base.
|
||||
@@ -82,13 +85,13 @@ class AttackerIntelMixin:
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._session() as session:
|
||||
stmt = (
|
||||
select(Attacker.uuid, Attacker.ip)
|
||||
select(col(Attacker.uuid), col(Attacker.ip))
|
||||
.outerjoin(
|
||||
AttackerIntel, AttackerIntel.attacker_uuid == Attacker.uuid,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
AttackerIntel.uuid.is_(None),
|
||||
col(AttackerIntel.uuid).is_(None),
|
||||
AttackerIntel.expires_at < now,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -12,11 +12,14 @@ import uuid as _uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sqlalchemy import desc, func, outerjoin, select
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import Attacker, AttackerIntel
|
||||
|
||||
|
||||
class AttackersCoreMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class AttackersCoreMixin(_MixinBase):
|
||||
@staticmethod
|
||||
def _deserialize_attacker(d: dict[str, Any]) -> dict[str, Any]:
|
||||
for key in ("services", "deckies", "fingerprints", "commands"):
|
||||
@@ -63,16 +66,16 @@ class AttackersCoreMixin:
|
||||
sort_by: str = "recent",
|
||||
service: Optional[str] = None,
|
||||
) -> List[dict[str, Any]]:
|
||||
order = {
|
||||
order: Any = {
|
||||
"active": desc(Attacker.event_count),
|
||||
"traversals": desc(Attacker.is_traversal),
|
||||
}.get(sort_by, desc(Attacker.last_seen))
|
||||
|
||||
statement = select(Attacker).order_by(order).offset(offset).limit(limit)
|
||||
if search:
|
||||
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
||||
statement = statement.where(col(Attacker.ip).like(f"%{search}%"))
|
||||
if service:
|
||||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||
statement = statement.where(col(Attacker.services).like(f'%"{service}"%'))
|
||||
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
@@ -121,9 +124,9 @@ class AttackersCoreMixin:
|
||||
) -> int:
|
||||
statement = select(func.count()).select_from(Attacker)
|
||||
if search:
|
||||
statement = statement.where(Attacker.ip.like(f"%{search}%"))
|
||||
statement = statement.where(col(Attacker.ip).like(f"%{search}%"))
|
||||
if service:
|
||||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||
statement = statement.where(col(Attacker.services).like(f'%"{service}"%'))
|
||||
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
|
||||
@@ -10,11 +10,14 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import Attacker, Bounty, Log
|
||||
|
||||
|
||||
class AttackerActivityMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class AttackerActivityMixin(_MixinBase):
|
||||
async def get_attacker_commands(
|
||||
self,
|
||||
uuid: str,
|
||||
@@ -24,7 +27,7 @@ class AttackerActivityMixin:
|
||||
) -> dict[str, Any]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Attacker.commands).where(Attacker.uuid == uuid)
|
||||
select(col(Attacker.commands)).where(Attacker.uuid == uuid)
|
||||
)
|
||||
raw = result.scalar_one_or_none()
|
||||
if raw is None:
|
||||
@@ -52,13 +55,13 @@ class AttackerActivityMixin:
|
||||
"""
|
||||
async with self._session() as session:
|
||||
ip_res = await session.execute(
|
||||
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
|
||||
select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
|
||||
)
|
||||
ip = ip_res.scalar_one_or_none()
|
||||
if not ip:
|
||||
return []
|
||||
rows = await session.execute(
|
||||
select(Log.service, Log.event_type)
|
||||
select(col(Log.service), col(Log.event_type))
|
||||
.where(Log.attacker_ip == ip)
|
||||
.distinct()
|
||||
)
|
||||
@@ -75,7 +78,7 @@ class AttackerActivityMixin:
|
||||
rotation detection."""
|
||||
async with self._session() as session:
|
||||
ip_res = await session.execute(
|
||||
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
|
||||
select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
|
||||
)
|
||||
ip = ip_res.scalar_one_or_none()
|
||||
if not ip:
|
||||
@@ -104,7 +107,7 @@ class AttackerActivityMixin:
|
||||
"""Cheap COUNT(*) for XFF-rotation detection."""
|
||||
async with self._session() as session:
|
||||
ip_res = await session.execute(
|
||||
select(Attacker.ip).where(Attacker.uuid == attacker_uuid)
|
||||
select(col(Attacker.ip)).where(Attacker.uuid == attacker_uuid)
|
||||
)
|
||||
ip = ip_res.scalar_one_or_none()
|
||||
if not ip:
|
||||
@@ -126,7 +129,7 @@ class AttackerActivityMixin:
|
||||
"""
|
||||
async with self._session() as session:
|
||||
ip_res = await session.execute(
|
||||
select(Attacker.ip).where(Attacker.uuid == uuid)
|
||||
select(col(Attacker.ip)).where(Attacker.uuid == uuid)
|
||||
)
|
||||
ip = ip_res.scalar_one_or_none()
|
||||
if not ip:
|
||||
@@ -150,7 +153,7 @@ class AttackerActivityMixin:
|
||||
"""
|
||||
async with self._session() as session:
|
||||
ip_res = await session.execute(
|
||||
select(Attacker.ip).where(Attacker.uuid == uuid)
|
||||
select(col(Attacker.ip)).where(Attacker.uuid == uuid)
|
||||
)
|
||||
ip = ip_res.scalar_one_or_none()
|
||||
if not ip:
|
||||
@@ -176,7 +179,7 @@ class AttackerActivityMixin:
|
||||
rows = await session.execute(
|
||||
select(Log)
|
||||
.where(Log.event_type == "session_recorded")
|
||||
.where(Log.fields.contains(needle))
|
||||
.where(col(Log.fields).contains(needle))
|
||||
.limit(1)
|
||||
)
|
||||
row = rows.scalars().first()
|
||||
@@ -192,7 +195,7 @@ class AttackerActivityMixin:
|
||||
"""
|
||||
async with self._session() as session:
|
||||
ip_res = await session.execute(
|
||||
select(Attacker.ip).where(Attacker.uuid == uuid)
|
||||
select(col(Attacker.ip)).where(Attacker.uuid == uuid)
|
||||
)
|
||||
ip = ip_res.scalar_one_or_none()
|
||||
if not ip:
|
||||
|
||||
@@ -7,11 +7,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import Attacker, AttackerBehavior
|
||||
|
||||
|
||||
class AttackerBehaviorMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class AttackerBehaviorMixin(_MixinBase):
|
||||
async def upsert_attacker_behavior(
|
||||
self,
|
||||
attacker_uuid: str,
|
||||
@@ -56,9 +59,9 @@ class AttackerBehaviorMixin:
|
||||
return {}
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Attacker.ip, AttackerBehavior)
|
||||
select(col(Attacker.ip), AttackerBehavior)
|
||||
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
||||
.where(Attacker.ip.in_(ips))
|
||||
.where(col(Attacker.ip).in_(ips))
|
||||
)
|
||||
out: dict[str, dict[str, Any]] = {}
|
||||
for ip, row in result.all():
|
||||
|
||||
@@ -9,7 +9,9 @@ from sqlalchemy import select
|
||||
from decnet.web.db.models import SessionProfile
|
||||
|
||||
|
||||
class SessionProfilesMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class SessionProfilesMixin(_MixinBase):
|
||||
async def upsert_session_profile(
|
||||
self,
|
||||
sid: str,
|
||||
|
||||
@@ -10,7 +10,9 @@ from sqlalchemy import desc, func, select
|
||||
from decnet.web.db.models import SmtpTarget
|
||||
|
||||
|
||||
class SmtpTargetsMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class SmtpTargetsMixin(_MixinBase):
|
||||
async def increment_smtp_target(self, attacker_uuid: str, domain: str) -> None:
|
||||
"""Upsert an (attacker_uuid, domain) pair and bump count + last_seen.
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ from sqlalchemy import select, update
|
||||
from decnet.web.db.models import User
|
||||
|
||||
|
||||
class AuthMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class AuthMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``.
|
||||
|
||||
``_ensure_admin_user`` stays in the package ``__init__`` so the
|
||||
|
||||
@@ -7,12 +7,15 @@ from typing import Any, List, Optional
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import asc, desc, func, or_, select, text
|
||||
from sqlmodel import col
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from decnet.web.db.models import Bounty
|
||||
|
||||
|
||||
class BountiesMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class BountiesMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
||||
@@ -40,7 +43,7 @@ class BountiesMixin:
|
||||
|
||||
async with self._session() as session:
|
||||
dup = await session.execute(
|
||||
select(Bounty.id).where(
|
||||
select(col(Bounty.id)).where(
|
||||
Bounty.bounty_type == data.get("bounty_type"),
|
||||
Bounty.attacker_ip == data.get("attacker_ip"),
|
||||
Bounty.payload == data.get("payload"),
|
||||
@@ -63,10 +66,10 @@ class BountiesMixin:
|
||||
lk = f"%{search}%"
|
||||
statement = statement.where(
|
||||
or_(
|
||||
Bounty.decky.like(lk),
|
||||
Bounty.service.like(lk),
|
||||
Bounty.attacker_ip.like(lk),
|
||||
Bounty.payload.like(lk),
|
||||
col(Bounty.decky).like(lk),
|
||||
col(Bounty.service).like(lk),
|
||||
col(Bounty.attacker_ip).like(lk),
|
||||
col(Bounty.payload).like(lk),
|
||||
)
|
||||
)
|
||||
return statement
|
||||
@@ -126,7 +129,7 @@ class BountiesMixin:
|
||||
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
||||
select(Bounty).where(col(Bounty.attacker_ip).in_(ips)).order_by(asc(Bounty.timestamp))
|
||||
)
|
||||
grouped: dict[str, List[dict[str, Any]]] = defaultdict(list)
|
||||
for item in result.scalars().all():
|
||||
|
||||
@@ -11,11 +11,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import desc, func, select, update
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import AttackerIdentity, Campaign
|
||||
|
||||
|
||||
class CampaignsMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class CampaignsMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def get_campaign_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||
@@ -41,7 +44,7 @@ class CampaignsMixin:
|
||||
) -> list[dict[str, Any]]:
|
||||
statement = (
|
||||
select(Campaign)
|
||||
.where(Campaign.merged_into_uuid.is_(None))
|
||||
.where(col(Campaign.merged_into_uuid).is_(None))
|
||||
.order_by(desc(Campaign.updated_at))
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
@@ -54,7 +57,7 @@ class CampaignsMixin:
|
||||
statement = (
|
||||
select(func.count())
|
||||
.select_from(Campaign)
|
||||
.where(Campaign.merged_into_uuid.is_(None))
|
||||
.where(col(Campaign.merged_into_uuid).is_(None))
|
||||
)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
@@ -91,7 +94,7 @@ class CampaignsMixin:
|
||||
# graph reads. Narrow on purpose — future denormalised
|
||||
# projections (commands_by_phase from log mining, decky-set
|
||||
# aggregates) can land here without churning callers.
|
||||
statement = select(
|
||||
statement = select( # type: ignore[call-overload, misc]
|
||||
AttackerIdentity.uuid,
|
||||
AttackerIdentity.campaign_id,
|
||||
AttackerIdentity.merged_into_uuid,
|
||||
|
||||
@@ -10,7 +10,9 @@ from sqlalchemy import desc, func, select, update
|
||||
from decnet.web.db.models import CanaryBlob, CanaryToken, CanaryTrigger
|
||||
|
||||
|
||||
class CanaryMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class CanaryMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def upsert_canary_blob(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -6,12 +6,15 @@ from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sqlalchemy import desc, func, or_, select, update
|
||||
from sqlmodel import col
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from decnet.web.db.models import Credential
|
||||
|
||||
|
||||
class CredentialsCoreMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class CredentialsCoreMixin(_MixinBase):
|
||||
async def upsert_credential(self, data: dict[str, Any]) -> int:
|
||||
"""Upsert a credential attempt; returns the row id.
|
||||
|
||||
@@ -37,7 +40,7 @@ class CredentialsCoreMixin:
|
||||
Credential.secret_sha256 == payload["secret_sha256"],
|
||||
# NULL == NULL is False under SQL — branch the predicate.
|
||||
(Credential.principal == principal) if principal is not None
|
||||
else Credential.principal.is_(None),
|
||||
else col(Credential.principal).is_(None),
|
||||
)
|
||||
existing = (await session.execute(stmt)).scalar_one_or_none()
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -48,7 +51,7 @@ class CredentialsCoreMixin:
|
||||
existing.outcome = payload["outcome"]
|
||||
session.add(existing)
|
||||
await session.commit()
|
||||
return existing.id # type: ignore[return-value]
|
||||
return existing.id
|
||||
row = Credential(
|
||||
attacker_ip=payload["attacker_ip"],
|
||||
decky_name=payload["decky_name"],
|
||||
@@ -84,10 +87,10 @@ class CredentialsCoreMixin:
|
||||
lk = f"%{search}%"
|
||||
statement = statement.where(
|
||||
or_(
|
||||
Credential.decky_name.like(lk),
|
||||
Credential.service.like(lk),
|
||||
Credential.principal.like(lk),
|
||||
Credential.secret_printable.like(lk),
|
||||
col(Credential.decky_name).like(lk),
|
||||
col(Credential.service).like(lk),
|
||||
col(Credential.principal).like(lk),
|
||||
col(Credential.secret_printable).like(lk),
|
||||
)
|
||||
)
|
||||
return statement
|
||||
@@ -188,7 +191,7 @@ class CredentialsCoreMixin:
|
||||
update(Credential)
|
||||
.where(
|
||||
Credential.attacker_ip == attacker_ip,
|
||||
Credential.attacker_uuid.is_(None),
|
||||
col(Credential.attacker_uuid).is_(None),
|
||||
)
|
||||
.values(attacker_uuid=attacker_uuid)
|
||||
)
|
||||
|
||||
@@ -9,11 +9,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import Credential, CredentialReuse
|
||||
|
||||
|
||||
class CredentialReuseMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class CredentialReuseMixin(_MixinBase):
|
||||
@staticmethod
|
||||
def _merge_unique(existing_json: str, value: Optional[str]) -> tuple[str, bool]:
|
||||
"""Append ``value`` to a JSON list[str] column if not present.
|
||||
@@ -117,7 +120,7 @@ class CredentialReuseMixin:
|
||||
Credential.secret_sha256 == secret_sha256,
|
||||
Credential.secret_kind == secret_kind,
|
||||
(Credential.principal == principal) if principal is not None
|
||||
else Credential.principal.is_(None),
|
||||
else col(Credential.principal).is_(None),
|
||||
)
|
||||
)
|
||||
target_count = (await session.execute(stmt)).scalar() or 0
|
||||
@@ -150,7 +153,7 @@ class CredentialReuseMixin:
|
||||
).label("target_count")
|
||||
async with self._session() as session:
|
||||
group_stmt = (
|
||||
select(
|
||||
select( # type: ignore[call-overload]
|
||||
Credential.secret_sha256,
|
||||
Credential.secret_kind,
|
||||
Credential.principal,
|
||||
@@ -171,7 +174,7 @@ class CredentialReuseMixin:
|
||||
Credential.secret_kind == kind,
|
||||
(Credential.principal == principal)
|
||||
if principal is not None
|
||||
else Credential.principal.is_(None),
|
||||
else col(Credential.principal).is_(None),
|
||||
)
|
||||
rows = (await session.execute(cred_stmt)).scalars().all()
|
||||
out.append({
|
||||
@@ -253,13 +256,13 @@ class CredentialReuseMixin:
|
||||
sha_set = {r["secret_sha256"] for r in rows}
|
||||
if not sha_set:
|
||||
return
|
||||
stmt = select(
|
||||
stmt = select( # type: ignore[call-overload]
|
||||
Credential.secret_sha256,
|
||||
Credential.secret_kind,
|
||||
Credential.principal,
|
||||
Credential.secret_printable,
|
||||
Credential.secret_b64,
|
||||
).where(Credential.secret_sha256.in_(sha_set))
|
||||
).where(col(Credential.secret_sha256).in_(sha_set))
|
||||
secret_map: dict[
|
||||
tuple[str, str, Optional[str]],
|
||||
tuple[Optional[str], Optional[str]],
|
||||
|
||||
@@ -11,7 +11,9 @@ from sqlalchemy import asc, select, text
|
||||
from decnet.web.db.models import DeckyShard
|
||||
|
||||
|
||||
class DeckiesMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class DeckiesMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def upsert_decky_shard(self, data: dict[str, Any]) -> None:
|
||||
|
||||
@@ -8,10 +8,13 @@ import orjson
|
||||
from sqlalchemy import asc, select, text, update
|
||||
|
||||
from decnet.web.db.models import DeckyShard, FleetDecky, LOCAL_HOST_SENTINEL
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _deserialize_json_fields
|
||||
from decnet.web.db.sqlmodel_repo._helpers import (
|
||||
_MixinBase,
|
||||
_deserialize_json_fields
|
||||
)
|
||||
|
||||
|
||||
class FleetMixin:
|
||||
class FleetMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``.
|
||||
|
||||
``list_running_deckies`` aggregates topology + fleet + swarm-shard
|
||||
|
||||
@@ -11,11 +11,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import desc, func, select, update
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import Attacker, AttackerIdentity
|
||||
|
||||
|
||||
class IdentitiesMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class IdentitiesMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``.
|
||||
|
||||
``self._deserialize_attacker`` resolves through ``AttackersMixin``
|
||||
@@ -51,7 +54,7 @@ class IdentitiesMixin:
|
||||
# and a future "merged into" endpoint when we need it.
|
||||
statement = (
|
||||
select(AttackerIdentity)
|
||||
.where(AttackerIdentity.merged_into_uuid.is_(None))
|
||||
.where(col(AttackerIdentity.merged_into_uuid).is_(None))
|
||||
.order_by(desc(AttackerIdentity.updated_at))
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
@@ -64,7 +67,7 @@ class IdentitiesMixin:
|
||||
statement = (
|
||||
select(func.count())
|
||||
.select_from(AttackerIdentity)
|
||||
.where(AttackerIdentity.merged_into_uuid.is_(None))
|
||||
.where(col(AttackerIdentity.merged_into_uuid).is_(None))
|
||||
)
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
@@ -105,7 +108,7 @@ class IdentitiesMixin:
|
||||
# joined from logs, c2 endpoints aggregated from sessions) can
|
||||
# land here without churning every caller. ``fingerprints`` is
|
||||
# the raw JSON list — the clusterer parses for JA3 / HASSH.
|
||||
statement = select(
|
||||
statement = select( # type: ignore[call-overload]
|
||||
Attacker.uuid, Attacker.asn, Attacker.identity_id, Attacker.fingerprints,
|
||||
).order_by(Attacker.first_seen)
|
||||
if limit is not None:
|
||||
|
||||
@@ -15,13 +15,16 @@ from typing import Any, List, Optional
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import asc, desc, func, or_, select, text
|
||||
from sqlmodel import col
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
from decnet.config import load_state
|
||||
from decnet.web.db.models import Log, TopologyDecky
|
||||
|
||||
|
||||
class LogsMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class LogsMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
@staticmethod
|
||||
@@ -61,9 +64,9 @@ class LogsMixin:
|
||||
end_time: Optional[str],
|
||||
) -> SelectOfScalar:
|
||||
if start_time:
|
||||
statement = statement.where(Log.timestamp >= start_time)
|
||||
statement = statement.where(col(Log.timestamp) >= start_time)
|
||||
if end_time:
|
||||
statement = statement.where(Log.timestamp <= end_time)
|
||||
statement = statement.where(col(Log.timestamp) <= end_time)
|
||||
|
||||
if search:
|
||||
try:
|
||||
@@ -95,10 +98,10 @@ class LogsMixin:
|
||||
lk = f"%{token}%"
|
||||
statement = statement.where(
|
||||
or_(
|
||||
Log.raw_line.like(lk),
|
||||
Log.decky.like(lk),
|
||||
Log.service.like(lk),
|
||||
Log.attacker_ip.like(lk),
|
||||
col(Log.raw_line).like(lk),
|
||||
col(Log.decky).like(lk),
|
||||
col(Log.service).like(lk),
|
||||
col(Log.attacker_ip).like(lk),
|
||||
)
|
||||
)
|
||||
return statement
|
||||
@@ -148,7 +151,7 @@ class LogsMixin:
|
||||
end_time: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
statement = (
|
||||
select(Log).where(Log.id > last_id).order_by(asc(Log.id)).limit(limit)
|
||||
select(Log).where(col(Log.id) > last_id).order_by(asc(Log.id)).limit(limit)
|
||||
)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
|
||||
@@ -7,11 +7,14 @@ from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
from sqlalchemy import desc, func, or_, select
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import OrchestratorEmail, OrchestratorEvent
|
||||
|
||||
|
||||
class OrchestratorMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class OrchestratorMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def record_orchestrator_event(self, data: dict[str, Any]) -> str:
|
||||
@@ -62,11 +65,11 @@ class OrchestratorMixin:
|
||||
deleted = 0
|
||||
async with self._session() as session:
|
||||
dst_rows = await session.execute(
|
||||
select(OrchestratorEvent.dst_decky_uuid).distinct()
|
||||
select(col(OrchestratorEvent.dst_decky_uuid)).distinct()
|
||||
)
|
||||
for (dst,) in dst_rows.all():
|
||||
keep = await session.execute(
|
||||
select(OrchestratorEvent.uuid)
|
||||
select(col(OrchestratorEvent.uuid))
|
||||
.where(OrchestratorEvent.dst_decky_uuid == dst)
|
||||
.order_by(desc(OrchestratorEvent.ts))
|
||||
.limit(per_dst_cap)
|
||||
@@ -76,7 +79,7 @@ class OrchestratorMixin:
|
||||
continue
|
||||
stmt = sa_delete(OrchestratorEvent).where(
|
||||
OrchestratorEvent.dst_decky_uuid == dst,
|
||||
OrchestratorEvent.uuid.notin_(keep_uuids),
|
||||
col(OrchestratorEvent.uuid).notin_(keep_uuids),
|
||||
)
|
||||
res = await session.execute(stmt)
|
||||
deleted += res.rowcount or 0
|
||||
@@ -156,7 +159,7 @@ class OrchestratorMixin:
|
||||
(OrchestratorEmail.sender_email == recipient_email)
|
||||
& (OrchestratorEmail.recipient_email == sender_email),
|
||||
),
|
||||
OrchestratorEmail.success.is_(True),
|
||||
col(OrchestratorEmail.success).is_(True),
|
||||
)
|
||||
.order_by(desc(OrchestratorEmail.ts))
|
||||
.limit(limit)
|
||||
@@ -169,11 +172,11 @@ class OrchestratorMixin:
|
||||
deleted = 0
|
||||
async with self._session() as session:
|
||||
decky_rows = await session.execute(
|
||||
select(OrchestratorEmail.mail_decky_uuid).distinct()
|
||||
select(col(OrchestratorEmail.mail_decky_uuid)).distinct()
|
||||
)
|
||||
for (mail_uuid,) in decky_rows.all():
|
||||
keep = await session.execute(
|
||||
select(OrchestratorEmail.uuid)
|
||||
select(col(OrchestratorEmail.uuid))
|
||||
.where(OrchestratorEmail.mail_decky_uuid == mail_uuid)
|
||||
.order_by(desc(OrchestratorEmail.ts))
|
||||
.limit(per_decky_cap)
|
||||
@@ -183,7 +186,7 @@ class OrchestratorMixin:
|
||||
continue
|
||||
stmt = sa_delete(OrchestratorEmail).where(
|
||||
OrchestratorEmail.mail_decky_uuid == mail_uuid,
|
||||
OrchestratorEmail.uuid.notin_(keep_uuids),
|
||||
col(OrchestratorEmail.uuid).notin_(keep_uuids),
|
||||
)
|
||||
res = await session.execute(stmt)
|
||||
deleted += res.rowcount or 0
|
||||
|
||||
@@ -10,7 +10,9 @@ from decnet.web.db.models import RealismConfig, SyntheticFile
|
||||
from decnet.web.db.models.realism import SYNTHETIC_FILE_BODY_LIMIT
|
||||
|
||||
|
||||
class RealismMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class RealismMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def record_synthetic_file(self, data: dict[str, Any]) -> str:
|
||||
|
||||
@@ -8,7 +8,9 @@ from sqlalchemy import asc, select, text, update
|
||||
from decnet.web.db.models import SwarmHost
|
||||
|
||||
|
||||
class SwarmMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class SwarmMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``. Expects ``self._session()``."""
|
||||
|
||||
async def add_swarm_host(self, data: dict[str, Any]) -> None:
|
||||
|
||||
@@ -11,7 +11,9 @@ from sqlalchemy import select
|
||||
from decnet.web.db.models import TarpitRule
|
||||
|
||||
|
||||
class TarpitMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class TarpitMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def set_tarpit_rule(self, data: dict[str, Any]) -> None:
|
||||
|
||||
@@ -9,10 +9,15 @@ from sqlalchemy import desc, func, select, text
|
||||
|
||||
from decnet.web.db.models import Topology, TopologyStatusEvent
|
||||
from decnet.web.db.models.topology import TopologySummary
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _serialize_json_fields
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.sqlmodel_repo._helpers import (
|
||||
_MixinBase,
|
||||
_serialize_json_fields
|
||||
)
|
||||
|
||||
|
||||
class TopologyCoreMixin:
|
||||
class TopologyCoreMixin(_MixinBase):
|
||||
"""Topologies CRUD + ``_assert_pending`` / ``_check_and_bump_version``.
|
||||
|
||||
The two private helpers live here because every other topology
|
||||
@@ -184,8 +189,8 @@ class TopologyCoreMixin:
|
||||
"""Return ids of topologies currently in ``active|degraded``."""
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Topology.id).where(
|
||||
Topology.status.in_(["active", "degraded"])
|
||||
select(col(Topology.id)).where(
|
||||
col(Topology.status).in_(["active", "degraded"])
|
||||
)
|
||||
)
|
||||
return [r for r in result.scalars().all()]
|
||||
|
||||
@@ -10,12 +10,13 @@ from sqlalchemy import asc, select, text, update
|
||||
from decnet.web.db.models import TopologyDecky
|
||||
from decnet.web.db.models.topology import DeckyRow
|
||||
from decnet.web.db.sqlmodel_repo._helpers import (
|
||||
_MixinBase,
|
||||
_deserialize_json_fields,
|
||||
_serialize_json_fields,
|
||||
)
|
||||
|
||||
|
||||
class TopologyDeckiesMixin:
|
||||
class TopologyDeckiesMixin(_MixinBase):
|
||||
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve
|
||||
through ``TopologyCoreMixin`` via MRO."""
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ from decnet.web.db.models import TopologyEdge, TopologyStatusEvent
|
||||
from decnet.web.db.models.topology import EdgeRow
|
||||
|
||||
|
||||
class TopologyEdgesMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class TopologyEdgesMixin(_MixinBase):
|
||||
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve
|
||||
through ``TopologyCoreMixin`` via MRO."""
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ from decnet.web.db.models import LAN, TopologyEdge
|
||||
from decnet.web.db.models.topology import LANRow
|
||||
|
||||
|
||||
class LansMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class LansMixin(_MixinBase):
|
||||
"""``self._assert_pending`` / ``self._check_and_bump_version`` resolve
|
||||
through ``TopologyCoreMixin`` via MRO."""
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ from sqlalchemy import asc, desc, select, text
|
||||
from decnet.web.db.models import TopologyMutation
|
||||
|
||||
|
||||
class TopologyMutationsMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class TopologyMutationsMixin(_MixinBase):
|
||||
"""``self._check_and_bump_version`` resolves through
|
||||
``TopologyCoreMixin`` via MRO."""
|
||||
|
||||
|
||||
@@ -5,11 +5,14 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlmodel import col
|
||||
|
||||
from decnet.web.db.models import WebhookSubscription
|
||||
|
||||
|
||||
class WebhooksMixin:
|
||||
from decnet.web.db.sqlmodel_repo._helpers import _MixinBase
|
||||
|
||||
class WebhooksMixin(_MixinBase):
|
||||
"""Mixin: composed onto ``SQLModelRepository``."""
|
||||
|
||||
async def create_webhook_subscription(self, data: dict[str, Any]) -> None:
|
||||
@@ -43,7 +46,7 @@ class WebhooksMixin:
|
||||
async with self._session() as session:
|
||||
stmt = select(WebhookSubscription)
|
||||
if enabled_only:
|
||||
stmt = stmt.where(WebhookSubscription.enabled.is_(True))
|
||||
stmt = stmt.where(col(WebhookSubscription.enabled).is_(True))
|
||||
stmt = stmt.order_by(WebhookSubscription.created_at)
|
||||
result = await session.execute(stmt)
|
||||
return [r.model_dump() for r in result.scalars().all()]
|
||||
@@ -100,7 +103,7 @@ class WebhooksMixin:
|
||||
# the counter informs the circuit-breaker heuristic, not a
|
||||
# correctness invariant.
|
||||
result = await session.execute(
|
||||
select(WebhookSubscription.consecutive_failures).where(
|
||||
select(col(WebhookSubscription.consecutive_failures)).where(
|
||||
WebhookSubscription.uuid == uuid
|
||||
)
|
||||
)
|
||||
|
||||
@@ -154,3 +154,11 @@ ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
# The pydantic plugin types SQLModel class-level field descriptors as their
|
||||
# Python value types (str, bool, …) instead of InstrumentedAttribute. Every
|
||||
# .where(Model.field == value) then becomes where(bool) — a false positive.
|
||||
# Suppressing arg-type here; genuine argument errors are caught by call-arg.
|
||||
module = "decnet.web.db.sqlmodel_repo.*"
|
||||
disable_error_code = ["arg-type"]
|
||||
|
||||
Reference in New Issue
Block a user