diff --git a/surfsense_backend/app/tasks/chat/streaming/runtime.py b/surfsense_backend/app/tasks/chat/streaming/runtime.py new file mode 100644 index 000000000..b45da2789 --- /dev/null +++ b/surfsense_backend/app/tasks/chat/streaming/runtime.py @@ -0,0 +1,92 @@ +"""Runtime setup helpers for orchestrated chat streaming.""" + +from __future__ import annotations + +import contextlib +import logging +from collections.abc import Callable +from typing import Any + +_PREFLIGHT_TIMEOUT_SEC: float = 2.5 +_PREFLIGHT_MAX_TOKENS: int = 1 + + +async def preflight_llm( + llm: Any, + *, + is_provider_rate_limited: Callable[[BaseException], bool], +) -> None: + """Issue a minimal completion probe to catch immediate provider 429s.""" + from litellm import acompletion + + model = getattr(llm, "model", None) + if not model or model == "auto": + return + + try: + await acompletion( + model=model, + messages=[{"role": "user", "content": "ping"}], + api_key=getattr(llm, "api_key", None), + api_base=getattr(llm, "api_base", None), + max_tokens=_PREFLIGHT_MAX_TOKENS, + timeout=_PREFLIGHT_TIMEOUT_SEC, + stream=False, + metadata={"tags": ["surfsense:internal", "auto-pin-preflight"]}, + ) + except Exception as exc: + if is_provider_rate_limited(exc): + raise + logging.getLogger(__name__).debug( + "auto_pin_preflight non_rate_limit_error model=%s err=%s", + model, + exc, + ) + + +async def build_main_agent_for_thread( + agent_factory: Any, + *, + llm: Any, + search_space_id: int, + db_session: Any, + connector_service: Any, + checkpointer: Any, + user_id: str | None, + thread_id: int | None, + agent_config: Any, + firecrawl_api_key: str | None, + thread_visibility: Any, + filesystem_selection: Any, + disabled_tools: list[str] | None = None, + mentioned_document_ids: list[int] | None = None, +) -> Any: + """Run one canonical agent-build call for a single thread.""" + return await agent_factory( + llm=llm, + search_space_id=search_space_id, + db_session=db_session, + connector_service=connector_service, + checkpointer=checkpointer, + user_id=user_id, + thread_id=thread_id, + agent_config=agent_config, + firecrawl_api_key=firecrawl_api_key, + thread_visibility=thread_visibility, + filesystem_selection=filesystem_selection, + disabled_tools=disabled_tools, + mentioned_document_ids=mentioned_document_ids, + ) + + +async def settle_speculative_agent_build(task: Any) -> None: + """Wait for a discarded speculative build and swallow its outcome.""" + with contextlib.suppress(BaseException): + await task + + +__all__ = [ + "build_main_agent_for_thread", + "preflight_llm", + "settle_speculative_agent_build", +] diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_runtime.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_runtime.py new file mode 100644 index 000000000..edb05edfa --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_orchestrator_runtime.py @@ -0,0 +1,120 @@ +"""Behavior tests for streaming runtime helpers.""" + +from __future__ import annotations + +import sys +import types +from typing import Any + +import pytest + +from app.tasks.chat.streaming import runtime + +pytestmark = pytest.mark.unit + + +async def test_preflight_llm_calls_litellm_when_model_present( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: dict[str, Any] = {} + + async def _fake_acompletion(**kwargs: Any): + calls.update(kwargs) + return {"ok": True} + + monkeypatch.setitem( + sys.modules, + "litellm", + types.SimpleNamespace(acompletion=_fake_acompletion), + ) + + llm = types.SimpleNamespace(model="openai/test", api_key="k", api_base="b") + await runtime.preflight_llm(llm, is_provider_rate_limited=lambda _: False) + + assert calls["model"] == "openai/test" + assert calls["max_tokens"] == 1 + assert calls["timeout"] == 2.5 + assert calls["stream"] is False + + +async def test_preflight_llm_rethrows_rate_limited(monkeypatch: pytest.MonkeyPatch) -> None: + class _RateLimitedError(Exception): + pass + + async def _fake_acompletion(**kwargs: Any): + del kwargs + raise _RateLimitedError("rl") + + monkeypatch.setitem( + sys.modules, + "litellm", + types.SimpleNamespace(acompletion=_fake_acompletion), + ) + + with pytest.raises(_RateLimitedError): + await runtime.preflight_llm( + types.SimpleNamespace(model="openai/test"), + is_provider_rate_limited=lambda exc: isinstance(exc, _RateLimitedError), + ) + + +async def test_preflight_llm_skips_probe_for_auto_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + called = {"count": 0} + + async def _fake_acompletion(**kwargs: Any): + del kwargs + called["count"] += 1 + return {"ok": True} + + monkeypatch.setitem( + sys.modules, + "litellm", + types.SimpleNamespace(acompletion=_fake_acompletion), + ) + + await runtime.preflight_llm( + types.SimpleNamespace(model="auto"), + is_provider_rate_limited=lambda _: False, + ) + assert called["count"] == 0 + + +async def test_build_main_agent_for_thread_forwards_arguments() -> None: + seen: dict[str, Any] = {} + + async def _factory(**kwargs: Any): + seen.update(kwargs) + return "agent" + + out = await runtime.build_main_agent_for_thread( + _factory, + llm="llm", + search_space_id=1, + db_session="db", + connector_service="connector", + checkpointer="cp", + user_id="u", + thread_id=10, + agent_config="cfg", + firecrawl_api_key="key", + thread_visibility="vis", + filesystem_selection="fs", + disabled_tools=["a"], + mentioned_document_ids=[5], + ) + assert out == "agent" + assert seen["thread_id"] == 10 + assert seen["mentioned_document_ids"] == [5] + + +async def test_settle_speculative_agent_build_swallows_exceptions() -> None: + async def _boom() -> None: + raise RuntimeError("ignore") + + import asyncio + + task = asyncio.create_task(_boom()) + await runtime.settle_speculative_agent_build(task) + assert task.done()