feat: prompt caching
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

- Updated `litellm` dependency version from `1.83.4` to `1.83.7`.
- Adjusted `aiohttp` version from `3.13.5` to `3.13.4` in the lock file.
- Implemented `apply_litellm_prompt_caching` in `chat_deepagent.py` to improve prompt caching.
- Added model name resolution logic in `chat_deepagent.py` to ensure correct provider-variant dispatch.
- Enhanced `llm_config.py` to configure prompt caching for various LLM providers.
- Updated tests to verify correct model name forwarding and prompt caching behavior.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-01 05:10:53 -07:00
parent 360b5f8e3a
commit e57c3a7d0c
12 changed files with 877 additions and 156 deletions

View file

@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
This lets us swap in ``SurfSenseFilesystemMiddleware`` a customisable
subclass of the default ``FilesystemMiddleware`` while preserving every
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
summarisation, prompt-caching, etc.).
summarisation, etc.). Prompt caching is configured at LLM-build time via
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
than as a middleware.
"""
import asyncio
@ -33,7 +35,6 @@ from langchain.agents.middleware import (
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -74,6 +75,7 @@ from app.agents.new_chat.plugin_loader import (
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.agents.new_chat.subagents import build_specialized_subagents
from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
@ -94,6 +96,39 @@ from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
def _resolve_prompt_model_name(
agent_config: AgentConfig | None,
llm: BaseChatModel,
) -> str | None:
"""Resolve the model id to feed to provider-variant detection.
Preference order (matches the established idiom in
``llm_router_service.py`` see ``params.get("base_model") or
params.get("model", "")`` usages there):
1. ``agent_config.litellm_params["base_model"]`` required for Azure
deployments where ``model_name`` is the deployment slug, not the
underlying family. Without this, a deployment named e.g.
``"prod-chat-001"`` would silently miss every provider regex.
2. ``agent_config.model_name`` the user's configured model id.
3. ``getattr(llm, "model", None)`` fallback for direct callers that
don't supply an ``AgentConfig`` (currently a defensive path; all
production callers pass ``agent_config``).
Returns ``None`` when nothing is available; ``compose_system_prompt``
treats that as the ``"default"`` variant (no provider block emitted).
"""
if agent_config is not None:
params = agent_config.litellm_params or {}
base_model = params.get("base_model")
if isinstance(base_model, str) and base_model.strip():
return base_model
if agent_config.model_name:
return agent_config.model_name
return getattr(llm, "model", None)
# =============================================================================
# Connector Type Mapping
# =============================================================================
@ -279,6 +314,14 @@ async def create_surfsense_deep_agent(
)
"""
_t_agent_total = time.perf_counter()
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
# build-time call in ``llm_config.py``; this run merely adds
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
# None or the provider is non-OpenAI-family.
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(
filesystem_selection,
@ -398,6 +441,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
model_name=_resolve_prompt_model_name(agent_config, llm),
)
else:
system_prompt = build_surfsense_system_prompt(
@ -405,6 +449,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
model_name=_resolve_prompt_model_name(agent_config, llm),
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
@ -568,7 +613,6 @@ def _build_compiled_agent_blocking(
),
create_surfsense_compaction_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
@ -1006,12 +1050,12 @@ def _build_compiled_agent_blocking(
action_log_mw,
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
# Plugin slot — sits just before AnthropicCache so plugin-side
# transforms see the final tool result and run before any
# caching heuristics. Multiple plugins in declared order; loader
# filtered by the admin allowlist already.
# Plugin slot — sits at the tail so plugin-side transforms see the
# final tool result. Prompt caching is now applied at LLM build time
# via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
# caching middleware is needed here. Multiple plugins run in declared
# order; loader filtered by the admin allowlist already.
*plugin_middlewares,
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
deepagent_middleware = [m for m in deepagent_middleware if m is not None]

View file

@ -27,6 +27,7 @@ from litellm import get_model_info
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
@ -494,6 +495,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
# Configure LiteLLM-native prompt caching (cache_control_injection_points
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
# ``agent_config=None`` here — the YAML path doesn't have provider intent
# in a structured form, so we set only the universal injection points.
apply_litellm_prompt_caching(llm)
return llm
@ -518,7 +524,16 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
return get_auto_mode_llm()
router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal cache_control_injection_points only — auto-mode
# fans out across providers, so OpenAI-only kwargs (e.g.
# ``prompt_cache_key``) are left off here. ``drop_params``
# would strip them at the provider boundary anyway, but
# there's no point setting them when we don't know the
# destination.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
@ -549,4 +564,9 @@ def create_chat_litellm_from_agent_config(
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
# Build-time prompt caching: sets ``cache_control_injection_points`` for
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
# Per-thread ``prompt_cache_key`` is layered on later in
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
apply_litellm_prompt_caching(llm, agent_config=agent_config)
return llm

View file

@ -0,0 +1,166 @@
"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)``
gate always failed) with LiteLLM's universal caching mechanism.
Coverage:
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
performs automatically when ``cache_control_injection_points`` is set):
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
``deepseek/``, ``xai/`` these caches automatically for prompts 1024
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
We inject **two** breakpoints per request:
- ``role: system`` pins the SurfSense system prompt (provider variant,
citation rules, tool catalog, KB tree, skills metadata) into the cache.
- ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend.
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
provider doesn't recognise is auto-stripped at the provider transformer
layer, so an OpenAIBedrock auto-mode fallback can't 400 on
``prompt_cache_key`` etc.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from langchain_core.language_models import BaseChatModel
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
logger = logging.getLogger(__name__)
# Two-breakpoint policy: system + latest message. See module docstring for
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
# use 2 here, leaving headroom for Phase-2 tool caching.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "role": "system"},
{"location": "message", "index": -1},
)
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
# MINIMAX), so we can't infer family from the litellm prefix alone.
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
def _is_router_llm(llm: BaseChatModel) -> bool:
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
Importing ``app.services.llm_router_service`` at module-load time would
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
Class-name comparison is sufficient since the class is defined in a
single place.
"""
return type(llm).__name__ == "ChatLiteLLMRouter"
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets an OpenAI-style prompt-cache surface.
Strict only returns True when the user explicitly chose OPENAI,
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
``YAMLConfig``. Auto-mode and custom providers return False because
we can't statically know the destination.
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
model. Returns ``None`` if the LLM type doesn't expose a writable
``model_kwargs`` attribute (caller should treat as no-op).
"""
model_kwargs = getattr(llm, "model_kwargs", None)
if isinstance(model_kwargs, dict):
return model_kwargs
try:
llm.model_kwargs = {} # type: ignore[attr-defined]
except Exception:
return None
refreshed = getattr(llm, "model_kwargs", None)
return refreshed if isinstance(refreshed, dict) else None
def apply_litellm_prompt_caching(
llm: BaseChatModel,
*,
agent_config: AgentConfig | None = None,
thread_id: int | None = None,
) -> None:
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
Idempotent values already present in ``llm.model_kwargs`` (e.g. from
``agent_config.litellm_params`` overrides) are preserved. Mutates
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
in our custom ``ChatLiteLLMRouter``.
Args:
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
agent_config: Optional ``AgentConfig`` driving provider-specific
behaviour. When omitted (or auto-mode), only the universal
``cache_control_injection_points`` are set.
thread_id: Optional thread id used to construct a per-thread
``prompt_cache_key`` for OpenAI-family providers. Caching still
works without it (server-side automatic), but the key improves
backend routing affinity and therefore hit rate.
"""
model_kwargs = _get_or_init_model_kwargs(llm)
if model_kwargs is None:
logger.debug(
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
type(llm).__name__,
)
return
if "cache_control_injection_points" not in model_kwargs:
model_kwargs["cache_control_injection_points"] = [
dict(point) for point in _DEFAULT_INJECTION_POINTS
]
# OpenAI-family extras only when we statically know the destination is
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
# so we can't safely set OpenAI-only kwargs there (drop_params would
# strip them but it's wasteful to set them in the first place).
if _is_router_llm(llm):
return
if not _is_openai_family_config(agent_config):
return
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
if "prompt_cache_retention" not in model_kwargs:
model_kwargs["prompt_cache_retention"] = "24h"