mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 05:42:39 +02:00
feat: prompt caching
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
- Updated `litellm` dependency version from `1.83.4` to `1.83.7`. - Adjusted `aiohttp` version from `3.13.5` to `3.13.4` in the lock file. - Implemented `apply_litellm_prompt_caching` in `chat_deepagent.py` to improve prompt caching. - Added model name resolution logic in `chat_deepagent.py` to ensure correct provider-variant dispatch. - Enhanced `llm_config.py` to configure prompt caching for various LLM providers. - Updated tests to verify correct model name forwarding and prompt caching behavior.
This commit is contained in:
parent
360b5f8e3a
commit
e57c3a7d0c
12 changed files with 877 additions and 156 deletions
|
|
@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
|
|||
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
||||
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
||||
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
||||
summarisation, prompt-caching, etc.).
|
||||
summarisation, etc.). Prompt caching is configured at LLM-build time via
|
||||
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
|
||||
than as a middleware.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -33,7 +35,6 @@ from langchain.agents.middleware import (
|
|||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
|
@ -74,6 +75,7 @@ from app.agents.new_chat.plugin_loader import (
|
|||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.agents.new_chat.subagents import build_specialized_subagents
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
build_configurable_system_prompt,
|
||||
|
|
@ -94,6 +96,39 @@ from app.utils.perf import get_perf_logger
|
|||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def _resolve_prompt_model_name(
|
||||
agent_config: AgentConfig | None,
|
||||
llm: BaseChatModel,
|
||||
) -> str | None:
|
||||
"""Resolve the model id to feed to provider-variant detection.
|
||||
|
||||
Preference order (matches the established idiom in
|
||||
``llm_router_service.py`` — see ``params.get("base_model") or
|
||||
params.get("model", "")`` usages there):
|
||||
|
||||
1. ``agent_config.litellm_params["base_model"]`` — required for Azure
|
||||
deployments where ``model_name`` is the deployment slug, not the
|
||||
underlying family. Without this, a deployment named e.g.
|
||||
``"prod-chat-001"`` would silently miss every provider regex.
|
||||
2. ``agent_config.model_name`` — the user's configured model id.
|
||||
3. ``getattr(llm, "model", None)`` — fallback for direct callers that
|
||||
don't supply an ``AgentConfig`` (currently a defensive path; all
|
||||
production callers pass ``agent_config``).
|
||||
|
||||
Returns ``None`` when nothing is available; ``compose_system_prompt``
|
||||
treats that as the ``"default"`` variant (no provider block emitted).
|
||||
"""
|
||||
if agent_config is not None:
|
||||
params = agent_config.litellm_params or {}
|
||||
base_model = params.get("base_model")
|
||||
if isinstance(base_model, str) and base_model.strip():
|
||||
return base_model
|
||||
if agent_config.model_name:
|
||||
return agent_config.model_name
|
||||
return getattr(llm, "model", None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Connector Type Mapping
|
||||
# =============================================================================
|
||||
|
|
@ -279,6 +314,14 @@ async def create_surfsense_deep_agent(
|
|||
)
|
||||
"""
|
||||
_t_agent_total = time.perf_counter()
|
||||
|
||||
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
|
||||
# build-time call in ``llm_config.py``; this run merely adds
|
||||
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
|
||||
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
|
||||
# None or the provider is non-OpenAI-family.
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||
|
||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||
backend_resolver = build_backend_resolver(
|
||||
filesystem_selection,
|
||||
|
|
@ -398,6 +441,7 @@ async def create_surfsense_deep_agent(
|
|||
enabled_tool_names=_enabled_tool_names,
|
||||
disabled_tool_names=_user_disabled_tool_names,
|
||||
mcp_connector_tools=_mcp_connector_tools,
|
||||
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||
)
|
||||
else:
|
||||
system_prompt = build_surfsense_system_prompt(
|
||||
|
|
@ -405,6 +449,7 @@ async def create_surfsense_deep_agent(
|
|||
enabled_tool_names=_enabled_tool_names,
|
||||
disabled_tool_names=_user_disabled_tool_names,
|
||||
mcp_connector_tools=_mcp_connector_tools,
|
||||
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||
)
|
||||
_perf_log.info(
|
||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -568,7 +613,6 @@ def _build_compiled_agent_blocking(
|
|||
),
|
||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||
|
|
@ -1006,12 +1050,12 @@ def _build_compiled_agent_blocking(
|
|||
action_log_mw,
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||
# Plugin slot — sits just before AnthropicCache so plugin-side
|
||||
# transforms see the final tool result and run before any
|
||||
# caching heuristics. Multiple plugins in declared order; loader
|
||||
# filtered by the admin allowlist already.
|
||||
# Plugin slot — sits at the tail so plugin-side transforms see the
|
||||
# final tool result. Prompt caching is now applied at LLM build time
|
||||
# via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
|
||||
# caching middleware is needed here. Multiple plugins run in declared
|
||||
# order; loader filtered by the admin allowlist already.
|
||||
*plugin_middlewares,
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from litellm import get_model_info
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
|
|
@ -494,6 +495,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||
# in a structured form, so we set only the universal injection points.
|
||||
apply_litellm_prompt_caching(llm)
|
||||
return llm
|
||||
|
||||
|
||||
|
|
@ -518,7 +524,16 @@ def create_chat_litellm_from_agent_config(
|
|||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
return get_auto_mode_llm()
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal cache_control_injection_points only — auto-mode
|
||||
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||
# would strip them at the provider boundary anyway, but
|
||||
# there's no point setting them when we don't know the
|
||||
# destination.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
|
@ -549,4 +564,9 @@ def create_chat_litellm_from_agent_config(
|
|||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||
return llm
|
||||
|
|
|
|||
166
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
166
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||
|
||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||
gate always failed) with LiteLLM's universal caching mechanism.
|
||||
|
||||
Coverage:
|
||||
|
||||
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||
performs automatically when ``cache_control_injection_points`` is set):
|
||||
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||
|
||||
We inject **two** breakpoints per request:
|
||||
|
||||
- ``role: system`` — pins the SurfSense system prompt (provider variant,
|
||||
citation rules, tool catalog, KB tree, skills metadata) into the cache.
|
||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||
N+1 still reads turn N's cache up to the shared prefix.
|
||||
|
||||
For OpenAI-family configs we additionally pass:
|
||||
|
||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||
raises hit rate by sending requests with a shared prefix to the same
|
||||
backend.
|
||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||
5-10 min in-memory cache.
|
||||
|
||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||
provider doesn't recognise is auto-stripped at the provider transformer
|
||||
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||
``prompt_cache_key`` etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Two-breakpoint policy: system + latest message. See module docstring for
|
||||
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
|
||||
# use 2 here, leaving headroom for Phase-2 tool caching.
|
||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||
{"location": "message", "role": "system"},
|
||||
{"location": "message", "index": -1},
|
||||
)
|
||||
|
||||
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
|
||||
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
|
||||
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
|
||||
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
|
||||
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
|
||||
# MINIMAX), so we can't infer family from the litellm prefix alone.
|
||||
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
|
||||
|
||||
|
||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||
|
||||
Importing ``app.services.llm_router_service`` at module-load time would
|
||||
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||
Class-name comparison is sufficient since the class is defined in a
|
||||
single place.
|
||||
"""
|
||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||
|
||||
|
||||
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
||||
"""Whether the config targets an OpenAI-style prompt-cache surface.
|
||||
|
||||
Strict — only returns True when the user explicitly chose OPENAI,
|
||||
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
|
||||
``YAMLConfig``. Auto-mode and custom providers return False because
|
||||
we can't statically know the destination.
|
||||
"""
|
||||
if agent_config is None or not agent_config.provider:
|
||||
return False
|
||||
if agent_config.is_auto_mode:
|
||||
return False
|
||||
if agent_config.custom_provider:
|
||||
return False
|
||||
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
|
||||
|
||||
|
||||
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||
|
||||
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||
``model_kwargs`` attribute (caller should treat as no-op).
|
||||
"""
|
||||
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||
if isinstance(model_kwargs, dict):
|
||||
return model_kwargs
|
||||
try:
|
||||
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return None
|
||||
refreshed = getattr(llm, "model_kwargs", None)
|
||||
return refreshed if isinstance(refreshed, dict) else None
|
||||
|
||||
|
||||
def apply_litellm_prompt_caching(
|
||||
llm: BaseChatModel,
|
||||
*,
|
||||
agent_config: AgentConfig | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> None:
|
||||
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||
|
||||
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||
in our custom ``ChatLiteLLMRouter``.
|
||||
|
||||
Args:
|
||||
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||
behaviour. When omitted (or auto-mode), only the universal
|
||||
``cache_control_injection_points`` are set.
|
||||
thread_id: Optional thread id used to construct a per-thread
|
||||
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||
works without it (server-side automatic), but the key improves
|
||||
backend routing affinity and therefore hit rate.
|
||||
"""
|
||||
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||
if model_kwargs is None:
|
||||
logger.debug(
|
||||
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||
type(llm).__name__,
|
||||
)
|
||||
return
|
||||
|
||||
if "cache_control_injection_points" not in model_kwargs:
|
||||
model_kwargs["cache_control_injection_points"] = [
|
||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||
]
|
||||
|
||||
# OpenAI-family extras only when we statically know the destination is
|
||||
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
|
||||
# so we can't safely set OpenAI-only kwargs there (drop_params would
|
||||
# strip them but it's wasteful to set them in the first place).
|
||||
if _is_router_llm(llm):
|
||||
return
|
||||
if not _is_openai_family_config(agent_config):
|
||||
return
|
||||
|
||||
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
|
||||
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||
if "prompt_cache_retention" not in model_kwargs:
|
||||
model_kwargs["prompt_cache_retention"] = "24h"
|
||||
Loading…
Add table
Add a link
Reference in a new issue