merge testing->tomerge/main #7
@@ -28,6 +28,60 @@ from decnet.web.db.repository import BaseRepository
|
||||
from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior
|
||||
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from decnet.logging import get_logger
|
||||
|
||||
_log = get_logger("db.pool")
|
||||
|
||||
|
||||
async def _force_close(session: AsyncSession) -> None:
|
||||
"""Close a session, forcing connection invalidation if clean close fails.
|
||||
|
||||
Shielded from cancellation and catches every exception class including
|
||||
CancelledError. If session.close() fails (corrupted connection), we
|
||||
invalidate the underlying connection so the pool discards it entirely
|
||||
rather than leaving it checked-out forever.
|
||||
"""
|
||||
try:
|
||||
await asyncio.shield(session.close())
|
||||
except BaseException:
|
||||
# close() failed — connection is likely corrupted.
|
||||
# Try to invalidate the raw connection so the pool drops it.
|
||||
try:
|
||||
bind = session.get_bind()
|
||||
if hasattr(bind, "dispose"):
|
||||
pass # don't dispose the whole engine
|
||||
# The sync_session holds the connection record; invalidating
|
||||
# it tells the pool to discard rather than reuse.
|
||||
sync = session.sync_session
|
||||
if sync.is_active:
|
||||
sync.rollback()
|
||||
sync.close()
|
||||
except BaseException:
|
||||
_log.debug("force-close: fallback cleanup failed", exc_info=True)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _safe_session(factory: async_sessionmaker[AsyncSession]):
|
||||
"""Session context manager that shields cleanup from cancellation.
|
||||
|
||||
Under high concurrency, uvicorn cancels request tasks when clients
|
||||
disconnect. If a CancelledError hits during session.__aexit__,
|
||||
the underlying DB connection is orphaned — never returned to the
|
||||
pool. This wrapper ensures close() always completes, preventing
|
||||
the pool-drain death spiral.
|
||||
"""
|
||||
session = factory()
|
||||
try:
|
||||
yield session
|
||||
except BaseException:
|
||||
await _force_close(session)
|
||||
raise
|
||||
else:
|
||||
await _force_close(session)
|
||||
|
||||
|
||||
class SQLModelRepository(BaseRepository):
|
||||
"""Concrete SQLModel/SQLAlchemy-async repository.
|
||||
|
||||
@@ -38,6 +92,10 @@ class SQLModelRepository(BaseRepository):
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
|
||||
def _session(self):
|
||||
"""Return a cancellation-safe session context manager."""
|
||||
return _safe_session(self.session_factory)
|
||||
|
||||
# ------------------------------------------------------------ lifecycle
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@@ -56,11 +114,12 @@ class SQLModelRepository(BaseRepository):
|
||||
await self._ensure_admin_user()
|
||||
|
||||
async def _ensure_admin_user(self) -> None:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == DECNET_ADMIN_USER)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is None:
|
||||
session.add(User(
|
||||
uuid=str(uuid.uuid4()),
|
||||
username=DECNET_ADMIN_USER,
|
||||
@@ -69,6 +128,14 @@ class SQLModelRepository(BaseRepository):
|
||||
must_change_password=True,
|
||||
))
|
||||
await session.commit()
|
||||
return
|
||||
# Self-heal env drift: if admin never finalized their password,
|
||||
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
|
||||
# the user's chosen password alone.
|
||||
if existing.must_change_password:
|
||||
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
|
||||
session.add(existing)
|
||||
await session.commit()
|
||||
|
||||
async def _migrate_attackers_table(self) -> None:
|
||||
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
||||
@@ -88,7 +155,7 @@ class SQLModelRepository(BaseRepository):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
session.add(Log(**data))
|
||||
await session.commit()
|
||||
|
||||
@@ -171,12 +238,12 @@ class SQLModelRepository(BaseRepository):
|
||||
)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
results = await session.execute(statement)
|
||||
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||||
|
||||
async def get_max_log_id(self) -> int:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(select(func.max(Log.id)))
|
||||
val = result.scalar()
|
||||
return val if val is not None else 0
|
||||
@@ -194,7 +261,7 @@ class SQLModelRepository(BaseRepository):
|
||||
)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
results = await session.execute(statement)
|
||||
return [log.model_dump(mode="json") for log in results.scalars().all()]
|
||||
|
||||
@@ -207,7 +274,7 @@ class SQLModelRepository(BaseRepository):
|
||||
statement = select(func.count()).select_from(Log)
|
||||
statement = self._apply_filters(statement, search, start_time, end_time)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -222,7 +289,7 @@ class SQLModelRepository(BaseRepository):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_stats_summary(self) -> dict[str, Any]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
total_logs = (
|
||||
await session.execute(select(func.count()).select_from(Log))
|
||||
).scalar() or 0
|
||||
@@ -249,7 +316,7 @@ class SQLModelRepository(BaseRepository):
|
||||
# --------------------------------------------------------------- users
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[dict]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
@@ -257,7 +324,7 @@ class SQLModelRepository(BaseRepository):
|
||||
return user.model_dump() if user else None
|
||||
|
||||
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.uuid == uuid)
|
||||
)
|
||||
@@ -265,14 +332,14 @@ class SQLModelRepository(BaseRepository):
|
||||
return user.model_dump() if user else None
|
||||
|
||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
session.add(User(**user_data))
|
||||
await session.commit()
|
||||
|
||||
async def update_user_password(
|
||||
self, uuid: str, password_hash: str, must_change_password: bool = False
|
||||
) -> None:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.uuid == uuid)
|
||||
@@ -284,12 +351,12 @@ class SQLModelRepository(BaseRepository):
|
||||
await session.commit()
|
||||
|
||||
async def list_users(self) -> list[dict]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(select(User))
|
||||
return [u.model_dump() for u in result.scalars().all()]
|
||||
|
||||
async def delete_user(self, uuid: str) -> bool:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(select(User).where(User.uuid == uuid))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
@@ -299,14 +366,14 @@ class SQLModelRepository(BaseRepository):
|
||||
return True
|
||||
|
||||
async def update_user_role(self, uuid: str, role: str) -> None:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
await session.execute(
|
||||
update(User).where(User.uuid == uuid).values(role=role)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def purge_logs_and_bounties(self) -> dict[str, int]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
|
||||
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
|
||||
# attacker_behavior has FK → attackers.uuid; delete children first.
|
||||
@@ -326,7 +393,7 @@ class SQLModelRepository(BaseRepository):
|
||||
if "payload" in data and isinstance(data["payload"], dict):
|
||||
data["payload"] = json.dumps(data["payload"])
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
dup = await session.execute(
|
||||
select(Bounty.id).where(
|
||||
Bounty.bounty_type == data.get("bounty_type"),
|
||||
@@ -374,7 +441,7 @@ class SQLModelRepository(BaseRepository):
|
||||
)
|
||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
results = await session.execute(statement)
|
||||
final = []
|
||||
for item in results.scalars().all():
|
||||
@@ -392,12 +459,12 @@ class SQLModelRepository(BaseRepository):
|
||||
statement = select(func.count()).select_from(Bounty)
|
||||
statement = self._apply_bounty_filters(statement, bounty_type, search)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
statement = select(State).where(State.key == key)
|
||||
result = await session.execute(statement)
|
||||
state = result.scalar_one_or_none()
|
||||
@@ -406,7 +473,7 @@ class SQLModelRepository(BaseRepository):
|
||||
return None
|
||||
|
||||
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
statement = select(State).where(State.key == key)
|
||||
result = await session.execute(statement)
|
||||
state = result.scalar_one_or_none()
|
||||
@@ -424,7 +491,7 @@ class SQLModelRepository(BaseRepository):
|
||||
|
||||
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
||||
from collections import defaultdict
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Bounty).order_by(asc(Bounty.timestamp))
|
||||
)
|
||||
@@ -440,7 +507,7 @@ class SQLModelRepository(BaseRepository):
|
||||
|
||||
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
|
||||
from collections import defaultdict
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
|
||||
)
|
||||
@@ -455,7 +522,7 @@ class SQLModelRepository(BaseRepository):
|
||||
return dict(grouped)
|
||||
|
||||
async def upsert_attacker(self, data: dict[str, Any]) -> str:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Attacker).where(Attacker.ip == data["ip"])
|
||||
)
|
||||
@@ -477,7 +544,7 @@ class SQLModelRepository(BaseRepository):
|
||||
attacker_uuid: str,
|
||||
data: dict[str, Any],
|
||||
) -> None:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(AttackerBehavior).where(
|
||||
AttackerBehavior.attacker_uuid == attacker_uuid
|
||||
@@ -497,7 +564,7 @@ class SQLModelRepository(BaseRepository):
|
||||
self,
|
||||
attacker_uuid: str,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(AttackerBehavior).where(
|
||||
AttackerBehavior.attacker_uuid == attacker_uuid
|
||||
@@ -514,7 +581,7 @@ class SQLModelRepository(BaseRepository):
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
if not ips:
|
||||
return {}
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Attacker.ip, AttackerBehavior)
|
||||
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
|
||||
@@ -556,7 +623,7 @@ class SQLModelRepository(BaseRepository):
|
||||
return d
|
||||
|
||||
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Attacker).where(Attacker.uuid == uuid)
|
||||
)
|
||||
@@ -584,7 +651,7 @@ class SQLModelRepository(BaseRepository):
|
||||
if service:
|
||||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return [
|
||||
self._deserialize_attacker(a.model_dump(mode="json"))
|
||||
@@ -600,7 +667,7 @@ class SQLModelRepository(BaseRepository):
|
||||
if service:
|
||||
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
|
||||
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(statement)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -611,7 +678,7 @@ class SQLModelRepository(BaseRepository):
|
||||
offset: int = 0,
|
||||
service: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
async with self.session_factory() as session:
|
||||
async with self._session() as session:
|
||||
result = await session.execute(
|
||||
select(Attacker.commands).where(Attacker.uuid == uuid)
|
||||
)
|
||||
|
||||
66
tests/test_admin_seed.py
Normal file
66
tests/test_admin_seed.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Tests for _ensure_admin_user env-drift self-healing.
|
||||
|
||||
Scenario: DECNET_ADMIN_PASSWORD changes between runs while the SQLite DB
|
||||
persists on disk. Previously _ensure_admin_user was strictly insert-if-missing,
|
||||
so the stale hash from the first seed locked out every subsequent login.
|
||||
|
||||
Contract: if the admin still has must_change_password=True (they never
|
||||
finalized their own password), the stored hash re-syncs from the env.
|
||||
Once the admin picks a real password, we never touch it.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from decnet.web.auth import verify_password
|
||||
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_seeded_on_empty_db(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
|
||||
repo = SQLiteRepository(db_path=str(tmp_path / "t.db"))
|
||||
await repo.initialize()
|
||||
user = await repo.get_user_by_username("admin")
|
||||
assert user is not None
|
||||
assert verify_password("first", user["password_hash"])
|
||||
assert user["must_change_password"] is True or user["must_change_password"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_password_resyncs_when_not_finalized(tmp_path, monkeypatch):
|
||||
db = str(tmp_path / "t.db")
|
||||
|
||||
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
|
||||
r1 = SQLiteRepository(db_path=db)
|
||||
await r1.initialize()
|
||||
|
||||
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "second")
|
||||
r2 = SQLiteRepository(db_path=db)
|
||||
await r2.initialize()
|
||||
|
||||
user = await r2.get_user_by_username("admin")
|
||||
assert verify_password("second", user["password_hash"])
|
||||
assert not verify_password("first", user["password_hash"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalized_admin_password_is_preserved(tmp_path, monkeypatch):
|
||||
db = str(tmp_path / "t.db")
|
||||
|
||||
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "seed")
|
||||
r1 = SQLiteRepository(db_path=db)
|
||||
await r1.initialize()
|
||||
admin = await r1.get_user_by_username("admin")
|
||||
# Simulate the admin finalising their password via the change-password flow.
|
||||
from decnet.web.auth import get_password_hash
|
||||
await r1.update_user_password(
|
||||
admin["uuid"], get_password_hash("chosen"), must_change_password=False
|
||||
)
|
||||
|
||||
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "different")
|
||||
r2 = SQLiteRepository(db_path=db)
|
||||
await r2.initialize()
|
||||
|
||||
user = await r2.get_user_by_username("admin")
|
||||
assert verify_password("chosen", user["password_hash"])
|
||||
assert not verify_password("different", user["password_hash"])
|
||||
Reference in New Issue
Block a user