merge testing->tomerge/main #7
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
@@ -8,6 +10,36 @@ from decnet.web.db.models import AttackersResponse
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Same pattern as /logs — cache the unfiltered total count; filtered
|
||||||
|
# counts go straight to the DB.
|
||||||
|
_TOTAL_TTL = 2.0
|
||||||
|
_total_cache: tuple[Optional[int], float] = (None, 0.0)
|
||||||
|
_total_lock: Optional[asyncio.Lock] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_total_cache() -> None:
|
||||||
|
global _total_cache, _total_lock
|
||||||
|
_total_cache = (None, 0.0)
|
||||||
|
_total_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_total_attackers_cached() -> int:
|
||||||
|
global _total_cache, _total_lock
|
||||||
|
value, ts = _total_cache
|
||||||
|
now = time.monotonic()
|
||||||
|
if value is not None and now - ts < _TOTAL_TTL:
|
||||||
|
return value
|
||||||
|
if _total_lock is None:
|
||||||
|
_total_lock = asyncio.Lock()
|
||||||
|
async with _total_lock:
|
||||||
|
value, ts = _total_cache
|
||||||
|
now = time.monotonic()
|
||||||
|
if value is not None and now - ts < _TOTAL_TTL:
|
||||||
|
return value
|
||||||
|
value = await repo.get_total_attackers()
|
||||||
|
_total_cache = (value, time.monotonic())
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/attackers",
|
"/attackers",
|
||||||
@@ -37,7 +69,10 @@ async def get_attackers(
|
|||||||
s = _norm(search)
|
s = _norm(search)
|
||||||
svc = _norm(service)
|
svc = _norm(service)
|
||||||
_data = await repo.get_attackers(limit=limit, offset=offset, search=s, sort_by=sort_by, service=svc)
|
_data = await repo.get_attackers(limit=limit, offset=offset, search=s, sort_by=sort_by, service=svc)
|
||||||
_total = await repo.get_total_attackers(search=s, service=svc)
|
if s is None and svc is None:
|
||||||
|
_total = await _get_total_attackers_cached()
|
||||||
|
else:
|
||||||
|
_total = await repo.get_total_attackers(search=s, service=svc)
|
||||||
|
|
||||||
# Bulk-join behavior rows for the IPs in this page to avoid N+1 queries.
|
# Bulk-join behavior rows for the IPs in this page to avoid N+1 queries.
|
||||||
_ips = {row["ip"] for row in _data if row.get("ip")}
|
_ips = {row["ip"] for row in _data if row.get("ip")}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
@@ -8,6 +10,37 @@ from decnet.web.db.models import LogsResponse
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Cache the unfiltered total-logs count. Filtered counts bypass the cache
|
||||||
|
# (rare, freshness matters for search). SELECT count(*) FROM logs is a
|
||||||
|
# full scan and gets hammered by paginating clients.
|
||||||
|
_TOTAL_TTL = 2.0
|
||||||
|
_total_cache: tuple[Optional[int], float] = (None, 0.0)
|
||||||
|
_total_lock: Optional[asyncio.Lock] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_total_cache() -> None:
|
||||||
|
global _total_cache, _total_lock
|
||||||
|
_total_cache = (None, 0.0)
|
||||||
|
_total_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_total_logs_cached() -> int:
|
||||||
|
global _total_cache, _total_lock
|
||||||
|
value, ts = _total_cache
|
||||||
|
now = time.monotonic()
|
||||||
|
if value is not None and now - ts < _TOTAL_TTL:
|
||||||
|
return value
|
||||||
|
if _total_lock is None:
|
||||||
|
_total_lock = asyncio.Lock()
|
||||||
|
async with _total_lock:
|
||||||
|
value, ts = _total_cache
|
||||||
|
now = time.monotonic()
|
||||||
|
if value is not None and now - ts < _TOTAL_TTL:
|
||||||
|
return value
|
||||||
|
value = await repo.get_total_logs()
|
||||||
|
_total_cache = (value, time.monotonic())
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@router.get("/logs", response_model=LogsResponse, tags=["Logs"],
|
@router.get("/logs", response_model=LogsResponse, tags=["Logs"],
|
||||||
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}})
|
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}})
|
||||||
@@ -30,7 +63,10 @@ async def get_logs(
|
|||||||
et = _norm(end_time)
|
et = _norm(end_time)
|
||||||
|
|
||||||
_logs: list[dict[str, Any]] = await repo.get_logs(limit=limit, offset=offset, search=s, start_time=st, end_time=et)
|
_logs: list[dict[str, Any]] = await repo.get_logs(limit=limit, offset=offset, search=s, start_time=st, end_time=et)
|
||||||
_total: int = await repo.get_total_logs(search=s, start_time=st, end_time=et)
|
if s is None and st is None and et is None:
|
||||||
|
_total: int = await _get_total_logs_cached()
|
||||||
|
else:
|
||||||
|
_total = await repo.get_total_logs(search=s, start_time=st, end_time=et)
|
||||||
return {
|
return {
|
||||||
"total": _total,
|
"total": _total,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from typing import Any
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
@@ -8,9 +10,41 @@ from decnet.web.db.models import StatsResponse
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
# /stats is aggregate telemetry polled constantly by the UI and locust.
|
||||||
|
# A 5s window collapses thousands of concurrent calls — each of which
|
||||||
|
# runs SELECT count(*) FROM logs + SELECT count(DISTINCT attacker_ip) —
|
||||||
|
# into one DB hit per window.
|
||||||
|
_STATS_TTL = 5.0
|
||||||
|
_stats_cache: tuple[Optional[dict[str, Any]], float] = (None, 0.0)
|
||||||
|
_stats_lock: Optional[asyncio.Lock] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_stats_cache() -> None:
|
||||||
|
global _stats_cache, _stats_lock
|
||||||
|
_stats_cache = (None, 0.0)
|
||||||
|
_stats_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_stats_cached() -> dict[str, Any]:
|
||||||
|
global _stats_cache, _stats_lock
|
||||||
|
value, ts = _stats_cache
|
||||||
|
now = time.monotonic()
|
||||||
|
if value is not None and now - ts < _STATS_TTL:
|
||||||
|
return value
|
||||||
|
if _stats_lock is None:
|
||||||
|
_stats_lock = asyncio.Lock()
|
||||||
|
async with _stats_lock:
|
||||||
|
value, ts = _stats_cache
|
||||||
|
now = time.monotonic()
|
||||||
|
if value is not None and now - ts < _STATS_TTL:
|
||||||
|
return value
|
||||||
|
value = await repo.get_stats_summary()
|
||||||
|
_stats_cache = (value, time.monotonic())
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats", response_model=StatsResponse, tags=["Observability"],
|
@router.get("/stats", response_model=StatsResponse, tags=["Observability"],
|
||||||
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}},)
|
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}},)
|
||||||
@_traced("api.get_stats")
|
@_traced("api.get_stats")
|
||||||
async def get_stats(user: dict = Depends(require_viewer)) -> dict[str, Any]:
|
async def get_stats(user: dict = Depends(require_viewer)) -> dict[str, Any]:
|
||||||
return await repo.get_stats_summary()
|
return await _get_stats_cached()
|
||||||
|
|||||||
@@ -56,8 +56,14 @@ async def setup_db(monkeypatch) -> AsyncGenerator[None, None]:
|
|||||||
# Reset per-request TTL caches so they don't leak across tests
|
# Reset per-request TTL caches so they don't leak across tests
|
||||||
from decnet.web.router.health import api_get_health as _h
|
from decnet.web.router.health import api_get_health as _h
|
||||||
from decnet.web.router.config import api_get_config as _c
|
from decnet.web.router.config import api_get_config as _c
|
||||||
|
from decnet.web.router.stats import api_get_stats as _s
|
||||||
|
from decnet.web.router.logs import api_get_logs as _l
|
||||||
|
from decnet.web.router.attackers import api_get_attackers as _a
|
||||||
_h._reset_db_cache()
|
_h._reset_db_cache()
|
||||||
_c._reset_state_cache()
|
_c._reset_state_cache()
|
||||||
|
_s._reset_stats_cache()
|
||||||
|
_l._reset_total_cache()
|
||||||
|
_a._reset_total_cache()
|
||||||
|
|
||||||
# Create schema
|
# Create schema
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
|
|||||||
@@ -15,6 +15,14 @@ import pytest
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from decnet.web.auth import create_access_token
|
from decnet.web.auth import create_access_token
|
||||||
|
from decnet.web.router.attackers.api_get_attackers import _reset_total_cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_attackers_cache():
|
||||||
|
_reset_total_cache()
|
||||||
|
yield
|
||||||
|
_reset_total_cache()
|
||||||
|
|
||||||
|
|
||||||
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
# ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|||||||
110
tests/test_router_cache.py
Normal file
110
tests/test_router_cache.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""
|
||||||
|
TTL-cache contract for /stats, /logs total count, and /attackers total count.
|
||||||
|
|
||||||
|
Under concurrent load N callers should collapse to 1 repo hit per TTL
|
||||||
|
window. Tests patch the repo — no real DB.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from decnet.web.router.stats import api_get_stats
|
||||||
|
from decnet.web.router.logs import api_get_logs
|
||||||
|
from decnet.web.router.attackers import api_get_attackers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_router_caches():
|
||||||
|
api_get_stats._reset_stats_cache()
|
||||||
|
api_get_logs._reset_total_cache()
|
||||||
|
api_get_attackers._reset_total_cache()
|
||||||
|
yield
|
||||||
|
api_get_stats._reset_stats_cache()
|
||||||
|
api_get_logs._reset_total_cache()
|
||||||
|
api_get_attackers._reset_total_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# ── /stats whole-response cache ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_cache_collapses_concurrent_calls():
|
||||||
|
api_get_stats._reset_stats_cache()
|
||||||
|
payload = {"total_logs": 42, "unique_attackers": 7, "active_deckies": 3, "deployed_deckies": 3}
|
||||||
|
with patch.object(api_get_stats, "repo") as mock_repo:
|
||||||
|
mock_repo.get_stats_summary = AsyncMock(return_value=payload)
|
||||||
|
results = await asyncio.gather(*[api_get_stats._get_stats_cached() for _ in range(50)])
|
||||||
|
assert all(r == payload for r in results)
|
||||||
|
assert mock_repo.get_stats_summary.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_cache_expires_after_ttl(monkeypatch):
|
||||||
|
api_get_stats._reset_stats_cache()
|
||||||
|
clock = {"t": 0.0}
|
||||||
|
monkeypatch.setattr(api_get_stats.time, "monotonic", lambda: clock["t"])
|
||||||
|
with patch.object(api_get_stats, "repo") as mock_repo:
|
||||||
|
mock_repo.get_stats_summary = AsyncMock(return_value={"total_logs": 1, "unique_attackers": 0, "active_deckies": 0, "deployed_deckies": 0})
|
||||||
|
await api_get_stats._get_stats_cached()
|
||||||
|
clock["t"] = 100.0 # past TTL
|
||||||
|
await api_get_stats._get_stats_cached()
|
||||||
|
assert mock_repo.get_stats_summary.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── /logs total-count cache ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logs_total_cache_collapses_concurrent_calls():
|
||||||
|
api_get_logs._reset_total_cache()
|
||||||
|
with patch.object(api_get_logs, "repo") as mock_repo:
|
||||||
|
mock_repo.get_total_logs = AsyncMock(return_value=1234)
|
||||||
|
results = await asyncio.gather(*[api_get_logs._get_total_logs_cached() for _ in range(50)])
|
||||||
|
assert all(r == 1234 for r in results)
|
||||||
|
assert mock_repo.get_total_logs.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logs_filtered_count_bypasses_cache():
|
||||||
|
"""When a filter is provided, the endpoint must hit repo every time."""
|
||||||
|
api_get_logs._reset_total_cache()
|
||||||
|
with patch.object(api_get_logs, "repo") as mock_repo:
|
||||||
|
mock_repo.get_logs = AsyncMock(return_value=[])
|
||||||
|
mock_repo.get_total_logs = AsyncMock(return_value=0)
|
||||||
|
for _ in range(3):
|
||||||
|
await api_get_logs.get_logs(
|
||||||
|
limit=50, offset=0, search="needle", start_time=None, end_time=None,
|
||||||
|
user={"uuid": "u", "role": "viewer"},
|
||||||
|
)
|
||||||
|
# 3 filtered calls → 3 repo hits, all with search=needle
|
||||||
|
assert mock_repo.get_total_logs.await_count == 3
|
||||||
|
for call in mock_repo.get_total_logs.await_args_list:
|
||||||
|
assert call.kwargs["search"] == "needle"
|
||||||
|
|
||||||
|
|
||||||
|
# ── /attackers total-count cache ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attackers_total_cache_collapses_concurrent_calls():
|
||||||
|
api_get_attackers._reset_total_cache()
|
||||||
|
with patch.object(api_get_attackers, "repo") as mock_repo:
|
||||||
|
mock_repo.get_total_attackers = AsyncMock(return_value=99)
|
||||||
|
results = await asyncio.gather(*[api_get_attackers._get_total_attackers_cached() for _ in range(50)])
|
||||||
|
assert all(r == 99 for r in results)
|
||||||
|
assert mock_repo.get_total_attackers.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attackers_filtered_count_bypasses_cache():
|
||||||
|
api_get_attackers._reset_total_cache()
|
||||||
|
with patch.object(api_get_attackers, "repo") as mock_repo:
|
||||||
|
mock_repo.get_attackers = AsyncMock(return_value=[])
|
||||||
|
mock_repo.get_total_attackers = AsyncMock(return_value=0)
|
||||||
|
mock_repo.get_behaviors_for_ips = AsyncMock(return_value={})
|
||||||
|
for _ in range(3):
|
||||||
|
await api_get_attackers.get_attackers(
|
||||||
|
limit=50, offset=0, search="10.", sort_by="recent", service=None,
|
||||||
|
user={"uuid": "u", "role": "viewer"},
|
||||||
|
)
|
||||||
|
assert mock_repo.get_total_attackers.await_count == 3
|
||||||
|
for call in mock_repo.get_total_attackers.await_args_list:
|
||||||
|
assert call.kwargs["search"] == "10."
|
||||||
Reference in New Issue
Block a user