perf(ingester): batch log writes into bulk commits

The ingester now accumulates up to DECNET_BATCH_SIZE rows (default 100)
or DECNET_BATCH_MAX_WAIT_MS (default 250ms) before flushing through
repo.add_logs — one transaction, one COMMIT per batch instead of per
row. Under attacker traffic this collapses N commits into ⌈N/100⌉ and
takes most of the SQLite writer-lock contention off the hot path.

Flush semantics are cancel-safe: _position only advances after a batch
commits successfully, and the flush helper bails without touching the
DB if the enclosing task is being cancelled (lifespan teardown).
Un-flushed lines stay in the file and are re-read on next startup.

Tests updated to assert on add_logs (bulk) instead of the per-row
add_log that the ingester no longer uses, plus a new test that 250
lines flush in ≤5 calls.
This commit is contained in:
2026-04-17 16:37:34 -04:00
parent 11b9e85874
commit a10aee282f
3 changed files with 121 additions and 29 deletions

View File

@@ -1,9 +1,11 @@
import asyncio import asyncio
import os import os
import json import json
import time
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
from decnet.env import DECNET_BATCH_SIZE, DECNET_BATCH_MAX_WAIT_MS
from decnet.logging import get_logger from decnet.logging import get_logger
from decnet.telemetry import ( from decnet.telemetry import (
traced as _traced, traced as _traced,
@@ -52,22 +54,26 @@ async def log_ingestion_worker(repo: BaseRepository) -> None:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
# Accumulate parsed rows and the file offset they end at. We
# only advance _position after the batch is successfully
# committed — if we get cancelled mid-flush, the next run
# re-reads the un-committed lines rather than losing them.
_batch: list[tuple[dict[str, Any], int]] = []
_batch_started: float = time.monotonic()
_max_wait_s: float = DECNET_BATCH_MAX_WAIT_MS / 1000.0
with open(_json_log_path, "r", encoding="utf-8", errors="replace") as _f: with open(_json_log_path, "r", encoding="utf-8", errors="replace") as _f:
_f.seek(_position) _f.seek(_position)
while True: while True:
_line: str = _f.readline() _line: str = _f.readline()
if not _line: if not _line or not _line.endswith('\n'):
break # EOF reached # EOF or partial line — flush what we have and stop
if not _line.endswith('\n'):
# Partial line read, don't process yet, don't advance position
break break
try: try:
_log_data: dict[str, Any] = json.loads(_line.strip()) _log_data: dict[str, Any] = json.loads(_line.strip())
# Extract trace context injected by the collector. # Collector injects trace context so the ingester span
# This makes the ingester span a child of the collector span, # chains off the collector's — full event journey in Jaeger.
# showing the full event journey in Jaeger.
_parent_ctx = _extract_ctx(_log_data) _parent_ctx = _extract_ctx(_log_data)
_tracer = _get_tracer("ingester") _tracer = _get_tracer("ingester")
with _start_span(_tracer, "ingester.process_record", context=_parent_ctx) as _span: with _start_span(_tracer, "ingester.process_record", context=_parent_ctx) as _span:
@@ -75,25 +81,29 @@ async def log_ingestion_worker(repo: BaseRepository) -> None:
_span.set_attribute("service", _log_data.get("service", "")) _span.set_attribute("service", _log_data.get("service", ""))
_span.set_attribute("event_type", _log_data.get("event_type", "")) _span.set_attribute("event_type", _log_data.get("event_type", ""))
_span.set_attribute("attacker_ip", _log_data.get("attacker_ip", "")) _span.set_attribute("attacker_ip", _log_data.get("attacker_ip", ""))
# Persist trace context in the DB row so the SSE
# read path can link back to this ingestion trace.
_sctx = getattr(_span, "get_span_context", None) _sctx = getattr(_span, "get_span_context", None)
if _sctx: if _sctx:
_ctx = _sctx() _ctx = _sctx()
if _ctx and getattr(_ctx, "trace_id", 0): if _ctx and getattr(_ctx, "trace_id", 0):
_log_data["trace_id"] = format(_ctx.trace_id, "032x") _log_data["trace_id"] = format(_ctx.trace_id, "032x")
_log_data["span_id"] = format(_ctx.span_id, "016x") _log_data["span_id"] = format(_ctx.span_id, "016x")
logger.debug("ingest: record decky=%s event_type=%s", _log_data.get("decky"), _log_data.get("event_type")) _batch.append((_log_data, _f.tell()))
await repo.add_log(_log_data)
await _extract_bounty(repo, _log_data)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error("ingest: failed to decode JSON log line: %s", _line.strip()) logger.error("ingest: failed to decode JSON log line: %s", _line.strip())
# Skip past bad line so we don't loop forever on it.
_position = _f.tell()
continue continue
# Update position after successful line read if len(_batch) >= DECNET_BATCH_SIZE or (
_position = _f.tell() time.monotonic() - _batch_started >= _max_wait_s
):
_position = await _flush_batch(repo, _batch, _position)
_batch.clear()
_batch_started = time.monotonic()
await repo.set_state(_INGEST_STATE_KEY, {"position": _position}) # Flush any remainder collected before EOF / partial-line break.
if _batch:
_position = await _flush_batch(repo, _batch, _position)
except Exception as _e: except Exception as _e:
_err_str = str(_e).lower() _err_str = str(_e).lower()
@@ -107,6 +117,32 @@ async def log_ingestion_worker(repo: BaseRepository) -> None:
await asyncio.sleep(1) await asyncio.sleep(1)
async def _flush_batch(
repo: BaseRepository,
batch: list[tuple[dict[str, Any], int]],
current_position: int,
) -> int:
"""Commit a batch of log rows and return the new file position.
If the enclosing task is being cancelled, bail out without touching
the DB — the session factory may already be disposed during lifespan
teardown, and awaiting it would stall the worker. The un-flushed
lines stay uncommitted; the next startup re-reads them from
``current_position``.
"""
_task = asyncio.current_task()
if _task is not None and _task.cancelling():
raise asyncio.CancelledError()
_entries = [_entry for _entry, _ in batch]
_new_position = batch[-1][1]
await repo.add_logs(_entries)
for _entry in _entries:
await _extract_bounty(repo, _entry)
await repo.set_state(_INGEST_STATE_KEY, {"position": _new_position})
return _new_position
@_traced("ingester.extract_bounty") @_traced("ingester.extract_bounty")
async def _extract_bounty(repo: BaseRepository, log_data: dict[str, Any]) -> None: async def _extract_bounty(repo: BaseRepository, log_data: dict[str, Any]) -> None:
"""Detect and extract valuable artifacts (bounties) from log entries.""" """Detect and extract valuable artifacts (bounties) from log entries."""

View File

@@ -85,6 +85,7 @@ class TestLogIngestionWorker:
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "nonexistent.log") log_file = str(tmp_path / "nonexistent.log")
@@ -100,13 +101,14 @@ class TestLogIngestionWorker:
with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep): with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep):
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
mock_repo.add_log.assert_not_awaited() mock_repo.add_logs.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ingests_json_lines(self, tmp_path): async def test_ingests_json_lines(self, tmp_path):
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
@@ -131,13 +133,17 @@ class TestLogIngestionWorker:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
mock_repo.add_log.assert_awaited_once() mock_repo.add_logs.assert_awaited_once()
_batch = mock_repo.add_logs.call_args[0][0]
assert len(_batch) == 1
assert _batch[0]["attacker_ip"] == "1.2.3.4"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handles_json_decode_error(self, tmp_path): async def test_handles_json_decode_error(self, tmp_path):
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
@@ -159,13 +165,14 @@ class TestLogIngestionWorker:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
mock_repo.add_log.assert_not_awaited() mock_repo.add_logs.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_truncation_resets_position(self, tmp_path): async def test_file_truncation_resets_position(self, tmp_path):
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
@@ -195,13 +202,15 @@ class TestLogIngestionWorker:
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
# Should have ingested lines from original + after truncation # Should have ingested lines from original + after truncation
assert mock_repo.add_log.await_count >= 2 _total = sum(len(call.args[0]) for call in mock_repo.add_logs.call_args_list)
assert _total >= 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_partial_line_not_processed(self, tmp_path): async def test_partial_line_not_processed(self, tmp_path):
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
@@ -224,7 +233,7 @@ class TestLogIngestionWorker:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
mock_repo.add_log.assert_not_awaited() mock_repo.add_logs.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_position_restored_skips_already_seen_lines(self, tmp_path): async def test_position_restored_skips_already_seen_lines(self, tmp_path):
@@ -232,6 +241,7 @@ class TestLogIngestionWorker:
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
@@ -262,9 +272,9 @@ class TestLogIngestionWorker:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
assert mock_repo.add_log.await_count == 1 _rows = [r for call in mock_repo.add_logs.call_args_list for r in call.args[0]]
ingested = mock_repo.add_log.call_args[0][0] assert len(_rows) == 1
assert ingested["attacker_ip"] == "2.2.2.2" assert _rows[0]["attacker_ip"] == "2.2.2.2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_state_called_with_position_after_batch(self, tmp_path): async def test_set_state_called_with_position_after_batch(self, tmp_path):
@@ -272,6 +282,7 @@ class TestLogIngestionWorker:
from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
@@ -301,12 +312,54 @@ class TestLogIngestionWorker:
saved_pos = position_calls[-1][0][1]["position"] saved_pos = position_calls[-1][0][1]["position"]
assert saved_pos == len(line.encode("utf-8")) assert saved_pos == len(line.encode("utf-8"))
@pytest.mark.asyncio
async def test_batches_many_lines_into_few_commits(self, tmp_path):
"""250 lines with BATCH_SIZE=100 should flush in a handful of calls."""
from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock()
mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock()
log_file = str(tmp_path / "test.log")
json_file = tmp_path / "test.json"
_lines = "".join(
json.dumps({
"decky": f"d{i}", "service": "ssh", "event_type": "auth",
"attacker_ip": f"10.0.0.{i % 256}", "fields": {}, "raw_line": "x", "msg": ""
}) + "\n"
for i in range(250)
)
json_file.write_text(_lines)
_call_count: int = 0
async def fake_sleep(secs):
nonlocal _call_count
_call_count += 1
if _call_count >= 2:
raise asyncio.CancelledError()
with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": log_file}):
with patch("decnet.web.ingester.asyncio.sleep", side_effect=fake_sleep):
with pytest.raises(asyncio.CancelledError):
await log_ingestion_worker(mock_repo)
# 250 lines, batch=100 → 2 size-triggered flushes + 1 remainder flush.
# Asserting <= 5 leaves headroom for time-triggered flushes on slow CI.
assert mock_repo.add_logs.await_count <= 5
_rows = [r for call in mock_repo.add_logs.call_args_list for r in call.args[0]]
assert len(_rows) == 250
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_truncation_resets_and_saves_zero_position(self, tmp_path): async def test_truncation_resets_and_saves_zero_position(self, tmp_path):
"""On file truncation, set_state is called with position=0.""" """On file truncation, set_state is called with position=0."""
from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY from decnet.web.ingester import log_ingestion_worker, _INGEST_STATE_KEY
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.add_bounty = AsyncMock() mock_repo.add_bounty = AsyncMock()
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()

View File

@@ -93,6 +93,7 @@ class TestIngesterIsolation:
from decnet.web.ingester import log_ingestion_worker from decnet.web.ingester import log_ingestion_worker
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_logs = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
iterations = 0 iterations = 0
@@ -110,7 +111,7 @@ class TestIngesterIsolation:
await task await task
# Should have waited at least 2 iterations without crashing # Should have waited at least 2 iterations without crashing
assert iterations >= 2 assert iterations >= 2
mock_repo.add_log.assert_not_called() mock_repo.add_logs.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ingester_survives_no_log_file_env(self): async def test_ingester_survives_no_log_file_env(self):
@@ -135,6 +136,7 @@ class TestIngesterIsolation:
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock() mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock()
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
iterations = 0 iterations = 0
@@ -150,7 +152,7 @@ class TestIngesterIsolation:
task = asyncio.create_task(log_ingestion_worker(mock_repo)) task = asyncio.create_task(log_ingestion_worker(mock_repo))
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await task await task
mock_repo.add_log.assert_not_called() mock_repo.add_logs.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ingester_exits_on_db_fatal_error(self, tmp_path): async def test_ingester_exits_on_db_fatal_error(self, tmp_path):
@@ -171,15 +173,16 @@ class TestIngesterIsolation:
json_file.write_text(json.dumps(valid_record) + "\n") json_file.write_text(json.dumps(valid_record) + "\n")
mock_repo = MagicMock() mock_repo = MagicMock()
mock_repo.add_log = AsyncMock(side_effect=Exception("no such table: logs")) mock_repo.add_log = AsyncMock()
mock_repo.add_logs = AsyncMock(side_effect=Exception("no such table: logs"))
mock_repo.get_state = AsyncMock(return_value=None) mock_repo.get_state = AsyncMock(return_value=None)
mock_repo.set_state = AsyncMock() mock_repo.set_state = AsyncMock()
with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": str(tmp_path / "test.log")}): with patch.dict(os.environ, {"DECNET_INGEST_LOG_FILE": str(tmp_path / "test.log")}):
# Worker should exit the loop on fatal DB error # Worker should exit the loop on fatal DB error
await log_ingestion_worker(mock_repo) await log_ingestion_worker(mock_repo)
# Should have attempted to add the log before dying # Should have attempted to bulk-add before dying
mock_repo.add_log.assert_awaited_once() mock_repo.add_logs.assert_awaited_once()
# ─── Attacker worker isolation ─────────────────────────────────────────────── # ─── Attacker worker isolation ───────────────────────────────────────────────