diff --git a/decnet/profiler/worker.py b/decnet/profiler/worker.py index 3d05418..0cabec6 100644 --- a/decnet/profiler/worker.py +++ b/decnet/profiler/worker.py @@ -50,6 +50,11 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) - """Periodically updates the Attacker table incrementally. Designed to run as an asyncio Task.""" logger.info("attacker profile worker started interval=%ds", interval) state = _WorkerState() + _saved_cursor = await repo.get_state(_STATE_KEY) + if _saved_cursor: + state.last_log_id = _saved_cursor.get("last_log_id", 0) + state.initialized = True + logger.info("attacker worker: resumed from cursor last_log_id=%d", state.last_log_id) while True: await asyncio.sleep(interval) try: diff --git a/decnet/web/ingester.py b/decnet/web/ingester.py index 513e958..188a833 100644 --- a/decnet/web/ingester.py +++ b/decnet/web/ingester.py @@ -9,6 +9,9 @@ from decnet.web.db.repository import BaseRepository logger = get_logger("api") +_INGEST_STATE_KEY = "ingest_worker_position" + + async def log_ingestion_worker(repo: BaseRepository) -> None: """ Background task that tails the DECNET_INGEST_LOG_FILE.json and @@ -20,9 +23,11 @@ async def log_ingestion_worker(repo: BaseRepository) -> None: return _json_log_path: Path = Path(_base_log_file).with_suffix(".json") - _position: int = 0 - logger.info("ingest worker started path=%s", _json_log_path) + _saved = await repo.get_state(_INGEST_STATE_KEY) + _position: int = _saved.get("position", 0) if _saved else 0 + + logger.info("ingest worker started path=%s position=%d", _json_log_path, _position) while True: try: @@ -34,6 +39,7 @@ async def log_ingestion_worker(repo: BaseRepository) -> None: if _stat.st_size < _position: # File rotated or truncated _position = 0 + await repo.set_state(_INGEST_STATE_KEY, {"position": 0}) if _stat.st_size == _position: # No new data @@ -63,6 +69,8 @@ async def log_ingestion_worker(repo: BaseRepository) -> None: # Update position after successful line read _position = _f.tell() + await repo.set_state(_INGEST_STATE_KEY, {"position": _position}) + except Exception as _e: _err_str = str(_e).lower() if "no such table" in _err_str or "no active connection" in _err_str or "connection closed" in _err_str: diff --git a/tests/test_attacker_worker.py b/tests/test_attacker_worker.py index c65dd4d..8049258 100644 --- a/tests/test_attacker_worker.py +++ b/tests/test_attacker_worker.py @@ -614,6 +614,60 @@ class TestAttackerProfileWorker: assert len(update_calls) >= 1 + @pytest.mark.asyncio + async def test_cursor_restored_from_db_on_startup(self): + """Worker loads saved last_log_id from DB and passes it to _incremental_update.""" + repo = _make_repo(saved_state={"last_log_id": 99}) + _call_count = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + captured_states = [] + + async def mock_update(_repo, state): + captured_states.append((state.last_log_id, state.initialized)) + + with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep): + with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update): + with pytest.raises(asyncio.CancelledError): + await attacker_profile_worker(repo) + + assert captured_states, "_incremental_update never called" + restored_id, initialized = captured_states[0] + assert restored_id == 99 + assert initialized is True + + @pytest.mark.asyncio + async def test_no_saved_cursor_starts_from_zero(self): + """When get_state returns None, worker starts fresh from log ID 0.""" + repo = _make_repo(saved_state=None) + _call_count = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + captured_states = [] + + async def mock_update(_repo, state): + captured_states.append((state.last_log_id, state.initialized)) + + with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep): + with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update): + with pytest.raises(asyncio.CancelledError): + await attacker_profile_worker(repo) + + assert captured_states, "_incremental_update never called" + restored_id, initialized = captured_states[0] + assert restored_id == 0 + assert initialized is False + # ─── JA3 bounty extraction from ingester ───────────────────────────────────── diff --git a/tests/test_ingester.py b/tests/test_ingester.py index bb3ae8a..3ad1d55 100644 --- a/tests/test_ingester.py +++ b/tests/test_ingester.py @@ -85,6 +85,8 @@ class TestLogIngestionWorker: from decnet.web.ingester import log_ingestion_worker mock_repo = MagicMock() mock_repo.add_log = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() log_file = str(tmp_path / "nonexistent.log") _call_count: int = 0 @@ -106,6 +108,8 @@ class TestLogIngestionWorker: mock_repo = MagicMock() mock_repo.add_log = AsyncMock() mock_repo.add_bounty = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() log_file = str(tmp_path / "test.log") json_file = tmp_path / "test.json" @@ -135,6 +139,8 @@ class TestLogIngestionWorker: mock_repo = MagicMock() mock_repo.add_log = AsyncMock() mock_repo.add_bounty = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() log_file = str(tmp_path / "test.log") json_file = tmp_path / "test.json" @@ -161,6 +167,8 @@ class TestLogIngestionWorker: mock_repo = MagicMock() mock_repo.add_log = AsyncMock() mock_repo.add_bounty = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() log_file = str(tmp_path / "test.log") json_file = tmp_path / "test.json" @@ -195,6 +203,8 @@ class TestLogIngestionWorker: mock_repo = MagicMock() mock_repo.add_log = AsyncMock() mock_repo.add_bounty = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() log_file = str(tmp_path / "test.log") json_file = tmp_path / "test.json" @@ -215,3 +225,117 @@ class TestLogIngestionWorker: await log_ingestion_worker(mock_repo) mock_repo.add_log.assert_not_awaited() + + @pytest.mark.asyncio + async def test_position_restored_skips_already_seen_lines(self, tmp_path): + """Worker resumes from saved position and skips already-ingested content.""" + from decnet.web.ingester import log_ingestion_worker + mock_repo = MagicMock() + mock_repo.add_log = AsyncMock() + mock_repo.add_bounty = AsyncMock() + mock_repo.set_state = AsyncMock() + + log_file = str(tmp_path / "test.log") + json_file = tmp_path / "test.json" + + line_old = json.dumps({"decky": "d1", "service": "ssh", "event_type": "auth", + "attacker_ip": "1.1.1.1", "fields": {}, "raw_line": "x", "msg": ""}) + "\n" + line_new = json.dumps({"decky": "d2", "service": "ftp", "event_type": "auth", + "attacker_ip": "2.2.2.2", "fields": {}, "raw_line": "y", "msg": ""}) + "\n" + + json_file.write_text(line_old + line_new) + + # Saved position points to end of first line — only line_new should be ingested + saved_position = len(line_old.encode("utf-8")) + mock_repo.get_state = AsyncMock(return_value={"position": saved_position}) + + _call_count: int = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}): + with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(asyncio.CancelledError): + await log_ingestion_worker(mock_repo) + + assert mock_repo.add_log.await_count == 1 + ingested = mock_repo.add_log.call_args[0][0] + assert ingested["attacker_ip"] == "2.2.2.2" + + @pytest.mark.asyncio + async def test_set_state_called_with_position_after_batch(self, tmp_path): + """set_state is called with the updated byte position after processing lines.""" + from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY + mock_repo = MagicMock() + mock_repo.add_log = AsyncMock() + mock_repo.add_bounty = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() + + log_file = str(tmp_path / "test.log") + json_file = tmp_path / "test.json" + line = json.dumps({"decky": "d1", "service": "ssh", "event_type": "auth", + "attacker_ip": "1.1.1.1", "fields": {}, "raw_line": "x", "msg": ""}) + "\n" + json_file.write_text(line) + + _call_count: int = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}): + with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(asyncio.CancelledError): + await log_ingestion_worker(mock_repo) + + set_state_calls = mock_repo.set_state.call_args_list + position_calls = [c for c in set_state_calls if c[0][0] == _INGEST_STATE_KEY] + assert position_calls, "set_state never called with ingest position key" + saved_pos = position_calls[-1][0][1]["position"] + assert saved_pos == len(line.encode("utf-8")) + + @pytest.mark.asyncio + async def test_truncation_resets_and_saves_zero_position(self, tmp_path): + """On file truncation, set_state is called with position=0.""" + from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY + mock_repo = MagicMock() + mock_repo.add_log = AsyncMock() + mock_repo.add_bounty = AsyncMock() + mock_repo.set_state = AsyncMock() + + log_file = str(tmp_path / "test.log") + json_file = tmp_path / "test.json" + + line = json.dumps({"decky": "d1", "service": "ssh", "event_type": "auth", + "attacker_ip": "1.1.1.1", "fields": {}, "raw_line": "x", "msg": ""}) + "\n" + # Pretend the saved position is past the end (simulates prior larger file) + big_position = len(line.encode("utf-8")) * 10 + mock_repo.get_state = AsyncMock(return_value={"position": big_position}) + + json_file.write_text(line) # file is smaller than saved position → truncation + + _call_count: int = 0 + + async def fake_sleep(secs): + nonlocal _call_count + _call_count += 1 + if _call_count >= 2: + raise asyncio.CancelledError() + + with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}): + with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(asyncio.CancelledError): + await log_ingestion_worker(mock_repo) + + reset_calls = [ + c for c in mock_repo.set_state.call_args_list + if c[0][0] == _INGEST_STATE_KEY and c[0][1] == {"position": 0} + ] + assert reset_calls, "set_state not called with position=0 after truncation" diff --git a/tests/test_service_isolation.py b/tests/test_service_isolation.py index 45734ec..2eeee58 100644 --- a/tests/test_service_isolation.py +++ b/tests/test_service_isolation.py @@ -93,6 +93,8 @@ class TestIngesterIsolation: from decnet.web.ingester import log_ingestion_worker mock_repo = MagicMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() iterations = 0 async def _controlled_sleep(seconds): @@ -133,6 +135,8 @@ class TestIngesterIsolation: mock_repo = MagicMock() mock_repo.add_log = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() iterations = 0 async def _controlled_sleep(seconds): @@ -168,6 +172,8 @@ class TestIngesterIsolation: mock_repo = MagicMock() mock_repo.add_log = AsyncMock(side_effect=Exception("no such table: logs")) + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": str(tmp_path / "test.log")}): # Worker should exit the loop on fatal DB error @@ -189,6 +195,7 @@ class TestAttackerWorkerIsolation: mock_repo = MagicMock() mock_repo.get_all_logs_raw = AsyncMock(side_effect=Exception("DB is locked")) mock_repo.get_max_log_id = AsyncMock(return_value=0) + mock_repo.get_state = AsyncMock(return_value=None) mock_repo.set_state = AsyncMock() iterations = 0 @@ -412,6 +419,8 @@ class TestCascadeIsolation: mock_repo = MagicMock() mock_repo.add_log = AsyncMock() + mock_repo.get_state = AsyncMock(return_value=None) + mock_repo.set_state = AsyncMock() iterations = 0 async def _controlled_sleep(seconds): @@ -437,6 +446,7 @@ class TestCascadeIsolation: mock_repo = MagicMock() mock_repo.get_all_logs_raw = AsyncMock(return_value=[]) mock_repo.get_max_log_id = AsyncMock(return_value=0) + mock_repo.get_state = AsyncMock(return_value=None) mock_repo.set_state = AsyncMock() mock_repo.get_logs_after_id = AsyncMock(return_value=[])