fix: remove event-loop-blocking cold start; unify profiler to cursor-based incremental
Cold start fetched all logs in one bulk query then processed them in a tight synchronous loop with no yields, blocking the asyncio event loop for seconds on datasets of 30K+ rows. This stalled every concurrent await — including the SSE stream generator's initial DB calls — causing the dashboard to show INITIALIZING SENSORS indefinitely. Changes: - Drop _cold_start() and get_all_logs_raw(); uninitialized state now runs the same cursor loop as incremental, starting from last_log_id=0 - Yield to the event loop after every _BATCH_SIZE rows (asyncio.sleep(0)) - Add SSE keepalive comment as first yield so the connection flushes before any DB work begins - Add Cache-Control/X-Accel-Buffering headers to StreamingResponse
This commit is contained in:
@@ -59,10 +59,7 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) -
|
|||||||
|
|
||||||
|
|
||||||
async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None:
|
async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None:
|
||||||
if not state.initialized:
|
was_cold = not state.initialized
|
||||||
await _cold_start(repo, state)
|
|
||||||
return
|
|
||||||
|
|
||||||
affected_ips: set[str] = set()
|
affected_ips: set[str] = set()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -76,9 +73,13 @@ async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None
|
|||||||
affected_ips.add(event.attacker_ip)
|
affected_ips.add(event.attacker_ip)
|
||||||
state.last_log_id = row["id"]
|
state.last_log_id = row["id"]
|
||||||
|
|
||||||
|
await asyncio.sleep(0) # yield to event loop after each batch
|
||||||
|
|
||||||
if len(batch) < _BATCH_SIZE:
|
if len(batch) < _BATCH_SIZE:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
state.initialized = True
|
||||||
|
|
||||||
if not affected_ips:
|
if not affected_ips:
|
||||||
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
||||||
return
|
return
|
||||||
@@ -86,29 +87,12 @@ async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None
|
|||||||
await _update_profiles(repo, state, affected_ips)
|
await _update_profiles(repo, state, affected_ips)
|
||||||
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
||||||
|
|
||||||
|
if was_cold:
|
||||||
|
logger.info("attacker worker: cold start rebuilt %d profiles", len(affected_ips))
|
||||||
|
else:
|
||||||
logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips))
|
logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips))
|
||||||
|
|
||||||
|
|
||||||
async def _cold_start(repo: BaseRepository, state: _WorkerState) -> None:
|
|
||||||
all_logs = await repo.get_all_logs_raw()
|
|
||||||
if not all_logs:
|
|
||||||
state.last_log_id = await repo.get_max_log_id()
|
|
||||||
state.initialized = True
|
|
||||||
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
|
||||||
return
|
|
||||||
|
|
||||||
for row in all_logs:
|
|
||||||
state.engine.ingest(row["raw_line"])
|
|
||||||
state.last_log_id = max(state.last_log_id, row["id"])
|
|
||||||
|
|
||||||
all_ips = set(state.engine._events.keys())
|
|
||||||
await _update_profiles(repo, state, all_ips)
|
|
||||||
await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id})
|
|
||||||
|
|
||||||
state.initialized = True
|
|
||||||
logger.info("attacker worker: cold start rebuilt %d profiles", len(all_ips))
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_profiles(
|
async def _update_profiles(
|
||||||
repo: BaseRepository,
|
repo: BaseRepository,
|
||||||
state: _WorkerState,
|
state: _WorkerState,
|
||||||
|
|||||||
@@ -111,11 +111,6 @@ class BaseRepository(ABC):
|
|||||||
"""Store a specific state entry by key."""
|
"""Store a specific state entry by key."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_all_logs_raw(self) -> list[dict[str, Any]]:
|
|
||||||
"""Retrieve all log rows with fields needed by the attacker profile worker."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_max_log_id(self) -> int:
|
async def get_max_log_id(self) -> int:
|
||||||
"""Return the highest log ID, or 0 if the table is empty."""
|
"""Return the highest log ID, or 0 if the table is empty."""
|
||||||
|
|||||||
@@ -413,34 +413,6 @@ class SQLModelRepository(BaseRepository):
|
|||||||
|
|
||||||
# ----------------------------------------------------------- attackers
|
# ----------------------------------------------------------- attackers
|
||||||
|
|
||||||
async def get_all_logs_raw(self) -> List[dict[str, Any]]:
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
result = await session.execute(
|
|
||||||
select(
|
|
||||||
Log.id,
|
|
||||||
Log.raw_line,
|
|
||||||
Log.attacker_ip,
|
|
||||||
Log.service,
|
|
||||||
Log.event_type,
|
|
||||||
Log.decky,
|
|
||||||
Log.timestamp,
|
|
||||||
Log.fields,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": r.id,
|
|
||||||
"raw_line": r.raw_line,
|
|
||||||
"attacker_ip": r.attacker_ip,
|
|
||||||
"service": r.service,
|
|
||||||
"event_type": r.event_type,
|
|
||||||
"decky": r.decky,
|
|
||||||
"timestamp": r.timestamp,
|
|
||||||
"fields": r.fields,
|
|
||||||
}
|
|
||||||
for r in result.all()
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]:
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ async def stream_events(
|
|||||||
loops_since_stats = 0
|
loops_since_stats = 0
|
||||||
emitted_chunks = 0
|
emitted_chunks = 0
|
||||||
try:
|
try:
|
||||||
|
yield ": keepalive\n\n" # flush headers immediately; helps diagnose pre-yield hangs
|
||||||
|
|
||||||
if last_id == 0:
|
if last_id == 0:
|
||||||
last_id = await repo.get_max_log_id()
|
last_id = await repo.get_max_log_id()
|
||||||
|
|
||||||
@@ -90,4 +92,11 @@ async def stream_events(
|
|||||||
log.exception("SSE stream error for user %s", last_event_id)
|
log.exception("SSE stream error for user %s", last_event_id)
|
||||||
yield f"event: error\ndata: {json.dumps({'type': 'error', 'message': 'Stream interrupted'})}\n\n"
|
yield f"event: error\ndata: {json.dumps({'type': 'error', 'message': 'Stream interrupted'})}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from decnet.profiler.worker import (
|
|||||||
_STATE_KEY,
|
_STATE_KEY,
|
||||||
_WorkerState,
|
_WorkerState,
|
||||||
_build_record,
|
_build_record,
|
||||||
_cold_start,
|
|
||||||
_extract_commands_from_events,
|
_extract_commands_from_events,
|
||||||
_first_contact_deckies,
|
_first_contact_deckies,
|
||||||
_incremental_update,
|
_incremental_update,
|
||||||
@@ -97,11 +96,12 @@ def _make_log_row(
|
|||||||
|
|
||||||
def _make_repo(logs=None, bounties=None, bounties_for_ips=None, max_log_id=0, saved_state=None):
|
def _make_repo(logs=None, bounties=None, bounties_for_ips=None, max_log_id=0, saved_state=None):
|
||||||
repo = MagicMock()
|
repo = MagicMock()
|
||||||
repo.get_all_logs_raw = AsyncMock(return_value=logs or [])
|
|
||||||
repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {})
|
repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {})
|
||||||
repo.get_bounties_for_ips = AsyncMock(return_value=bounties_for_ips or {})
|
repo.get_bounties_for_ips = AsyncMock(return_value=bounties_for_ips or {})
|
||||||
repo.get_max_log_id = AsyncMock(return_value=max_log_id)
|
repo.get_max_log_id = AsyncMock(return_value=max_log_id)
|
||||||
repo.get_logs_after_id = AsyncMock(return_value=[])
|
# Return provided logs on first call (simulating a single page < BATCH_SIZE), then [] to end loop
|
||||||
|
_log_pages = [logs or [], []]
|
||||||
|
repo.get_logs_after_id = AsyncMock(side_effect=_log_pages)
|
||||||
repo.get_state = AsyncMock(return_value=saved_state)
|
repo.get_state = AsyncMock(return_value=saved_state)
|
||||||
repo.set_state = AsyncMock()
|
repo.set_state = AsyncMock()
|
||||||
repo.upsert_attacker = AsyncMock(return_value="mock-uuid")
|
repo.upsert_attacker = AsyncMock(return_value="mock-uuid")
|
||||||
@@ -283,7 +283,7 @@ class TestBuildRecord:
|
|||||||
assert record["updated_at"].tzinfo is not None
|
assert record["updated_at"].tzinfo is not None
|
||||||
|
|
||||||
|
|
||||||
# ─── _cold_start ─────────────────────────────────────────────────────────────
|
# ─── cold start via _incremental_update (uninitialized state) ────────────────
|
||||||
|
|
||||||
class TestColdStart:
|
class TestColdStart:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -299,7 +299,7 @@ class TestColdStart:
|
|||||||
repo = _make_repo(logs=rows, max_log_id=3)
|
repo = _make_repo(logs=rows, max_log_id=3)
|
||||||
state = _WorkerState()
|
state = _WorkerState()
|
||||||
|
|
||||||
await _cold_start(repo, state)
|
await _incremental_update(repo, state)
|
||||||
|
|
||||||
assert state.initialized is True
|
assert state.initialized is True
|
||||||
assert state.last_log_id == 3
|
assert state.last_log_id == 3
|
||||||
@@ -313,7 +313,7 @@ class TestColdStart:
|
|||||||
repo = _make_repo(logs=[], max_log_id=0)
|
repo = _make_repo(logs=[], max_log_id=0)
|
||||||
state = _WorkerState()
|
state = _WorkerState()
|
||||||
|
|
||||||
await _cold_start(repo, state)
|
await _incremental_update(repo, state)
|
||||||
|
|
||||||
assert state.initialized is True
|
assert state.initialized is True
|
||||||
assert state.last_log_id == 0
|
assert state.last_log_id == 0
|
||||||
@@ -337,7 +337,7 @@ class TestColdStart:
|
|||||||
repo = _make_repo(logs=rows, max_log_id=2)
|
repo = _make_repo(logs=rows, max_log_id=2)
|
||||||
state = _WorkerState()
|
state = _WorkerState()
|
||||||
|
|
||||||
await _cold_start(repo, state)
|
await _incremental_update(repo, state)
|
||||||
|
|
||||||
record = repo.upsert_attacker.call_args[0][0]
|
record = repo.upsert_attacker.call_args[0][0]
|
||||||
assert record["is_traversal"] is True
|
assert record["is_traversal"] is True
|
||||||
@@ -357,7 +357,7 @@ class TestColdStart:
|
|||||||
)
|
)
|
||||||
state = _WorkerState()
|
state = _WorkerState()
|
||||||
|
|
||||||
await _cold_start(repo, state)
|
await _incremental_update(repo, state)
|
||||||
|
|
||||||
record = repo.upsert_attacker.call_args[0][0]
|
record = repo.upsert_attacker.call_args[0][0]
|
||||||
assert record["bounty_count"] == 2
|
assert record["bounty_count"] == 2
|
||||||
@@ -376,7 +376,7 @@ class TestColdStart:
|
|||||||
repo = _make_repo(logs=[row], max_log_id=1)
|
repo = _make_repo(logs=[row], max_log_id=1)
|
||||||
state = _WorkerState()
|
state = _WorkerState()
|
||||||
|
|
||||||
await _cold_start(repo, state)
|
await _incremental_update(repo, state)
|
||||||
|
|
||||||
record = repo.upsert_attacker.call_args[0][0]
|
record = repo.upsert_attacker.call_args[0][0]
|
||||||
commands = json.loads(record["commands"])
|
commands = json.loads(record["commands"])
|
||||||
@@ -542,7 +542,7 @@ class TestIncrementalUpdate:
|
|||||||
assert called_ips == {"1.1.1.1", "2.2.2.2"}
|
assert called_ips == {"1.1.1.1", "2.2.2.2"}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_uninitialized_state_triggers_cold_start(self):
|
async def test_uninitialized_state_runs_full_cursor_sweep(self):
|
||||||
rows = [
|
rows = [
|
||||||
_make_log_row(
|
_make_log_row(
|
||||||
row_id=1,
|
row_id=1,
|
||||||
@@ -556,7 +556,8 @@ class TestIncrementalUpdate:
|
|||||||
await _incremental_update(repo, state)
|
await _incremental_update(repo, state)
|
||||||
|
|
||||||
assert state.initialized is True
|
assert state.initialized is True
|
||||||
repo.get_all_logs_raw.assert_awaited_once()
|
assert state.last_log_id == 1
|
||||||
|
repo.upsert_attacker.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
# ─── attacker_profile_worker ────────────────────────────────────────────────
|
# ─── attacker_profile_worker ────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class DummyRepo(BaseRepository):
|
|||||||
async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw)
|
async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw)
|
||||||
async def get_state(self, k): await super().get_state(k)
|
async def get_state(self, k): await super().get_state(k)
|
||||||
async def set_state(self, k, v): await super().set_state(k, v)
|
async def set_state(self, k, v): await super().set_state(k, v)
|
||||||
async def get_all_logs_raw(self): await super().get_all_logs_raw()
|
|
||||||
async def get_max_log_id(self): await super().get_max_log_id()
|
async def get_max_log_id(self): await super().get_max_log_id()
|
||||||
async def get_logs_after_id(self, last_id, limit=500): await super().get_logs_after_id(last_id, limit)
|
async def get_logs_after_id(self, last_id, limit=500): await super().get_logs_after_id(last_id, limit)
|
||||||
async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip()
|
async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip()
|
||||||
@@ -58,7 +57,6 @@ async def test_base_repo_coverage():
|
|||||||
await dr.get_total_bounties()
|
await dr.get_total_bounties()
|
||||||
await dr.get_state("k")
|
await dr.get_state("k")
|
||||||
await dr.set_state("k", "v")
|
await dr.set_state("k", "v")
|
||||||
await dr.get_all_logs_raw()
|
|
||||||
await dr.get_max_log_id()
|
await dr.get_max_log_id()
|
||||||
await dr.get_logs_after_id(0)
|
await dr.get_logs_after_id(0)
|
||||||
await dr.get_all_bounties_by_ip()
|
await dr.get_all_bounties_by_ip()
|
||||||
|
|||||||
@@ -212,8 +212,7 @@ class TestAttackerWorkerIsolation:
|
|||||||
from decnet.profiler.worker import _WorkerState, _incremental_update
|
from decnet.profiler.worker import _WorkerState, _incremental_update
|
||||||
|
|
||||||
mock_repo = MagicMock()
|
mock_repo = MagicMock()
|
||||||
mock_repo.get_all_logs_raw = AsyncMock(return_value=[])
|
mock_repo.get_logs_after_id = AsyncMock(return_value=[])
|
||||||
mock_repo.get_max_log_id = AsyncMock(return_value=0)
|
|
||||||
mock_repo.set_state = AsyncMock()
|
mock_repo.set_state = AsyncMock()
|
||||||
|
|
||||||
state = _WorkerState()
|
state = _WorkerState()
|
||||||
|
|||||||
Reference in New Issue
Block a user