merge testing->tomerge/main #7
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -8,6 +10,36 @@ from decnet.web.db.models import AttackersResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Same pattern as /logs — cache the unfiltered total count; filtered
|
||||
# counts go straight to the DB.
|
||||
_TOTAL_TTL = 2.0
|
||||
_total_cache: tuple[Optional[int], float] = (None, 0.0)
|
||||
_total_lock: Optional[asyncio.Lock] = None
|
||||
|
||||
|
||||
def _reset_total_cache() -> None:
|
||||
global _total_cache, _total_lock
|
||||
_total_cache = (None, 0.0)
|
||||
_total_lock = None
|
||||
|
||||
|
||||
async def _get_total_attackers_cached() -> int:
|
||||
global _total_cache, _total_lock
|
||||
value, ts = _total_cache
|
||||
now = time.monotonic()
|
||||
if value is not None and now - ts < _TOTAL_TTL:
|
||||
return value
|
||||
if _total_lock is None:
|
||||
_total_lock = asyncio.Lock()
|
||||
async with _total_lock:
|
||||
value, ts = _total_cache
|
||||
now = time.monotonic()
|
||||
if value is not None and now - ts < _TOTAL_TTL:
|
||||
return value
|
||||
value = await repo.get_total_attackers()
|
||||
_total_cache = (value, time.monotonic())
|
||||
return value
|
||||
|
||||
|
||||
@router.get(
|
||||
"/attackers",
|
||||
@@ -37,7 +69,10 @@ async def get_attackers(
|
||||
s = _norm(search)
|
||||
svc = _norm(service)
|
||||
_data = await repo.get_attackers(limit=limit, offset=offset, search=s, sort_by=sort_by, service=svc)
|
||||
_total = await repo.get_total_attackers(search=s, service=svc)
|
||||
if s is None and svc is None:
|
||||
_total = await _get_total_attackers_cached()
|
||||
else:
|
||||
_total = await repo.get_total_attackers(search=s, service=svc)
|
||||
|
||||
# Bulk-join behavior rows for the IPs in this page to avoid N+1 queries.
|
||||
_ips = {row["ip"] for row in _data if row.get("ip")}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -8,6 +10,37 @@ from decnet.web.db.models import LogsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Cache the unfiltered total-logs count. Filtered counts bypass the cache
|
||||
# (rare, freshness matters for search). SELECT count(*) FROM logs is a
|
||||
# full scan and gets hammered by paginating clients.
|
||||
_TOTAL_TTL = 2.0
|
||||
_total_cache: tuple[Optional[int], float] = (None, 0.0)
|
||||
_total_lock: Optional[asyncio.Lock] = None
|
||||
|
||||
|
||||
def _reset_total_cache() -> None:
|
||||
global _total_cache, _total_lock
|
||||
_total_cache = (None, 0.0)
|
||||
_total_lock = None
|
||||
|
||||
|
||||
async def _get_total_logs_cached() -> int:
|
||||
global _total_cache, _total_lock
|
||||
value, ts = _total_cache
|
||||
now = time.monotonic()
|
||||
if value is not None and now - ts < _TOTAL_TTL:
|
||||
return value
|
||||
if _total_lock is None:
|
||||
_total_lock = asyncio.Lock()
|
||||
async with _total_lock:
|
||||
value, ts = _total_cache
|
||||
now = time.monotonic()
|
||||
if value is not None and now - ts < _TOTAL_TTL:
|
||||
return value
|
||||
value = await repo.get_total_logs()
|
||||
_total_cache = (value, time.monotonic())
|
||||
return value
|
||||
|
||||
|
||||
@router.get("/logs", response_model=LogsResponse, tags=["Logs"],
|
||||
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}})
|
||||
@@ -30,7 +63,10 @@ async def get_logs(
|
||||
et = _norm(end_time)
|
||||
|
||||
_logs: list[dict[str, Any]] = await repo.get_logs(limit=limit, offset=offset, search=s, start_time=st, end_time=et)
|
||||
_total: int = await repo.get_total_logs(search=s, start_time=st, end_time=et)
|
||||
if s is None and st is None and et is None:
|
||||
_total: int = await _get_total_logs_cached()
|
||||
else:
|
||||
_total = await repo.get_total_logs(search=s, start_time=st, end_time=et)
|
||||
return {
|
||||
"total": _total,
|
||||
"limit": limit,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Any
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
@@ -8,9 +10,41 @@ from decnet.web.db.models import StatsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# /stats is aggregate telemetry polled constantly by the UI and locust.
|
||||
# A 5s window collapses thousands of concurrent calls — each of which
|
||||
# runs SELECT count(*) FROM logs + SELECT count(DISTINCT attacker_ip) —
|
||||
# into one DB hit per window.
|
||||
_STATS_TTL = 5.0
|
||||
_stats_cache: tuple[Optional[dict[str, Any]], float] = (None, 0.0)
|
||||
_stats_lock: Optional[asyncio.Lock] = None
|
||||
|
||||
|
||||
def _reset_stats_cache() -> None:
|
||||
global _stats_cache, _stats_lock
|
||||
_stats_cache = (None, 0.0)
|
||||
_stats_lock = None
|
||||
|
||||
|
||||
async def _get_stats_cached() -> dict[str, Any]:
|
||||
global _stats_cache, _stats_lock
|
||||
value, ts = _stats_cache
|
||||
now = time.monotonic()
|
||||
if value is not None and now - ts < _STATS_TTL:
|
||||
return value
|
||||
if _stats_lock is None:
|
||||
_stats_lock = asyncio.Lock()
|
||||
async with _stats_lock:
|
||||
value, ts = _stats_cache
|
||||
now = time.monotonic()
|
||||
if value is not None and now - ts < _STATS_TTL:
|
||||
return value
|
||||
value = await repo.get_stats_summary()
|
||||
_stats_cache = (value, time.monotonic())
|
||||
return value
|
||||
|
||||
|
||||
@router.get("/stats", response_model=StatsResponse, tags=["Observability"],
|
||||
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}},)
|
||||
@_traced("api.get_stats")
|
||||
async def get_stats(user: dict = Depends(require_viewer)) -> dict[str, Any]:
|
||||
return await repo.get_stats_summary()
|
||||
return await _get_stats_cached()
|
||||
|
||||
@@ -56,8 +56,14 @@ async def setup_db(monkeypatch) -> AsyncGenerator[None, None]:
|
||||
# Reset per-request TTL caches so they don't leak across tests
|
||||
from decnet.web.router.health import api_get_health as _h
|
||||
from decnet.web.router.config import api_get_config as _c
|
||||
from decnet.web.router.stats import api_get_stats as _s
|
||||
from decnet.web.router.logs import api_get_logs as _l
|
||||
from decnet.web.router.attackers import api_get_attackers as _a
|
||||
_h._reset_db_cache()
|
||||
_c._reset_state_cache()
|
||||
_s._reset_stats_cache()
|
||||
_l._reset_total_cache()
|
||||
_a._reset_total_cache()
|
||||
|
||||
# Create schema
|
||||
async with engine.begin() as conn:
|
||||
|
||||
@@ -15,6 +15,14 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from decnet.web.auth import create_access_token
|
||||
from decnet.web.router.attackers.api_get_attackers import _reset_total_cache
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_attackers_cache():
|
||||
_reset_total_cache()
|
||||
yield
|
||||
_reset_total_cache()
|
||||
|
||||
|
||||
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
110
tests/test_router_cache.py
Normal file
110
tests/test_router_cache.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
TTL-cache contract for /stats, /logs total count, and /attackers total count.
|
||||
|
||||
Under concurrent load N callers should collapse to 1 repo hit per TTL
|
||||
window. Tests patch the repo — no real DB.
|
||||
"""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.web.router.stats import api_get_stats
|
||||
from decnet.web.router.logs import api_get_logs
|
||||
from decnet.web.router.attackers import api_get_attackers
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_router_caches():
|
||||
api_get_stats._reset_stats_cache()
|
||||
api_get_logs._reset_total_cache()
|
||||
api_get_attackers._reset_total_cache()
|
||||
yield
|
||||
api_get_stats._reset_stats_cache()
|
||||
api_get_logs._reset_total_cache()
|
||||
api_get_attackers._reset_total_cache()
|
||||
|
||||
|
||||
# ── /stats whole-response cache ──────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_cache_collapses_concurrent_calls():
|
||||
api_get_stats._reset_stats_cache()
|
||||
payload = {"total_logs": 42, "unique_attackers": 7, "active_deckies": 3, "deployed_deckies": 3}
|
||||
with patch.object(api_get_stats, "repo") as mock_repo:
|
||||
mock_repo.get_stats_summary = AsyncMock(return_value=payload)
|
||||
results = await asyncio.gather(*[api_get_stats._get_stats_cached() for _ in range(50)])
|
||||
assert all(r == payload for r in results)
|
||||
assert mock_repo.get_stats_summary.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_cache_expires_after_ttl(monkeypatch):
|
||||
api_get_stats._reset_stats_cache()
|
||||
clock = {"t": 0.0}
|
||||
monkeypatch.setattr(api_get_stats.time, "monotonic", lambda: clock["t"])
|
||||
with patch.object(api_get_stats, "repo") as mock_repo:
|
||||
mock_repo.get_stats_summary = AsyncMock(return_value={"total_logs": 1, "unique_attackers": 0, "active_deckies": 0, "deployed_deckies": 0})
|
||||
await api_get_stats._get_stats_cached()
|
||||
clock["t"] = 100.0 # past TTL
|
||||
await api_get_stats._get_stats_cached()
|
||||
assert mock_repo.get_stats_summary.await_count == 2
|
||||
|
||||
|
||||
# ── /logs total-count cache ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_total_cache_collapses_concurrent_calls():
|
||||
api_get_logs._reset_total_cache()
|
||||
with patch.object(api_get_logs, "repo") as mock_repo:
|
||||
mock_repo.get_total_logs = AsyncMock(return_value=1234)
|
||||
results = await asyncio.gather(*[api_get_logs._get_total_logs_cached() for _ in range(50)])
|
||||
assert all(r == 1234 for r in results)
|
||||
assert mock_repo.get_total_logs.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_filtered_count_bypasses_cache():
|
||||
"""When a filter is provided, the endpoint must hit repo every time."""
|
||||
api_get_logs._reset_total_cache()
|
||||
with patch.object(api_get_logs, "repo") as mock_repo:
|
||||
mock_repo.get_logs = AsyncMock(return_value=[])
|
||||
mock_repo.get_total_logs = AsyncMock(return_value=0)
|
||||
for _ in range(3):
|
||||
await api_get_logs.get_logs(
|
||||
limit=50, offset=0, search="needle", start_time=None, end_time=None,
|
||||
user={"uuid": "u", "role": "viewer"},
|
||||
)
|
||||
# 3 filtered calls → 3 repo hits, all with search=needle
|
||||
assert mock_repo.get_total_logs.await_count == 3
|
||||
for call in mock_repo.get_total_logs.await_args_list:
|
||||
assert call.kwargs["search"] == "needle"
|
||||
|
||||
|
||||
# ── /attackers total-count cache ─────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attackers_total_cache_collapses_concurrent_calls():
|
||||
api_get_attackers._reset_total_cache()
|
||||
with patch.object(api_get_attackers, "repo") as mock_repo:
|
||||
mock_repo.get_total_attackers = AsyncMock(return_value=99)
|
||||
results = await asyncio.gather(*[api_get_attackers._get_total_attackers_cached() for _ in range(50)])
|
||||
assert all(r == 99 for r in results)
|
||||
assert mock_repo.get_total_attackers.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attackers_filtered_count_bypasses_cache():
|
||||
api_get_attackers._reset_total_cache()
|
||||
with patch.object(api_get_attackers, "repo") as mock_repo:
|
||||
mock_repo.get_attackers = AsyncMock(return_value=[])
|
||||
mock_repo.get_total_attackers = AsyncMock(return_value=0)
|
||||
mock_repo.get_behaviors_for_ips = AsyncMock(return_value={})
|
||||
for _ in range(3):
|
||||
await api_get_attackers.get_attackers(
|
||||
limit=50, offset=0, search="10.", sort_by="recent", service=None,
|
||||
user={"uuid": "u", "role": "viewer"},
|
||||
)
|
||||
assert mock_repo.get_total_attackers.await_count == 3
|
||||
for call in mock_repo.get_total_attackers.await_args_list:
|
||||
assert call.kwargs["search"] == "10."
|
||||
Reference in New Issue
Block a user