refactor(db): run Alembic at boot, retire ad-hoc _migrate_* helpers
initialize() now delegates to _apply_schema(): real boots run 'alembic upgrade head' (schema owned by the migration history); tests (DECNET_TESTING=1) keep create_all, which is faster and needs no upgrade path. MySQL wraps the upgrade in the existing GET_LOCK advisory lock so concurrent uvicorn workers don't race on DDL. Deletes the three _migrate_* crimes (attackers-table legacy drop + GeoIP backfill, TEXT->MEDIUMTEXT widening) — all now handled by the baseline migration and the _BIG_TEXT model variants. Drops the test file that only exercised the deleted helpers; adds tests pinning the alembic-vs-create_all gate and guarding that every model table is in the migration head.
This commit is contained in:
38
decnet/web/db/migrate.py
Normal file
38
decnet/web/db/migrate.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
"""Programmatic Alembic upgrade, run at app boot for managed databases.
|
||||||
|
|
||||||
|
Real boots run ``alembic upgrade head`` so the schema is owned by the
|
||||||
|
versioned migration history. Test/ephemeral DBs skip this and use
|
||||||
|
``SQLModel.metadata.create_all`` instead (see
|
||||||
|
:meth:`SQLModelRepository._apply_schema`) — faster, and a throwaway DB never
|
||||||
|
needs an upgrade path.
|
||||||
|
|
||||||
|
The migration scripts live inside the package (``db/migrations``), so this
|
||||||
|
works from an installed wheel without depending on the repo-root
|
||||||
|
``alembic.ini`` (that file exists only for the ``alembic`` CLI).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
|
_MIGRATIONS_DIR = Path(__file__).resolve().parent / "migrations"
|
||||||
|
|
||||||
|
|
||||||
|
def _upgrade(connection: Connection) -> None:
|
||||||
|
# No ini file: env.py skips fileConfig and reuses this connection
|
||||||
|
# (passed via attributes) instead of building its own engine.
|
||||||
|
cfg = Config()
|
||||||
|
cfg.set_main_option("script_location", str(_MIGRATIONS_DIR))
|
||||||
|
cfg.attributes["connection"] = connection
|
||||||
|
command.upgrade(cfg, "head")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations(engine: AsyncEngine) -> None:
|
||||||
|
"""Upgrade ``engine``'s database to the latest revision (alembic head)."""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(_upgrade)
|
||||||
@@ -3,13 +3,12 @@
|
|||||||
MySQL implementation of :class:`BaseRepository`.
|
MySQL implementation of :class:`BaseRepository`.
|
||||||
|
|
||||||
Inherits the portable SQLModel query code from :class:`SQLModelRepository`
|
Inherits the portable SQLModel query code from :class:`SQLModelRepository`
|
||||||
and only overrides the two places where MySQL's SQL dialect differs from
|
and only overrides where MySQL's SQL dialect differs from SQLite's:
|
||||||
SQLite's:
|
|
||||||
|
|
||||||
* :meth:`_migrate_attackers_table` — uses ``information_schema`` (MySQL
|
* :meth:`_apply_schema` — wraps the Alembic upgrade in a MySQL advisory
|
||||||
has no ``PRAGMA``).
|
lock to serialize DDL across concurrent workers.
|
||||||
* :meth:`get_log_histogram` — uses ``FROM_UNIXTIME`` /
|
* :meth:`get_log_histogram` — uses ``FROM_UNIXTIME`` / ``UNIX_TIMESTAMP`` +
|
||||||
``UNIX_TIMESTAMP`` + integer division for bucketing.
|
integer division for bucketing.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -34,86 +33,24 @@ class MySQLRepository(SQLModelRepository):
|
|||||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _migrate_attackers_table(self) -> None:
|
async def _apply_schema(self) -> None:
|
||||||
"""Drop the legacy (pre-UUID) ``attackers`` table if it exists without a ``uuid`` column.
|
"""Run the Alembic upgrade under a MySQL advisory lock.
|
||||||
|
|
||||||
Also adds the GeoIP columns (``country_code``, ``country_source``)
|
The lock serializes DDL across concurrent uvicorn workers — Alembic
|
||||||
to existing tables that predate them. MySQL exposes column
|
does not lock MySQL DDL itself, so without it parallel workers race
|
||||||
metadata via ``information_schema.COLUMNS``; ``DATABASE()`` scopes
|
('Table was skipped since its definition is being modified by
|
||||||
the lookup to the currently connected schema.
|
concurrent DDL'). Tests (``DECNET_TESTING=1``) take the base
|
||||||
|
``create_all`` path, which is single-process and needs no lock.
|
||||||
"""
|
"""
|
||||||
async with self.engine.begin() as conn:
|
import os
|
||||||
rows = (await conn.execute(text(
|
if os.environ.get("DECNET_TESTING") == "1":
|
||||||
"SELECT COLUMN_NAME FROM information_schema.COLUMNS "
|
await super()._apply_schema()
|
||||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'attackers'"
|
return
|
||||||
))).fetchall()
|
from decnet.web.db.migrate import run_migrations
|
||||||
if not rows:
|
|
||||||
return # table absent; create_all() handles it.
|
|
||||||
if not any(r[0] == "uuid" for r in rows):
|
|
||||||
await conn.execute(text("DROP TABLE attackers"))
|
|
||||||
return
|
|
||||||
existing_cols = {r[0] for r in rows}
|
|
||||||
if "country_code" not in existing_cols:
|
|
||||||
await conn.execute(text(
|
|
||||||
"ALTER TABLE attackers "
|
|
||||||
"ADD COLUMN country_code VARCHAR(2) NULL, "
|
|
||||||
"ADD INDEX ix_attackers_country_code (country_code)"
|
|
||||||
))
|
|
||||||
if "country_source" not in existing_cols:
|
|
||||||
await conn.execute(text(
|
|
||||||
"ALTER TABLE attackers ADD COLUMN country_source VARCHAR(16) NULL"
|
|
||||||
))
|
|
||||||
|
|
||||||
async def _migrate_column_types(self) -> None:
|
|
||||||
"""Upgrade TEXT → MEDIUMTEXT for columns that accumulate large JSON blobs.
|
|
||||||
|
|
||||||
``create_all()`` never alters existing columns, so tables created before
|
|
||||||
``_BIG_TEXT`` was introduced keep their 64 KiB ``TEXT`` cap. This method
|
|
||||||
inspects ``information_schema`` and issues ``ALTER TABLE … MODIFY COLUMN``
|
|
||||||
for each offending column found.
|
|
||||||
"""
|
|
||||||
targets: dict[str, dict[str, str]] = {
|
|
||||||
"attackers": {
|
|
||||||
"commands": "MEDIUMTEXT NOT NULL DEFAULT '[]'",
|
|
||||||
"fingerprints": "MEDIUMTEXT NOT NULL DEFAULT '[]'",
|
|
||||||
"services": "MEDIUMTEXT NOT NULL DEFAULT '[]'",
|
|
||||||
"deckies": "MEDIUMTEXT NOT NULL DEFAULT '[]'",
|
|
||||||
},
|
|
||||||
"state": {
|
|
||||||
"value": "MEDIUMTEXT NOT NULL",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
async with self.engine.begin() as conn:
|
|
||||||
rows = (await conn.execute(text(
|
|
||||||
"SELECT TABLE_NAME, COLUMN_NAME FROM information_schema.COLUMNS "
|
|
||||||
"WHERE TABLE_SCHEMA = DATABASE() "
|
|
||||||
" AND TABLE_NAME IN ('attackers', 'state') "
|
|
||||||
" AND COLUMN_NAME IN ('commands','fingerprints','services','deckies','value') "
|
|
||||||
" AND DATA_TYPE = 'text'"
|
|
||||||
))).fetchall()
|
|
||||||
for table_name, col_name in rows:
|
|
||||||
spec = targets.get(table_name, {}).get(col_name)
|
|
||||||
if spec:
|
|
||||||
await conn.execute(text(
|
|
||||||
f"ALTER TABLE `{table_name}` MODIFY COLUMN `{col_name}` {spec}"
|
|
||||||
))
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""Create tables and run all MySQL-specific migrations.
|
|
||||||
|
|
||||||
Uses a MySQL advisory lock to serialize DDL across concurrent
|
|
||||||
uvicorn workers — prevents the 'Table was skipped since its
|
|
||||||
definition is being modified by concurrent DDL' race.
|
|
||||||
"""
|
|
||||||
from sqlmodel import SQLModel
|
|
||||||
async with self.engine.connect() as lock_conn:
|
async with self.engine.connect() as lock_conn:
|
||||||
await lock_conn.execute(text("SELECT GET_LOCK('decnet_schema_init', 30)"))
|
await lock_conn.execute(text("SELECT GET_LOCK('decnet_schema_init', 30)"))
|
||||||
try:
|
try:
|
||||||
await self._migrate_attackers_table()
|
await run_migrations(self.engine)
|
||||||
await self._migrate_column_types()
|
|
||||||
async with self.engine.begin() as conn:
|
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
|
||||||
await self._ensure_admin_user()
|
|
||||||
finally:
|
finally:
|
||||||
await lock_conn.execute(text("SELECT RELEASE_LOCK('decnet_schema_init')"))
|
await lock_conn.execute(text("SELECT RELEASE_LOCK('decnet_schema_init')"))
|
||||||
await lock_conn.close()
|
await lock_conn.close()
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from decnet.web.db.sqlmodel_repo import SQLModelRepository
|
|||||||
class SQLiteRepository(SQLModelRepository):
|
class SQLiteRepository(SQLModelRepository):
|
||||||
"""SQLite backend — uses ``aiosqlite``.
|
"""SQLite backend — uses ``aiosqlite``.
|
||||||
|
|
||||||
Overrides the two places where SQLite's SQL dialect differs from
|
Overrides the one place where SQLite's SQL dialect differs from
|
||||||
MySQL/PostgreSQL: legacy-schema migration (via ``PRAGMA table_info``)
|
MySQL/PostgreSQL: the log-histogram bucket expression (via ``strftime``
|
||||||
and the log-histogram bucket expression (via ``strftime`` + ``unixepoch``).
|
+ ``unixepoch``). Schema is managed by Alembic (see db/migrate.py).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None:
|
def __init__(self, db_path: str = str(_ROOT / "decnet.db")) -> None:
|
||||||
@@ -27,35 +27,6 @@ class SQLiteRepository(SQLModelRepository):
|
|||||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _migrate_attackers_table(self) -> None:
|
|
||||||
"""Drop the old attackers table if it lacks the uuid column (pre-UUID schema).
|
|
||||||
|
|
||||||
Also adds the GeoIP columns (``country_code``, ``country_source``)
|
|
||||||
to existing tables that predate them. SQLite's
|
|
||||||
``ALTER TABLE ADD COLUMN`` is idempotent only if we gate on
|
|
||||||
``PRAGMA table_info`` first — re-adding raises.
|
|
||||||
"""
|
|
||||||
async with self.engine.begin() as conn:
|
|
||||||
rows = (await conn.execute(text("PRAGMA table_info(attackers)"))).fetchall()
|
|
||||||
if rows and not any(r[1] == "uuid" for r in rows):
|
|
||||||
await conn.execute(text("DROP TABLE attackers"))
|
|
||||||
return # create_all() rebuilds fresh — no need to patch columns.
|
|
||||||
if not rows:
|
|
||||||
return # table absent; create_all() handles it.
|
|
||||||
existing_cols = {r[1] for r in rows}
|
|
||||||
if "country_code" not in existing_cols:
|
|
||||||
await conn.execute(text(
|
|
||||||
"ALTER TABLE attackers ADD COLUMN country_code VARCHAR(2)"
|
|
||||||
))
|
|
||||||
await conn.execute(text(
|
|
||||||
"CREATE INDEX IF NOT EXISTS ix_attackers_country_code "
|
|
||||||
"ON attackers (country_code)"
|
|
||||||
))
|
|
||||||
if "country_source" not in existing_cols:
|
|
||||||
await conn.execute(text(
|
|
||||||
"ALTER TABLE attackers ADD COLUMN country_source VARCHAR(16)"
|
|
||||||
))
|
|
||||||
|
|
||||||
def _json_field_equals(self, key: str, param_name: str = "val"):
|
def _json_field_equals(self, key: str, param_name: str = "val"):
|
||||||
# SQLite stores JSON as text; json_extract is the canonical accessor.
|
# SQLite stores JSON as text; json_extract is the canonical accessor.
|
||||||
return text(f"json_extract(fields, '$.{key}') = :{param_name}")
|
return text(f"json_extract(fields, '$.{key}') = :{param_name}")
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ Contains all dialect-portable query code used by the SQLite and MySQL
|
|||||||
backends. Dialect-specific behavior lives in subclasses:
|
backends. Dialect-specific behavior lives in subclasses:
|
||||||
|
|
||||||
* engine/session construction (``__init__``)
|
* engine/session construction (``__init__``)
|
||||||
* ``_migrate_attackers_table`` (legacy schema check; DDL introspection
|
* ``_apply_schema`` (MySQL wraps the Alembic upgrade in an advisory lock)
|
||||||
is not portable)
|
|
||||||
* ``get_log_histogram`` (date-bucket expression differs per dialect)
|
* ``get_log_histogram`` (date-bucket expression differs per dialect)
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -103,14 +102,27 @@ class SQLModelRepository(
|
|||||||
# ------------------------------------------------------------ lifecycle
|
# ------------------------------------------------------------ lifecycle
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Create tables if absent and seed the admin user."""
|
"""Bring the schema up to date and seed the admin user."""
|
||||||
from sqlmodel import SQLModel
|
await self._apply_schema()
|
||||||
await self._migrate_attackers_table()
|
|
||||||
async with self.engine.begin() as conn:
|
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
|
||||||
await self._ensure_admin_user()
|
await self._ensure_admin_user()
|
||||||
await self._ensure_contract_user()
|
await self._ensure_contract_user()
|
||||||
|
|
||||||
|
async def _apply_schema(self) -> None:
|
||||||
|
"""Create/upgrade tables.
|
||||||
|
|
||||||
|
Real boots run Alembic migrations — the schema is owned by the
|
||||||
|
versioned migration history. Test/ephemeral DBs (``DECNET_TESTING=1``)
|
||||||
|
skip Alembic and use ``create_all``: faster, and an in-memory/throwaway
|
||||||
|
DB never needs an upgrade path.
|
||||||
|
"""
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
if os.environ.get("DECNET_TESTING") == "1":
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
return
|
||||||
|
from decnet.web.db.migrate import run_migrations
|
||||||
|
await run_migrations(self.engine)
|
||||||
|
|
||||||
async def reinitialize(self) -> None:
|
async def reinitialize(self) -> None:
|
||||||
"""Re-create schema (for tests / reset flows). Does NOT drop existing tables."""
|
"""Re-create schema (for tests / reset flows). Does NOT drop existing tables."""
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
@@ -165,10 +177,6 @@ class SQLModelRepository(
|
|||||||
))
|
))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _migrate_attackers_table(self) -> None:
|
|
||||||
"""Legacy-schema cleanup. Override per dialect (DDL introspection is non-portable)."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_deckies(self) -> List[dict]:
|
async def get_deckies(self) -> List[dict]:
|
||||||
# The fleet inventory the UI/API sees is fleet_deckies — the
|
# The fleet inventory the UI/API sees is fleet_deckies — the
|
||||||
# engine-mirrored table written on EVERY deploy/teardown (CLI or web),
|
# engine-mirrored table written on EVERY deploy/teardown (CLI or web),
|
||||||
|
|||||||
@@ -1,234 +0,0 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
|
||||||
"""
|
|
||||||
Tests for MySQLRepository._migrate_column_types().
|
|
||||||
|
|
||||||
No live MySQL server required — uses an in-memory SQLite engine that exposes
|
|
||||||
the same information_schema-style query surface via a mocked connection, plus
|
|
||||||
an integration-style test using a real async engine over aiosqlite (which
|
|
||||||
ignores the TEXT/MEDIUMTEXT distinction but verifies the ALTER path is called
|
|
||||||
and idempotent).
|
|
||||||
|
|
||||||
The ALTER TABLE branch is tested via unittest.mock: we intercept the
|
|
||||||
information_schema query result and assert the correct MODIFY COLUMN
|
|
||||||
statements are issued.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
|
||||||
|
|
||||||
from decnet.web.db.mysql.repository import MySQLRepository
|
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ──────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _make_repo() -> MySQLRepository:
|
|
||||||
"""Construct a MySQLRepository without touching any real DB."""
|
|
||||||
return MySQLRepository.__new__(MySQLRepository)
|
|
||||||
|
|
||||||
|
|
||||||
# ── _migrate_column_types ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_migrate_column_types_issues_alter_for_text_columns():
|
|
||||||
"""When information_schema reports TEXT columns, ALTER TABLE is called for each."""
|
|
||||||
repo = _make_repo()
|
|
||||||
|
|
||||||
# Rows returned by the information_schema query: two TEXT columns found
|
|
||||||
fake_rows = [
|
|
||||||
("attackers", "commands"),
|
|
||||||
("attackers", "fingerprints"),
|
|
||||||
("state", "value"),
|
|
||||||
]
|
|
||||||
|
|
||||||
exec_results: list[str] = []
|
|
||||||
|
|
||||||
async def fake_execute(stmt):
|
|
||||||
sql = str(stmt)
|
|
||||||
if "information_schema" in sql:
|
|
||||||
result = MagicMock()
|
|
||||||
result.fetchall.return_value = fake_rows
|
|
||||||
return result
|
|
||||||
# Capture ALTER TABLE calls
|
|
||||||
exec_results.append(sql)
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
fake_conn = AsyncMock()
|
|
||||||
fake_conn.execute.side_effect = fake_execute
|
|
||||||
|
|
||||||
fake_ctx = AsyncMock()
|
|
||||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
|
||||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
repo.engine = MagicMock()
|
|
||||||
repo.engine.begin.return_value = fake_ctx
|
|
||||||
|
|
||||||
await repo._migrate_column_types()
|
|
||||||
|
|
||||||
# Three ALTER TABLE statements expected, one per TEXT column returned
|
|
||||||
assert len(exec_results) == 3
|
|
||||||
assert any("`commands` MEDIUMTEXT" in s for s in exec_results)
|
|
||||||
assert any("`fingerprints` MEDIUMTEXT" in s for s in exec_results)
|
|
||||||
assert any("`value` MEDIUMTEXT" in s for s in exec_results)
|
|
||||||
# Verify NOT NULL is preserved
|
|
||||||
assert all("NOT NULL" in s for s in exec_results)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_migrate_column_types_no_alter_when_already_mediumtext():
|
|
||||||
"""When information_schema returns no TEXT rows, no ALTER is issued."""
|
|
||||||
repo = _make_repo()
|
|
||||||
|
|
||||||
exec_results: list[str] = []
|
|
||||||
|
|
||||||
async def fake_execute(stmt):
|
|
||||||
sql = str(stmt)
|
|
||||||
if "information_schema" in sql:
|
|
||||||
result = MagicMock()
|
|
||||||
result.fetchall.return_value = [] # nothing to migrate
|
|
||||||
return result
|
|
||||||
exec_results.append(sql)
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
fake_conn = AsyncMock()
|
|
||||||
fake_conn.execute.side_effect = fake_execute
|
|
||||||
|
|
||||||
fake_ctx = AsyncMock()
|
|
||||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
|
||||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
repo.engine = MagicMock()
|
|
||||||
repo.engine.begin.return_value = fake_ctx
|
|
||||||
|
|
||||||
await repo._migrate_column_types()
|
|
||||||
|
|
||||||
assert exec_results == [], "No ALTER TABLE should be issued when columns are already MEDIUMTEXT"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_migrate_column_types_idempotent_on_repeated_calls():
|
|
||||||
"""Calling _migrate_column_types twice is safe: second call is a no-op."""
|
|
||||||
repo = _make_repo()
|
|
||||||
call_count = 0
|
|
||||||
|
|
||||||
async def fake_execute(stmt):
|
|
||||||
nonlocal call_count
|
|
||||||
sql = str(stmt)
|
|
||||||
if "information_schema" in sql:
|
|
||||||
result = MagicMock()
|
|
||||||
# First call: two TEXT columns; second call: zero (already migrated)
|
|
||||||
call_count += 1
|
|
||||||
result.fetchall.return_value = (
|
|
||||||
[("attackers", "commands")] if call_count == 1 else []
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
def _make_ctx():
|
|
||||||
fake_conn = AsyncMock()
|
|
||||||
fake_conn.execute.side_effect = fake_execute
|
|
||||||
ctx = AsyncMock()
|
|
||||||
ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
|
||||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
return ctx
|
|
||||||
|
|
||||||
repo.engine = MagicMock()
|
|
||||||
repo.engine.begin.side_effect = _make_ctx
|
|
||||||
|
|
||||||
await repo._migrate_column_types()
|
|
||||||
await repo._migrate_column_types() # second call must not raise
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_migrate_column_types_default_clause_per_column():
|
|
||||||
"""Each attacker column gets DEFAULT '[]'; state.value gets no DEFAULT."""
|
|
||||||
repo = _make_repo()
|
|
||||||
|
|
||||||
all_text_rows = [
|
|
||||||
("attackers", "commands"),
|
|
||||||
("attackers", "fingerprints"),
|
|
||||||
("attackers", "services"),
|
|
||||||
("attackers", "deckies"),
|
|
||||||
("state", "value"),
|
|
||||||
]
|
|
||||||
alter_stmts: list[str] = []
|
|
||||||
|
|
||||||
async def fake_execute(stmt):
|
|
||||||
sql = str(stmt)
|
|
||||||
if "information_schema" in sql:
|
|
||||||
result = MagicMock()
|
|
||||||
result.fetchall.return_value = all_text_rows
|
|
||||||
return result
|
|
||||||
alter_stmts.append(sql)
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
fake_conn = AsyncMock()
|
|
||||||
fake_conn.execute.side_effect = fake_execute
|
|
||||||
|
|
||||||
fake_ctx = AsyncMock()
|
|
||||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
|
||||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
repo.engine = MagicMock()
|
|
||||||
repo.engine.begin.return_value = fake_ctx
|
|
||||||
|
|
||||||
await repo._migrate_column_types()
|
|
||||||
|
|
||||||
attacker_alters = [s for s in alter_stmts if "`attackers`" in s]
|
|
||||||
state_alters = [s for s in alter_stmts if "`state`" in s]
|
|
||||||
|
|
||||||
assert len(attacker_alters) == 4
|
|
||||||
assert len(state_alters) == 1
|
|
||||||
|
|
||||||
for stmt in attacker_alters:
|
|
||||||
assert "DEFAULT '[]'" in stmt, f"Missing DEFAULT '[]' in: {stmt}"
|
|
||||||
|
|
||||||
# state.value has no DEFAULT in the schema
|
|
||||||
assert "DEFAULT" not in state_alters[0], \
|
|
||||||
f"Unexpected DEFAULT in state.value alter: {state_alters[0]}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── initialize override ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mysql_initialize_calls_migrate_column_types():
|
|
||||||
"""MySQLRepository.initialize() must invoke every migration helper
|
|
||||||
in the right order: attackers first, then column types, then seed
|
|
||||||
the admin user.
|
|
||||||
|
|
||||||
The legacy ``_migrate_session_profile_table`` step (DEBT-036) was
|
|
||||||
dropped when SessionProfile was deleted in favour of the
|
|
||||||
``observations`` table — see DEBT-050 / BEHAVE-INTEGRATION.md."""
|
|
||||||
repo = _make_repo()
|
|
||||||
|
|
||||||
call_order: list[str] = []
|
|
||||||
|
|
||||||
async def fake_migrate_attackers():
|
|
||||||
call_order.append("migrate_attackers")
|
|
||||||
|
|
||||||
async def fake_migrate_column_types():
|
|
||||||
call_order.append("migrate_column_types")
|
|
||||||
|
|
||||||
async def fake_ensure_admin():
|
|
||||||
call_order.append("ensure_admin")
|
|
||||||
|
|
||||||
repo._migrate_attackers_table = fake_migrate_attackers
|
|
||||||
repo._migrate_column_types = fake_migrate_column_types
|
|
||||||
repo._ensure_admin_user = fake_ensure_admin
|
|
||||||
|
|
||||||
# Stub engine.begin() so create_all is a no-op
|
|
||||||
fake_conn = AsyncMock()
|
|
||||||
fake_conn.run_sync = AsyncMock()
|
|
||||||
fake_ctx = AsyncMock()
|
|
||||||
fake_ctx.__aenter__ = AsyncMock(return_value=fake_conn)
|
|
||||||
fake_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
repo.engine = MagicMock()
|
|
||||||
repo.engine.begin.return_value = fake_ctx
|
|
||||||
|
|
||||||
await repo.initialize()
|
|
||||||
|
|
||||||
assert call_order == [
|
|
||||||
"migrate_attackers",
|
|
||||||
"migrate_column_types",
|
|
||||||
"ensure_admin",
|
|
||||||
], f"Unexpected call order: {call_order}"
|
|
||||||
79
tests/db/test_alembic_migrations.py
Normal file
79
tests/db/test_alembic_migrations.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
"""Alembic wiring guards.
|
||||||
|
|
||||||
|
These pin the two halves of SQLModelRepository._apply_schema:
|
||||||
|
* real boots run `alembic upgrade head` (schema owned by migration history),
|
||||||
|
* tests (DECNET_TESTING=1) take the faster create_all path.
|
||||||
|
|
||||||
|
The first test also doubles as a drift guard: if someone adds a model table
|
||||||
|
but forgets to autogenerate a migration, `alembic upgrade head` won't create
|
||||||
|
it and this fails.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
import decnet.web.db.models # noqa: F401 (registers every table on metadata)
|
||||||
|
from decnet.web.db.migrate import run_migrations
|
||||||
|
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||||
|
|
||||||
|
|
||||||
|
def _table_names(db_path: str) -> set[str]:
|
||||||
|
con = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
rows = con.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
).fetchall()
|
||||||
|
finally:
|
||||||
|
con.close()
|
||||||
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migrations_create_every_model_table(tmp_path):
|
||||||
|
"""`alembic upgrade head` must materialise every SQLModel table —
|
||||||
|
catches a model added without a corresponding migration."""
|
||||||
|
db_path = str(tmp_path / "mig.db")
|
||||||
|
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||||
|
try:
|
||||||
|
await run_migrations(engine)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
created = _table_names(db_path)
|
||||||
|
expected = set(SQLModel.metadata.tables)
|
||||||
|
missing = expected - created
|
||||||
|
assert not missing, f"migration head is missing tables: {sorted(missing)}"
|
||||||
|
assert "alembic_version" in created
|
||||||
|
|
||||||
|
|
||||||
|
async def test_real_boot_runs_alembic(tmp_path, monkeypatch):
|
||||||
|
"""With DECNET_TESTING unset, initialize() runs migrations and stamps
|
||||||
|
the alembic_version table."""
|
||||||
|
monkeypatch.delenv("DECNET_TESTING", raising=False)
|
||||||
|
repo = SQLiteRepository(db_path=str(tmp_path / "boot.db"))
|
||||||
|
try:
|
||||||
|
await repo._apply_schema()
|
||||||
|
async with repo.engine.begin() as conn:
|
||||||
|
ver = (await conn.execute(text("SELECT version_num FROM alembic_version"))).fetchall()
|
||||||
|
finally:
|
||||||
|
await repo.engine.dispose()
|
||||||
|
assert ver, "alembic_version not stamped — migrations did not run"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_testing_mode_uses_create_all(tmp_path, monkeypatch):
|
||||||
|
"""Under DECNET_TESTING=1 the schema comes from create_all, so there is
|
||||||
|
no alembic_version table (Alembic was skipped)."""
|
||||||
|
monkeypatch.setenv("DECNET_TESTING", "1")
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
repo = SQLiteRepository(db_path=db_path)
|
||||||
|
try:
|
||||||
|
await repo._apply_schema()
|
||||||
|
finally:
|
||||||
|
await repo.engine.dispose()
|
||||||
|
tables = _table_names(db_path)
|
||||||
|
assert "attackers" in tables # schema was created…
|
||||||
|
assert "alembic_version" not in tables # …but not via Alembic
|
||||||
Reference in New Issue
Block a user