refactor(emailgen): pluggable LLM backend (base/factory/impl)
Lift the Ollama subprocess shell-out out of EmailDriver and into a
proper provider subpackage shape:
decnet/orchestrator/emailgen/llm/
base.py — LLMBackend Protocol + LLMResult + LLMTimeout
factory.py — get_llm() reads DECNET_EMAILGEN_LLM
impl/ollama.py — current 'ollama run' subprocess path
impl/fake.py — canned-output backend used by tests
Driver now takes an LLMBackend on construction (or inherits the
factory default). Tests inject FakeBackend instead of monkeypatching
the subprocess layer, which is cleaner and ~10x faster. Swapping
Ollama for the Anthropic API / vLLM / llama.cpp is now a third branch
in factory.py; no driver rewrite needed.
Mirrors the convention used by decnet.web.db.factory + decnet.bus.factory
per the provider-subpackages-from-day-one rule in memory.
This commit is contained in:
@@ -1,14 +1,40 @@
|
||||
"""EmailDriver: stub the Ollama subprocess + docker exec; verify EML
|
||||
parse-and-repair and payload metadata."""
|
||||
"""EmailDriver: inject a fake LLM backend + stub docker exec; verify
|
||||
EML parse-and-repair and payload metadata."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.orchestrator.drivers import email as email_driver
|
||||
from decnet.orchestrator.emailgen.llm.base import LLMResult, LLMTimeout
|
||||
from decnet.orchestrator.emailgen.llm.impl.fake import FakeBackend
|
||||
from decnet.orchestrator.emailgen.personas import EmailPersona
|
||||
from decnet.orchestrator.emailgen.scheduler import EmailAction
|
||||
|
||||
|
||||
class _RaisingBackend:
|
||||
"""Async stub that raises LLMTimeout on every call."""
|
||||
model = "stuck-model"
|
||||
timeout = 0.1
|
||||
|
||||
async def generate(self, prompt: str) -> LLMResult: # noqa: ARG002
|
||||
raise LLMTimeout("stuck")
|
||||
|
||||
|
||||
class _FailingBackend:
|
||||
"""Async stub that returns success=False."""
|
||||
model = "broken-model"
|
||||
timeout = 1.0
|
||||
|
||||
async def generate(self, prompt: str) -> LLMResult: # noqa: ARG002
|
||||
return LLMResult(
|
||||
success=False,
|
||||
text="",
|
||||
model=self.model,
|
||||
latency_ms=5,
|
||||
extra={"rc": 1, "stderr": "model not found"},
|
||||
)
|
||||
|
||||
|
||||
def _persona(name="John", email="john@corp.com"):
|
||||
return EmailPersona(
|
||||
name=name,
|
||||
@@ -110,19 +136,20 @@ def test_container_for_pop3_only():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_driver_run_success_path(monkeypatch):
|
||||
"""Stub both subprocess calls (ollama + docker exec) as success."""
|
||||
calls: list[list[str]] = []
|
||||
"""Inject a FakeBackend + stub docker exec; success end-to-end."""
|
||||
docker_calls: list[list[str]] = []
|
||||
|
||||
async def fake_run_capture(argv, *, stdin_data=None, timeout=8.0):
|
||||
calls.append(list(argv))
|
||||
if argv[0] == "ollama":
|
||||
return 0, "Subject: Q3 budget\n\nHi Sarah,\nNumbers attached.\n", ""
|
||||
# docker exec
|
||||
docker_calls.append(list(argv))
|
||||
return 0, "", ""
|
||||
|
||||
monkeypatch.setattr(email_driver, "_run_capture", fake_run_capture)
|
||||
|
||||
drv = email_driver.EmailDriver(model="llama3.1", ollama_timeout=1.0)
|
||||
llm = FakeBackend(
|
||||
model="llama3.1",
|
||||
output="Subject: Q3 budget\n\nHi Sarah,\nNumbers attached.\n",
|
||||
)
|
||||
drv = email_driver.EmailDriver(llm=llm)
|
||||
result = await drv.run(_action())
|
||||
assert result.success is True
|
||||
assert result.payload["model"] == "llama3.1"
|
||||
@@ -132,46 +159,56 @@ async def test_driver_run_success_path(monkeypatch):
|
||||
assert result.payload["message_id"].startswith("<")
|
||||
assert result.payload["eml_path"].endswith(".eml")
|
||||
assert result.payload["container"] == "mailhost-imap"
|
||||
# Two subprocess calls: ollama, then docker exec.
|
||||
assert calls[0][0] == "ollama"
|
||||
assert calls[1][0] == "docker"
|
||||
# docker exec shell command must include `touch -d` so the file's
|
||||
# mtime matches the EML's Date: header — otherwise the spool's
|
||||
# `ls -lt` clusters every email inside the worker tick window.
|
||||
docker_sh = calls[1][-1]
|
||||
# Only docker exec is shelled out now — the LLM call is in-process
|
||||
# via the FakeBackend.
|
||||
assert len(docker_calls) == 1
|
||||
assert docker_calls[0][0] == "docker"
|
||||
docker_sh = docker_calls[0][-1]
|
||||
assert "touch -d" in docker_sh
|
||||
assert "tee" in docker_sh
|
||||
# And tee must come before touch so we don't touch a file that
|
||||
# doesn't exist yet.
|
||||
assert docker_sh.index("tee") < docker_sh.index("touch -d")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_driver_run_ollama_failure_short_circuits(monkeypatch):
|
||||
async def test_driver_run_llm_failure_short_circuits(monkeypatch):
|
||||
"""When the backend reports success=False, no docker exec should fire."""
|
||||
docker_called = False
|
||||
|
||||
async def fake_run_capture(argv, *, stdin_data=None, timeout=8.0):
|
||||
if argv[0] == "ollama":
|
||||
return 1, "", "ollama: model not found"
|
||||
nonlocal docker_called
|
||||
docker_called = True
|
||||
return 0, "", ""
|
||||
|
||||
monkeypatch.setattr(email_driver, "_run_capture", fake_run_capture)
|
||||
|
||||
drv = email_driver.EmailDriver()
|
||||
drv = email_driver.EmailDriver(llm=_FailingBackend())
|
||||
result = await drv.run(_action())
|
||||
assert result.success is False
|
||||
assert result.payload["stage"] == "ollama"
|
||||
assert result.payload["stage"] == "llm"
|
||||
assert "stderr" in result.payload
|
||||
assert "model not found" in result.payload["stderr"]
|
||||
assert docker_called is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_driver_run_llm_timeout_reported_distinctly(monkeypatch):
|
||||
drv = email_driver.EmailDriver(llm=_RaisingBackend())
|
||||
result = await drv.run(_action())
|
||||
assert result.success is False
|
||||
assert result.payload["stage"] == "llm"
|
||||
assert result.payload["error"] == "timeout"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_driver_run_delivery_failure(monkeypatch):
|
||||
async def fake_run_capture(argv, *, stdin_data=None, timeout=8.0):
|
||||
if argv[0] == "ollama":
|
||||
return 0, "Subject: hi\n\nbody\n", ""
|
||||
return 1, "", "no such container"
|
||||
|
||||
monkeypatch.setattr(email_driver, "_run_capture", fake_run_capture)
|
||||
|
||||
drv = email_driver.EmailDriver()
|
||||
drv = email_driver.EmailDriver(
|
||||
llm=FakeBackend(output="Subject: hi\n\nbody\n"),
|
||||
)
|
||||
result = await drv.run(_action())
|
||||
assert result.success is False
|
||||
assert result.payload["stage"] == "delivery"
|
||||
|
||||
137
tests/orchestrator/emailgen/test_llm.py
Normal file
137
tests/orchestrator/emailgen/test_llm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""LLM backend factory + Ollama implementation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from decnet.orchestrator.emailgen.llm import LLMTimeout, get_llm
|
||||
from decnet.orchestrator.emailgen.llm.impl.fake import FakeBackend
|
||||
from decnet.orchestrator.emailgen.llm.impl.ollama import OllamaBackend
|
||||
|
||||
|
||||
# ── factory dispatch ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_factory_default_is_ollama(monkeypatch):
|
||||
monkeypatch.delenv("DECNET_EMAILGEN_LLM", raising=False)
|
||||
backend = get_llm()
|
||||
assert isinstance(backend, OllamaBackend)
|
||||
|
||||
|
||||
def test_factory_selects_fake(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_EMAILGEN_LLM", "fake")
|
||||
backend = get_llm()
|
||||
assert isinstance(backend, FakeBackend)
|
||||
|
||||
|
||||
def test_factory_unknown_raises(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_EMAILGEN_LLM", "vllm-someday")
|
||||
with pytest.raises(ValueError, match="Unsupported"):
|
||||
get_llm()
|
||||
|
||||
|
||||
def test_factory_passes_model_through(monkeypatch):
|
||||
monkeypatch.setenv("DECNET_EMAILGEN_LLM", "ollama")
|
||||
backend = get_llm(model="qwen2:7b")
|
||||
assert backend.model == "qwen2:7b"
|
||||
|
||||
|
||||
# ── FakeBackend ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_backend_returns_canned_output():
|
||||
fb = FakeBackend(output="Subject: hi\n\nbody")
|
||||
result = await fb.generate("any prompt")
|
||||
assert result.success is True
|
||||
assert result.text.startswith("Subject:")
|
||||
assert result.model == "fake-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_backend_can_simulate_failure():
|
||||
fb = FakeBackend(success=False)
|
||||
result = await fb.generate("prompt")
|
||||
assert result.success is False
|
||||
assert result.text == ""
|
||||
|
||||
|
||||
# ── OllamaBackend (subprocess stubbed) ───────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_backend_success(monkeypatch):
|
||||
"""Stub asyncio.create_subprocess_exec to return canned stdout."""
|
||||
|
||||
class _StubProc:
|
||||
returncode = 0
|
||||
|
||||
async def communicate(self, _stdin):
|
||||
return b"Subject: hi\n\nbody\n", b""
|
||||
|
||||
async def fake_create(*args, **kwargs): # noqa: ARG001
|
||||
return _StubProc()
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create)
|
||||
|
||||
backend = OllamaBackend(model="m1", timeout=1.0)
|
||||
result = await backend.generate("hello")
|
||||
assert result.success is True
|
||||
assert "Subject:" in result.text
|
||||
assert result.model == "m1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_backend_non_zero_rc_marks_failure(monkeypatch):
|
||||
class _StubProc:
|
||||
returncode = 1
|
||||
|
||||
async def communicate(self, _stdin):
|
||||
return b"", b"model not found"
|
||||
|
||||
async def fake_create(*args, **kwargs): # noqa: ARG001
|
||||
return _StubProc()
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create)
|
||||
|
||||
backend = OllamaBackend(model="m1", timeout=1.0)
|
||||
result = await backend.generate("hello")
|
||||
assert result.success is False
|
||||
assert result.extra["rc"] == 1
|
||||
assert "model not found" in result.extra["stderr"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_backend_timeout_raises(monkeypatch):
|
||||
class _StubProc:
|
||||
returncode = None
|
||||
|
||||
async def communicate(self, _stdin):
|
||||
await asyncio.sleep(10) # well past the timeout
|
||||
return b"", b""
|
||||
|
||||
def kill(self):
|
||||
pass
|
||||
|
||||
async def fake_create(*args, **kwargs): # noqa: ARG001
|
||||
return _StubProc()
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create)
|
||||
|
||||
backend = OllamaBackend(model="m1", timeout=0.05)
|
||||
with pytest.raises(LLMTimeout):
|
||||
await backend.generate("hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_backend_missing_binary_returns_failure(monkeypatch):
|
||||
async def fake_create(*args, **kwargs): # noqa: ARG001
|
||||
raise FileNotFoundError("ollama: not found")
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_create)
|
||||
|
||||
backend = OllamaBackend(model="m1", timeout=1.0)
|
||||
result = await backend.generate("hello")
|
||||
assert result.success is False
|
||||
assert result.extra["rc"] == 127
|
||||
@@ -10,6 +10,7 @@ import pytest_asyncio
|
||||
from decnet.bus.fake import FakeBus
|
||||
from decnet.orchestrator.drivers import email as email_driver
|
||||
from decnet.orchestrator.emailgen import worker as eg_worker
|
||||
from decnet.orchestrator.emailgen.llm.impl.fake import FakeBackend
|
||||
from decnet.orchestrator.emailgen.scheduler import EmailAction # noqa: F401
|
||||
from decnet.web.db.models import Topology, TopologyDecky
|
||||
from decnet.web.db.sqlite.repository import SQLiteRepository
|
||||
@@ -82,9 +83,9 @@ async def _seed_mail_topology(repo: SQLiteRepository) -> str:
|
||||
async def test_one_tick_records_and_publishes(repo, fake_bus, monkeypatch):
|
||||
decky_uuid = await _seed_mail_topology(repo)
|
||||
|
||||
# Stub only the docker exec subprocess; the LLM call goes through
|
||||
# an injected FakeBackend with deterministic output.
|
||||
async def fake_run_capture(argv, *, stdin_data=None, timeout=8.0):
|
||||
if argv[0] == "ollama":
|
||||
return 0, "Subject: Hi\n\nBody here.\n", ""
|
||||
return 0, "", ""
|
||||
|
||||
monkeypatch.setattr(email_driver, "_run_capture", fake_run_capture)
|
||||
@@ -101,7 +102,9 @@ async def test_one_tick_records_and_publishes(repo, fake_bus, monkeypatch):
|
||||
collector = asyncio.create_task(collect())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
driver = email_driver.EmailDriver()
|
||||
driver = email_driver.EmailDriver(
|
||||
llm=FakeBackend(output="Subject: Hi\n\nBody here.\n"),
|
||||
)
|
||||
await eg_worker._one_tick(repo, driver, fake_bus)
|
||||
await asyncio.wait_for(collector, timeout=2.0)
|
||||
|
||||
@@ -126,11 +129,13 @@ async def test_one_tick_noop_when_no_mail_decky(repo, fake_bus, monkeypatch):
|
||||
async def fake_run_capture(argv, *, stdin_data=None, timeout=8.0):
|
||||
nonlocal called
|
||||
called = True
|
||||
return 0, "Subject: x\n\nb\n", ""
|
||||
return 0, "", ""
|
||||
|
||||
monkeypatch.setattr(email_driver, "_run_capture", fake_run_capture)
|
||||
|
||||
driver = email_driver.EmailDriver()
|
||||
driver = email_driver.EmailDriver(
|
||||
llm=FakeBackend(output="Subject: x\n\nb\n"),
|
||||
)
|
||||
await eg_worker._one_tick(repo, driver, fake_bus)
|
||||
assert called is False
|
||||
assert await repo.list_orchestrator_emails() == []
|
||||
|
||||
Reference in New Issue
Block a user