diff --git a/decnet/profiler/worker.py b/decnet/profiler/worker.py index ebd1ed0..3d05418 100644 --- a/decnet/profiler/worker.py +++ b/decnet/profiler/worker.py @@ -59,10 +59,7 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) - async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None: - if not state.initialized: - await _cold_start(repo, state) - return - + was_cold = not state.initialized affected_ips: set[str] = set() while True: @@ -76,9 +73,13 @@ async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None affected_ips.add(event.attacker_ip) state.last_log_id = row["id"] + await asyncio.sleep(0) # yield to event loop after each batch + if len(batch) < _BATCH_SIZE: break + state.initialized = True + if not affected_ips: await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id}) return @@ -86,27 +87,10 @@ async def _incremental_update(repo: BaseRepository, state: _WorkerState) -> None await _update_profiles(repo, state, affected_ips) await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id}) - logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips)) - - -async def _cold_start(repo: BaseRepository, state: _WorkerState) -> None: - all_logs = await repo.get_all_logs_raw() - if not all_logs: - state.last_log_id = await repo.get_max_log_id() - state.initialized = True - await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id}) - return - - for row in all_logs: - state.engine.ingest(row["raw_line"]) - state.last_log_id = max(state.last_log_id, row["id"]) - - all_ips = set(state.engine._events.keys()) - await _update_profiles(repo, state, all_ips) - await repo.set_state(_STATE_KEY, {"last_log_id": state.last_log_id}) - - state.initialized = True - logger.info("attacker worker: cold start rebuilt %d profiles", len(all_ips)) + if was_cold: + logger.info("attacker worker: cold start rebuilt %d profiles", len(affected_ips)) + else: + logger.info("attacker worker: updated %d profiles (incremental)", len(affected_ips)) async def _update_profiles( diff --git a/decnet/web/db/repository.py b/decnet/web/db/repository.py index 97ba167..118c289 100644 --- a/decnet/web/db/repository.py +++ b/decnet/web/db/repository.py @@ -111,11 +111,6 @@ class BaseRepository(ABC): """Store a specific state entry by key.""" pass - @abstractmethod - async def get_all_logs_raw(self) -> list[dict[str, Any]]: - """Retrieve all log rows with fields needed by the attacker profile worker.""" - pass - @abstractmethod async def get_max_log_id(self) -> int: """Return the highest log ID, or 0 if the table is empty.""" diff --git a/decnet/web/db/sqlmodel_repo.py b/decnet/web/db/sqlmodel_repo.py index e50b652..7185f69 100644 --- a/decnet/web/db/sqlmodel_repo.py +++ b/decnet/web/db/sqlmodel_repo.py @@ -413,34 +413,6 @@ class SQLModelRepository(BaseRepository): # ----------------------------------------------------------- attackers - async def get_all_logs_raw(self) -> List[dict[str, Any]]: - async with self.session_factory() as session: - result = await session.execute( - select( - Log.id, - Log.raw_line, - Log.attacker_ip, - Log.service, - Log.event_type, - Log.decky, - Log.timestamp, - Log.fields, - ) - ) - return [ - { - "id": r.id, - "raw_line": r.raw_line, - "attacker_ip": r.attacker_ip, - "service": r.service, - "event_type": r.event_type, - "decky": r.decky, - "timestamp": r.timestamp, - "fields": r.fields, - } - for r in result.all() - ] - async def get_all_bounties_by_ip(self) -> dict[str, List[dict[str, Any]]]: from collections import defaultdict async with self.session_factory() as session: diff --git a/decnet/web/router/stream/api_stream_events.py b/decnet/web/router/stream/api_stream_events.py index 01f3e20..823a322 100644 --- a/decnet/web/router/stream/api_stream_events.py +++ b/decnet/web/router/stream/api_stream_events.py @@ -40,6 +40,8 @@ async def stream_events( loops_since_stats = 0 emitted_chunks = 0 try: + yield ": keepalive\n\n" # flush headers immediately; helps diagnose pre-yield hangs + if last_id == 0: last_id = await repo.get_max_log_id() @@ -90,4 +92,11 @@ async def stream_events( log.exception("SSE stream error for user %s", last_event_id) yield f"event: error\ndata: {json.dumps({'type': 'error', 'message': 'Stream interrupted'})}\n\n" - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) diff --git a/tests/test_attacker_worker.py b/tests/test_attacker_worker.py index bdc7502..c65dd4d 100644 --- a/tests/test_attacker_worker.py +++ b/tests/test_attacker_worker.py @@ -27,7 +27,6 @@ from decnet.profiler.worker import ( _STATE_KEY, _WorkerState, _build_record, - _cold_start, _extract_commands_from_events, _first_contact_deckies, _incremental_update, @@ -97,11 +96,12 @@ def _make_log_row( def _make_repo(logs=None, bounties=None, bounties_for_ips=None, max_log_id=0, saved_state=None): repo = MagicMock() - repo.get_all_logs_raw = AsyncMock(return_value=logs or []) repo.get_all_bounties_by_ip = AsyncMock(return_value=bounties or {}) repo.get_bounties_for_ips = AsyncMock(return_value=bounties_for_ips or {}) repo.get_max_log_id = AsyncMock(return_value=max_log_id) - repo.get_logs_after_id = AsyncMock(return_value=[]) + # Return provided logs on first call (simulating a single page < BATCH_SIZE), then [] to end loop + _log_pages = [logs or [], []] + repo.get_logs_after_id = AsyncMock(side_effect=_log_pages) repo.get_state = AsyncMock(return_value=saved_state) repo.set_state = AsyncMock() repo.upsert_attacker = AsyncMock(return_value="mock-uuid") @@ -283,7 +283,7 @@ class TestBuildRecord: assert record["updated_at"].tzinfo is not None -# ─── _cold_start ───────────────────────────────────────────────────────────── +# ─── cold start via _incremental_update (uninitialized state) ──────────────── class TestColdStart: @pytest.mark.asyncio @@ -299,7 +299,7 @@ class TestColdStart: repo = _make_repo(logs=rows, max_log_id=3) state = _WorkerState() - await _cold_start(repo, state) + await _incremental_update(repo, state) assert state.initialized is True assert state.last_log_id == 3 @@ -313,7 +313,7 @@ class TestColdStart: repo = _make_repo(logs=[], max_log_id=0) state = _WorkerState() - await _cold_start(repo, state) + await _incremental_update(repo, state) assert state.initialized is True assert state.last_log_id == 0 @@ -337,7 +337,7 @@ class TestColdStart: repo = _make_repo(logs=rows, max_log_id=2) state = _WorkerState() - await _cold_start(repo, state) + await _incremental_update(repo, state) record = repo.upsert_attacker.call_args[0][0] assert record["is_traversal"] is True @@ -357,7 +357,7 @@ class TestColdStart: ) state = _WorkerState() - await _cold_start(repo, state) + await _incremental_update(repo, state) record = repo.upsert_attacker.call_args[0][0] assert record["bounty_count"] == 2 @@ -376,7 +376,7 @@ class TestColdStart: repo = _make_repo(logs=[row], max_log_id=1) state = _WorkerState() - await _cold_start(repo, state) + await _incremental_update(repo, state) record = repo.upsert_attacker.call_args[0][0] commands = json.loads(record["commands"]) @@ -542,7 +542,7 @@ class TestIncrementalUpdate: assert called_ips == {"1.1.1.1", "2.2.2.2"} @pytest.mark.asyncio - async def test_uninitialized_state_triggers_cold_start(self): + async def test_uninitialized_state_runs_full_cursor_sweep(self): rows = [ _make_log_row( row_id=1, @@ -556,7 +556,8 @@ class TestIncrementalUpdate: await _incremental_update(repo, state) assert state.initialized is True - repo.get_all_logs_raw.assert_awaited_once() + assert state.last_log_id == 1 + repo.upsert_attacker.assert_awaited_once() # ─── attacker_profile_worker ──────────────────────────────────────────────── diff --git a/tests/test_base_repo.py b/tests/test_base_repo.py index cb04ac9..dd7531e 100644 --- a/tests/test_base_repo.py +++ b/tests/test_base_repo.py @@ -21,7 +21,6 @@ class DummyRepo(BaseRepository): async def get_total_bounties(self, **kw): await super().get_total_bounties(**kw) async def get_state(self, k): await super().get_state(k) async def set_state(self, k, v): await super().set_state(k, v) - async def get_all_logs_raw(self): await super().get_all_logs_raw() async def get_max_log_id(self): await super().get_max_log_id() async def get_logs_after_id(self, last_id, limit=500): await super().get_logs_after_id(last_id, limit) async def get_all_bounties_by_ip(self): await super().get_all_bounties_by_ip() @@ -58,7 +57,6 @@ async def test_base_repo_coverage(): await dr.get_total_bounties() await dr.get_state("k") await dr.set_state("k", "v") - await dr.get_all_logs_raw() await dr.get_max_log_id() await dr.get_logs_after_id(0) await dr.get_all_bounties_by_ip() diff --git a/tests/test_service_isolation.py b/tests/test_service_isolation.py index 42133a1..45734ec 100644 --- a/tests/test_service_isolation.py +++ b/tests/test_service_isolation.py @@ -212,8 +212,7 @@ class TestAttackerWorkerIsolation: from decnet.profiler.worker import _WorkerState, _incremental_update mock_repo = MagicMock() - mock_repo.get_all_logs_raw = AsyncMock(return_value=[]) - mock_repo.get_max_log_id = AsyncMock(return_value=0) + mock_repo.get_logs_after_id = AsyncMock(return_value=[]) mock_repo.set_state = AsyncMock() state = _WorkerState()