merge testing->tomerge/main #7

Open
anti wants to merge 242 commits from testing into tomerge/main
6 changed files with 233 additions and 4 deletions
Showing only changes of commit 6301504c0e - Show all commits

View File

@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Any, Optional
from fastapi import APIRouter, Depends, Query
@@ -8,6 +10,36 @@ from decnet.web.db.models import AttackersResponse
router = APIRouter()
# Same pattern as /logs — cache the unfiltered total count; filtered
# counts go straight to the DB.
_TOTAL_TTL = 2.0
_total_cache: tuple[Optional[int], float] = (None, 0.0)
_total_lock: Optional[asyncio.Lock] = None
def _reset_total_cache() -> None:
global _total_cache, _total_lock
_total_cache = (None, 0.0)
_total_lock = None
async def _get_total_attackers_cached() -> int:
global _total_cache, _total_lock
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
if _total_lock is None:
_total_lock = asyncio.Lock()
async with _total_lock:
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
value = await repo.get_total_attackers()
_total_cache = (value, time.monotonic())
return value
@router.get(
"/attackers",
@@ -37,7 +69,10 @@ async def get_attackers(
s = _norm(search)
svc = _norm(service)
_data = await repo.get_attackers(limit=limit, offset=offset, search=s, sort_by=sort_by, service=svc)
_total = await repo.get_total_attackers(search=s, service=svc)
if s is None and svc is None:
_total = await _get_total_attackers_cached()
else:
_total = await repo.get_total_attackers(search=s, service=svc)
# Bulk-join behavior rows for the IPs in this page to avoid N+1 queries.
_ips = {row["ip"] for row in _data if row.get("ip")}

View File

@@ -1,3 +1,5 @@
import asyncio
import time
from typing import Any, Optional
from fastapi import APIRouter, Depends, Query
@@ -8,6 +10,37 @@ from decnet.web.db.models import LogsResponse
router = APIRouter()
# Cache the unfiltered total-logs count. Filtered counts bypass the cache
# (rare, freshness matters for search). SELECT count(*) FROM logs is a
# full scan and gets hammered by paginating clients.
_TOTAL_TTL = 2.0
_total_cache: tuple[Optional[int], float] = (None, 0.0)
_total_lock: Optional[asyncio.Lock] = None
def _reset_total_cache() -> None:
global _total_cache, _total_lock
_total_cache = (None, 0.0)
_total_lock = None
async def _get_total_logs_cached() -> int:
global _total_cache, _total_lock
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
if _total_lock is None:
_total_lock = asyncio.Lock()
async with _total_lock:
value, ts = _total_cache
now = time.monotonic()
if value is not None and now - ts < _TOTAL_TTL:
return value
value = await repo.get_total_logs()
_total_cache = (value, time.monotonic())
return value
@router.get("/logs", response_model=LogsResponse, tags=["Logs"],
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}})
@@ -30,7 +63,10 @@ async def get_logs(
et = _norm(end_time)
_logs: list[dict[str, Any]] = await repo.get_logs(limit=limit, offset=offset, search=s, start_time=st, end_time=et)
_total: int = await repo.get_total_logs(search=s, start_time=st, end_time=et)
if s is None and st is None and et is None:
_total: int = await _get_total_logs_cached()
else:
_total = await repo.get_total_logs(search=s, start_time=st, end_time=et)
return {
"total": _total,
"limit": limit,

View File

@@ -1,4 +1,6 @@
from typing import Any
import asyncio
import time
from typing import Any, Optional
from fastapi import APIRouter, Depends
@@ -8,9 +10,41 @@ from decnet.web.db.models import StatsResponse
router = APIRouter()
# /stats is aggregate telemetry polled constantly by the UI and locust.
# A 5s window collapses thousands of concurrent calls — each of which
# runs SELECT count(*) FROM logs + SELECT count(DISTINCT attacker_ip) —
# into one DB hit per window.
_STATS_TTL = 5.0
_stats_cache: tuple[Optional[dict[str, Any]], float] = (None, 0.0)
_stats_lock: Optional[asyncio.Lock] = None
def _reset_stats_cache() -> None:
global _stats_cache, _stats_lock
_stats_cache = (None, 0.0)
_stats_lock = None
async def _get_stats_cached() -> dict[str, Any]:
global _stats_cache, _stats_lock
value, ts = _stats_cache
now = time.monotonic()
if value is not None and now - ts < _STATS_TTL:
return value
if _stats_lock is None:
_stats_lock = asyncio.Lock()
async with _stats_lock:
value, ts = _stats_cache
now = time.monotonic()
if value is not None and now - ts < _STATS_TTL:
return value
value = await repo.get_stats_summary()
_stats_cache = (value, time.monotonic())
return value
@router.get("/stats", response_model=StatsResponse, tags=["Observability"],
responses={401: {"description": "Could not validate credentials"}, 403: {"description": "Insufficient permissions"}, 422: {"description": "Validation error"}},)
@_traced("api.get_stats")
async def get_stats(user: dict = Depends(require_viewer)) -> dict[str, Any]:
return await repo.get_stats_summary()
return await _get_stats_cached()

View File

@@ -56,8 +56,14 @@ async def setup_db(monkeypatch) -> AsyncGenerator[None, None]:
# Reset per-request TTL caches so they don't leak across tests
from decnet.web.router.health import api_get_health as _h
from decnet.web.router.config import api_get_config as _c
from decnet.web.router.stats import api_get_stats as _s
from decnet.web.router.logs import api_get_logs as _l
from decnet.web.router.attackers import api_get_attackers as _a
_h._reset_db_cache()
_c._reset_state_cache()
_s._reset_stats_cache()
_l._reset_total_cache()
_a._reset_total_cache()
# Create schema
async with engine.begin() as conn:

View File

@@ -15,6 +15,14 @@ import pytest
from fastapi import HTTPException
from decnet.web.auth import create_access_token
from decnet.web.router.attackers.api_get_attackers import _reset_total_cache
@pytest.fixture(autouse=True)
def _reset_attackers_cache():
_reset_total_cache()
yield
_reset_total_cache()
# ─── Helpers ──────────────────────────────────────────────────────────────────

110
tests/test_router_cache.py Normal file
View File

@@ -0,0 +1,110 @@
"""
TTL-cache contract for /stats, /logs total count, and /attackers total count.
Under concurrent load N callers should collapse to 1 repo hit per TTL
window. Tests patch the repo — no real DB.
"""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from decnet.web.router.stats import api_get_stats
from decnet.web.router.logs import api_get_logs
from decnet.web.router.attackers import api_get_attackers
@pytest.fixture(autouse=True)
def _reset_router_caches():
api_get_stats._reset_stats_cache()
api_get_logs._reset_total_cache()
api_get_attackers._reset_total_cache()
yield
api_get_stats._reset_stats_cache()
api_get_logs._reset_total_cache()
api_get_attackers._reset_total_cache()
# ── /stats whole-response cache ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_stats_cache_collapses_concurrent_calls():
api_get_stats._reset_stats_cache()
payload = {"total_logs": 42, "unique_attackers": 7, "active_deckies": 3, "deployed_deckies": 3}
with patch.object(api_get_stats, "repo") as mock_repo:
mock_repo.get_stats_summary = AsyncMock(return_value=payload)
results = await asyncio.gather(*[api_get_stats._get_stats_cached() for _ in range(50)])
assert all(r == payload for r in results)
assert mock_repo.get_stats_summary.await_count == 1
@pytest.mark.asyncio
async def test_stats_cache_expires_after_ttl(monkeypatch):
api_get_stats._reset_stats_cache()
clock = {"t": 0.0}
monkeypatch.setattr(api_get_stats.time, "monotonic", lambda: clock["t"])
with patch.object(api_get_stats, "repo") as mock_repo:
mock_repo.get_stats_summary = AsyncMock(return_value={"total_logs": 1, "unique_attackers": 0, "active_deckies": 0, "deployed_deckies": 0})
await api_get_stats._get_stats_cached()
clock["t"] = 100.0 # past TTL
await api_get_stats._get_stats_cached()
assert mock_repo.get_stats_summary.await_count == 2
# ── /logs total-count cache ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_logs_total_cache_collapses_concurrent_calls():
api_get_logs._reset_total_cache()
with patch.object(api_get_logs, "repo") as mock_repo:
mock_repo.get_total_logs = AsyncMock(return_value=1234)
results = await asyncio.gather(*[api_get_logs._get_total_logs_cached() for _ in range(50)])
assert all(r == 1234 for r in results)
assert mock_repo.get_total_logs.await_count == 1
@pytest.mark.asyncio
async def test_logs_filtered_count_bypasses_cache():
"""When a filter is provided, the endpoint must hit repo every time."""
api_get_logs._reset_total_cache()
with patch.object(api_get_logs, "repo") as mock_repo:
mock_repo.get_logs = AsyncMock(return_value=[])
mock_repo.get_total_logs = AsyncMock(return_value=0)
for _ in range(3):
await api_get_logs.get_logs(
limit=50, offset=0, search="needle", start_time=None, end_time=None,
user={"uuid": "u", "role": "viewer"},
)
# 3 filtered calls → 3 repo hits, all with search=needle
assert mock_repo.get_total_logs.await_count == 3
for call in mock_repo.get_total_logs.await_args_list:
assert call.kwargs["search"] == "needle"
# ── /attackers total-count cache ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_attackers_total_cache_collapses_concurrent_calls():
api_get_attackers._reset_total_cache()
with patch.object(api_get_attackers, "repo") as mock_repo:
mock_repo.get_total_attackers = AsyncMock(return_value=99)
results = await asyncio.gather(*[api_get_attackers._get_total_attackers_cached() for _ in range(50)])
assert all(r == 99 for r in results)
assert mock_repo.get_total_attackers.await_count == 1
@pytest.mark.asyncio
async def test_attackers_filtered_count_bypasses_cache():
api_get_attackers._reset_total_cache()
with patch.object(api_get_attackers, "repo") as mock_repo:
mock_repo.get_attackers = AsyncMock(return_value=[])
mock_repo.get_total_attackers = AsyncMock(return_value=0)
mock_repo.get_behaviors_for_ips = AsyncMock(return_value={})
for _ in range(3):
await api_get_attackers.get_attackers(
limit=50, offset=0, search="10.", sort_by="recent", service=None,
user={"uuid": "u", "role": "viewer"},
)
assert mock_repo.get_total_attackers.await_count == 3
for call in mock_repo.get_total_attackers.await_args_list:
assert call.kwargs["search"] == "10."