From bd406090a78b7afd44346fcfb32cc9bdc8a201b7 Mon Sep 17 00:00:00 2001 From: anti Date: Fri, 17 Apr 2026 14:49:13 -0400 Subject: [PATCH] fix: re-seed admin password when still unfinalized (must_change_password=True) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _ensure_admin_user was strict insert-if-missing: once a stale hash landed in decnet.db (e.g. from a deploy that used a different DECNET_ADMIN_PASSWORD), login silently 401'd because changing the env var later had no effect. Now on startup: if the admin still has must_change_password=True (they never finalized their own password), re-sync the hash from the current env var. Once the admin sets a real password, we leave it alone. Found via locustfile.py login storm — see tests/test_admin_seed.py. Note: this commit also bundles uncommitted pool-management work already present in sqlmodel_repo.py from prior sessions. --- decnet/web/db/sqlmodel_repo.py | 129 +++++++++++++++++++++++++-------- tests/test_admin_seed.py | 66 +++++++++++++++++ 2 files changed, 164 insertions(+), 31 deletions(-) create mode 100644 tests/test_admin_seed.py diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py index 3b0cf86..35b5fcf 100644 --- a/decnet/web/db/sqlmodel_repo.py +++ b/decnet/web/db/sqlmodel_repo.py @@ -28,6 +28,60 @@ from decnet.web.db.repository import BaseRepository from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior +from contextlib import asynccontextmanager + +from decnet.logging import get_logger + +_log = get_logger("db.pool") + + +async def _force_close(session: AsyncSession) -> None: + """Close a session, forcing connection invalidation if clean close fails. + + Shielded from cancellation and catches every exception class including + CancelledError. If session.close() fails (corrupted connection), we + invalidate the underlying connection so the pool discards it entirely + rather than leaving it checked-out forever. + """ + try: + await asyncio.shield(session.close()) + except BaseException: + # close() failed — connection is likely corrupted. + # Try to invalidate the raw connection so the pool drops it. + try: + bind = session.get_bind() + if hasattr(bind, "dispose"): + pass # don't dispose the whole engine + # The sync_session holds the connection record; invalidating + # it tells the pool to discard rather than reuse. + sync = session.sync_session + if sync.is_active: + sync.rollback() + sync.close() + except BaseException: + _log.debug("force-close: fallback cleanup failed", exc_info=True) + + +@asynccontextmanager +async def _safe_session(factory: async_sessionmaker[AsyncSession]): + """Session context manager that shields cleanup from cancellation. + + Under high concurrency, uvicorn cancels request tasks when clients + disconnect. If a CancelledError hits during session.__aexit__, + the underlying DB connection is orphaned — never returned to the + pool. This wrapper ensures close() always completes, preventing + the pool-drain death spiral. + """ + session = factory() + try: + yield session + except BaseException: + await _force_close(session) + raise + else: + await _force_close(session) + + class SQLModelRepository(BaseRepository): """Concrete SQLModel/SQLAlchemy-async repository. @@ -38,6 +92,10 @@ class SQLModelRepository(BaseRepository): engine: AsyncEngine session_factory: async_sessionmaker[AsyncSession] + def _session(self): + """Return a cancellation-safe session context manager.""" + return _safe_session(self.session_factory) + # ------------------------------------------------------------ lifecycle async def initialize(self) -> None: @@ -56,11 +114,12 @@ class SQLModelRepository(BaseRepository): await self._ensure_admin_user() async def _ensure_admin_user(self) -> None: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(User).where(User.username == DECNET_ADMIN_USER) ) - if not result.scalar_one_or_none(): + existing = result.scalar_one_or_none() + if existing is None: session.add(User( uuid=str(uuid.uuid4()), username=DECNET_ADMIN_USER, @@ -69,6 +128,14 @@ class SQLModelRepository(BaseRepository): must_change_password=True, )) await session.commit() + return + # Self-heal env drift: if admin never finalized their password, + # re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave + # the user's chosen password alone. + if existing.must_change_password: + existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD) + session.add(existing) + await session.commit() async def _migrate_attackers_table(self) -> None: """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" @@ -88,7 +155,7 @@ class SQLModelRepository(BaseRepository): except ValueError: pass - async with self.session_factory() as session: + async with self._session() as session: session.add(Log(**data)) await session.commit() @@ -171,12 +238,12 @@ class SQLModelRepository(BaseRepository): ) statement = self._apply_filters(statement, search, start_time, end_time) - async with self.session_factory() as session: + async with self._session() as session: results = await session.execute(statement) return [log.model_dump(mode="json") for log in results.scalars().all()] async def get_max_log_id(self) -> int: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(select(func.max(Log.id))) val = result.scalar() return val if val is not None else 0 @@ -194,7 +261,7 @@ class SQLModelRepository(BaseRepository): ) statement = self._apply_filters(statement, search, start_time, end_time) - async with self.session_factory() as session: + async with self._session() as session: results = await session.execute(statement) return [log.model_dump(mode="json") for log in results.scalars().all()] @@ -207,7 +274,7 @@ class SQLModelRepository(BaseRepository): statement = select(func.count()).select_from(Log) statement = self._apply_filters(statement, search, start_time, end_time) - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(statement) return result.scalar() or 0 @@ -222,7 +289,7 @@ class SQLModelRepository(BaseRepository): raise NotImplementedError async def get_stats_summary(self) -> dict[str, Any]: - async with self.session_factory() as session: + async with self._session() as session: total_logs = ( await session.execute(select(func.count()).select_from(Log)) ).scalar() or 0 @@ -249,7 +316,7 @@ class SQLModelRepository(BaseRepository): # --------------------------------------------------------------- users async def get_user_by_username(self, username: str) -> Optional[dict]: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(User).where(User.username == username) ) @@ -257,7 +324,7 @@ class SQLModelRepository(BaseRepository): return user.model_dump() if user else None async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(User).where(User.uuid == uuid) ) @@ -265,14 +332,14 @@ class SQLModelRepository(BaseRepository): return user.model_dump() if user else None async def create_user(self, user_data: dict[str, Any]) -> None: - async with self.session_factory() as session: + async with self._session() as session: session.add(User(**user_data)) await session.commit() async def update_user_password( self, uuid: str, password_hash: str, must_change_password: bool = False ) -> None: - async with self.session_factory() as session: + async with self._session() as session: await session.execute( update(User) .where(User.uuid == uuid) @@ -284,12 +351,12 @@ class SQLModelRepository(BaseRepository): await session.commit() async def list_users(self) -> list[dict]: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(select(User)) return [u.model_dump() for u in result.scalars().all()] async def delete_user(self, uuid: str) -> bool: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(select(User).where(User.uuid == uuid)) user = result.scalar_one_or_none() if not user: @@ -299,14 +366,14 @@ class SQLModelRepository(BaseRepository): return True async def update_user_role(self, uuid: str, role: str) -> None: - async with self.session_factory() as session: + async with self._session() as session: await session.execute( update(User).where(User.uuid == uuid).values(role=role) ) await session.commit() async def purge_logs_and_bounties(self) -> dict[str, int]: - async with self.session_factory() as session: + async with self._session() as session: logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount # attacker_behavior has FK → attackers.uuid; delete children first. @@ -326,7 +393,7 @@ class SQLModelRepository(BaseRepository): if "payload" in data and isinstance(data["payload"], dict): data["payload"] = json.dumps(data["payload"]) - async with self.session_factory() as session: + async with self._session() as session: dup = await session.execute( select(Bounty.id).where( Bounty.bounty_type == data.get("bounty_type"), @@ -374,7 +441,7 @@ class SQLModelRepository(BaseRepository): ) statement = self._apply_bounty_filters(statement, bounty_type, search) - async with self.session_factory() as session: + async with self._session() as session: results = await session.execute(statement) final = [] for item in results.scalars().all(): @@ -392,12 +459,12 @@ class SQLModelRepository(BaseRepository): statement = select(func.count()).select_from(Bounty) statement = self._apply_bounty_filters(statement, bounty_type, search) - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(statement) return result.scalar() or 0 async def get_state(self, key: str) -> Optional[dict[str, Any]]: - async with self.session_factory() as session: + async with self._session() as session: statement = select(State).where(State.key == key) result = await session.execute(statement) state = result.scalar_one_or_none() @@ -406,7 +473,7 @@ class SQLModelRepository(BaseRepository): return None async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 - async with self.session_factory() as session: + async with self._session() as session: statement = select(State).where(State.key == key) result = await session.execute(statement) state = result.scalar_one_or_none() @@ -424,7 +491,7 @@ class SQLModelRepository(BaseRepository): async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(Bounty).order_by(asc(Bounty.timestamp)) ) @@ -440,7 +507,7 @@ class SQLModelRepository(BaseRepository): async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) ) @@ -455,7 +522,7 @@ class SQLModelRepository(BaseRepository): return dict(grouped) async def upsert_attacker(self, data: dict[str, Any]) -> str: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(Attacker).where(Attacker.ip == data["ip"]) ) @@ -477,7 +544,7 @@ class SQLModelRepository(BaseRepository): attacker_uuid: str, data: dict[str, Any], ) -> None: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(AttackerBehavior).where( AttackerBehavior.attacker_uuid == attacker_uuid @@ -497,7 +564,7 @@ class SQLModelRepository(BaseRepository): self, attacker_uuid: str, ) -> Optional[dict[str, Any]]: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(AttackerBehavior).where( AttackerBehavior.attacker_uuid == attacker_uuid @@ -514,7 +581,7 @@ class SQLModelRepository(BaseRepository): ) -> dict[str, dict[str, Any]]: if not ips: return {} - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(Attacker.ip, AttackerBehavior) .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) @@ -556,7 +623,7 @@ class SQLModelRepository(BaseRepository): return d async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(Attacker).where(Attacker.uuid == uuid) ) @@ -584,7 +651,7 @@ class SQLModelRepository(BaseRepository): if service: statement = statement.where(Attacker.services.like(f'%"{service}"%')) - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(statement) return [ self._deserialize_attacker(a.model_dump(mode="json")) @@ -600,7 +667,7 @@ class SQLModelRepository(BaseRepository): if service: statement = statement.where(Attacker.services.like(f'%"{service}"%')) - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute(statement) return result.scalar() or 0 @@ -611,7 +678,7 @@ class SQLModelRepository(BaseRepository): offset: int = 0, service: Optional[str] = None, ) -> dict[str, Any]: - async with self.session_factory() as session: + async with self._session() as session: result = await session.execute( select(Attacker.commands).where(Attacker.uuid == uuid) ) diff --git a/tests/test_admin_seed.py b/tests/test_admin_seed.py new file mode 100644 index 0000000..4c91d5b --- /dev/null +++ b/tests/test_admin_seed.py @@ -0,0 +1,66 @@ +""" +Tests for _ensure_admin_user env-drift self-healing. + +Scenario: DECNET_ADMIN_PASSWORD changes between runs while the SQLite DB +persists on disk. Previously _ensure_admin_user was strictly insert-if-missing, +so the stale hash from the first seed locked out every subsequent login. + +Contract: if the admin still has must_change_password=True (they never +finalized their own password), the stored hash re-syncs from the env. +Once the admin picks a real password, we never touch it. +""" +import pytest + +from decnet.web.auth import verify_password +from decnet.web.db.sqlite.repository import SQLiteRepository + + +@pytest.mark.asyncio +async def test_admin_seeded_on_empty_db(tmp_path, monkeypatch): + monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first") + repo = SQLiteRepository(db_path=str(tmp_path / "t.db")) + await repo.initialize() + user = await repo.get_user_by_username("admin") + assert user is not None + assert verify_password("first", user["password_hash"]) + assert user["must_change_password"] is True or user["must_change_password"] == 1 + + +@pytest.mark.asyncio +async def test_admin_password_resyncs_when_not_finalized(tmp_path, monkeypatch): + db = str(tmp_path / "t.db") + + monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first") + r1 = SQLiteRepository(db_path=db) + await r1.initialize() + + monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "second") + r2 = SQLiteRepository(db_path=db) + await r2.initialize() + + user = await r2.get_user_by_username("admin") + assert verify_password("second", user["password_hash"]) + assert not verify_password("first", user["password_hash"]) + + +@pytest.mark.asyncio +async def test_finalized_admin_password_is_preserved(tmp_path, monkeypatch): + db = str(tmp_path / "t.db") + + monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "seed") + r1 = SQLiteRepository(db_path=db) + await r1.initialize() + admin = await r1.get_user_by_username("admin") + # Simulate the admin finalising their password via the change-password flow. + from decnet.web.auth import get_password_hash + await r1.update_user_password( + admin["uuid"], get_password_hash("chosen"), must_change_password=False + ) + + monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "different") + r2 = SQLiteRepository(db_path=db) + await r2.initialize() + + user = await r2.get_user_by_username("admin") + assert verify_password("chosen", user["password_hash"]) + assert not verify_password("different", user["password_hash"])