diff --git a/decnet.db-shm b/decnet.db-shm deleted file mode 100644 index 9a6b1e3..0000000 Binary files a/decnet.db-shm and /dev/null differ diff --git a/decnet.db-wal b/decnet.db-wal deleted file mode 100644 index 19a5035..0000000 Binary files a/decnet.db-wal and /dev/null differ diff --git a/decnet/cli.py b/decnet/cli.py index e2001d8..91415e5 100644 --- a/decnet/cli.py +++ b/decnet/cli.py @@ -305,16 +305,18 @@ def mutate( from decnet.mutator import mutate_decky, mutate_all, run_watch_loop from decnet.web.dependencies import repo - if watch: - asyncio.run(run_watch_loop(repo)) - return + async def _run() -> None: + await repo.initialize() + if watch: + await run_watch_loop(repo) + elif decky_name: + await mutate_decky(decky_name, repo) + elif force_all: + await mutate_all(force=True, repo=repo) + else: + await mutate_all(force=False, repo=repo) - if decky_name: - asyncio.run(mutate_decky(decky_name, repo)) - elif force_all: - asyncio.run(mutate_all(force=True, repo=repo)) - else: - asyncio.run(mutate_all(force=False, repo=repo)) + asyncio.run(_run()) @app.command() diff --git a/decnet/web/db/sqlite/repository.py b/decnet/web/db/sqlite/repository.py index 92f6721..db4e73b 100644 --- a/decnet/web/db/sqlite/repository.py +++ b/decnet/web/db/sqlite/repository.py @@ -25,34 +25,27 @@ class SQLiteRepository(BaseRepository): self.session_factory = async_sessionmaker( self.engine, class_=AsyncSession, expire_on_commit=False ) - self._initialize_sync() - - def _initialize_sync(self) -> None: - """Initialize the database schema synchronously.""" - init_db(self.db_path) - - from decnet.web.db.sqlite.database import get_sync_engine - engine = get_sync_engine(self.db_path) - with engine.connect() as conn: - conn.execute( - text( - "INSERT OR IGNORE INTO users (uuid, username, password_hash, role, must_change_password) " - "VALUES (:uuid, :u, :p, :r, :m)" - ), - { - "uuid": str(uuid.uuid4()), - "u": DECNET_ADMIN_USER, - "p": get_password_hash(DECNET_ADMIN_PASSWORD), - "r": "admin", - "m": 1, - }, - ) - conn.commit() async def initialize(self) -> None: - """Async warm-up / verification.""" + """Async warm-up / verification. Creates tables if they don't exist.""" + from sqlmodel import SQLModel + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + async with self.session_factory() as session: - await session.execute(text("SELECT 1")) + # Check if admin exists + result = await session.execute( + select(User).where(User.username == DECNET_ADMIN_USER) + ) + if not result.scalar_one_or_none(): + session.add(User( + uuid=str(uuid.uuid4()), + username=DECNET_ADMIN_USER, + password_hash=get_password_hash(DECNET_ADMIN_PASSWORD), + role="admin", + must_change_password=True, + )) + await session.commit() async def reinitialize(self) -> None: """Initialize the database schema asynchronously (useful for tests).""" diff --git a/decnet/web/dependencies.py b/decnet/web/dependencies.py index 4c31648..a1f9c4a 100644 --- a/decnet/web/dependencies.py +++ b/decnet/web/dependencies.py @@ -9,11 +9,16 @@ from decnet.web.db.repository import BaseRepository from decnet.web.db.factory import get_repository # Shared repository singleton -repo: BaseRepository = get_repository() +_repo: Optional[BaseRepository] = None def get_repo() -> BaseRepository: """FastAPI dependency to inject the configured repository.""" - return repo + global _repo + if _repo is None: + _repo = get_repository() + return _repo + +repo = get_repo() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") diff --git a/tests/api/logs/test_histogram.py b/tests/api/logs/test_histogram.py index 0f999e7..863913d 100644 --- a/tests/api/logs/test_histogram.py +++ b/tests/api/logs/test_histogram.py @@ -14,8 +14,10 @@ from ..conftest import _FUZZ_SETTINGS @pytest.fixture -def repo(tmp_path): - return get_repository(db_path=str(tmp_path / "histogram_test.db")) +async def repo(tmp_path): + r = get_repository(db_path=str(tmp_path / "histogram_test.db")) + await r.initialize() + return r def _log(decky="d", service="ssh", ip="1.2.3.4", timestamp=None): diff --git a/tests/api/test_repository.py b/tests/api/test_repository.py index 69ae689..2337882 100644 --- a/tests/api/test_repository.py +++ b/tests/api/test_repository.py @@ -10,8 +10,10 @@ from .conftest import _FUZZ_SETTINGS @pytest.fixture -def repo(tmp_path): - return get_repository(db_path=str(tmp_path / "test.db")) +async def repo(tmp_path): + r = get_repository(db_path=str(tmp_path / "test.db")) + await r.initialize() + return r @pytest.mark.anyio