perf(api): TTL-cache /stats + unfiltered pagination counts

Every /stats call ran SELECT count(*) FROM logs + SELECT count(DISTINCT
attacker_ip) FROM logs; every /logs and /attackers call ran an
unfiltered count for the paginator. At 500 concurrent users these
serialize through aiosqlite's worker threads and dominate wall time.

Cache at the router layer (repo stays dialect-agnostic):
  - /stats response: 5s TTL
  - /logs total (only when no filters): 2s TTL
  - /attackers total (only when no filters): 2s TTL

Filtered paths bypass the cache. Pattern reused from api_get_config
and api_get_health (asyncio.Lock + time.monotonic window + lazy lock).
This commit is contained in:
2026-04-17 19:09:15 -04:00
parent de4b64d857
commit 6301504c0e
6 changed files with 233 additions and 4 deletions

View File

@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Any, Optional
from fastapi import APIRouter, Depends, Query
@@ -8,6 +10,36 @@ from decnet.web.db.models import AttackersResponse
router = APIRouter()
# Same pattern as /logs — cache the unfiltered total count; filtered
# counts go straight to the DB.
_TOTAL_TTL = 2.0
_total_cache: tuple[Optional[int], float] = (None, 0.0)
_total_lock: Optional[asyncio.Lock] = None
def _reset_total_cache() -> None:
global _total_cache, _total_lock
_total_cache = (None, 0.0)
_total_lock = None
async def _get_total_attackers_cached() -> int:
global _total_cache, _total_lock
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
if _total_lock is None:
_total_lock = asyncio.Lock()
async with _total_lock:
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
value = await repo.get_total_attackers()
_total_cache = (value, time.monotonic())
return value
@router.get(
"/attackers",
@@ -37,7 +69,10 @@ async def get_attackers(
s = _norm(search)
svc = _norm(service)
_data = await repo.get_attackers(limit=limit, offset=offset, search=s, sort_by=sort_by, service=svc)
_total = await repo.get_total_attackers(search=s, service=svc)
if s is None and svc is None:
_total = await _get_total_attackers_cached()
else:
_total = await repo.get_total_attackers(search=s, service=svc)
# Bulk-join behavior rows for the IPs in this page to avoid N+1 queries.
_ips = {row["ip"] for row in _data if row.get("ip")}

View File

@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Any, Optional
from fastapi import APIRouter, Depends, Query
@@ -8,6 +10,37 @@ from decnet.web.db.models import LogsResponse
router = APIRouter()
# Cache the unfiltered total-logs count. Filtered counts bypass the cache
# (rare, freshness matters for search). SELECT count(*) FROM logs is a
# full scan and gets hammered by paginating clients.
_TOTAL_TTL = 2.0
_total_cache: tuple[Optional[int], float] = (None, 0.0)
_total_lock: Optional[asyncio.Lock] = None
def _reset_total_cache() -> None:
global _total_cache, _total_lock
_total_cache = (None, 0.0)
_total_lock = None
async def _get_total_logs_cached() -> int:
global _total_cache, _total_lock
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
if _total_lock is None:
_total_lock = asyncio.Lock()
async with _total_lock:
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
value = await repo.get_total_logs()
_total_cache = (value, time.monotonic())
return value
@router.get("/logs", response_model=LogsResponse, tags=["Logs"],
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}})
@@ -30,7 +63,10 @@ async def get_logs(
et = _norm(end_time)
_logs: list[dict[str, Any]] = await repo.get_logs(limit=limit, offset=offset, search=s, start_time=st, end_time=et)
_total: int = await repo.get_total_logs(search=s, start_time=st, end_time=et)
if s is None and st is None and et is None:
_total: int = await _get_total_logs_cached()
else:
_total = await repo.get_total_logs(search=s, start_time=st, end_time=et)
return {
"total": _total,
"limit": limit,

View File

@@ -1,4 +1,6 @@
from typing import Any
import asyncio
import time
from typing import Any, Optional
from fastapi import APIRouter, Depends
@@ -8,9 +10,41 @@ from decnet.web.db.models import StatsResponse
router = APIRouter()
# /stats is aggregate telemetry polled constantly by the UI and locust.
# A 5s window collapses thousands of concurrent calls — each of which
# runs SELECT count(*) FROM logs + SELECT count(DISTINCT attacker_ip) —
# into one DB hit per window.
_STATS_TTL = 5.0
_stats_cache: tuple[Optional[dict[str, Any]], float] = (None, 0.0)
_stats_lock: Optional[asyncio.Lock] = None
def _reset_stats_cache() -> None:
global _stats_cache, _stats_lock
_stats_cache = (None, 0.0)
_stats_lock = None
async def _get_stats_cached() -> dict[str, Any]:
global _stats_cache, _stats_lock
value, ts = _stats_cache
now = time.monotonic()
if value is not None and now - ts < _STATS_TTL:
return value
if _stats_lock is None:
_stats_lock = asyncio.Lock()
async with _stats_lock:
value, ts = _stats_cache
now = time.monotonic()
if value is not None and now - ts < _STATS_TTL:
return value
value = await repo.get_stats_summary()
_stats_cache = (value, time.monotonic())
return value
@router.get("/stats", response_model=StatsResponse, tags=["Observability"],
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}},)
@_traced("api.get_stats")
async def get_stats(user: dict = Depends(require_viewer)) -> dict[str, Any]:
return await repo.get_stats_summary()
return await _get_stats_cached()

View File

@@ -56,8 +56,14 @@ async def setup_db(monkeypatch) -> AsyncGenerator[None, None]:
# Reset per-request TTL caches so they don't leak across tests
from decnet.web.router.health import api_get_health as _h
from decnet.web.router.config import api_get_config as _c
from decnet.web.router.stats import api_get_stats as _s
from decnet.web.router.logs import api_get_logs as _l
from decnet.web.router.attackers import api_get_attackers as _a
_h._reset_db_cache()
_c._reset_state_cache()
_s._reset_stats_cache()
_l._reset_total_cache()
_a._reset_total_cache()
# Create schema
async with engine.begin() as conn:

View File

@@ -15,6 +15,14 @@ import pytest
from fastapi import HTTPException
from decnet.web.auth import create_access_token
from decnet.web.router.attackers.api_get_attackers import _reset_total_cache
@pytest.fixture(autouse=True)
def _reset_attackers_cache():
_reset_total_cache()
yield
_reset_total_cache()
# ─── Helpers ──────────────────────────────────────────────────────────────────

110
tests/test_router_cache.py Normal file
View File

@@ -0,0 +1,110 @@
"""
TTL-cache contract for /stats, /logs total count, and /attackers total count.
Under concurrent load N callers should collapse to 1 repo hit per TTL
window. Tests patch the repo — no real DB.
"""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from decnet.web.router.stats import api_get_stats
from decnet.web.router.logs import api_get_logs
from decnet.web.router.attackers import api_get_attackers
@pytest.fixture(autouse=True)
def _reset_router_caches():
api_get_stats._reset_stats_cache()
api_get_logs._reset_total_cache()
api_get_attackers._reset_total_cache()
yield
api_get_stats._reset_stats_cache()
api_get_logs._reset_total_cache()
api_get_attackers._reset_total_cache()
# ── /stats whole-response cache ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_stats_cache_collapses_concurrent_calls():
api_get_stats._reset_stats_cache()
payload = {"total_logs": 42, "unique_attackers": 7, "active_deckies": 3, "deployed_deckies": 3}
with patch.object(api_get_stats, "repo") as mock_repo:
mock_repo.get_stats_summary = AsyncMock(return_value=payload)
results = await asyncio.gather(*[api_get_stats._get_stats_cached() for _ in range(50)])
assert all(r == payload for r in results)
assert mock_repo.get_stats_summary.await_count == 1
@pytest.mark.asyncio
async def test_stats_cache_expires_after_ttl(monkeypatch):
api_get_stats._reset_stats_cache()
clock = {"t": 0.0}
monkeypatch.setattr(api_get_stats.time, "monotonic", lambda: clock["t"])
with patch.object(api_get_stats, "repo") as mock_repo:
mock_repo.get_stats_summary = AsyncMock(return_value={"total_logs": 1, "unique_attackers": 0, "active_deckies": 0, "deployed_deckies": 0})
await api_get_stats._get_stats_cached()
clock["t"] = 100.0 # past TTL
await api_get_stats._get_stats_cached()
assert mock_repo.get_stats_summary.await_count == 2
# ── /logs total-count cache ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_logs_total_cache_collapses_concurrent_calls():
api_get_logs._reset_total_cache()
with patch.object(api_get_logs, "repo") as mock_repo:
mock_repo.get_total_logs = AsyncMock(return_value=1234)
results = await asyncio.gather(*[api_get_logs._get_total_logs_cached() for _ in range(50)])
assert all(r == 1234 for r in results)
assert mock_repo.get_total_logs.await_count == 1
@pytest.mark.asyncio
async def test_logs_filtered_count_bypasses_cache():
"""When a filter is provided, the endpoint must hit repo every time."""
api_get_logs._reset_total_cache()
with patch.object(api_get_logs, "repo") as mock_repo:
mock_repo.get_logs = AsyncMock(return_value=[])
mock_repo.get_total_logs = AsyncMock(return_value=0)
for _ in range(3):
await api_get_logs.get_logs(
limit=50, offset=0, search="needle", start_time=None, end_time=None,
user={"uuid": "u", "role": "viewer"},
)
# 3 filtered calls → 3 repo hits, all with search=needle
assert mock_repo.get_total_logs.await_count == 3
for call in mock_repo.get_total_logs.await_args_list:
assert call.kwargs["search"] == "needle"
# ── /attackers total-count cache ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_attackers_total_cache_collapses_concurrent_calls():
api_get_attackers._reset_total_cache()
with patch.object(api_get_attackers, "repo") as mock_repo:
mock_repo.get_total_attackers = AsyncMock(return_value=99)
results = await asyncio.gather(*[api_get_attackers._get_total_attackers_cached() for _ in range(50)])
assert all(r == 99 for r in results)
assert mock_repo.get_total_attackers.await_count == 1
@pytest.mark.asyncio
async def test_attackers_filtered_count_bypasses_cache():
api_get_attackers._reset_total_cache()
with patch.object(api_get_attackers, "repo") as mock_repo:
mock_repo.get_attackers = AsyncMock(return_value=[])
mock_repo.get_total_attackers = AsyncMock(return_value=0)
mock_repo.get_behaviors_for_ips = AsyncMock(return_value={})
for _ in range(3):
await api_get_attackers.get_attackers(
limit=50, offset=0, search="10.", sort_by="recent", service=None,
user={"uuid": "u", "role": "viewer"},
)
assert mock_repo.get_total_attackers.await_count == 3
for call in mock_repo.get_total_attackers.await_args_list:
assert call.kwargs["search"] == "10."