merge testing->tomerge/main #7

Open
anti wants to merge 242 commits from testing into tomerge/main
2 changed files with 164 additions and 31 deletions
Showing only changes of commit bd406090a7 - Show all commits

View File

@@ -28,6 +28,60 @@ from decnet.web.db.repository import BaseRepository
from decnet.web.db.models import User, Log, Bounty, State, Attacker, AttackerBehavior
from contextlib import asynccontextmanager
from decnet.logging import get_logger
_log = get_logger("db.pool")
async def _force_close(session: AsyncSession) -> None:
"""Close a session, forcing connection invalidation if clean close fails.
Shielded from cancellation and catches every exception class including
CancelledError. If session.close() fails (corrupted connection), we
invalidate the underlying connection so the pool discards it entirely
rather than leaving it checked-out forever.
"""
try:
await asyncio.shield(session.close())
except BaseException:
# close() failed — connection is likely corrupted.
# Try to invalidate the raw connection so the pool drops it.
try:
bind = session.get_bind()
if hasattr(bind, "dispose"):
pass # don't dispose the whole engine
# The sync_session holds the connection record; invalidating
# it tells the pool to discard rather than reuse.
sync = session.sync_session
if sync.is_active:
sync.rollback()
sync.close()
except BaseException:
_log.debug("force-close: fallback cleanup failed", exc_info=True)
@asynccontextmanager
async def _safe_session(factory: async_sessionmaker[AsyncSession]):
"""Session context manager that shields cleanup from cancellation.
Under high concurrency, uvicorn cancels request tasks when clients
disconnect. If a CancelledError hits during session.__aexit__,
the underlying DB connection is orphaned — never returned to the
pool. This wrapper ensures close() always completes, preventing
the pool-drain death spiral.
"""
session = factory()
try:
yield session
except BaseException:
await _force_close(session)
raise
else:
await _force_close(session)
class SQLModelRepository(BaseRepository):
"""Concrete SQLModel/SQLAlchemy-async repository.
@@ -38,6 +92,10 @@ class SQLModelRepository(BaseRepository):
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
def _session(self):
"""Return a cancellation-safe session context manager."""
return _safe_session(self.session_factory)
# ------------------------------------------------------------ lifecycle
async def initialize(self) -> None:
@@ -56,11 +114,12 @@ class SQLModelRepository(BaseRepository):
await self._ensure_admin_user()
async def _ensure_admin_user(self) -> None:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(User).where(User.username == DECNET_ADMIN_USER)
)
if not result.scalar_one_or_none():
existing = result.scalar_one_or_none()
if existing is None:
session.add(User(
uuid=str(uuid.uuid4()),
username=DECNET_ADMIN_USER,
@@ -69,6 +128,14 @@ class SQLModelRepository(BaseRepository):
must_change_password=True,
))
await session.commit()
return
# Self-heal env drift: if admin never finalized their password,
# re-sync the hash from DECNET_ADMIN_PASSWORD. Otherwise leave
# the user's chosen password alone.
if existing.must_change_password:
existing.password_hash = get_password_hash(DECNET_ADMIN_PASSWORD)
session.add(existing)
await session.commit()
async def _migrate_attackers_table(self) -> None:
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
@@ -88,7 +155,7 @@ class SQLModelRepository(BaseRepository):
except ValueError:
pass
async with self.session_factory() as session:
async with self._session() as session:
session.add(Log(**data))
await session.commit()
@@ -171,12 +238,12 @@ class SQLModelRepository(BaseRepository):
)
statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session:
async with self._session() as session:
results = await session.execute(statement)
return [log.model_dump(mode="json") for log in results.scalars().all()]
async def get_max_log_id(self) -> int:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(select(func.max(Log.id)))
val = result.scalar()
return val if val is not None else 0
@@ -194,7 +261,7 @@ class SQLModelRepository(BaseRepository):
)
statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session:
async with self._session() as session:
results = await session.execute(statement)
return [log.model_dump(mode="json") for log in results.scalars().all()]
@@ -207,7 +274,7 @@ class SQLModelRepository(BaseRepository):
statement = select(func.count()).select_from(Log)
statement = self._apply_filters(statement, search, start_time, end_time)
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(statement)
return result.scalar() or 0
@@ -222,7 +289,7 @@ class SQLModelRepository(BaseRepository):
raise NotImplementedError
async def get_stats_summary(self) -> dict[str, Any]:
async with self.session_factory() as session:
async with self._session() as session:
total_logs = (
await session.execute(select(func.count()).select_from(Log))
).scalar() or 0
@@ -249,7 +316,7 @@ class SQLModelRepository(BaseRepository):
# --------------------------------------------------------------- users
async def get_user_by_username(self, username: str) -> Optional[dict]:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(User).where(User.username == username)
)
@@ -257,7 +324,7 @@ class SQLModelRepository(BaseRepository):
return user.model_dump() if user else None
async def get_user_by_uuid(self, uuid: str) -> Optional[dict]:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(User).where(User.uuid == uuid)
)
@@ -265,14 +332,14 @@ class SQLModelRepository(BaseRepository):
return user.model_dump() if user else None
async def create_user(self, user_data: dict[str, Any]) -> None:
async with self.session_factory() as session:
async with self._session() as session:
session.add(User(**user_data))
await session.commit()
async def update_user_password(
self, uuid: str, password_hash: str, must_change_password: bool = False
) -> None:
async with self.session_factory() as session:
async with self._session() as session:
await session.execute(
update(User)
.where(User.uuid == uuid)
@@ -284,12 +351,12 @@ class SQLModelRepository(BaseRepository):
await session.commit()
async def list_users(self) -> list[dict]:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(select(User))
return [u.model_dump() for u in result.scalars().all()]
async def delete_user(self, uuid: str) -> bool:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(select(User).where(User.uuid == uuid))
user = result.scalar_one_or_none()
if not user:
@@ -299,14 +366,14 @@ class SQLModelRepository(BaseRepository):
return True
async def update_user_role(self, uuid: str, role: str) -> None:
async with self.session_factory() as session:
async with self._session() as session:
await session.execute(
update(User).where(User.uuid == uuid).values(role=role)
)
await session.commit()
async def purge_logs_and_bounties(self) -> dict[str, int]:
async with self.session_factory() as session:
async with self._session() as session:
logs_deleted = (await session.execute(text("DELETE FROM logs"))).rowcount
bounties_deleted = (await session.execute(text("DELETE FROM bounty"))).rowcount
# attacker_behavior has FK → attackers.uuid; delete children first.
@@ -326,7 +393,7 @@ class SQLModelRepository(BaseRepository):
if "payload" in data and isinstance(data["payload"], dict):
data["payload"] = json.dumps(data["payload"])
async with self.session_factory() as session:
async with self._session() as session:
dup = await session.execute(
select(Bounty.id).where(
Bounty.bounty_type == data.get("bounty_type"),
@@ -374,7 +441,7 @@ class SQLModelRepository(BaseRepository):
)
statement = self._apply_bounty_filters(statement, bounty_type, search)
async with self.session_factory() as session:
async with self._session() as session:
results = await session.execute(statement)
final = []
for item in results.scalars().all():
@@ -392,12 +459,12 @@ class SQLModelRepository(BaseRepository):
statement = select(func.count()).select_from(Bounty)
statement = self._apply_bounty_filters(statement, bounty_type, search)
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(statement)
return result.scalar() or 0
async def get_state(self, key: str) -> Optional[dict[str, Any]]:
async with self.session_factory() as session:
async with self._session() as session:
statement = select(State).where(State.key == key)
result = await session.execute(statement)
state = result.scalar_one_or_none()
@@ -406,7 +473,7 @@ class SQLModelRepository(BaseRepository):
return None
async def set_state(self, key: str, value: Any) -> None: # noqa: ANN401
async with self.session_factory() as session:
async with self._session() as session:
statement = select(State).where(State.key == key)
result = await session.execute(statement)
state = result.scalar_one_or_none()
@@ -424,7 +491,7 @@ class SQLModelRepository(BaseRepository):
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
from collections import defaultdict
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(Bounty).order_by(asc(Bounty.timestamp))
)
@@ -440,7 +507,7 @@ class SQLModelRepository(BaseRepository):
async def get_bounties_for_ips(self, ips: set[str]) -> dict[str, List[dict[str, Any]]]:
from collections import defaultdict
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(Bounty).where(Bounty.attacker_ip.in_(ips)).order_by(asc(Bounty.timestamp))
)
@@ -455,7 +522,7 @@ class SQLModelRepository(BaseRepository):
return dict(grouped)
async def upsert_attacker(self, data: dict[str, Any]) -> str:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(Attacker).where(Attacker.ip == data["ip"])
)
@@ -477,7 +544,7 @@ class SQLModelRepository(BaseRepository):
attacker_uuid: str,
data: dict[str, Any],
) -> None:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(AttackerBehavior).where(
AttackerBehavior.attacker_uuid == attacker_uuid
@@ -497,7 +564,7 @@ class SQLModelRepository(BaseRepository):
self,
attacker_uuid: str,
) -> Optional[dict[str, Any]]:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(AttackerBehavior).where(
AttackerBehavior.attacker_uuid == attacker_uuid
@@ -514,7 +581,7 @@ class SQLModelRepository(BaseRepository):
) -> dict[str, dict[str, Any]]:
if not ips:
return {}
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(Attacker.ip, AttackerBehavior)
.join(AttackerBehavior, Attacker.uuid == AttackerBehavior.attacker_uuid)
@@ -556,7 +623,7 @@ class SQLModelRepository(BaseRepository):
return d
async def get_attacker_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(Attacker).where(Attacker.uuid == uuid)
)
@@ -584,7 +651,7 @@ class SQLModelRepository(BaseRepository):
if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(statement)
return [
self._deserialize_attacker(a.model_dump(mode="json"))
@@ -600,7 +667,7 @@ class SQLModelRepository(BaseRepository):
if service:
statement = statement.where(Attacker.services.like(f'%"{service}"%'))
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(statement)
return result.scalar() or 0
@@ -611,7 +678,7 @@ class SQLModelRepository(BaseRepository):
offset: int = 0,
service: Optional[str] = None,
) -> dict[str, Any]:
async with self.session_factory() as session:
async with self._session() as session:
result = await session.execute(
select(Attacker.commands).where(Attacker.uuid == uuid)
)

66
tests/test_admin_seed.py Normal file
View File

@@ -0,0 +1,66 @@
"""
Tests for _ensure_admin_user env-drift self-healing.
Scenario: DECNET_ADMIN_PASSWORD changes between runs while the SQLite DB
persists on disk. Previously _ensure_admin_user was strictly insert-if-missing,
so the stale hash from the first seed locked out every subsequent login.
Contract: if the admin still has must_change_password=True (they never
finalized their own password), the stored hash re-syncs from the env.
Once the admin picks a real password, we never touch it.
"""
import pytest
from decnet.web.auth import verify_password
from decnet.web.db.sqlite.repository import SQLiteRepository
@pytest.mark.asyncio
async def test_admin_seeded_on_empty_db(tmp_path, monkeypatch):
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
repo = SQLiteRepository(db_path=str(tmp_path / "t.db"))
await repo.initialize()
user = await repo.get_user_by_username("admin")
assert user is not None
assert verify_password("first", user["password_hash"])
assert user["must_change_password"] is True or user["must_change_password"] == 1
@pytest.mark.asyncio
async def test_admin_password_resyncs_when_not_finalized(tmp_path, monkeypatch):
db = str(tmp_path / "t.db")
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "first")
r1 = SQLiteRepository(db_path=db)
await r1.initialize()
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "second")
r2 = SQLiteRepository(db_path=db)
await r2.initialize()
user = await r2.get_user_by_username("admin")
assert verify_password("second", user["password_hash"])
assert not verify_password("first", user["password_hash"])
@pytest.mark.asyncio
async def test_finalized_admin_password_is_preserved(tmp_path, monkeypatch):
db = str(tmp_path / "t.db")
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "seed")
r1 = SQLiteRepository(db_path=db)
await r1.initialize()
admin = await r1.get_user_by_username("admin")
# Simulate the admin finalising their password via the change-password flow.
from decnet.web.auth import get_password_hash
await r1.update_user_password(
admin["uuid"], get_password_hash("chosen"), must_change_password=False
)
monkeypatch.setattr("decnet.web.db.sqlmodel_repo.DECNET_ADMIN_PASSWORD", "different")
r2 = SQLiteRepository(db_path=db)
await r2.initialize()
user = await r2.get_user_by_username("admin")
assert verify_password("chosen", user["password_hash"])
assert not verify_password("different", user["password_hash"])