fix: re-seed admin password when still unfinalized (must_change_password=True)
_ensure_admin_user was strict insert-if-missing: once a stale hash landed in decnet.db (e.g. from a deploy that used a different DECNET_ADMIN_PASSWORD), login silently 401'd because changing the env var later had no effect. Now on startup: if the admin still has must_change_password=True (they never finalized their own password), re-sync the hash from the current env var. Once the admin sets a real password, we leave it alone. Found via locustfile.py login storm — see tests/test_admin_seed.py. Note: this commit also bundles uncommitted pool-management work already present in sqlmodel_repo.py from prior sessions.
This commit is contained in:
@@ -28,6 +28,60 @@ from decnet.web.db.repository import BaseRepository
|
|||||||
from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior
|
from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior
|
||||||
|
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from decnet.logging import get_logger
|
||||||
|
|
||||||
|
_log = get_logger("db.pool")
|
||||||
|
|
||||||
|
|
||||||
|
async def _force_close(session: AsyncSession) -> None:
|
||||||
|
"""Close a session, forcing connection invalidation if clean close fails.
|
||||||
|
|
||||||
|
Shielded from cancellation and catches every exception class including
|
||||||
|
CancelledError. If session.close() fails (corrupted connection), we
|
||||||
|
invalidate the underlying connection so the pool discards it entirely
|
||||||
|
rather than leaving it checked-out forever.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await asyncio.shield(session.close())
|
||||||
|
except BaseException:
|
||||||
|
# close() failed — connection is likely corrupted.
|
||||||
|
# Try to invalidate the raw connection so the pool drops it.
|
||||||
|
try:
|
||||||
|
bind = session.get_bind()
|
||||||
|
if hasattr(bind, "dispose"):
|
||||||
|
pass # don't dispose the whole engine
|
||||||
|
# The sync_session holds the connection record; invalidating
|
||||||
|
# it tells the pool to discard rather than reuse.
|
||||||
|
sync = session.sync_session
|
||||||
|
if sync.is_active:
|
||||||
|
sync.rollback()
|
||||||
|
sync.close()
|
||||||
|
except BaseException:
|
||||||
|
_log.debug("force-close: fallback cleanup failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _safe_session(factory: async_sessionmaker[AsyncSession]):
|
||||||
|
"""Session context manager that shields cleanup from cancellation.
|
||||||
|
|
||||||
|
Under high concurrency, uvicorn cancels request tasks when clients
|
||||||
|
disconnect. If a CancelledError hits during session.__aexit__,
|
||||||
|
the underlying DB connection is orphaned — never returned to the
|
||||||
|
pool. This wrapper ensures close() always completes, preventing
|
||||||
|
the pool-drain death spiral.
|
||||||
|
"""
|
||||||
|
session = factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except BaseException:
|
||||||
|
await _force_close(session)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
await _force_close(session)
|
||||||
|
|
||||||
|
|
||||||
class SQLModelRepository(BaseRepository):
|
class SQLModelRepository(BaseRepository):
|
||||||
"""Concrete SQLModel/SQLAlchemy-async repository.
|
"""Concrete SQLModel/SQLAlchemy-async repository.
|
||||||
|
|
||||||
@@ -38,6 +92,10 @@ class SQLModelRepository(BaseRepository):
|
|||||||
engine: AsyncEngine
|
engine: AsyncEngine
|
||||||
session_factory: async_sessionmaker[AsyncSession]
|
session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
|
||||||
|
def _session(self):
|
||||||
|
"""Return a cancellation-safe session context manager."""
|
||||||
|
return _safe_session(self.session_factory)
|
||||||
|
|
||||||
# ------------------------------------------------------------ lifecycle
|
# ------------------------------------------------------------ lifecycle
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@@ -56,11 +114,12 @@ class SQLModelRepository(BaseRepository):
|
|||||||
await self._ensure_admin_user()
|
await self._ensure_admin_user()
|
||||||
|
|
||||||
async def _ensure_admin_user(self) -> None:
|
async def _ensure_admin_user(self) -> None:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(User).where(User.username == DECNET_ADMIN_USER)
|
select(User).where(User.username == DECNET_ADMIN_USER)
|
||||||
)
|
)
|
||||||
if not result.scalar_one_or_none():
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is None:
|
||||||
session.add(User(
|
session.add(User(
|
||||||
uuid=str(uuid.uuid4()),
|
uuid=str(uuid.uuid4()),
|
||||||
username=DECNET_ADMIN_USER,
|
username=DECNET_ADMIN_USER,
|
||||||
@@ -69,6 +128,14 @@ class SQLModelRepository(BaseRepository):
|
|||||||
must_change_password=True,
|
must_change_password=True,
|
||||||
))
|
))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
return
|
||||||
|
# Self-heal env drift: if admin never finalized their password,
|
||||||
|
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
|
||||||
|
# the user's chosen password alone.
|
||||||
|
if existing.must_change_password:
|
||||||
|
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
|
||||||
|
session.add(existing)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def _migrate_attackers_table(self) -> None:
|
async def _migrate_attackers_table(self) -> None:
|
||||||
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
||||||
@@ -88,7 +155,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
session.add(Log(**data))
|
session.add(Log(**data))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -171,12 +238,12 @@ class SQLModelRepository(BaseRepository):
|
|||||||
)
|
)
|
||||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
results = await session.execute(statement)
|
results = await session.execute(statement)
|
||||||
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||||||
|
|
||||||
async def get_max_log_id(self) -> int:
|
async def get_max_log_id(self) -> int:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(select(func.max(Log.id)))
|
result = await session.execute(select(func.max(Log.id)))
|
||||||
val = result.scalar()
|
val = result.scalar()
|
||||||
return val if val is not None else 0
|
return val if val is not None else 0
|
||||||
@@ -194,7 +261,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
)
|
)
|
||||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
results = await session.execute(statement)
|
results = await session.execute(statement)
|
||||||
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||||||
|
|
||||||
@@ -207,7 +274,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
statement = select(func.count()).select_from(Log)
|
statement = select(func.count()).select_from(Log)
|
||||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
return result.scalar() or 0
|
return result.scalar() or 0
|
||||||
|
|
||||||
@@ -222,7 +289,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_stats_summary(self) -> dict[str, Any]:
|
async def get_stats_summary(self) -> dict[str, Any]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
total_logs = (
|
total_logs = (
|
||||||
await session.execute(select(func.count()).select_from(Log))
|
await session.execute(select(func.count()).select_from(Log))
|
||||||
).scalar() or 0
|
).scalar() or 0
|
||||||
@@ -249,7 +316,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
# --------------------------------------------------------------- users
|
# --------------------------------------------------------------- users
|
||||||
|
|
||||||
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(User).where(User.username == username)
|
select(User).where(User.username == username)
|
||||||
)
|
)
|
||||||
@@ -257,7 +324,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
return user.model_dump() if user else None
|
return user.model_dump() if user else None
|
||||||
|
|
||||||
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(User).where(User.uuid == uuid)
|
select(User).where(User.uuid == uuid)
|
||||||
)
|
)
|
||||||
@@ -265,14 +332,14 @@ class SQLModelRepository(BaseRepository):
|
|||||||
return user.model_dump() if user else None
|
return user.model_dump() if user else None
|
||||||
|
|
||||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
session.add(User(**user_data))
|
session.add(User(**user_data))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_user_password(
|
async def update_user_password(
|
||||||
self, uuid: str, password_hash: str, must_change_password: bool = False
|
self, uuid: str, password_hash: str, must_change_password: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.uuid == uuid)
|
.where(User.uuid == uuid)
|
||||||
@@ -284,12 +351,12 @@ class SQLModelRepository(BaseRepository):
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def list_users(self) -> list[dict]:
|
async def list_users(self) -> list[dict]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(select(User))
|
result = await session.execute(select(User))
|
||||||
return [u.model_dump() for u in result.scalars().all()]
|
return [u.model_dump() for u in result.scalars().all()]
|
||||||
|
|
||||||
async def delete_user(self, uuid: str) -> bool:
|
async def delete_user(self, uuid: str) -> bool:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(select(User).where(User.uuid == uuid))
|
result = await session.execute(select(User).where(User.uuid == uuid))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
if not user:
|
if not user:
|
||||||
@@ -299,14 +366,14 @@ class SQLModelRepository(BaseRepository):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def update_user_role(self, uuid: str, role: str) -> None:
|
async def update_user_role(self, uuid: str, role: str) -> None:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
update(User).where(User.uuid == uuid).values(role=role)
|
update(User).where(User.uuid == uuid).values(role=role)
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
|
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
|
||||||
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
|
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
|
||||||
# attacker_behavior has FK → attackers.uuid; delete children first.
|
# attacker_behavior has FK → attackers.uuid; delete children first.
|
||||||
@@ -326,7 +393,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
if "payload" in data and isinstance(data["payload"], dict):
|
if "payload" in data and isinstance(data["payload"], dict):
|
||||||
data["payload"] = json.dumps(data["payload"])
|
data["payload"] = json.dumps(data["payload"])
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
dup = await session.execute(
|
dup = await session.execute(
|
||||||
select(Bounty.id).where(
|
select(Bounty.id).where(
|
||||||
Bounty.bounty_type == data.get("bounty_type"),
|
Bounty.bounty_type == data.get("bounty_type"),
|
||||||
@@ -374,7 +441,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
)
|
)
|
||||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
results = await session.execute(statement)
|
results = await session.execute(statement)
|
||||||
final = []
|
final = []
|
||||||
for item in results.scalars().all():
|
for item in results.scalars().all():
|
||||||
@@ -392,12 +459,12 @@ class SQLModelRepository(BaseRepository):
|
|||||||
statement = select(func.count()).select_from(Bounty)
|
statement = select(func.count()).select_from(Bounty)
|
||||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
return result.scalar() or 0
|
return result.scalar() or 0
|
||||||
|
|
||||||
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
statement = select(State).where(State.key == key)
|
statement = select(State).where(State.key == key)
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
state = result.scalar_one_or_none()
|
state = result.scalar_one_or_none()
|
||||||
@@ -406,7 +473,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
statement = select(State).where(State.key == key)
|
statement = select(State).where(State.key == key)
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
state = result.scalar_one_or_none()
|
state = result.scalar_one_or_none()
|
||||||
@@ -424,7 +491,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
|
|
||||||
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Bounty).order_by(asc(Bounty.timestamp))
|
select(Bounty).order_by(asc(Bounty.timestamp))
|
||||||
)
|
)
|
||||||
@@ -440,7 +507,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
|
|
||||||
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
||||||
)
|
)
|
||||||
@@ -455,7 +522,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
return dict(grouped)
|
return dict(grouped)
|
||||||
|
|
||||||
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Attacker).where(Attacker.ip == data["ip"])
|
select(Attacker).where(Attacker.ip == data["ip"])
|
||||||
)
|
)
|
||||||
@@ -477,7 +544,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
attacker_uuid: str,
|
attacker_uuid: str,
|
||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(AttackerBehavior).where(
|
select(AttackerBehavior).where(
|
||||||
AttackerBehavior.attacker_uuid == attacker_uuid
|
AttackerBehavior.attacker_uuid == attacker_uuid
|
||||||
@@ -497,7 +564,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
self,
|
self,
|
||||||
attacker_uuid: str,
|
attacker_uuid: str,
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> Optional[dict[str, Any]]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(AttackerBehavior).where(
|
select(AttackerBehavior).where(
|
||||||
AttackerBehavior.attacker_uuid == attacker_uuid
|
AttackerBehavior.attacker_uuid == attacker_uuid
|
||||||
@@ -514,7 +581,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
if not ips:
|
if not ips:
|
||||||
return {}
|
return {}
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Attacker.ip, AttackerBehavior)
|
select(Attacker.ip, AttackerBehavior)
|
||||||
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
||||||
@@ -556,7 +623,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Attacker).where(Attacker.uuid == uuid)
|
select(Attacker).where(Attacker.uuid == uuid)
|
||||||
)
|
)
|
||||||
@@ -584,7 +651,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
if service:
|
if service:
|
||||||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
return [
|
return [
|
||||||
self._deserialize_attacker(a.model_dump(mode="json"))
|
self._deserialize_attacker(a.model_dump(mode="json"))
|
||||||
@@ -600,7 +667,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
if service:
|
if service:
|
||||||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||||
|
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
return result.scalar() or 0
|
return result.scalar() or 0
|
||||||
|
|
||||||
@@ -611,7 +678,7 @@ class SQLModelRepository(BaseRepository):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
service: Optional[str] = None,
|
service: Optional[str] = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
async with self.session_factory() as session:
|
async with self._session() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Attacker.commands).where(Attacker.uuid == uuid)
|
select(Attacker.commands).where(Attacker.uuid == uuid)
|
||||||
)
|
)
|
||||||
|
|||||||
66
tests/test_admin_seed.py
Normal file
66
tests/test_admin_seed.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Tests for _ensure_admin_user env-drift self-healing.
|
||||||
|
|
||||||
|
Scenario: DECNET_ADMIN_PASSWORD changes between runs while the SQLite DB
|
||||||
|
persists on disk. Previously _ensure_admin_user was strictly insert-if-missing,
|
||||||
|
so the stale hash from the first seed locked out every subsequent login.
|
||||||
|
|
||||||
|
Contract: if the admin still has must_change_password=True (they never
|
||||||
|
finalized their own password), the stored hash re-syncs from the env.
|
||||||
|
Once the admin picks a real password, we never touch it.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from decnet.web.auth import verify_password
|
||||||
|
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_seeded_on_empty_db(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
|
||||||
|
repo = SQLiteRepository(db_path=str(tmp_path / "t.db"))
|
||||||
|
await repo.initialize()
|
||||||
|
user = await repo.get_user_by_username("admin")
|
||||||
|
assert user is not None
|
||||||
|
assert verify_password("first", user["password_hash"])
|
||||||
|
assert user["must_change_password"] is True or user["must_change_password"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_password_resyncs_when_not_finalized(tmp_path, monkeypatch):
|
||||||
|
db = str(tmp_path / "t.db")
|
||||||
|
|
||||||
|
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
|
||||||
|
r1 = SQLiteRepository(db_path=db)
|
||||||
|
await r1.initialize()
|
||||||
|
|
||||||
|
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "second")
|
||||||
|
r2 = SQLiteRepository(db_path=db)
|
||||||
|
await r2.initialize()
|
||||||
|
|
||||||
|
user = await r2.get_user_by_username("admin")
|
||||||
|
assert verify_password("second", user["password_hash"])
|
||||||
|
assert not verify_password("first", user["password_hash"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finalized_admin_password_is_preserved(tmp_path, monkeypatch):
|
||||||
|
db = str(tmp_path / "t.db")
|
||||||
|
|
||||||
|
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "seed")
|
||||||
|
r1 = SQLiteRepository(db_path=db)
|
||||||
|
await r1.initialize()
|
||||||
|
admin = await r1.get_user_by_username("admin")
|
||||||
|
# Simulate the admin finalising their password via the change-password flow.
|
||||||
|
from decnet.web.auth import get_password_hash
|
||||||
|
await r1.update_user_password(
|
||||||
|
admin["uuid"], get_password_hash("chosen"), must_change_password=False
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "different")
|
||||||
|
r2 = SQLiteRepository(db_path=db)
|
||||||
|
await r2.initialize()
|
||||||
|
|
||||||
|
user = await r2.get_user_by_username("admin")
|
||||||
|
assert verify_password("chosen", user["password_hash"])
|
||||||
|
assert not verify_password("different", user["password_hash"])
|
||||||
Reference in New Issue
Block a user