Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-05 01:44:46 +02:00
commit 5119915f4f
278 changed files with 34669 additions and 8970 deletions

View file

@ -226,6 +226,31 @@ class TestCompose:
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_provider_hints_render_with_custom_system_instructions(
self, fixed_today: datetime
) -> None:
"""Regression guard for the always-append decision: provider hints
append AFTER a custom system prompt.
Provider hints are stylistic nudges (parallel tool-call rules,
formatting guidance, etc.) that help the model regardless of
what the system instructions say. Suppressing them when a
custom prompt is set would partially defeat the per-family
prompt machinery.
"""
prompt = compose_system_prompt(
today=fixed_today,
custom_system_instructions="You are a custom assistant.",
model_name="anthropic/claude-3-5-sonnet",
)
assert "You are a custom assistant." in prompt
assert "<provider_hints>" in prompt
# The custom prompt must come BEFORE the provider hints so the
# user's framing isn't drowned out by the stylistic nudges.
assert prompt.index("You are a custom assistant.") < prompt.index(
"<provider_hints>"
)
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:

View file

@ -0,0 +1,268 @@
"""Regression tests for the compiled-agent cache.
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
build-failure non-caching) and the cache-key signature helpers that
``create_surfsense_deep_agent`` relies on. The integration with
``create_surfsense_deep_agent`` is covered separately by the streaming
contract tests; this module focuses on the primitives so a regression
in the cache implementation is caught before it reaches the agent
factory.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import pytest
from app.agents.new_chat.agent_cache import (
flags_signature,
reload_for_tests,
stable_hash,
system_prompt_hash,
tools_signature,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# stable_hash + signature helpers
# ---------------------------------------------------------------------------
def test_stable_hash_is_deterministic_across_calls() -> None:
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
assert a == b
def test_stable_hash_changes_when_any_part_changes() -> None:
base = stable_hash("v1", 42, "thread-9")
assert stable_hash("v1", 42, "thread-10") != base
assert stable_hash("v2", 42, "thread-9") != base
assert stable_hash("v1", 43, "thread-9") != base
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
"""Two tool lists with the same surface must hash identically.
The cache key MUST NOT change when the underlying ``BaseTool``
instances are different Python objects (a fresh request constructs
fresh tool instances every time). Hashing on ``(name, description)``
keeps the cache hot across requests with identical tool surfaces.
"""
@dataclass
class FakeTool:
name: str
description: str
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
sig_a = tools_signature(
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
)
sig_b = tools_signature(
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_a == sig_b, "tool order must not affect the signature"
# Adding a tool rotates the key.
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
sig_c = tools_signature(
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_c != sig_a
def test_tools_signature_rotates_when_connector_set_changes() -> None:
@dataclass
class FakeTool:
name: str
description: str
tools = [FakeTool("a", "x")]
base = tools_signature(
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
)
added = tools_signature(
tools,
available_connectors=["NOTION", "SLACK"],
available_document_types=["FILE"],
)
assert base != added, "adding a connector must rotate the cache key"
def test_flags_signature_changes_when_flag_flips() -> None:
@dataclass(frozen=True)
class Flags:
a: bool = True
b: bool = False
base = flags_signature(Flags())
flipped = flags_signature(Flags(b=True))
assert base != flipped
def test_system_prompt_hash_is_stable_and_distinct() -> None:
p1 = "You are a helpful assistant."
p2 = "You are a helpful assistant!" # one-character delta
assert system_prompt_hash(p1) == system_prompt_hash(p1)
assert system_prompt_hash(p1) != system_prompt_hash(p2)
# ---------------------------------------------------------------------------
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is b, "cache must return the SAME object across hits"
assert builds == 1, "builder must run exactly once"
@pytest.mark.asyncio
async def test_cache_different_keys_get_different_instances() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k1", builder=builder)
b = await cache.get_or_build("k2", builder=builder)
assert a is not b
@pytest.mark.asyncio
async def test_cache_stale_entries_get_rebuilt() -> None:
# ttl=0 means every read sees the entry as immediately stale.
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "stale entry must rebuild a fresh instance"
assert builds == 2
@pytest.mark.asyncio
async def test_cache_evicts_lru_when_full() -> None:
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("a", builder=builder)
_ = await cache.get_or_build("b", builder=builder)
# Re-touch "a" so "b" is now the LRU victim.
a_again = await cache.get_or_build("a", builder=builder)
assert a_again is a
# Inserting "c" should evict "b" (LRU), not "a".
_ = await cache.get_or_build("c", builder=builder)
assert cache.stats()["size"] == 2
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
a_hit = await cache.get_or_build("a", builder=builder)
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
@pytest.mark.asyncio
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
"""Two concurrent get_or_build calls on the same key must share one builder."""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
build_started = asyncio.Event()
builds = 0
async def slow_builder() -> object:
nonlocal builds
builds += 1
build_started.set()
# Yield control so the second waiter can race against us.
await asyncio.sleep(0.05)
return object()
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
# Wait until the first builder has started, then race a second waiter.
await build_started.wait()
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
a, b = await asyncio.gather(task_a, task_b)
assert a is b, "coalesced waiters must observe the same value"
assert builds == 1, "concurrent cold misses must collapse to ONE build"
@pytest.mark.asyncio
async def test_cache_does_not_store_failed_builds() -> None:
"""A builder that raises must NOT poison the cache.
The next caller for the same key must run the builder again (not
re-raise the cached exception).
"""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
attempts = 0
async def flaky_builder() -> object:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("transient")
return object()
with pytest.raises(RuntimeError, match="transient"):
await cache.get_or_build("k", builder=flaky_builder)
# Second call must retry — not re-raise the cached exception.
value = await cache.get_or_build("k", builder=flaky_builder)
assert value is not None
assert attempts == 2
@pytest.mark.asyncio
async def test_cache_invalidate_drops_entry() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k", builder=builder)
assert cache.invalidate("k") is True
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "post-invalidation lookup must rebuild"
@pytest.mark.asyncio
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
async def builder() -> object:
return object()
await cache.get_or_build("user:1:thread:1", builder=builder)
await cache.get_or_build("user:1:thread:2", builder=builder)
await cache.get_or_build("user:2:thread:1", builder=builder)
removed = cache.invalidate_prefix("user:1:")
assert removed == 2
assert cache.stats()["size"] == 1
# The user:2 entry must still be hot (no rebuild).
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
assert survivor_value is not None

View file

@ -7,7 +7,9 @@ import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
end_turn,
get_cancel_event,
is_cancel_requested,
manager,
request_cancel,
reset_cancel,
@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")
def test_request_cancel_creates_event_for_unseen_thread() -> None:
thread_id = "never-seen-cancel"
reset_cancel(thread_id)
assert request_cancel(thread_id) is True
assert get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is True
@pytest.mark.asyncio
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
thread_id = "forced-end-turn"
mw = BusyMutexMiddleware()
runtime = _Runtime(thread_id)
await mw.abefore_agent({}, runtime)
assert manager.lock_for(thread_id).locked()
request_cancel(thread_id)
assert is_cancel_requested(thread_id) is True
end_turn(thread_id)
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False
@pytest.mark.asyncio
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
"""A stale aafter call from attempt A must not unlock attempt B.
Repro flow:
1) attempt A acquires thread lock
2) forced end_turn clears A so retry can proceed
3) attempt B acquires same thread lock
4) stale attempt-A aafter runs late
Expected: B lock remains held.
"""
thread_id = "stale-aafter-lock"
runtime = _Runtime(thread_id)
attempt_a = BusyMutexMiddleware()
attempt_b = BusyMutexMiddleware()
await attempt_a.abefore_agent({}, runtime)
lock = manager.lock_for(thread_id)
assert lock.locked()
end_turn(thread_id)
assert not lock.locked()
await attempt_b.abefore_agent({}, runtime)
assert lock.locked()
# Stale cleanup from attempt A must not release attempt B's lock.
await attempt_a.aafter_agent({}, runtime)
assert lock.locked()
await attempt_b.aafter_agent({}, runtime)

View file

@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
"SURFSENSE_ENABLE_AGENT_CACHE",
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]:
monkeypatch.delenv(name, raising=False)
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
assert flags.any_new_middleware_enabled() is False
assert flags.enable_context_editing is True
assert flags.enable_compaction_v2 is True
assert flags.enable_retry_after is True
assert flags.enable_model_fallback is False
assert flags.enable_model_call_limit is True
assert flags.enable_tool_call_limit is True
assert flags.enable_tool_call_repair is True
assert flags.enable_doom_loop is True
assert flags.enable_permission is True
assert flags.enable_busy_mutex is True
assert flags.enable_llm_tool_selector is False
assert flags.enable_skills is True
assert flags.enable_specialized_subagents is True
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
# Phase 2: agent cache is now default-on (the prerequisite tool
# ``db_session`` refactor landed). The companion gp-subagent share
# flag stays default-off pending data on cold-miss frequency.
assert flags.enable_agent_cache is True
assert flags.enable_agent_cache_share_gp_subagent is False
assert flags.any_new_middleware_enabled() is True
def test_master_kill_switch_overrides_individual_flags(
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
# `enable_otel` is intentionally orthogonal — it does NOT count toward
# ``any_new_middleware_enabled`` because OTel is observability-only and
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
monkeypatch.setenv(env_name, "true")
monkeypatch.setenv(env_name, "false")
flags = reload_for_tests()
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
if attr in counts_toward_middleware:
assert flags.any_new_middleware_enabled() is True
else:
assert flags.any_new_middleware_enabled() is False
assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"

View file

@ -0,0 +1,344 @@
"""Tests for ``FlattenSystemMessageMiddleware``.
The middleware exists to defend against Anthropic's "Found 5 cache_control
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
the system message and the OpenRouterAnthropic adapter redistributes
``cache_control`` across all of them. The flattening collapses every
all-text system content list to a single string before the LLM call.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
_flatten_text_blocks,
_flattened_request,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# _flatten_text_blocks — pure helper, the heart of the middleware.
# ---------------------------------------------------------------------------
class TestFlattenTextBlocks:
def test_joins_text_blocks_with_double_newline(self) -> None:
blocks = [
{"type": "text", "text": "<surfsense base>"},
{"type": "text", "text": "<filesystem section>"},
{"type": "text", "text": "<skills section>"},
]
assert (
_flatten_text_blocks(blocks)
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
)
def test_handles_single_text_block(self) -> None:
blocks = [{"type": "text", "text": "only one"}]
assert _flatten_text_blocks(blocks) == "only one"
def test_handles_empty_list(self) -> None:
assert _flatten_text_blocks([]) == ""
def test_passes_through_bare_string_blocks(self) -> None:
# LangChain content can mix bare strings and dict blocks.
blocks = ["raw string", {"type": "text", "text": "dict block"}]
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
def test_returns_none_for_image_block(self) -> None:
# System messages with images are rare — but we never want to
# silently lose the image payload by joining as text.
blocks = [
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_for_non_dict_non_str_block(self) -> None:
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_field_missing(self) -> None:
blocks = [{"type": "text"}] # no ``text`` key
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_is_not_string(self) -> None:
blocks = [{"type": "text", "text": ["nested", "list"]}]
assert _flatten_text_blocks(blocks) is None
def test_drops_cache_control_from_inner_blocks(self) -> None:
# The whole point: existing cache_control on inner blocks is
# discarded so LiteLLM's ``cache_control_injection_points`` can
# re-attach exactly one breakpoint after flattening.
blocks = [
{"type": "text", "text": "first"},
{
"type": "text",
"text": "second",
"cache_control": {"type": "ephemeral"},
},
]
flattened = _flatten_text_blocks(blocks)
assert flattened == "first\n\nsecond"
assert "cache_control" not in flattened # type: ignore[operator]
# ---------------------------------------------------------------------------
# _flattened_request — decides when to override and when to no-op.
# ---------------------------------------------------------------------------
def _make_request(system_message: SystemMessage | None) -> Any:
"""Build a minimal ModelRequest stub. We only need .system_message
and .override(system_message=...) the middleware never touches
other fields.
"""
request = MagicMock()
request.system_message = system_message
def override(**kwargs: Any) -> Any:
new_request = MagicMock()
new_request.system_message = kwargs.get(
"system_message", request.system_message
)
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
return new_request
request.override = override
return request
class TestFlattenedRequest:
def test_collapses_multi_block_system_to_string(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<base>"},
{"type": "text", "text": "<todo>"},
{"type": "text", "text": "<filesystem>"},
{"type": "text", "text": "<skills>"},
{"type": "text", "text": "<subagents>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message, SystemMessage)
assert flattened.system_message.content == (
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
)
def test_no_op_for_string_content(self) -> None:
sys = SystemMessage(content="already a string")
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_for_single_block_list(self) -> None:
# One block already produces one breakpoint — no need to flatten.
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_when_system_message_missing(self) -> None:
request = _make_request(None)
assert _flattened_request(request) is None
def test_no_op_when_list_contains_non_text_block(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
)
request = _make_request(sys)
assert _flattened_request(request) is None
def test_preserves_additional_kwargs_and_metadata(self) -> None:
# Defensive: nothing in the current chain sets these on a system
# message, but losing them silently when something does in the
# future would be a regression. ``name`` in particular is the only
# ``additional_kwargs`` field that ChatLiteLLM's
# ``_convert_message_to_dict`` propagates onto the wire.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
],
additional_kwargs={"name": "surfsense_system", "x": 1},
response_metadata={"tokens": 42},
)
sys.id = "sys-msg-1"
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert flattened.system_message.content == "a\n\nb"
assert flattened.system_message.additional_kwargs == {
"name": "surfsense_system",
"x": 1,
}
assert flattened.system_message.response_metadata == {"tokens": 42}
assert flattened.system_message.id == "sys-msg-1"
def test_idempotent_when_run_twice(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
request = _make_request(sys)
first = _flattened_request(request)
assert first is not None
# Second pass on the already-flattened request should be a no-op.
# We re-wrap in a request stub since the helper inspects
# ``request.system_message.content``.
second_request = _make_request(first.system_message)
assert _flattened_request(second_request) is None
# ---------------------------------------------------------------------------
# Middleware integration — verify the handler sees a flattened request.
# ---------------------------------------------------------------------------
class TestMiddlewareWrap:
@pytest.mark.asyncio
async def test_async_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = await mw.awrap_model_call(request, handler)
assert result == "ok"
assert isinstance(captured["request"].system_message, SystemMessage)
assert captured["request"].system_message.content == "alpha\n\nbeta"
@pytest.mark.asyncio
async def test_async_passes_through_when_already_string(self) -> None:
sys = SystemMessage(content="just a string")
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
await mw.awrap_model_call(request, handler)
# Same request object: no override happened.
assert captured["request"] is request
def test_sync_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = mw.wrap_model_call(request, handler)
assert result == "ok"
assert captured["request"].system_message.content == "alpha\n\nbeta"
def test_sync_passes_through_when_no_system_message(self) -> None:
request = _make_request(None)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
mw.wrap_model_call(request, handler)
assert captured["request"] is request
# ---------------------------------------------------------------------------
# Regression guard — pin the worst-case shape that triggered the
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
# downstream cache_control_injection_points can only place 1 breakpoint
# on the system message regardless of provider redistribution quirks.
# ---------------------------------------------------------------------------
def test_regression_five_block_system_collapses_to_one_block() -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
{"type": "text", "text": "<TodoListMiddleware section>"},
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
{"type": "text", "text": "<SkillsMiddleware section>"},
{"type": "text", "text": "<SubAgentMiddleware section>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message.content, str)
# The exact join doesn't matter for the cache_control accounting —
# only that there is exactly ONE content block when LiteLLM's
# AnthropicCacheControlHook later targets ``role: system``.
assert "<surfsense base" in flattened.system_message.content
assert "<SubAgentMiddleware" in flattened.system_message.content
def test_regression_human_message_not_modified() -> None:
# Sanity: the middleware MUST NOT touch user messages — only the
# system message. Multi-block user content is the path that carries
# image attachments and would lose its image_url block on
# accidental flatten.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
user = HumanMessage(
content=[
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
]
)
request = _make_request(sys)
request.messages = [user]
flattened = _flattened_request(request)
assert flattened is not None
# System flattened to string …
assert isinstance(flattened.system_message.content, str)
# … user message is untouched (the helper does not even look at it).
assert flattened.messages == [user]
assert isinstance(user.content, list)
assert len(user.content) == 2

View file

@ -27,6 +27,7 @@ class TestDefaultAutoApprovedToolsList:
expected = {
"create_gmail_draft",
"update_gmail_draft",
"create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
@ -41,13 +42,12 @@ class TestDefaultAutoApprovedToolsList:
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
def test_send_tools_are_not_auto_approved(self) -> None:
# External-broadcast tools must always prompt.
# External-broadcast / destructive tools must always prompt.
for tool_name in (
"send_gmail_email",
"send_discord_message",
"send_teams_message",
"delete_notion_page",
"create_calendar_event",
"delete_calendar_event",
):
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (

View file

@ -0,0 +1,370 @@
r"""Tests for ``apply_litellm_prompt_caching`` in
:mod:`app.agents.new_chat.prompt_caching`.
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
``litellm.completion(...)``. The tests below pin its public contract:
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
savings compound across multi-turn conversations on Anthropic-family
providers. ``index: 0`` is used (rather than ``role: system``) because
the deepagent stack accumulates multiple ``SystemMessage``\ s in
``state["messages"]`` and ``role: system`` would tag every one of
them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no
OpenAI-only kwargs because the router fans out across providers.
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
skipped rather than raising.
"""
from __future__ import annotations
from typing import Any
import pytest
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class _FakeLLM:
"""Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``.
The helper only inspects ``getattr(llm, "model_kwargs", None)``,
``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple
object suffices we don't need to spin up real LangChain/LiteLLM
machinery for unit tests of the helper's logic.
"""
def __init__(
self,
model: str = "openai/gpt-4o",
model_kwargs: dict[str, Any] | None = None,
) -> None:
self.model = model
self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {}
class ChatLiteLLMRouter:
"""Class-name-only impostor of the real router.
The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"``
(a deliberate stringly-typed check to avoid an import cycle with
``app.services.llm_router_service``). Reusing the same class name here
triggers the same code path without instantiating a real ``Router``.
"""
def __init__(self) -> None:
self.model = "auto"
self.model_kwargs: dict[str, Any] = {}
def _make_cfg(**overrides: Any) -> AgentConfig:
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
defaults: dict[str, Any] = {
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "k",
}
return AgentConfig(**{**defaults, **overrides})
# ---------------------------------------------------------------------------
# (a) Universal injection points
# ---------------------------------------------------------------------------
def test_sets_both_cache_control_injection_points_with_no_config() -> None:
"""Bare call (no agent_config, no thread_id) still sets the two
universal breakpoints these cost nothing on providers that don't
consume them and unlock caching on every supported provider."""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert {"location": "message", "index": 0} in points
assert {"location": "message", "index": -1} in points
assert len(points) == 2
def test_does_not_inject_role_system_breakpoint() -> None:
"""Regression: deliberately AVOID ``role: system`` so we don't tag
every SystemMessage the deepagent ``before_agent`` injectors push
into ``state["messages"]`` (priority, tree, memory, file-intent,
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
``cache_control`` cap and surfaces as
``OpenrouterException: A maximum of 4 blocks with cache_control may
be provided. Found N`` 400s.
"""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert all(p.get("role") != "system" for p in points), (
f"Expected no role=system breakpoint, got: {points}"
)
def test_injection_points_set_for_anthropic_config() -> None:
"""Anthropic-family configs need the marker — verify it lands."""
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
llm = _FakeLLM(model="anthropic/claude-3-5-sonnet")
apply_litellm_prompt_caching(llm, agent_config=cfg)
assert "cache_control_injection_points" in llm.model_kwargs
# ---------------------------------------------------------------------------
# (b) Idempotency / user override wins
# ---------------------------------------------------------------------------
def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None:
"""Users who set their own injection points (e.g. with ``ttl: "1h"``
via ``litellm_params``) keep them the helper merges, never
clobbers."""
user_points = [
{"location": "message", "role": "system", "ttl": "1h"},
]
llm = _FakeLLM(
model_kwargs={"cache_control_injection_points": user_points},
)
apply_litellm_prompt_caching(llm)
assert llm.model_kwargs["cache_control_injection_points"] is user_points
def test_idempotent_when_called_multiple_times() -> None:
"""Build-time + thread-time double-call must be a no-op the second time."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
snapshot = {
"cache_control_injection_points": list(
llm.model_kwargs["cache_control_injection_points"]
),
"prompt_cache_key": llm.model_kwargs["prompt_cache_key"],
"prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"],
}
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
assert (
llm.model_kwargs["cache_control_injection_points"]
== snapshot["cache_control_injection_points"]
)
assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"]
assert (
llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"]
)
def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
"""A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via
``litellm_params``) wins over our default per-thread key."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"})
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc"
# ---------------------------------------------------------------------------
# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
def test_sets_openai_family_extras(provider: str) -> None:
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
cache TTL beyond the default 5-10 min)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Without a thread id we can't construct a per-thread key. Retention
is still useful so we set it (it's free)."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
assert "prompt_cache_key" not in llm.model_kwargs
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
@pytest.mark.parametrize(
"provider",
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
)
def test_no_openai_extras_for_other_providers(provider: str) -> None:
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
skip it. ``cache_control_injection_points`` is still set (universal)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_no_openai_extras_in_auto_mode() -> None:
"""Auto-mode fans out across mixed providers — we can't statically
target OpenAI-only kwargs."""
cfg = AgentConfig.from_auto_mode()
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_no_openai_extras_for_custom_provider() -> None:
"""Custom providers route through arbitrary user-supplied prefixes —
we don't try to infer OpenAI-family compatibility."""
cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
# ---------------------------------------------------------------------------
# (d) ChatLiteLLMRouter — universal injection points only
# ---------------------------------------------------------------------------
def test_router_llm_gets_only_universal_injection_points() -> None:
"""Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must
receive only the universal injection points its requests dispatch
across provider deployments and OpenAI-only kwargs would be wasted
(or stripped by ``drop_params``) on non-OpenAI legs."""
router = ChatLiteLLMRouter()
cfg = _make_cfg(provider="OPENAI")
apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42)
assert "cache_control_injection_points" in router.model_kwargs
assert "prompt_cache_key" not in router.model_kwargs
assert "prompt_cache_retention" not in router.model_kwargs
# ---------------------------------------------------------------------------
# (e) Defensive paths
# ---------------------------------------------------------------------------
def test_handles_llm_with_no_writable_model_kwargs() -> None:
"""Some LLM implementations (e.g. fakes / minimal subclasses) don't
expose a writable ``model_kwargs``. The helper must skip silently
raising would crash the entire LLM build path on a non-critical
optimisation."""
class _ImmutableLLM:
# ``__slots__`` blocks attribute creation, so ``setattr`` raises.
__slots__ = ("model",)
def __init__(self) -> None:
self.model = "openai/gpt-4o"
llm = _ImmutableLLM()
apply_litellm_prompt_caching(llm)
def test_initialises_missing_model_kwargs_dict() -> None:
"""When ``model_kwargs`` is present-but-None (Pydantic v2 default
pattern when no factory is set), the helper initialises it to an
empty dict before mutating."""
class _LazyLLM:
def __init__(self) -> None:
self.model = "openai/gpt-4o"
self.model_kwargs: dict[str, Any] | None = None
llm = _LazyLLM()
apply_litellm_prompt_caching(llm)
assert isinstance(llm.model_kwargs, dict)
assert "cache_control_injection_points" in llm.model_kwargs
def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None:
"""Direct caller path (e.g. ``create_chat_litellm_from_config`` for
YAML configs without a structured ``AgentConfig``): without
``agent_config`` the helper sets only the universal injection points
no OpenAI-family extras even if the prefix says ``openai/``.
Conservative: we'd rather miss the speedup than silently misroute."""
llm = _FakeLLM(model="openai/gpt-4o")
apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99)
assert "cache_control_injection_points" in llm.model_kwargs
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
# ---------------------------------------------------------------------------
# (f) drop_params safety net (regression guard for #19346)
# ---------------------------------------------------------------------------
def test_litellm_drop_params_is_globally_enabled() -> None:
"""``litellm.drop_params=True`` is set globally in
:mod:`app.services.llm_service` so any ``prompt_cache_key`` /
``prompt_cache_retention`` we set on an OpenAI-family config is
auto-stripped if the request later routes to a non-supporting
provider (e.g. via auto-mode router fallback). This test pins that
invariant losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``.
"""
import litellm
import app.services.llm_service # noqa: F401 (side-effect: sets globals)
assert litellm.drop_params is True
# ---------------------------------------------------------------------------
# Regression note: LiteLLM #15696 (multi-content-block last message)
# ---------------------------------------------------------------------------
#
# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]``
# would get ``cache_control`` applied to *every* content block instead
# of only the last one — wasting cache breakpoints and triggering 400s
# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in
# https://github.com/BerriAI/litellm/pull/15699.
#
# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix).
# An end-to-end behavioural test would need to run ``litellm.completion``
# through the Anthropic transformer, which is integration territory and
# better covered by LiteLLM's own test suite. The unit guard here is the
# version pin plus the build-time ``model_kwargs`` shape we verify above.

View file

@ -0,0 +1,117 @@
"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`.
The helper picks the model id fed to ``detect_provider_variant`` so the
right ``<provider_hints>`` block lands in the system prompt. The tests
below pin its preference order:
1. ``agent_config.litellm_params["base_model"]`` (Azure-correct).
2. ``agent_config.model_name``.
3. ``getattr(llm, "model", None)``.
Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would
silently miss every provider regex.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
from app.agents.new_chat.llm_config import AgentConfig
pytestmark = pytest.mark.unit
def _make_cfg(**overrides) -> AgentConfig:
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
defaults = {
"provider": "OPENAI",
"model_name": "x",
"api_key": "k",
}
return AgentConfig(**{**defaults, **overrides})
class _FakeLLM:
"""Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance.
The resolver only reads the ``.model`` attribute via ``getattr``,
matching the established idiom in ``knowledge_search.py`` /
``stream_new_chat.py`` / ``document_summarizer.py``.
"""
def __init__(self, model: str | None) -> None:
self.model = model
def test_prefers_litellm_params_base_model_over_deployment_name() -> None:
"""Azure deployment slug must NOT shadow the underlying model family.
This is the failure mode the helper exists to prevent: a deployment
named ``"azure/prod-chat-001"`` would not match any provider regex
on its own, but the family ``"gpt-4o"`` lives in
``litellm_params["base_model"]`` and routes to ``openai_classic``.
"""
cfg = _make_cfg(
model_name="azure/prod-chat-001",
litellm_params={"base_model": "gpt-4o"},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o"
def test_falls_back_to_model_name_when_litellm_params_is_none() -> None:
cfg = _make_cfg(
model_name="anthropic/claude-3-5-sonnet",
litellm_params=None,
)
got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet"))
assert got == "anthropic/claude-3-5-sonnet"
def test_handles_litellm_params_without_base_model_key() -> None:
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"temperature": 0.5},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_ignores_blank_base_model() -> None:
"""Whitespace-only ``base_model`` must not shadow ``model_name``."""
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"base_model": " "},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_ignores_non_string_base_model() -> None:
"""Defensive: a non-string ``base_model`` should not crash the resolver."""
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"base_model": 42},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_falls_back_to_llm_model_when_no_agent_config() -> None:
"""No ``agent_config`` -> use ``llm.model`` directly. Defensive path
for direct callers; production callers always supply a config."""
assert (
_resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini"))
== "openai/gpt-4o-mini"
)
def test_returns_none_when_nothing_available() -> None:
"""``compose_system_prompt`` treats ``None`` as the ``"default"``
variant and emits no provider block."""
assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None
def test_auto_mode_resolves_to_auto_string() -> None:
"""Auto mode -> ``"auto"``. ``detect_provider_variant("auto")``
returns ``"default"``, which is correct: the child model isn't
known until the LiteLLM Router dispatches."""
cfg = AgentConfig.from_auto_mode()
assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto"