mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 21:32:39 +02:00
feat: implement agent caches and fix invalid prompt cache configs
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
- Added a new function `_warm_agent_jit_caches` to pre-warm agent caches at startup, reducing cold invocation costs. - Updated the `SurfSenseContextSchema` to include per-invocation fields for better state management during agent execution. - Introduced caching mechanisms in various tools to ensure fresh database sessions are used, improving performance and reliability. - Enhanced middleware to support new context features and improve error handling during connector and document type discovery.
This commit is contained in:
parent
90a653c8c7
commit
a34f1fb25c
60 changed files with 8477 additions and 5381 deletions
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
"""Regression tests for the compiled-agent cache.
|
||||
|
||||
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
|
||||
build-failure non-caching) and the cache-key signature helpers that
|
||||
``create_surfsense_deep_agent`` relies on. The integration with
|
||||
``create_surfsense_deep_agent`` is covered separately by the streaming
|
||||
contract tests; this module focuses on the primitives so a regression
|
||||
in the cache implementation is caught before it reaches the agent
|
||||
factory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
reload_for_tests,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stable_hash + signature helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stable_hash_is_deterministic_across_calls() -> None:
|
||||
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_stable_hash_changes_when_any_part_changes() -> None:
|
||||
base = stable_hash("v1", 42, "thread-9")
|
||||
assert stable_hash("v1", 42, "thread-10") != base
|
||||
assert stable_hash("v2", 42, "thread-9") != base
|
||||
assert stable_hash("v1", 43, "thread-9") != base
|
||||
|
||||
|
||||
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
|
||||
"""Two tool lists with the same surface must hash identically.
|
||||
|
||||
The cache key MUST NOT change when the underlying ``BaseTool``
|
||||
instances are different Python objects (a fresh request constructs
|
||||
fresh tool instances every time). Hashing on ``(name, description)``
|
||||
keeps the cache hot across requests with identical tool surfaces.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FakeTool:
|
||||
name: str
|
||||
description: str
|
||||
|
||||
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
|
||||
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
|
||||
sig_a = tools_signature(
|
||||
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
sig_b = tools_signature(
|
||||
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
assert sig_a == sig_b, "tool order must not affect the signature"
|
||||
|
||||
# Adding a tool rotates the key.
|
||||
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
|
||||
sig_c = tools_signature(
|
||||
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
assert sig_c != sig_a
|
||||
|
||||
|
||||
def test_tools_signature_rotates_when_connector_set_changes() -> None:
|
||||
@dataclass
|
||||
class FakeTool:
|
||||
name: str
|
||||
description: str
|
||||
|
||||
tools = [FakeTool("a", "x")]
|
||||
base = tools_signature(
|
||||
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
added = tools_signature(
|
||||
tools,
|
||||
available_connectors=["NOTION", "SLACK"],
|
||||
available_document_types=["FILE"],
|
||||
)
|
||||
assert base != added, "adding a connector must rotate the cache key"
|
||||
|
||||
|
||||
def test_flags_signature_changes_when_flag_flips() -> None:
|
||||
@dataclass(frozen=True)
|
||||
class Flags:
|
||||
a: bool = True
|
||||
b: bool = False
|
||||
|
||||
base = flags_signature(Flags())
|
||||
flipped = flags_signature(Flags(b=True))
|
||||
assert base != flipped
|
||||
|
||||
|
||||
def test_system_prompt_hash_is_stable_and_distinct() -> None:
|
||||
p1 = "You are a helpful assistant."
|
||||
p2 = "You are a helpful assistant!" # one-character delta
|
||||
assert system_prompt_hash(p1) == system_prompt_hash(p1)
|
||||
assert system_prompt_hash(p1) != system_prompt_hash(p2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
builds = 0
|
||||
|
||||
async def builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is b, "cache must return the SAME object across hits"
|
||||
assert builds == 1, "builder must run exactly once"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_different_keys_get_different_instances() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k1", builder=builder)
|
||||
b = await cache.get_or_build("k2", builder=builder)
|
||||
assert a is not b
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_stale_entries_get_rebuilt() -> None:
|
||||
# ttl=0 means every read sees the entry as immediately stale.
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
|
||||
builds = 0
|
||||
|
||||
async def builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is not b, "stale entry must rebuild a fresh instance"
|
||||
assert builds == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_evicts_lru_when_full() -> None:
|
||||
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("a", builder=builder)
|
||||
_ = await cache.get_or_build("b", builder=builder)
|
||||
# Re-touch "a" so "b" is now the LRU victim.
|
||||
a_again = await cache.get_or_build("a", builder=builder)
|
||||
assert a_again is a
|
||||
# Inserting "c" should evict "b" (LRU), not "a".
|
||||
_ = await cache.get_or_build("c", builder=builder)
|
||||
assert cache.stats()["size"] == 2
|
||||
|
||||
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
|
||||
a_hit = await cache.get_or_build("a", builder=builder)
|
||||
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
|
||||
"""Two concurrent get_or_build calls on the same key must share one builder."""
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
build_started = asyncio.Event()
|
||||
builds = 0
|
||||
|
||||
async def slow_builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
build_started.set()
|
||||
# Yield control so the second waiter can race against us.
|
||||
await asyncio.sleep(0.05)
|
||||
return object()
|
||||
|
||||
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||
# Wait until the first builder has started, then race a second waiter.
|
||||
await build_started.wait()
|
||||
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||
|
||||
a, b = await asyncio.gather(task_a, task_b)
|
||||
assert a is b, "coalesced waiters must observe the same value"
|
||||
assert builds == 1, "concurrent cold misses must collapse to ONE build"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_does_not_store_failed_builds() -> None:
|
||||
"""A builder that raises must NOT poison the cache.
|
||||
|
||||
The next caller for the same key must run the builder again (not
|
||||
re-raise the cached exception).
|
||||
"""
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
attempts = 0
|
||||
|
||||
async def flaky_builder() -> object:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("transient")
|
||||
return object()
|
||||
|
||||
with pytest.raises(RuntimeError, match="transient"):
|
||||
await cache.get_or_build("k", builder=flaky_builder)
|
||||
|
||||
# Second call must retry — not re-raise the cached exception.
|
||||
value = await cache.get_or_build("k", builder=flaky_builder)
|
||||
assert value is not None
|
||||
assert attempts == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidate_drops_entry() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
assert cache.invalidate("k") is True
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is not b, "post-invalidation lookup must rebuild"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
|
||||
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
await cache.get_or_build("user:1:thread:1", builder=builder)
|
||||
await cache.get_or_build("user:1:thread:2", builder=builder)
|
||||
await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||
|
||||
removed = cache.invalidate_prefix("user:1:")
|
||||
assert removed == 2
|
||||
assert cache.stats()["size"] == 1
|
||||
|
||||
# The user:2 entry must still be hot (no rebuild).
|
||||
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||
assert survivor_value is not None
|
||||
|
|
@ -34,6 +34,8 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"SURFSENSE_ENABLE_OTEL",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
|
||||
]:
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
|
@ -62,6 +64,11 @@ def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) ->
|
|||
assert flags.enable_stream_parity_v2 is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
assert flags.enable_otel is False
|
||||
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||
# ``db_session`` refactor landed). The companion gp-subagent share
|
||||
# flag stays default-off pending data on cold-miss frequency.
|
||||
assert flags.enable_agent_cache is True
|
||||
assert flags.enable_agent_cache_share_gp_subagent is False
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for ``FlattenSystemMessageMiddleware``.
|
||||
|
||||
The middleware exists to defend against Anthropic's "Found 5 cache_control
|
||||
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
|
||||
the system message and the OpenRouter→Anthropic adapter redistributes
|
||||
``cache_control`` across all of them. The flattening collapses every
|
||||
all-text system content list to a single string before the LLM call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.agents.new_chat.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
_flatten_text_blocks,
|
||||
_flattened_request,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_text_blocks — pure helper, the heart of the middleware.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenTextBlocks:
|
||||
def test_joins_text_blocks_with_double_newline(self) -> None:
|
||||
blocks = [
|
||||
{"type": "text", "text": "<surfsense base>"},
|
||||
{"type": "text", "text": "<filesystem section>"},
|
||||
{"type": "text", "text": "<skills section>"},
|
||||
]
|
||||
assert (
|
||||
_flatten_text_blocks(blocks)
|
||||
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
|
||||
)
|
||||
|
||||
def test_handles_single_text_block(self) -> None:
|
||||
blocks = [{"type": "text", "text": "only one"}]
|
||||
assert _flatten_text_blocks(blocks) == "only one"
|
||||
|
||||
def test_handles_empty_list(self) -> None:
|
||||
assert _flatten_text_blocks([]) == ""
|
||||
|
||||
def test_passes_through_bare_string_blocks(self) -> None:
|
||||
# LangChain content can mix bare strings and dict blocks.
|
||||
blocks = ["raw string", {"type": "text", "text": "dict block"}]
|
||||
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
|
||||
|
||||
def test_returns_none_for_image_block(self) -> None:
|
||||
# System messages with images are rare — but we never want to
|
||||
# silently lose the image payload by joining as text.
|
||||
blocks = [
|
||||
{"type": "text", "text": "look at this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
|
||||
]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_for_non_dict_non_str_block(self) -> None:
|
||||
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_when_text_field_missing(self) -> None:
|
||||
blocks = [{"type": "text"}] # no ``text`` key
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_when_text_is_not_string(self) -> None:
|
||||
blocks = [{"type": "text", "text": ["nested", "list"]}]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_drops_cache_control_from_inner_blocks(self) -> None:
|
||||
# The whole point: existing cache_control on inner blocks is
|
||||
# discarded so LiteLLM's ``cache_control_injection_points`` can
|
||||
# re-attach exactly one breakpoint after flattening.
|
||||
blocks = [
|
||||
{"type": "text", "text": "first"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "second",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
]
|
||||
flattened = _flatten_text_blocks(blocks)
|
||||
assert flattened == "first\n\nsecond"
|
||||
assert "cache_control" not in flattened # type: ignore[operator]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flattened_request — decides when to override and when to no-op.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(system_message: SystemMessage | None) -> Any:
|
||||
"""Build a minimal ModelRequest stub. We only need .system_message
|
||||
and .override(system_message=...) — the middleware never touches
|
||||
other fields.
|
||||
"""
|
||||
request = MagicMock()
|
||||
request.system_message = system_message
|
||||
|
||||
def override(**kwargs: Any) -> Any:
|
||||
new_request = MagicMock()
|
||||
new_request.system_message = kwargs.get(
|
||||
"system_message", request.system_message
|
||||
)
|
||||
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
|
||||
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
|
||||
return new_request
|
||||
|
||||
request.override = override
|
||||
return request
|
||||
|
||||
|
||||
class TestFlattenedRequest:
|
||||
def test_collapses_multi_block_system_to_string(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "<base>"},
|
||||
{"type": "text", "text": "<todo>"},
|
||||
{"type": "text", "text": "<filesystem>"},
|
||||
{"type": "text", "text": "<skills>"},
|
||||
{"type": "text", "text": "<subagents>"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
flattened = _flattened_request(request)
|
||||
|
||||
assert flattened is not None
|
||||
assert isinstance(flattened.system_message, SystemMessage)
|
||||
assert flattened.system_message.content == (
|
||||
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
|
||||
)
|
||||
|
||||
def test_no_op_for_string_content(self) -> None:
|
||||
sys = SystemMessage(content="already a string")
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_for_single_block_list(self) -> None:
|
||||
# One block already produces one breakpoint — no need to flatten.
|
||||
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_when_system_message_missing(self) -> None:
|
||||
request = _make_request(None)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_when_list_contains_non_text_block(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "look"},
|
||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_preserves_additional_kwargs_and_metadata(self) -> None:
|
||||
# Defensive: nothing in the current chain sets these on a system
|
||||
# message, but losing them silently when something does in the
|
||||
# future would be a regression. ``name`` in particular is the only
|
||||
# ``additional_kwargs`` field that ChatLiteLLM's
|
||||
# ``_convert_message_to_dict`` propagates onto the wire.
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
],
|
||||
additional_kwargs={"name": "surfsense_system", "x": 1},
|
||||
response_metadata={"tokens": 42},
|
||||
)
|
||||
sys.id = "sys-msg-1"
|
||||
request = _make_request(sys)
|
||||
|
||||
flattened = _flattened_request(request)
|
||||
assert flattened is not None
|
||||
assert flattened.system_message.content == "a\n\nb"
|
||||
assert flattened.system_message.additional_kwargs == {
|
||||
"name": "surfsense_system",
|
||||
"x": 1,
|
||||
}
|
||||
assert flattened.system_message.response_metadata == {"tokens": 42}
|
||||
assert flattened.system_message.id == "sys-msg-1"
|
||||
|
||||
def test_idempotent_when_run_twice(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
first = _flattened_request(request)
|
||||
assert first is not None
|
||||
|
||||
# Second pass on the already-flattened request should be a no-op.
|
||||
# We re-wrap in a request stub since the helper inspects
|
||||
# ``request.system_message.content``.
|
||||
second_request = _make_request(first.system_message)
|
||||
assert _flattened_request(second_request) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware integration — verify the handler sees a flattened request.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMiddlewareWrap:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_passes_flattened_request_to_handler(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "alpha"},
|
||||
{"type": "text", "text": "beta"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
|
||||
assert result == "ok"
|
||||
assert isinstance(captured["request"].system_message, SystemMessage)
|
||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_passes_through_when_already_string(self) -> None:
|
||||
sys = SystemMessage(content="just a string")
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
await mw.awrap_model_call(request, handler)
|
||||
|
||||
# Same request object: no override happened.
|
||||
assert captured["request"] is request
|
||||
|
||||
def test_sync_passes_flattened_request_to_handler(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "alpha"},
|
||||
{"type": "text", "text": "beta"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
result = mw.wrap_model_call(request, handler)
|
||||
|
||||
assert result == "ok"
|
||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||
|
||||
def test_sync_passes_through_when_no_system_message(self) -> None:
|
||||
request = _make_request(None)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
mw.wrap_model_call(request, handler)
|
||||
assert captured["request"] is request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression guard — pin the worst-case shape that triggered the
|
||||
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
|
||||
# downstream cache_control_injection_points can only place 1 breakpoint
|
||||
# on the system message regardless of provider redistribution quirks.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_regression_five_block_system_collapses_to_one_block() -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
|
||||
{"type": "text", "text": "<TodoListMiddleware section>"},
|
||||
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
|
||||
{"type": "text", "text": "<SkillsMiddleware section>"},
|
||||
{"type": "text", "text": "<SubAgentMiddleware section>"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
flattened = _flattened_request(request)
|
||||
|
||||
assert flattened is not None
|
||||
assert isinstance(flattened.system_message.content, str)
|
||||
# The exact join doesn't matter for the cache_control accounting —
|
||||
# only that there is exactly ONE content block when LiteLLM's
|
||||
# AnthropicCacheControlHook later targets ``role: system``.
|
||||
assert "<surfsense base" in flattened.system_message.content
|
||||
assert "<SubAgentMiddleware" in flattened.system_message.content
|
||||
|
||||
|
||||
def test_regression_human_message_not_modified() -> None:
|
||||
# Sanity: the middleware MUST NOT touch user messages — only the
|
||||
# system message. Multi-block user content is the path that carries
|
||||
# image attachments and would lose its image_url block on
|
||||
# accidental flatten.
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
]
|
||||
)
|
||||
user = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "look at this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
request.messages = [user]
|
||||
|
||||
flattened = _flattened_request(request)
|
||||
assert flattened is not None
|
||||
# System flattened to string …
|
||||
assert isinstance(flattened.system_message.content, str)
|
||||
# … user message is untouched (the helper does not even look at it).
|
||||
assert flattened.messages == [user]
|
||||
assert isinstance(user.content, list)
|
||||
assert len(user.content) == 2
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for ``apply_litellm_prompt_caching`` in
|
||||
r"""Tests for ``apply_litellm_prompt_caching`` in
|
||||
:mod:`app.agents.new_chat.prompt_caching`.
|
||||
|
||||
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
|
||||
|
|
@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider
|
|||
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
|
||||
``litellm.completion(...)``. The tests below pin its public contract:
|
||||
|
||||
1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
|
||||
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
|
||||
savings compound across multi-turn conversations on Anthropic-family
|
||||
providers.
|
||||
providers. ``index: 0`` is used (rather than ``role: system``) because
|
||||
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
||||
``state["messages"]`` and ``role: system`` would tag every one of
|
||||
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
||||
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
|
||||
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
|
||||
prompt-cache surface is available).
|
||||
|
|
@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None:
|
|||
apply_litellm_prompt_caching(llm)
|
||||
|
||||
points = llm.model_kwargs["cache_control_injection_points"]
|
||||
assert {"location": "message", "role": "system"} in points
|
||||
assert {"location": "message", "index": 0} in points
|
||||
assert {"location": "message", "index": -1} in points
|
||||
assert len(points) == 2
|
||||
|
||||
|
||||
def test_does_not_inject_role_system_breakpoint() -> None:
|
||||
"""Regression: deliberately AVOID ``role: system`` so we don't tag
|
||||
every SystemMessage the deepagent ``before_agent`` injectors push
|
||||
into ``state["messages"]`` (priority, tree, memory, file-intent,
|
||||
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
|
||||
``cache_control`` cap and surfaces as
|
||||
``OpenrouterException: A maximum of 4 blocks with cache_control may
|
||||
be provided. Found N`` 400s.
|
||||
"""
|
||||
llm = _FakeLLM()
|
||||
apply_litellm_prompt_caching(llm)
|
||||
points = llm.model_kwargs["cache_control_injection_points"]
|
||||
assert all(p.get("role") != "system" for p in points), (
|
||||
f"Expected no role=system breakpoint, got: {points}"
|
||||
)
|
||||
|
||||
|
||||
def test_injection_points_set_for_anthropic_config() -> None:
|
||||
"""Anthropic-family configs need the marker — verify it lands."""
|
||||
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
|
||||
|
|
|
|||
|
|
@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
|
|||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
|
||||
# ── mentioned_document_ids cross-turn drain ────────────────────────────
|
||||
|
||||
|
||||
class TestKnowledgePriorityMentionDrain:
|
||||
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
|
||||
|
||||
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
|
||||
instance across turns of the same thread. ``mentioned_document_ids``
|
||||
can therefore enter the middleware via two paths:
|
||||
|
||||
1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
|
||||
seeded by the cache-miss build on turn 1.
|
||||
2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
|
||||
turn by the streaming task.
|
||||
|
||||
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
|
||||
on turn 2 would fall through to the closure (because ``[]`` is falsy in
|
||||
Python) and replay turn 1's mentions. This class pins down the
|
||||
correct behaviour: the runtime path is authoritative even when empty,
|
||||
and the closure is drained the first time the runtime path fires so
|
||||
no later turn can ever resurrect stale state.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_runtime(mention_ids: list[int]):
|
||||
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(mentioned_document_ids=mention_ids),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _planner_llm() -> "FakeLLM":
|
||||
# Planner returns a stable, non-recency plan so we always land in
|
||||
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
|
||||
# invoked alongside the main search).
|
||||
return FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "follow up question",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
|
||||
"""Turn 1 with mentions in BOTH closure and runtime context: the
|
||||
runtime path wins AND the closure is drained so a future turn
|
||||
cannot replay it.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what is in those docs?")]},
|
||||
runtime=self._make_runtime([1, 2, 3]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [[1, 2, 3]], (
|
||||
"runtime.context mentions must be the source of truth on turn 1"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == [], (
|
||||
"closure must be drained the first time the runtime path fires "
|
||||
"so no later turn can replay stale mentions"
|
||||
)
|
||||
|
||||
async def test_empty_runtime_context_does_not_replay_closure_mentions(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Regression: turn 2 with NO mentions must not surface turn 1's
|
||||
mentions from the constructor closure.
|
||||
|
||||
Before the fix, ``if ctx_mentions:`` treated an empty list as
|
||||
absent and fell through to ``elif self.mentioned_document_ids:``,
|
||||
replaying turn 1's mentions. This test pins down the corrected
|
||||
behaviour.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
# Simulate a cached middleware instance whose closure was seeded
|
||||
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
# Turn 2: streaming task supplies an EMPTY mention list (no
|
||||
# mentions on this follow-up turn).
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what about the next steps?")]},
|
||||
runtime=self._make_runtime([]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [], (
|
||||
"fetch_mentioned_documents must NOT be called when the runtime "
|
||||
"context says there are no mentions for this turn"
|
||||
)
|
||||
|
||||
async def test_legacy_path_fires_only_when_runtime_context_absent(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Backward-compat: if a caller doesn't supply runtime.context (old
|
||||
non-streaming code path), the closure-injected mentions are still
|
||||
honoured exactly once and then drained.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[7, 8],
|
||||
)
|
||||
|
||||
# First call: no runtime → legacy path uses the closure.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="initial question")]},
|
||||
runtime=None,
|
||||
)
|
||||
# Second call: still no runtime — closure already drained, so no replay.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="follow up")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert fetched_ids == [[7, 8]], (
|
||||
"legacy path must honour the closure exactly once and then drain it"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == []
|
||||
|
|
|
|||
|
|
@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model():
|
|||
await _preflight_llm(fake_llm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settle_speculative_agent_build_swallows_exceptions():
|
||||
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
|
||||
caller can safely re-touch the request-scoped session afterwards.
|
||||
|
||||
The helper guards the parallel preflight + agent-build path: when the
|
||||
speculative build is being discarded (429 or non-429 preflight failure)
|
||||
we await it solely to release any in-flight ``AsyncSession`` usage —
|
||||
the build's outcome is irrelevant. Any exception (including
|
||||
``CancelledError``) leaking out would skip the caller's recovery flow
|
||||
and re-introduce the very session-concurrency hazard the helper exists
|
||||
to prevent.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
||||
|
||||
async def _raises() -> None:
|
||||
raise RuntimeError("speculative build crashed")
|
||||
|
||||
async def _succeeds() -> str:
|
||||
return "agent"
|
||||
|
||||
async def _slow() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
for coro in (_raises(), _succeeds(), _slow()):
|
||||
task = asyncio.create_task(coro)
|
||||
await _settle_speculative_agent_build(task)
|
||||
assert task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settle_speculative_agent_build_handles_already_done_task():
|
||||
"""Done tasks (success or failure) must still be settled without raising."""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
||||
|
||||
async def _ok() -> str:
|
||||
return "ok"
|
||||
|
||||
async def _bad() -> None:
|
||||
raise ValueError("nope")
|
||||
|
||||
ok_task = asyncio.create_task(_ok())
|
||||
bad_task = asyncio.create_task(_bad())
|
||||
# Drive both to completion before settling.
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await _settle_speculative_agent_build(ok_task)
|
||||
await _settle_speculative_agent_build(bad_task)
|
||||
assert ok_task.result() == "ok"
|
||||
# ``bad_task`` exception was consumed by the settle helper; calling
|
||||
# ``.exception()`` after the fact must still return the original error
|
||||
# (the helper observes it but doesn't clear it).
|
||||
assert isinstance(bad_task.exception(), ValueError)
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue