feat: implement agent caches and fix invalid prompt cache configs
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

- Added a new function `_warm_agent_jit_caches` to pre-warm agent caches at startup, reducing cold invocation costs.
- Updated the `SurfSenseContextSchema` to include per-invocation fields for better state management during agent execution.
- Introduced caching mechanisms in various tools to ensure fresh database sessions are used, improving performance and reliability.
- Enhanced middleware to support new context features and improve error handling during connector and document type discovery.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-03 06:03:40 -07:00
parent 90a653c8c7
commit a34f1fb25c
60 changed files with 8477 additions and 5381 deletions

View file

@ -0,0 +1,268 @@
"""Regression tests for the compiled-agent cache.
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
build-failure non-caching) and the cache-key signature helpers that
``create_surfsense_deep_agent`` relies on. The integration with
``create_surfsense_deep_agent`` is covered separately by the streaming
contract tests; this module focuses on the primitives so a regression
in the cache implementation is caught before it reaches the agent
factory.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import pytest
from app.agents.new_chat.agent_cache import (
flags_signature,
reload_for_tests,
stable_hash,
system_prompt_hash,
tools_signature,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# stable_hash + signature helpers
# ---------------------------------------------------------------------------
def test_stable_hash_is_deterministic_across_calls() -> None:
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
assert a == b
def test_stable_hash_changes_when_any_part_changes() -> None:
base = stable_hash("v1", 42, "thread-9")
assert stable_hash("v1", 42, "thread-10") != base
assert stable_hash("v2", 42, "thread-9") != base
assert stable_hash("v1", 43, "thread-9") != base
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
"""Two tool lists with the same surface must hash identically.
The cache key MUST NOT change when the underlying ``BaseTool``
instances are different Python objects (a fresh request constructs
fresh tool instances every time). Hashing on ``(name, description)``
keeps the cache hot across requests with identical tool surfaces.
"""
@dataclass
class FakeTool:
name: str
description: str
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
sig_a = tools_signature(
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
)
sig_b = tools_signature(
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_a == sig_b, "tool order must not affect the signature"
# Adding a tool rotates the key.
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
sig_c = tools_signature(
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_c != sig_a
def test_tools_signature_rotates_when_connector_set_changes() -> None:
@dataclass
class FakeTool:
name: str
description: str
tools = [FakeTool("a", "x")]
base = tools_signature(
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
)
added = tools_signature(
tools,
available_connectors=["NOTION", "SLACK"],
available_document_types=["FILE"],
)
assert base != added, "adding a connector must rotate the cache key"
def test_flags_signature_changes_when_flag_flips() -> None:
@dataclass(frozen=True)
class Flags:
a: bool = True
b: bool = False
base = flags_signature(Flags())
flipped = flags_signature(Flags(b=True))
assert base != flipped
def test_system_prompt_hash_is_stable_and_distinct() -> None:
p1 = "You are a helpful assistant."
p2 = "You are a helpful assistant!" # one-character delta
assert system_prompt_hash(p1) == system_prompt_hash(p1)
assert system_prompt_hash(p1) != system_prompt_hash(p2)
# ---------------------------------------------------------------------------
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is b, "cache must return the SAME object across hits"
assert builds == 1, "builder must run exactly once"
@pytest.mark.asyncio
async def test_cache_different_keys_get_different_instances() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k1", builder=builder)
b = await cache.get_or_build("k2", builder=builder)
assert a is not b
@pytest.mark.asyncio
async def test_cache_stale_entries_get_rebuilt() -> None:
# ttl=0 means every read sees the entry as immediately stale.
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "stale entry must rebuild a fresh instance"
assert builds == 2
@pytest.mark.asyncio
async def test_cache_evicts_lru_when_full() -> None:
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("a", builder=builder)
_ = await cache.get_or_build("b", builder=builder)
# Re-touch "a" so "b" is now the LRU victim.
a_again = await cache.get_or_build("a", builder=builder)
assert a_again is a
# Inserting "c" should evict "b" (LRU), not "a".
_ = await cache.get_or_build("c", builder=builder)
assert cache.stats()["size"] == 2
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
a_hit = await cache.get_or_build("a", builder=builder)
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
@pytest.mark.asyncio
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
"""Two concurrent get_or_build calls on the same key must share one builder."""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
build_started = asyncio.Event()
builds = 0
async def slow_builder() -> object:
nonlocal builds
builds += 1
build_started.set()
# Yield control so the second waiter can race against us.
await asyncio.sleep(0.05)
return object()
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
# Wait until the first builder has started, then race a second waiter.
await build_started.wait()
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
a, b = await asyncio.gather(task_a, task_b)
assert a is b, "coalesced waiters must observe the same value"
assert builds == 1, "concurrent cold misses must collapse to ONE build"
@pytest.mark.asyncio
async def test_cache_does_not_store_failed_builds() -> None:
"""A builder that raises must NOT poison the cache.
The next caller for the same key must run the builder again (not
re-raise the cached exception).
"""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
attempts = 0
async def flaky_builder() -> object:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("transient")
return object()
with pytest.raises(RuntimeError, match="transient"):
await cache.get_or_build("k", builder=flaky_builder)
# Second call must retry — not re-raise the cached exception.
value = await cache.get_or_build("k", builder=flaky_builder)
assert value is not None
assert attempts == 2
@pytest.mark.asyncio
async def test_cache_invalidate_drops_entry() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k", builder=builder)
assert cache.invalidate("k") is True
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "post-invalidation lookup must rebuild"
@pytest.mark.asyncio
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
async def builder() -> object:
return object()
await cache.get_or_build("user:1:thread:1", builder=builder)
await cache.get_or_build("user:1:thread:2", builder=builder)
await cache.get_or_build("user:2:thread:1", builder=builder)
removed = cache.invalidate_prefix("user:1:")
assert removed == 2
assert cache.stats()["size"] == 1
# The user:2 entry must still be hot (no rebuild).
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
assert survivor_value is not None

View file

@ -34,6 +34,8 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
"SURFSENSE_ENABLE_AGENT_CACHE",
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]:
monkeypatch.delenv(name, raising=False)
@ -62,6 +64,11 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
# Phase 2: agent cache is now default-on (the prerequisite tool
# ``db_session`` refactor landed). The companion gp-subagent share
# flag stays default-off pending data on cold-miss frequency.
assert flags.enable_agent_cache is True
assert flags.enable_agent_cache_share_gp_subagent is False
assert flags.any_new_middleware_enabled() is True

View file

@ -0,0 +1,344 @@
"""Tests for ``FlattenSystemMessageMiddleware``.
The middleware exists to defend against Anthropic's "Found 5 cache_control
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
the system message and the OpenRouterAnthropic adapter redistributes
``cache_control`` across all of them. The flattening collapses every
all-text system content list to a single string before the LLM call.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
_flatten_text_blocks,
_flattened_request,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# _flatten_text_blocks — pure helper, the heart of the middleware.
# ---------------------------------------------------------------------------
class TestFlattenTextBlocks:
def test_joins_text_blocks_with_double_newline(self) -> None:
blocks = [
{"type": "text", "text": "<surfsense base>"},
{"type": "text", "text": "<filesystem section>"},
{"type": "text", "text": "<skills section>"},
]
assert (
_flatten_text_blocks(blocks)
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
)
def test_handles_single_text_block(self) -> None:
blocks = [{"type": "text", "text": "only one"}]
assert _flatten_text_blocks(blocks) == "only one"
def test_handles_empty_list(self) -> None:
assert _flatten_text_blocks([]) == ""
def test_passes_through_bare_string_blocks(self) -> None:
# LangChain content can mix bare strings and dict blocks.
blocks = ["raw string", {"type": "text", "text": "dict block"}]
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
def test_returns_none_for_image_block(self) -> None:
# System messages with images are rare — but we never want to
# silently lose the image payload by joining as text.
blocks = [
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_for_non_dict_non_str_block(self) -> None:
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_field_missing(self) -> None:
blocks = [{"type": "text"}] # no ``text`` key
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_is_not_string(self) -> None:
blocks = [{"type": "text", "text": ["nested", "list"]}]
assert _flatten_text_blocks(blocks) is None
def test_drops_cache_control_from_inner_blocks(self) -> None:
# The whole point: existing cache_control on inner blocks is
# discarded so LiteLLM's ``cache_control_injection_points`` can
# re-attach exactly one breakpoint after flattening.
blocks = [
{"type": "text", "text": "first"},
{
"type": "text",
"text": "second",
"cache_control": {"type": "ephemeral"},
},
]
flattened = _flatten_text_blocks(blocks)
assert flattened == "first\n\nsecond"
assert "cache_control" not in flattened # type: ignore[operator]
# ---------------------------------------------------------------------------
# _flattened_request — decides when to override and when to no-op.
# ---------------------------------------------------------------------------
def _make_request(system_message: SystemMessage | None) -> Any:
"""Build a minimal ModelRequest stub. We only need .system_message
and .override(system_message=...) the middleware never touches
other fields.
"""
request = MagicMock()
request.system_message = system_message
def override(**kwargs: Any) -> Any:
new_request = MagicMock()
new_request.system_message = kwargs.get(
"system_message", request.system_message
)
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
return new_request
request.override = override
return request
class TestFlattenedRequest:
def test_collapses_multi_block_system_to_string(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<base>"},
{"type": "text", "text": "<todo>"},
{"type": "text", "text": "<filesystem>"},
{"type": "text", "text": "<skills>"},
{"type": "text", "text": "<subagents>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message, SystemMessage)
assert flattened.system_message.content == (
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
)
def test_no_op_for_string_content(self) -> None:
sys = SystemMessage(content="already a string")
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_for_single_block_list(self) -> None:
# One block already produces one breakpoint — no need to flatten.
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_when_system_message_missing(self) -> None:
request = _make_request(None)
assert _flattened_request(request) is None
def test_no_op_when_list_contains_non_text_block(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
)
request = _make_request(sys)
assert _flattened_request(request) is None
def test_preserves_additional_kwargs_and_metadata(self) -> None:
# Defensive: nothing in the current chain sets these on a system
# message, but losing them silently when something does in the
# future would be a regression. ``name`` in particular is the only
# ``additional_kwargs`` field that ChatLiteLLM's
# ``_convert_message_to_dict`` propagates onto the wire.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
],
additional_kwargs={"name": "surfsense_system", "x": 1},
response_metadata={"tokens": 42},
)
sys.id = "sys-msg-1"
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert flattened.system_message.content == "a\n\nb"
assert flattened.system_message.additional_kwargs == {
"name": "surfsense_system",
"x": 1,
}
assert flattened.system_message.response_metadata == {"tokens": 42}
assert flattened.system_message.id == "sys-msg-1"
def test_idempotent_when_run_twice(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
request = _make_request(sys)
first = _flattened_request(request)
assert first is not None
# Second pass on the already-flattened request should be a no-op.
# We re-wrap in a request stub since the helper inspects
# ``request.system_message.content``.
second_request = _make_request(first.system_message)
assert _flattened_request(second_request) is None
# ---------------------------------------------------------------------------
# Middleware integration — verify the handler sees a flattened request.
# ---------------------------------------------------------------------------
class TestMiddlewareWrap:
@pytest.mark.asyncio
async def test_async_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = await mw.awrap_model_call(request, handler)
assert result == "ok"
assert isinstance(captured["request"].system_message, SystemMessage)
assert captured["request"].system_message.content == "alpha\n\nbeta"
@pytest.mark.asyncio
async def test_async_passes_through_when_already_string(self) -> None:
sys = SystemMessage(content="just a string")
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
await mw.awrap_model_call(request, handler)
# Same request object: no override happened.
assert captured["request"] is request
def test_sync_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = mw.wrap_model_call(request, handler)
assert result == "ok"
assert captured["request"].system_message.content == "alpha\n\nbeta"
def test_sync_passes_through_when_no_system_message(self) -> None:
request = _make_request(None)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
mw.wrap_model_call(request, handler)
assert captured["request"] is request
# ---------------------------------------------------------------------------
# Regression guard — pin the worst-case shape that triggered the
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
# downstream cache_control_injection_points can only place 1 breakpoint
# on the system message regardless of provider redistribution quirks.
# ---------------------------------------------------------------------------
def test_regression_five_block_system_collapses_to_one_block() -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
{"type": "text", "text": "<TodoListMiddleware section>"},
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
{"type": "text", "text": "<SkillsMiddleware section>"},
{"type": "text", "text": "<SubAgentMiddleware section>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message.content, str)
# The exact join doesn't matter for the cache_control accounting —
# only that there is exactly ONE content block when LiteLLM's
# AnthropicCacheControlHook later targets ``role: system``.
assert "<surfsense base" in flattened.system_message.content
assert "<SubAgentMiddleware" in flattened.system_message.content
def test_regression_human_message_not_modified() -> None:
# Sanity: the middleware MUST NOT touch user messages — only the
# system message. Multi-block user content is the path that carries
# image attachments and would lose its image_url block on
# accidental flatten.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
user = HumanMessage(
content=[
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
]
)
request = _make_request(sys)
request.messages = [user]
flattened = _flattened_request(request)
assert flattened is not None
# System flattened to string …
assert isinstance(flattened.system_message.content, str)
# … user message is untouched (the helper does not even look at it).
assert flattened.messages == [user]
assert isinstance(user.content, list)
assert len(user.content) == 2

View file

@ -1,4 +1,4 @@
"""Tests for ``apply_litellm_prompt_caching`` in
r"""Tests for ``apply_litellm_prompt_caching`` in
:mod:`app.agents.new_chat.prompt_caching`.
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
``litellm.completion(...)``. The tests below pin its public contract:
1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
savings compound across multi-turn conversations on Anthropic-family
providers.
providers. ``index: 0`` is used (rather than ``role: system``) because
the deepagent stack accumulates multiple ``SystemMessage``\ s in
``state["messages"]`` and ``role: system`` would tag every one of
them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None:
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert {"location": "message", "role": "system"} in points
assert {"location": "message", "index": 0} in points
assert {"location": "message", "index": -1} in points
assert len(points) == 2
def test_does_not_inject_role_system_breakpoint() -> None:
"""Regression: deliberately AVOID ``role: system`` so we don't tag
every SystemMessage the deepagent ``before_agent`` injectors push
into ``state["messages"]`` (priority, tree, memory, file-intent,
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
``cache_control`` cap and surfaces as
``OpenrouterException: A maximum of 4 blocks with cache_control may
be provided. Found N`` 400s.
"""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert all(p.get("role") != "system" for p in points), (
f"Expected no role=system breakpoint, got: {points}"
)
def test_injection_points_set_for_anthropic_config() -> None:
"""Anthropic-family configs need the marker — verify it lands."""
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")

View file

@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
)
)
assert plan.is_recency_query is False
# ── mentioned_document_ids cross-turn drain ────────────────────────────
class TestKnowledgePriorityMentionDrain:
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
instance across turns of the same thread. ``mentioned_document_ids``
can therefore enter the middleware via two paths:
1. The constructor closure (``__init__(mentioned_document_ids=...)``)
seeded by the cache-miss build on turn 1.
2. ``runtime.context.mentioned_document_ids`` supplied freshly per
turn by the streaming task.
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
on turn 2 would fall through to the closure (because ``[]`` is falsy in
Python) and replay turn 1's mentions. This class pins down the
correct behaviour: the runtime path is authoritative even when empty,
and the closure is drained the first time the runtime path fires so
no later turn can ever resurrect stale state.
"""
@staticmethod
def _make_runtime(mention_ids: list[int]):
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
from types import SimpleNamespace
return SimpleNamespace(
context=SimpleNamespace(mentioned_document_ids=mention_ids),
)
@staticmethod
def _planner_llm() -> "FakeLLM":
# Planner returns a stable, non-recency plan so we always land in
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
# invoked alongside the main search).
return FakeLLM(
json.dumps(
{
"optimized_query": "follow up question",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
"""Turn 1 with mentions in BOTH closure and runtime context: the
runtime path wins AND the closure is drained so a future turn
cannot replay it.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what is in those docs?")]},
runtime=self._make_runtime([1, 2, 3]),
)
assert fetched_ids == [[1, 2, 3]], (
"runtime.context mentions must be the source of truth on turn 1"
)
assert middleware.mentioned_document_ids == [], (
"closure must be drained the first time the runtime path fires "
"so no later turn can replay stale mentions"
)
async def test_empty_runtime_context_does_not_replay_closure_mentions(
self, monkeypatch
):
"""Regression: turn 2 with NO mentions must not surface turn 1's
mentions from the constructor closure.
Before the fix, ``if ctx_mentions:`` treated an empty list as
absent and fell through to ``elif self.mentioned_document_ids:``,
replaying turn 1's mentions. This test pins down the corrected
behaviour.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
# Simulate a cached middleware instance whose closure was seeded
# by a previous turn's cache-miss build (mentions=[1,2,3]).
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
# Turn 2: streaming task supplies an EMPTY mention list (no
# mentions on this follow-up turn).
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what about the next steps?")]},
runtime=self._make_runtime([]),
)
assert fetched_ids == [], (
"fetch_mentioned_documents must NOT be called when the runtime "
"context says there are no mentions for this turn"
)
async def test_legacy_path_fires_only_when_runtime_context_absent(
self, monkeypatch
):
"""Backward-compat: if a caller doesn't supply runtime.context (old
non-streaming code path), the closure-injected mentions are still
honoured exactly once and then drained.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[7, 8],
)
# First call: no runtime → legacy path uses the closure.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="initial question")]},
runtime=None,
)
# Second call: still no runtime — closure already drained, so no replay.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="follow up")]},
runtime=None,
)
assert fetched_ids == [[7, 8]], (
"legacy path must honour the closure exactly once and then drain it"
)
assert middleware.mentioned_document_ids == []

View file

@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model():
await _preflight_llm(fake_llm)
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_swallows_exceptions():
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
caller can safely re-touch the request-scoped session afterwards.
The helper guards the parallel preflight + agent-build path: when the
speculative build is being discarded (429 or non-429 preflight failure)
we await it solely to release any in-flight ``AsyncSession`` usage
the build's outcome is irrelevant. Any exception (including
``CancelledError``) leaking out would skip the caller's recovery flow
and re-introduce the very session-concurrency hazard the helper exists
to prevent.
"""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _raises() -> None:
raise RuntimeError("speculative build crashed")
async def _succeeds() -> str:
return "agent"
async def _slow() -> None:
await asyncio.sleep(0.05)
for coro in (_raises(), _succeeds(), _slow()):
task = asyncio.create_task(coro)
await _settle_speculative_agent_build(task)
assert task.done()
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_handles_already_done_task():
"""Done tasks (success or failure) must still be settled without raising."""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _ok() -> str:
return "ok"
async def _bad() -> None:
raise ValueError("nope")
ok_task = asyncio.create_task(_ok())
bad_task = asyncio.create_task(_bad())
# Drive both to completion before settling.
await asyncio.sleep(0)
await asyncio.sleep(0)
await _settle_speculative_agent_build(ok_task)
await _settle_speculative_agent_build(bad_task)
assert ok_task.result() == "ok"
# ``bad_task`` exception was consumed by the settle helper; calling
# ``.exception()`` after the fact must still return the original error
# (the helper observes it but doesn't clear it).
assert isinstance(bad_task.exception(), ValueError)
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(