fix: re-seed admin password when still unfinalized (must_change_password=True)

_ensure_admin_user was strict insert-if-missing: once a stale hash landed
in decnet.db (e.g. from a deploy that used a different DECNET_ADMIN_PASSWORD),
login silently 401'd because changing the env var later had no effect.

Now on startup: if the admin still has must_change_password=True (they
never finalized their own password), re-sync the hash from the current
env var. Once the admin sets a real password, we leave it alone.

Found via locustfile.py login storm — see tests/test_admin_seed.py.

Note: this commit also bundles uncommitted pool-management work already
present in sqlmodel_repo.py from prior sessions.
This commit is contained in:
2026-04-17 14:49:13 -04:00
parent e22d057e68
commit bd406090a7
2 changed files with 164 additions and 31 deletions

View File

@@ -28,6 +28,60 @@ from decnet.web.db.repository import BaseRepository
from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior
from contextlib import asynccontextmanager
from decnet.logging import get_logger
_log = get_logger("db.pool")
async def _force_close(session: AsyncSession) -> None:
"""Close a session, forcing connection invalidation if clean close fails.
Shielded from cancellation and catches every exception class including
CancelledError. If session.close() fails (corrupted connection), we
invalidate the underlying connection so the pool discards it entirely
rather than leaving it checked-out forever.
"""
try:
await asyncio.shield(session.close())
except BaseException:
# close() failed — connection is likely corrupted.
# Try to invalidate the raw connection so the pool drops it.
try:
bind = session.get_bind()
if hasattr(bind, "dispose"):
pass # don't dispose the whole engine
# The sync_session holds the connection record; invalidating
# it tells the pool to discard rather than reuse.
sync = session.sync_session
if sync.is_active:
sync.rollback()
sync.close()
except BaseException:
_log.debug("force-close: fallback cleanup failed", exc_info=True)
@asynccontextmanager
async def _safe_session(factory: async_sessionmaker[AsyncSession]):
"""Session context manager that shields cleanup from cancellation.
Under high concurrency, uvicorn cancels request tasks when clients
disconnect. If a CancelledError hits during session.__aexit__,
the underlying DB connection is orphaned — never returned to the
pool. This wrapper ensures close() always completes, preventing
the pool-drain death spiral.
"""
session = factory()
try:
yield session
except BaseException:
await _force_close(session)
raise
else:
await _force_close(session)
class SQLModelRepository(BaseRepository): class SQLModelRepository(BaseRepository):
"""Concrete SQLModel/SQLAlchemy-async repository. """Concrete SQLModel/SQLAlchemy-async repository.
@@ -38,6 +92,10 @@ class SQLModelRepository(BaseRepository):
engine: AsyncEngine engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession] session_factory: async_sessionmaker[AsyncSession]
def _session(self):
"""Return a cancellation-safe session context manager."""
return _safe_session(self.session_factory)
# ------------------------------------------------------------ lifecycle # ------------------------------------------------------------ lifecycle
async def initialize(self) -> None: async def initialize(self) -> None:
@@ -56,11 +114,12 @@ class SQLModelRepository(BaseRepository):
await self._ensure_admin_user() await self._ensure_admin_user()
async def _ensure_admin_user(self) -> None: async def _ensure_admin_user(self) -> None:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(User).where(User.username == DECNET_ADMIN_USER) select(User).where(User.username == DECNET_ADMIN_USER)
) )
if not result.scalar_one_or_none(): existing = result.scalar_one_or_none()
if existing is None:
session.add(User( session.add(User(
uuid=str(uuid.uuid4()), uuid=str(uuid.uuid4()),
username=DECNET_ADMIN_USER, username=DECNET_ADMIN_USER,
@@ -69,6 +128,14 @@ class SQLModelRepository(BaseRepository):
must_change_password=True, must_change_password=True,
)) ))
await session.commit() await session.commit()
return
# Self-heal env drift: if admin never finalized their password,
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
# the user's chosen password alone.
if existing.must_change_password:
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
session.add(existing)
await session.commit()
async def _migrate_attackers_table(self) -> None: async def _migrate_attackers_table(self) -> None:
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable).""" """Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
@@ -88,7 +155,7 @@ class SQLModelRepository(BaseRepository):
except ValueError: except ValueError:
pass pass
async with self.session_factory() as session: async with self._session() as session:
session.add(Log(**data)) session.add(Log(**data))
await session.commit() await session.commit()
@@ -171,12 +238,12 @@ class SQLModelRepository(BaseRepository):
) )
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session: async with self._session() as session:
results = await session.execute(statement) results = await session.execute(statement)
return [log.model_dump(mode="json") for log in results.scalars().all()] return [log.model_dump(mode="json") for log in results.scalars().all()]
async def get_max_log_id(self) -> int: async def get_max_log_id(self) -> int:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(select(func.max(Log.id))) result = await session.execute(select(func.max(Log.id)))
val = result.scalar() val = result.scalar()
return val if val is not None else 0 return val if val is not None else 0
@@ -194,7 +261,7 @@ class SQLModelRepository(BaseRepository):
) )
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session: async with self._session() as session:
results = await session.execute(statement) results = await session.execute(statement)
return [log.model_dump(mode="json") for log in results.scalars().all()] return [log.model_dump(mode="json") for log in results.scalars().all()]
@@ -207,7 +274,7 @@ class SQLModelRepository(BaseRepository):
statement = select(func.count()).select_from(Log) statement = select(func.count()).select_from(Log)
statement = self._apply_filters(statement, search, start_time, end_time) statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
return result.scalar() or 0 return result.scalar() or 0
@@ -222,7 +289,7 @@ class SQLModelRepository(BaseRepository):
raise NotImplementedError raise NotImplementedError
async def get_stats_summary(self) -> dict[str, Any]: async def get_stats_summary(self) -> dict[str, Any]:
async with self.session_factory() as session: async with self._session() as session:
total_logs = ( total_logs = (
await session.execute(select(func.count()).select_from(Log)) await session.execute(select(func.count()).select_from(Log))
).scalar() or 0 ).scalar() or 0
@@ -249,7 +316,7 @@ class SQLModelRepository(BaseRepository):
# --------------------------------------------------------------- users # --------------------------------------------------------------- users
async def get_user_by_username(self, username: str) -> Optional[dict]: async def get_user_by_username(self, username: str) -> Optional[dict]:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(User).where(User.username == username) select(User).where(User.username == username)
) )
@@ -257,7 +324,7 @@ class SQLModelRepository(BaseRepository):
return user.model_dump() if user else None return user.model_dump() if user else None
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]: async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(User).where(User.uuid == uuid) select(User).where(User.uuid == uuid)
) )
@@ -265,14 +332,14 @@ class SQLModelRepository(BaseRepository):
return user.model_dump() if user else None return user.model_dump() if user else None
async def create_user(self, user_data: dict[str, Any]) -> None: async def create_user(self, user_data: dict[str, Any]) -> None:
async with self.session_factory() as session: async with self._session() as session:
session.add(User(**user_data)) session.add(User(**user_data))
await session.commit() await session.commit()
async def update_user_password( async def update_user_password(
self, uuid: str, password_hash: str, must_change_password: bool = False self, uuid: str, password_hash: str, must_change_password: bool = False
) -> None: ) -> None:
async with self.session_factory() as session: async with self._session() as session:
await session.execute( await session.execute(
update(User) update(User)
.where(User.uuid == uuid) .where(User.uuid == uuid)
@@ -284,12 +351,12 @@ class SQLModelRepository(BaseRepository):
await session.commit() await session.commit()
async def list_users(self) -> list[dict]: async def list_users(self) -> list[dict]:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(select(User)) result = await session.execute(select(User))
return [u.model_dump() for u in result.scalars().all()] return [u.model_dump() for u in result.scalars().all()]
async def delete_user(self, uuid: str) -> bool: async def delete_user(self, uuid: str) -> bool:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(select(User).where(User.uuid == uuid)) result = await session.execute(select(User).where(User.uuid == uuid))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
@@ -299,14 +366,14 @@ class SQLModelRepository(BaseRepository):
return True return True
async def update_user_role(self, uuid: str, role: str) -> None: async def update_user_role(self, uuid: str, role: str) -> None:
async with self.session_factory() as session: async with self._session() as session:
await session.execute( await session.execute(
update(User).where(User.uuid == uuid).values(role=role) update(User).where(User.uuid == uuid).values(role=role)
) )
await session.commit() await session.commit()
async def purge_logs_and_bounties(self) -> dict[str, int]: async def purge_logs_and_bounties(self) -> dict[str, int]:
async with self.session_factory() as session: async with self._session() as session:
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
# attacker_behavior has FK → attackers.uuid; delete children first. # attacker_behavior has FK → attackers.uuid; delete children first.
@@ -326,7 +393,7 @@ class SQLModelRepository(BaseRepository):
if "payload" in data and isinstance(data["payload"], dict): if "payload" in data and isinstance(data["payload"], dict):
data["payload"] = json.dumps(data["payload"]) data["payload"] = json.dumps(data["payload"])
async with self.session_factory() as session: async with self._session() as session:
dup = await session.execute( dup = await session.execute(
select(Bounty.id).where( select(Bounty.id).where(
Bounty.bounty_type == data.get("bounty_type"), Bounty.bounty_type == data.get("bounty_type"),
@@ -374,7 +441,7 @@ class SQLModelRepository(BaseRepository):
) )
statement = self._apply_bounty_filters(statement, bounty_type, search) statement = self._apply_bounty_filters(statement, bounty_type, search)
async with self.session_factory() as session: async with self._session() as session:
results = await session.execute(statement) results = await session.execute(statement)
final = [] final = []
for item in results.scalars().all(): for item in results.scalars().all():
@@ -392,12 +459,12 @@ class SQLModelRepository(BaseRepository):
statement = select(func.count()).select_from(Bounty) statement = select(func.count()).select_from(Bounty)
statement = self._apply_bounty_filters(statement, bounty_type, search) statement = self._apply_bounty_filters(statement, bounty_type, search)
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
return result.scalar() or 0 return result.scalar() or 0
async def get_state(self, key: str) -> Optional[dict[str, Any]]: async def get_state(self, key: str) -> Optional[dict[str, Any]]:
async with self.session_factory() as session: async with self._session() as session:
statement = select(State).where(State.key == key) statement = select(State).where(State.key == key)
result = await session.execute(statement) result = await session.execute(statement)
state = result.scalar_one_or_none() state = result.scalar_one_or_none()
@@ -406,7 +473,7 @@ class SQLModelRepository(BaseRepository):
return None return None
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401 async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
async with self.session_factory() as session: async with self._session() as session:
statement = select(State).where(State.key == key) statement = select(State).where(State.key == key)
result = await session.execute(statement) result = await session.execute(statement)
state = result.scalar_one_or_none() state = result.scalar_one_or_none()
@@ -424,7 +491,7 @@ class SQLModelRepository(BaseRepository):
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
from collections import defaultdict from collections import defaultdict
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Bounty).order_by(asc(Bounty.timestamp)) select(Bounty).order_by(asc(Bounty.timestamp))
) )
@@ -440,7 +507,7 @@ class SQLModelRepository(BaseRepository):
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]: async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
from collections import defaultdict from collections import defaultdict
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp)) select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
) )
@@ -455,7 +522,7 @@ class SQLModelRepository(BaseRepository):
return dict(grouped) return dict(grouped)
async def upsert_attacker(self, data: dict[str, Any]) -> str: async def upsert_attacker(self, data: dict[str, Any]) -> str:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Attacker).where(Attacker.ip == data["ip"]) select(Attacker).where(Attacker.ip == data["ip"])
) )
@@ -477,7 +544,7 @@ class SQLModelRepository(BaseRepository):
attacker_uuid: str, attacker_uuid: str,
data: dict[str, Any], data: dict[str, Any],
) -> None: ) -> None:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(AttackerBehavior).where( select(AttackerBehavior).where(
AttackerBehavior.attacker_uuid == attacker_uuid AttackerBehavior.attacker_uuid == attacker_uuid
@@ -497,7 +564,7 @@ class SQLModelRepository(BaseRepository):
self, self,
attacker_uuid: str, attacker_uuid: str,
) -> Optional[dict[str, Any]]: ) -> Optional[dict[str, Any]]:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(AttackerBehavior).where( select(AttackerBehavior).where(
AttackerBehavior.attacker_uuid == attacker_uuid AttackerBehavior.attacker_uuid == attacker_uuid
@@ -514,7 +581,7 @@ class SQLModelRepository(BaseRepository):
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
if not ips: if not ips:
return {} return {}
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Attacker.ip, AttackerBehavior) select(Attacker.ip, AttackerBehavior)
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid) .join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
@@ -556,7 +623,7 @@ class SQLModelRepository(BaseRepository):
return d return d
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Attacker).where(Attacker.uuid == uuid) select(Attacker).where(Attacker.uuid == uuid)
) )
@@ -584,7 +651,7 @@ class SQLModelRepository(BaseRepository):
if service: if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%')) statement = statement.where(Attacker.services.like(f'%"{service}"%'))
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
return [ return [
self._deserialize_attacker(a.model_dump(mode="json")) self._deserialize_attacker(a.model_dump(mode="json"))
@@ -600,7 +667,7 @@ class SQLModelRepository(BaseRepository):
if service: if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%')) statement = statement.where(Attacker.services.like(f'%"{service}"%'))
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute(statement) result = await session.execute(statement)
return result.scalar() or 0 return result.scalar() or 0
@@ -611,7 +678,7 @@ class SQLModelRepository(BaseRepository):
offset: int = 0, offset: int = 0,
service: Optional[str] = None, service: Optional[str] = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
async with self.session_factory() as session: async with self._session() as session:
result = await session.execute( result = await session.execute(
select(Attacker.commands).where(Attacker.uuid == uuid) select(Attacker.commands).where(Attacker.uuid == uuid)
) )

66
tests/test_admin_seed.py Normal file
View File

@@ -0,0 +1,66 @@
"""
Tests for _ensure_admin_user env-drift self-healing.
Scenario: DECNET_ADMIN_PASSWORD changes between runs while the SQLite DB
persists on disk. Previously _ensure_admin_user was strictly insert-if-missing,
so the stale hash from the first seed locked out every subsequent login.
Contract: if the admin still has must_change_password=True (they never
finalized their own password), the stored hash re-syncs from the env.
Once the admin picks a real password, we never touch it.
"""
import pytest
from decnet.web.auth import verify_password
from decnet.web.db.sqlite.repository import SQLiteRepository
@pytest.mark.asyncio
async def test_admin_seeded_on_empty_db(tmp_path, monkeypatch):
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
repo = SQLiteRepository(db_path=str(tmp_path / "t.db"))
await repo.initialize()
user = await repo.get_user_by_username("admin")
assert user is not None
assert verify_password("first", user["password_hash"])
assert user["must_change_password"] is True or user["must_change_password"] == 1
@pytest.mark.asyncio
async def test_admin_password_resyncs_when_not_finalized(tmp_path, monkeypatch):
db = str(tmp_path / "t.db")
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
r1 = SQLiteRepository(db_path=db)
await r1.initialize()
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "second")
r2 = SQLiteRepository(db_path=db)
await r2.initialize()
user = await r2.get_user_by_username("admin")
assert verify_password("second", user["password_hash"])
assert not verify_password("first", user["password_hash"])
@pytest.mark.asyncio
async def test_finalized_admin_password_is_preserved(tmp_path, monkeypatch):
db = str(tmp_path / "t.db")
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "seed")
r1 = SQLiteRepository(db_path=db)
await r1.initialize()
admin = await r1.get_user_by_username("admin")
# Simulate the admin finalising their password via the change-password flow.
from decnet.web.auth import get_password_hash
await r1.update_user_password(
admin["uuid"], get_password_hash("chosen"), must_change_password=False
)
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "different")
r2 = SQLiteRepository(db_path=db)
await r2.initialize()
user = await r2.get_user_by_username("admin")
assert verify_password("chosen", user["password_hash"])
assert not verify_password("different", user["password_hash"])