diff --git a/decnet/web/router/attackers/api_get_attackers.py b/decnet/web/router/attackers/api_get_attackers.py index 6f3daa5..f1ff7b4 100644 --- a/decnet/web/router/attackers/api_get_attackers.py +++ b/decnet/web/router/attackers/api_get_attackers.py @@ -1,3 +1,5 @@ +import asyncio +import time from typing import Any, Optional from fastapi import APIRouter, Depends, Query @@ -8,6 +10,36 @@ from decnet.web.db.models import AttackersResponse router = APIRouter() +# Same pattern as /logs — cache the unfiltered total count; filtered +# counts go straight to the DB. +_TOTAL_TTL = 2.0 +_total_cache: tuple[Optional[int], float] = (None, 0.0) +_total_lock: Optional[asyncio.Lock] = None + + +def _reset_total_cache() -> None: + global _total_cache, _total_lock + _total_cache = (None, 0.0) + _total_lock = None + + +async def _get_total_attackers_cached() -> int: + global _total_cache, _total_lock + value, ts = _total_cache + now = time.monotonic() + if value is not None and now - ts < _TOTAL_TTL: + return value + if _total_lock is None: + _total_lock = asyncio.Lock() + async with _total_lock: + value, ts = _total_cache + now = time.monotonic() + if value is not None and now - ts < _TOTAL_TTL: + return value + value = await repo.get_total_attackers() + _total_cache = (value, time.monotonic()) + return value + @router.get( "/attackers", @@ -37,7 +69,10 @@ async def get_attackers( s = _norm(search) svc = _norm(service) _data = await repo.get_attackers(limit=limit, offset=offset, search=s, sort_by=sort_by, service=svc) - _total = await repo.get_total_attackers(search=s, service=svc) + if s is None and svc is None: + _total = await _get_total_attackers_cached() + else: + _total = await repo.get_total_attackers(search=s, service=svc) # Bulk-join behavior rows for the IPs in this page to avoid N+1 queries. _ips = {row["ip"] for row in _data if row.get("ip")} diff --git a/decnet/web/router/logs/api_get_logs.py b/decnet/web/router/logs/api_get_logs.py index 46c5a14..8bd864b 100644 --- a/decnet/web/router/logs/api_get_logs.py +++ b/decnet/web/router/logs/api_get_logs.py @@ -1,3 +1,5 @@ +import asyncio +import time from typing import Any, Optional from fastapi import APIRouter, Depends, Query @@ -8,6 +10,37 @@ from decnet.web.db.models import LogsResponse router = APIRouter() +# Cache the unfiltered total-logs count. Filtered counts bypass the cache +# (rare, freshness matters for search). SELECT count(*) FROM logs is a +# full scan and gets hammered by paginating clients. +_TOTAL_TTL = 2.0 +_total_cache: tuple[Optional[int], float] = (None, 0.0) +_total_lock: Optional[asyncio.Lock] = None + + +def _reset_total_cache() -> None: + global _total_cache, _total_lock + _total_cache = (None, 0.0) + _total_lock = None + + +async def _get_total_logs_cached() -> int: + global _total_cache, _total_lock + value, ts = _total_cache + now = time.monotonic() + if value is not None and now - ts < _TOTAL_TTL: + return value + if _total_lock is None: + _total_lock = asyncio.Lock() + async with _total_lock: + value, ts = _total_cache + now = time.monotonic() + if value is not None and now - ts < _TOTAL_TTL: + return value + value = await repo.get_total_logs() + _total_cache = (value, time.monotonic()) + return value + @router.get("/logs", response_model=LogsResponse, tags=["Logs"], responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}}) @@ -30,7 +63,10 @@ async def get_logs( et = _norm(end_time) _logs: list[dict[str, Any]] = await repo.get_logs(limit=limit, offset=offset, search=s, start_time=st, end_time=et) - _total: int = await repo.get_total_logs(search=s, start_time=st, end_time=et) + if s is None and st is None and et is None: + _total: int = await _get_total_logs_cached() + else: + _total = await repo.get_total_logs(search=s, start_time=st, end_time=et) return { "total": _total, "limit": limit, diff --git a/decnet/web/router/stats/api_get_stats.py b/decnet/web/router/stats/api_get_stats.py index a1739b7..474331d 100644 --- a/decnet/web/router/stats/api_get_stats.py +++ b/decnet/web/router/stats/api_get_stats.py @@ -1,4 +1,6 @@ -from typing import Any +import asyncio +import time +from typing import Any, Optional from fastapi import APIRouter, Depends @@ -8,9 +10,41 @@ from decnet.web.db.models import StatsResponse router = APIRouter() +# /stats is aggregate telemetry polled constantly by the UI and locust. +# A 5s window collapses thousands of concurrent calls — each of which +# runs SELECT count(*) FROM logs + SELECT count(DISTINCT attacker_ip) — +# into one DB hit per window. +_STATS_TTL = 5.0 +_stats_cache: tuple[Optional[dict[str, Any]], float] = (None, 0.0) +_stats_lock: Optional[asyncio.Lock] = None + + +def _reset_stats_cache() -> None: + global _stats_cache, _stats_lock + _stats_cache = (None, 0.0) + _stats_lock = None + + +async def _get_stats_cached() -> dict[str, Any]: + global _stats_cache, _stats_lock + value, ts = _stats_cache + now = time.monotonic() + if value is not None and now - ts < _STATS_TTL: + return value + if _stats_lock is None: + _stats_lock = asyncio.Lock() + async with _stats_lock: + value, ts = _stats_cache + now = time.monotonic() + if value is not None and now - ts < _STATS_TTL: + return value + value = await repo.get_stats_summary() + _stats_cache = (value, time.monotonic()) + return value + @router.get("/stats", response_model=StatsResponse, tags=["Observability"], responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}},) @_traced("api.get_stats") async def get_stats(user: dict = Depends(require_viewer)) -> dict[str, Any]: - return await repo.get_stats_summary() + return await _get_stats_cached() diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 7727f02..186caa1 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -56,8 +56,14 @@ async def setup_db(monkeypatch) -> AsyncGenerator[None, None]: # Reset per-request TTL caches so they don't leak across tests from decnet.web.router.health import api_get_health as _h from decnet.web.router.config import api_get_config as _c + from decnet.web.router.stats import api_get_stats as _s + from decnet.web.router.logs import api_get_logs as _l + from decnet.web.router.attackers import api_get_attackers as _a _h._reset_db_cache() _c._reset_state_cache() + _s._reset_stats_cache() + _l._reset_total_cache() + _a._reset_total_cache() # Create schema async with engine.begin() as conn: diff --git a/tests/test_api_attackers.py b/tests/test_api_attackers.py index 82022eb..9efa573 100644 --- a/tests/test_api_attackers.py +++ b/tests/test_api_attackers.py @@ -15,6 +15,14 @@ import pytest from fastapi import HTTPException from decnet.web.auth import create_access_token +from decnet.web.router.attackers.api_get_attackers import _reset_total_cache + + +@pytest.fixture(autouse=True) +def _reset_attackers_cache(): + _reset_total_cache() + yield + _reset_total_cache() # ─── Helpers ────────────────────────────────────────────────────────────────── diff --git a/tests/test_router_cache.py b/tests/test_router_cache.py new file mode 100644 index 0000000..ab81682 --- /dev/null +++ b/tests/test_router_cache.py @@ -0,0 +1,110 @@ +""" +TTL-cache contract for /stats, /logs total count, and /attackers total count. + +Under concurrent load N callers should collapse to 1 repo hit per TTL +window. Tests patch the repo — no real DB. +""" +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from decnet.web.router.stats import api_get_stats +from decnet.web.router.logs import api_get_logs +from decnet.web.router.attackers import api_get_attackers + + +@pytest.fixture(autouse=True) +def _reset_router_caches(): + api_get_stats._reset_stats_cache() + api_get_logs._reset_total_cache() + api_get_attackers._reset_total_cache() + yield + api_get_stats._reset_stats_cache() + api_get_logs._reset_total_cache() + api_get_attackers._reset_total_cache() + + +# ── /stats whole-response cache ────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_stats_cache_collapses_concurrent_calls(): + api_get_stats._reset_stats_cache() + payload = {"total_logs": 42, "unique_attackers": 7, "active_deckies": 3, "deployed_deckies": 3} + with patch.object(api_get_stats, "repo") as mock_repo: + mock_repo.get_stats_summary = AsyncMock(return_value=payload) + results = await asyncio.gather(*[api_get_stats._get_stats_cached() for _ in range(50)]) + assert all(r == payload for r in results) + assert mock_repo.get_stats_summary.await_count == 1 + + +@pytest.mark.asyncio +async def test_stats_cache_expires_after_ttl(monkeypatch): + api_get_stats._reset_stats_cache() + clock = {"t": 0.0} + monkeypatch.setattr(api_get_stats.time, "monotonic", lambda: clock["t"]) + with patch.object(api_get_stats, "repo") as mock_repo: + mock_repo.get_stats_summary = AsyncMock(return_value={"total_logs": 1, "unique_attackers": 0, "active_deckies": 0, "deployed_deckies": 0}) + await api_get_stats._get_stats_cached() + clock["t"] = 100.0 # past TTL + await api_get_stats._get_stats_cached() + assert mock_repo.get_stats_summary.await_count == 2 + + +# ── /logs total-count cache ────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_logs_total_cache_collapses_concurrent_calls(): + api_get_logs._reset_total_cache() + with patch.object(api_get_logs, "repo") as mock_repo: + mock_repo.get_total_logs = AsyncMock(return_value=1234) + results = await asyncio.gather(*[api_get_logs._get_total_logs_cached() for _ in range(50)]) + assert all(r == 1234 for r in results) + assert mock_repo.get_total_logs.await_count == 1 + + +@pytest.mark.asyncio +async def test_logs_filtered_count_bypasses_cache(): + """When a filter is provided, the endpoint must hit repo every time.""" + api_get_logs._reset_total_cache() + with patch.object(api_get_logs, "repo") as mock_repo: + mock_repo.get_logs = AsyncMock(return_value=[]) + mock_repo.get_total_logs = AsyncMock(return_value=0) + for _ in range(3): + await api_get_logs.get_logs( + limit=50, offset=0, search="needle", start_time=None, end_time=None, + user={"uuid": "u", "role": "viewer"}, + ) + # 3 filtered calls → 3 repo hits, all with search=needle + assert mock_repo.get_total_logs.await_count == 3 + for call in mock_repo.get_total_logs.await_args_list: + assert call.kwargs["search"] == "needle" + + +# ── /attackers total-count cache ───────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_attackers_total_cache_collapses_concurrent_calls(): + api_get_attackers._reset_total_cache() + with patch.object(api_get_attackers, "repo") as mock_repo: + mock_repo.get_total_attackers = AsyncMock(return_value=99) + results = await asyncio.gather(*[api_get_attackers._get_total_attackers_cached() for _ in range(50)]) + assert all(r == 99 for r in results) + assert mock_repo.get_total_attackers.await_count == 1 + + +@pytest.mark.asyncio +async def test_attackers_filtered_count_bypasses_cache(): + api_get_attackers._reset_total_cache() + with patch.object(api_get_attackers, "repo") as mock_repo: + mock_repo.get_attackers = AsyncMock(return_value=[]) + mock_repo.get_total_attackers = AsyncMock(return_value=0) + mock_repo.get_behaviors_for_ips = AsyncMock(return_value={}) + for _ in range(3): + await api_get_attackers.get_attackers( + limit=50, offset=0, search="10.", sort_by="recent", service=None, + user={"uuid": "u", "role": "viewer"}, + ) + assert mock_repo.get_total_attackers.await_count == 3 + for call in mock_repo.get_total_attackers.await_args_list: + assert call.kwargs["search"] == "10."