fix: persist ingester position and profiler cursor across restarts

- Ingester now loads byte-offset from DB on startup (key: ingest_worker_position)
  and saves it after each batch — prevents full re-read on every API restart
- On file truncation/rotation the saved offset is reset to 0
- Profiler worker now loads last_log_id from DB on startup — every restart
  becomes an incremental update instead of a full cold rebuild
- Updated all affected tests to mock get_state/set_state; added new tests
  covering position restore, set_state call, truncation reset, and cursor
  restore/cold-start paths
This commit is contained in:
2026-04-15 13:58:12 -04:00
parent 314e6c6388
commit 63efe6c7ba
5 changed files with 203 additions and 2 deletions

View File

@@ -50,6 +50,11 @@ async def attacker_profile_worker(repo: BaseRepository, *, interval: int = 30) -
"""Periodically updates the Attacker table incrementally. Designed to run as an asyncio Task.""" """Periodically updates the Attacker table incrementally. Designed to run as an asyncio Task."""
logger.info("attacker profile worker started interval=%ds", interval) logger.info("attacker profile worker started interval=%ds", interval)
state = _WorkerState() state = _WorkerState()
_saved_cursor = await repo.get_state(_STATE_KEY)
if _saved_cursor:
state.last_log_id = _saved_cursor.get("last_log_id", 0)
state.initialized = True
logger.info("attacker worker: resumed from cursor last_log_id=%d", state.last_log_id)
while True: while True:
await asyncio.sleep(interval) await asyncio.sleep(interval)
try: try:

View File

@@ -9,6 +9,9 @@ from decnet.web.db.repository import BaseRepository
logger = get_logger("api") logger = get_logger("api")
_INGEST_STATE_KEY = "ingest_worker_position"
async def log_ingestion_worker(repo: BaseRepository) -> None: async def log_ingestion_worker(repo: BaseRepository) -> None:
""" """
Background task that tails the DECNET_INGEST_LOG_FILE.json and Background task that tails the DECNET_INGEST_LOG_FILE.json and
@@ -20,9 +23,11 @@ async def log_ingestion_worker(repo: BaseRepository) -> None:
return return
_json_log_path: Path = Path(_base_log_file).with_suffix(".json") _json_log_path: Path = Path(_base_log_file).with_suffix(".json")
_position: int = 0
logger.info("ingest worker started path=%s", _json_log_path) _saved = await repo.get_state(_INGEST_STATE_KEY)
_position: int = _saved.get("position", 0) if _saved else 0
logger.info("ingest worker started path=%s position=%d", _json_log_path, _position)
while True: while True:
try: try:
@@ -34,6 +39,7 @@ async def log_ingestion_worker(repo: BaseRepository) -> None:
if _stat.st_size < _position: if _stat.st_size < _position:
# File rotated or truncated # File rotated or truncated
_position = 0 _position = 0
await repo.set_state(_INGEST_STATE_KEY, {"position": 0})
if _stat.st_size == _position: if _stat.st_size == _position:
# No new data # No new data
@@ -63,6 +69,8 @@ async def log_ingestion_worker(repo: BaseRepository) -> None:
# Update position after successful line read # Update position after successful line read
_position = _f.tell() _position = _f.tell()
await repo.set_state(_INGEST_STATE_KEY, {"position": _position})
except Exception as _e: except Exception as _e:
_err_str = str(_e).lower() _err_str = str(_e).lower()
if "no such table" in _err_str or "no active connection" in _err_str or "connection closed" in _err_str: if "no such table" in _err_str or "no active connection" in _err_str or "connection closed" in _err_str:

View File

@@ -614,6 +614,60 @@ class TestAttackerProfileWorker:
assert len(update_calls) >= 1 assert len(update_calls) >= 1
@pytest.mark.asyncio
async def test_cursor_restored_from_db_on_startup(self):
"""Worker loads saved last_log_id from DB and passes it to _incremental_update."""
repo = _make_repo(saved_state={"last_log_id": 99})
_call_count = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
captured_states = []
async def mock_update(_repo, state):
captured_states.append((state.last_log_id, state.initialized))
with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep):
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
with pytest.raises(asyncio.CancelledError):
await attacker_profile_worker(repo)
assert captured_states, "_incremental_update never called"
restored_id, initialized = captured_states[0]
assert restored_id == 99
assert initialized is True
@pytest.mark.asyncio
async def test_no_saved_cursor_starts_from_zero(self):
"""When get_state returns None, worker starts fresh from log ID 0."""
repo = _make_repo(saved_state=None)
_call_count = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
captured_states = []
async def mock_update(_repo, state):
captured_states.append((state.last_log_id, state.initialized))
with patch("decnet.profiler.worker.asyncio.sleep", side_effect=fake_sleep):
with patch("decnet.profiler.worker._incremental_update", side_effect=mock_update):
with pytest.raises(asyncio.CancelledError):
await attacker_profile_worker(repo)
assert captured_states, "_incremental_update never called"
restored_id, initialized = captured_states[0]
assert restored_id == 0
assert initialized is False
# ─── JA3 bounty extraction from ingester ───────────────────────────────────── # ─── JA3 bounty extraction from ingester ─────────────────────────────────────

View File

@@ -85,6 +85,8 @@ class TestLogIngestionWorker:
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "nonexistent.log") log_file = str(tmp_path / "nonexistent.log")
_call_count: int = 0 _call_count: int = 0
@@ -106,6 +108,8 @@ class TestLogIngestionWorker:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log") log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json" json_file = tmp_path / "test.json"
@@ -135,6 +139,8 @@ class TestLogIngestionWorker:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log") log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json" json_file = tmp_path / "test.json"
@@ -161,6 +167,8 @@ class TestLogIngestionWorker:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log") log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json" json_file = tmp_path / "test.json"
@@ -195,6 +203,8 @@ class TestLogIngestionWorker:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log") log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json" json_file = tmp_path / "test.json"
@@ -215,3 +225,117 @@ class TestLogIngestionWorker:
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
mock_repo.add_log.assert_not_awaited() mock_repo.add_log.assert_not_awaited()
@pytest.mark.asyncio
async def test_position_restored_skips_already_seen_lines(self, tmp_path):
"""Worker resumes from saved position and skips already-ingested content."""
from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock()
mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock()
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json"
line_old = json.dumps({"decky": "d1", "service": "ssh", "event_type": "auth",
"attacker_ip": "1.1.1.1", "fields": {}, "raw_line": "x", "msg": ""}) + "\n"
line_new = json.dumps({"decky": "d2", "service": "ftp", "event_type": "auth",
"attacker_ip": "2.2.2.2", "fields": {}, "raw_line": "y", "msg": ""}) + "\n"
json_file.write_text(line_old + line_new)
# Saved position points to end of first line — only line_new should be ingested
saved_position = len(line_old.encode("utf-8"))
mock_repo.get_state = AsyncMock(return_value={"position": saved_position})
_call_count: int = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}):
with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep):
with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo)
assert mock_repo.add_log.await_count == 1
ingested = mock_repo.add_log.call_args[0][0]
assert ingested["attacker_ip"] == "2.2.2.2"
@pytest.mark.asyncio
async def test_set_state_called_with_position_after_batch(self, tmp_path):
"""set_state is called with the updated byte position after processing lines."""
from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY
mock_repo = MagicMock()
mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json"
line = json.dumps({"decky": "d1", "service": "ssh", "event_type": "auth",
"attacker_ip": "1.1.1.1", "fields": {}, "raw_line": "x", "msg": ""}) + "\n"
json_file.write_text(line)
_call_count: int = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}):
with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep):
with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo)
set_state_calls = mock_repo.set_state.call_args_list
position_calls = [c for c in set_state_calls if c[0][0] == _INGEST_STATE_KEY]
assert position_calls, "set_state never called with ingest position key"
saved_pos = position_calls[-1][0][1]["position"]
assert saved_pos == len(line.encode("utf-8"))
@pytest.mark.asyncio
async def test_truncation_resets_and_saves_zero_position(self, tmp_path):
"""On file truncation, set_state is called with position=0."""
from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY
mock_repo = MagicMock()
mock_repo.add_log = AsyncMock()
mock_repo.add_bounty = AsyncMock()
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json"
line = json.dumps({"decky": "d1", "service": "ssh", "event_type": "auth",
"attacker_ip": "1.1.1.1", "fields": {}, "raw_line": "x", "msg": ""}) + "\n"
# Pretend the saved position is past the end (simulates prior larger file)
big_position = len(line.encode("utf-8")) * 10
mock_repo.get_state = AsyncMock(return_value={"position": big_position})
json_file.write_text(line) # file is smaller than saved position → truncation
_call_count: int = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}):
with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep):
with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo)
reset_calls = [
c for c in mock_repo.set_state.call_args_list
if c[0][0] == _INGEST_STATE_KEY and c[0][1] == {"position": 0}
]
assert reset_calls, "set_state not called with position=0 after truncation"

View File

@@ -93,6 +93,8 @@ class TestIngesterIsolation:
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
iterations = 0 iterations = 0
async def _controlled_sleep(seconds): async def _controlled_sleep(seconds):
@@ -133,6 +135,8 @@ class TestIngesterIsolation:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
iterations = 0 iterations = 0
async def _controlled_sleep(seconds): async def _controlled_sleep(seconds):
@@ -168,6 +172,8 @@ class TestIngesterIsolation:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock(side_effect=Exception("no such table: logs")) mock_repo.add_log = AsyncMock(side_effect=Exception("no such table: logs"))
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": str(tmp_path / "test.log")}): with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": str(tmp_path / "test.log")}):
# Worker should exit the loop on fatal DB error # Worker should exit the loop on fatal DB error
@@ -189,6 +195,7 @@ class TestAttackerWorkerIsolation:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.get_all_logs_raw = AsyncMock(side_effect=Exception("DB is locked")) mock_repo.get_all_logs_raw = AsyncMock(side_effect=Exception("DB is locked"))
mock_repo.get_max_log_id = AsyncMock(return_value=0) mock_repo.get_max_log_id = AsyncMock(return_value=0)
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
iterations = 0 iterations = 0
@@ -412,6 +419,8 @@ class TestCascadeIsolation:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
iterations = 0 iterations = 0
async def _controlled_sleep(seconds): async def _controlled_sleep(seconds):
@@ -437,6 +446,7 @@ class TestCascadeIsolation:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.get_all_logs_raw = AsyncMock(return_value=[]) mock_repo.get_all_logs_raw = AsyncMock(return_value=[])
mock_repo.get_max_log_id = AsyncMock(return_value=0) mock_repo.get_max_log_id = AsyncMock(return_value=0)
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
mock_repo.get_logs_after_id = AsyncMock(return_value=[]) mock_repo.get_logs_after_id = AsyncMock(return_value=[])