SurfSense/surfsense_backend/app/agents/new_chat/prompt_caching.py
CREDO23 db8bffab38 perf(prompt-cache): enable Azure prompt_cache_key routing hint
Splits the OpenAI-family gate into per-param predicates so AZURE and
AZURE_OPENAI configs now receive prompt_cache_key for backend routing
affinity (Microsoft auto-caches GPT-4o+ deployments at >=1024 tokens;
the key clusters same-prefix requests on the same GPU pool and raises
hit rate on turn 2+). prompt_cache_retention stays opted out for Azure
because litellm 1.83.14's Azure transformer would drop it silently;
revisit when Azure's supported params list is updated.
2026-05-20 11:58:15 +02:00

241 lines
11 KiB
Python

r"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
gate always failed) with LiteLLM's universal caching mechanism.
Coverage:
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
performs automatically when ``cache_control_injection_points`` is set):
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
We inject **two** breakpoints per request:
- ``index: 0`` — pins the SurfSense system prompt at the head of the
request (provider variant, citation rules, tool catalog, KB tree,
skills metadata). The langchain agent factory always prepends
``request.system_message`` at index 0 (see ``factory.py``
``_execute_model_async``), so this targets exactly the main system
prompt regardless of how many other ``SystemMessage``\ s the
``before_agent`` injectors (priority, tree, memory, file-intent,
anonymous-doc) have inserted into ``state["messages"]``. Using
``role: system`` here would apply ``cache_control`` to **every**
system-role message and trip Anthropic's hard cap of 4 cache
breakpoints per request once the conversation accumulates enough
injected system messages — which surfaces as the upstream 400
``A maximum of 4 blocks with cache_control may be provided. Found N``
via OpenRouter→Anthropic.
- ``index: -1`` — pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
``azure/`` (added to LiteLLM's Azure transformer in
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
server-side support landed in Microsoft's docs on 2026-05-13 but
LiteLLM 1.83.14's Azure transformer still omits it from its supported
params list, so it gets silently dropped by ``litellm.drop_params``.
Azure's default in-memory retention (5-10 min, max 1 h) already
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
provider doesn't recognise is auto-stripped at the provider transformer
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
``prompt_cache_key`` etc.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from langchain_core.language_models import BaseChatModel
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
logger = logging.getLogger(__name__)
# Two-breakpoint policy: head-of-request + latest message. See module
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
#
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
# anonymous-doc) insert ``SystemMessage`` instances into
# ``state["messages"]`` that accumulate across turns. With
# ``role: system`` the LiteLLM hook would tag *every* one of them with
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
# always targets the langchain-prepended ``request.system_message``
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
# block), giving us exactly one stable cache breakpoint.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "index": 0},
{"location": "message", "index": -1},
)
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
# and that ``prompt_cache_key`` is combined with the prefix hash to
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
# transformer ships ``prompt_cache_key`` in its supported params as of
# https://github.com/BerriAI/litellm/pull/20989.
#
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
# through litellm's ``openai`` prefix without implementing the OpenAI
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
# family from the litellm prefix alone.
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
)
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
{"OPENAI", "DEEPSEEK", "XAI"}
)
def _is_router_llm(llm: BaseChatModel) -> bool:
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
Importing ``app.services.llm_router_service`` at module-load time would
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
Class-name comparison is sufficient since the class is defined in a
single place.
"""
return type(llm).__name__ == "ChatLiteLLMRouter"
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
providers return False because we can't statically know the
destination and the router fans out across mixed providers.
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
def _provider_supports_prompt_cache_retention(
agent_config: AgentConfig | None,
) -> bool:
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
deployments are excluded until LiteLLM ships the param in its Azure
transformer (see module docstring).
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
model. Returns ``None`` if the LLM type doesn't expose a writable
``model_kwargs`` attribute (caller should treat as no-op).
"""
model_kwargs = getattr(llm, "model_kwargs", None)
if isinstance(model_kwargs, dict):
return model_kwargs
try:
llm.model_kwargs = {} # type: ignore[attr-defined]
except Exception:
return None
refreshed = getattr(llm, "model_kwargs", None)
return refreshed if isinstance(refreshed, dict) else None
def apply_litellm_prompt_caching(
llm: BaseChatModel,
*,
agent_config: AgentConfig | None = None,
thread_id: int | None = None,
) -> None:
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
``agent_config.litellm_params`` overrides) are preserved. Mutates
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
in our custom ``ChatLiteLLMRouter``.
Args:
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
agent_config: Optional ``AgentConfig`` driving provider-specific
behaviour. When omitted (or auto-mode), only the universal
``cache_control_injection_points`` are set.
thread_id: Optional thread id used to construct a per-thread
``prompt_cache_key`` for OpenAI-family providers. Caching still
works without it (server-side automatic), but the key improves
backend routing affinity and therefore hit rate.
"""
model_kwargs = _get_or_init_model_kwargs(llm)
if model_kwargs is None:
logger.debug(
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
type(llm).__name__,
)
return
if "cache_control_injection_points" not in model_kwargs:
model_kwargs["cache_control_injection_points"] = [
dict(point) for point in _DEFAULT_INJECTION_POINTS
]
# OpenAI-style extras only when we statically know the destination
# accepts them. Auto-mode router fans out across mixed providers so
# we can't safely set destination-specific kwargs there (drop_params
# would strip them but it's wasteful to set them in the first
# place).
if _is_router_llm(llm):
return
if (
thread_id is not None
and "prompt_cache_key" not in model_kwargs
and _provider_supports_prompt_cache_key(agent_config)
):
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
if (
"prompt_cache_retention" not in model_kwargs
and _provider_supports_prompt_cache_retention(agent_config)
):
model_kwargs["prompt_cache_retention"] = "24h"