mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move llm_config + prompt_caching to app/agents/shared (slice 4b)
Relocate the mutually-dependent LLM config layer and the LiteLLM prompt-caching
helper to the shared kernel as one unit, rewiring their internal cross-reference
to the shared paths. Flip 21 non-frozen importers. Re-export shims remain at
new_chat/{llm_config,prompt_caching}.py for the frozen single-agent stack
(chat_deepagent); they will be removed when that stack is retired.
This commit is contained in:
parent
8fca2753aa
commit
946f8a8c5d
23 changed files with 928 additions and 882 deletions
|
|
@ -25,8 +25,8 @@ from app.agents.new_chat.connector_searchable_types import (
|
||||||
from app.agents.shared.feature_flags import AgentFeatureFlags, get_flags
|
from app.agents.shared.feature_flags import AgentFeatureFlags, get_flags
|
||||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||||
from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
from app.agents.shared.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME, invalid_tool
|
||||||
from app.agents.new_chat.tools.registry import build_tools_async
|
from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
|
|
|
||||||
|
|
@ -1,622 +1,33 @@
|
||||||
"""
|
"""Backward-compatible shim.
|
||||||
LLM configuration utilities for SurfSense agents.
|
|
||||||
|
|
||||||
This module provides functions for loading LLM configurations from:
|
The LLM configuration layer now lives in the shared agent kernel at
|
||||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
``app.agents.shared.llm_config``. This module re-exports it so frozen
|
||||||
2. YAML files (global configs with negative IDs)
|
single-agent code (``chat_deepagent``) keeps working until that stack is
|
||||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
retired.
|
||||||
|
|
||||||
It also provides utilities for creating ChatLiteLLM instances and
|
|
||||||
managing prompt configurations.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from __future__ import annotations
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
from app.agents.shared.llm_config import (
|
||||||
from langchain_core.callbacks import (
|
AgentConfig,
|
||||||
AsyncCallbackManagerForLLMRun,
|
SanitizedChatLiteLLM,
|
||||||
CallbackManagerForLLMRun,
|
create_chat_litellm_from_agent_config,
|
||||||
)
|
create_chat_litellm_from_config,
|
||||||
from langchain_core.messages import AIMessage, BaseMessage
|
load_agent_config,
|
||||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
load_agent_llm_config_for_search_space,
|
||||||
from langchain_litellm import ChatLiteLLM
|
load_global_llm_config_by_id,
|
||||||
from litellm import get_model_info
|
load_llm_config_from_yaml,
|
||||||
from sqlalchemy import select
|
load_new_llm_config_from_db,
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
|
||||||
from app.services.llm_router_service import (
|
|
||||||
AUTO_MODE_ID,
|
|
||||||
ChatLiteLLMRouter,
|
|
||||||
LLMRouterService,
|
|
||||||
_sanitize_content,
|
|
||||||
get_auto_mode_llm,
|
|
||||||
is_auto_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
"AgentConfig",
|
||||||
"""Sanitize content on every message so it is safe for any provider.
|
"SanitizedChatLiteLLM",
|
||||||
|
"create_chat_litellm_from_agent_config",
|
||||||
Handles three cross-provider incompatibilities:
|
"create_chat_litellm_from_config",
|
||||||
- List content with provider-specific blocks (e.g. ``thinking``)
|
"load_agent_config",
|
||||||
- List content with bare strings or empty text blocks
|
"load_agent_llm_config_for_search_space",
|
||||||
- AI messages with empty content + tool calls: some providers (Bedrock)
|
"load_global_llm_config_by_id",
|
||||||
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
"load_llm_config_from_yaml",
|
||||||
reject the blank text. The OpenAI spec says ``content`` should be
|
"load_new_llm_config_from_db",
|
||||||
``null`` when an assistant message only carries tool calls.
|
]
|
||||||
"""
|
|
||||||
for msg in messages:
|
|
||||||
if isinstance(msg.content, list):
|
|
||||||
msg.content = _sanitize_content(msg.content)
|
|
||||||
if (
|
|
||||||
isinstance(msg, AIMessage)
|
|
||||||
and (not msg.content or msg.content == "")
|
|
||||||
and getattr(msg, "tool_calls", None)
|
|
||||||
):
|
|
||||||
msg.content = None # type: ignore[assignment]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
|
||||||
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
|
||||||
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
|
||||||
in content arrays before forwarding to the underlying provider."""
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
stop: list[str] | None = None,
|
|
||||||
run_manager: CallbackManagerForLLMRun | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
return super()._generate(
|
|
||||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _astream(
|
|
||||||
self,
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
stop: list[str] | None = None,
|
|
||||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
|
||||||
async for chunk in super()._astream(
|
|
||||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
# Provider mapping for LiteLLM model string construction.
|
|
||||||
#
|
|
||||||
# Single source of truth lives in
|
|
||||||
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
|
||||||
# runs during ``app.config`` class-body init) can resolve provider
|
|
||||||
# prefixes without dragging the agent / tools tree into module load
|
|
||||||
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
|
||||||
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
|
||||||
# tests) keep working unchanged.
|
|
||||||
from app.services.provider_capabilities import ( # noqa: E402
|
|
||||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
|
||||||
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
|
||||||
try:
|
|
||||||
info = get_model_info(model_string)
|
|
||||||
max_input_tokens = info.get("max_input_tokens")
|
|
||||||
if isinstance(max_input_tokens, int) and max_input_tokens > 0:
|
|
||||||
llm.profile = {
|
|
||||||
"max_input_tokens": max_input_tokens,
|
|
||||||
"max_input_tokens_upper": max_input_tokens,
|
|
||||||
"token_count_model": model_string,
|
|
||||||
"token_count_models": [model_string],
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentConfig:
|
|
||||||
"""
|
|
||||||
Complete configuration for the SurfSense agent.
|
|
||||||
|
|
||||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
|
||||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# LLM Model Settings
|
|
||||||
provider: str
|
|
||||||
model_name: str
|
|
||||||
api_key: str
|
|
||||||
api_base: str | None = None
|
|
||||||
custom_provider: str | None = None
|
|
||||||
litellm_params: dict | None = None
|
|
||||||
|
|
||||||
# Prompt Configuration
|
|
||||||
system_instructions: str | None = None
|
|
||||||
use_default_system_instructions: bool = True
|
|
||||||
citations_enabled: bool = True
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
config_id: int | None = None
|
|
||||||
config_name: str | None = None
|
|
||||||
|
|
||||||
# Auto mode flag
|
|
||||||
is_auto_mode: bool = False
|
|
||||||
|
|
||||||
# Token quota and policy
|
|
||||||
billing_tier: str = "free"
|
|
||||||
is_premium: bool = False
|
|
||||||
anonymous_enabled: bool = False
|
|
||||||
quota_reserve_tokens: int | None = None
|
|
||||||
|
|
||||||
# Capability flag: best-effort True for the chat selector / catalog.
|
|
||||||
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
|
||||||
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
|
||||||
# otherwise consults LiteLLM's authoritative model map. Default True
|
|
||||||
# is the conservative-allow stance — the streaming-task safety net
|
|
||||||
# (``is_known_text_only_chat_model``) is the *only* place a False
|
|
||||||
# actually blocks a request. Setting this to False here without an
|
|
||||||
# authoritative source would silently hide vision-capable models
|
|
||||||
# (the regression we're fixing).
|
|
||||||
supports_image_input: bool = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_auto_mode(cls) -> "AgentConfig":
|
|
||||||
"""
|
|
||||||
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance configured for Auto mode
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
provider="AUTO",
|
|
||||||
model_name="auto",
|
|
||||||
api_key="", # Not needed for router
|
|
||||||
api_base=None,
|
|
||||||
custom_provider=None,
|
|
||||||
litellm_params=None,
|
|
||||||
system_instructions=None,
|
|
||||||
use_default_system_instructions=True,
|
|
||||||
citations_enabled=True,
|
|
||||||
config_id=AUTO_MODE_ID,
|
|
||||||
config_name="Auto (Fastest)",
|
|
||||||
is_auto_mode=True,
|
|
||||||
billing_tier="free",
|
|
||||||
is_premium=False,
|
|
||||||
anonymous_enabled=False,
|
|
||||||
quota_reserve_tokens=None,
|
|
||||||
# Auto routes across the configured pool, which usually
|
|
||||||
# contains at least one vision-capable deployment; the router
|
|
||||||
# will surface a 404 from a non-vision deployment as a normal
|
|
||||||
# ``allowed_fails`` event and fail over rather than blocking
|
|
||||||
# the request outright.
|
|
||||||
supports_image_input=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
|
||||||
"""
|
|
||||||
Create an AgentConfig from a NewLLMConfig database model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: NewLLMConfig database model instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance
|
|
||||||
"""
|
|
||||||
# Lazy import to avoid pulling provider_capabilities (and its
|
|
||||||
# transitive litellm import) into module-init order.
|
|
||||||
from app.services.provider_capabilities import derive_supports_image_input
|
|
||||||
|
|
||||||
provider_value = (
|
|
||||||
config.provider.value
|
|
||||||
if hasattr(config.provider, "value")
|
|
||||||
else str(config.provider)
|
|
||||||
)
|
|
||||||
litellm_params = config.litellm_params or {}
|
|
||||||
base_model = (
|
|
||||||
litellm_params.get("base_model")
|
|
||||||
if isinstance(litellm_params, dict)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
provider=provider_value,
|
|
||||||
model_name=config.model_name,
|
|
||||||
api_key=config.api_key,
|
|
||||||
api_base=config.api_base,
|
|
||||||
custom_provider=config.custom_provider,
|
|
||||||
litellm_params=config.litellm_params,
|
|
||||||
system_instructions=config.system_instructions,
|
|
||||||
use_default_system_instructions=config.use_default_system_instructions,
|
|
||||||
citations_enabled=config.citations_enabled,
|
|
||||||
config_id=config.id,
|
|
||||||
config_name=config.name,
|
|
||||||
is_auto_mode=False,
|
|
||||||
billing_tier="free",
|
|
||||||
is_premium=False,
|
|
||||||
anonymous_enabled=False,
|
|
||||||
quota_reserve_tokens=None,
|
|
||||||
# BYOK rows have no operator-curated capability flag, so we
|
|
||||||
# ask LiteLLM (default-allow on unknown). The streaming
|
|
||||||
# safety net still blocks if the model is *explicitly*
|
|
||||||
# marked text-only.
|
|
||||||
supports_image_input=derive_supports_image_input(
|
|
||||||
provider=provider_value,
|
|
||||||
model_name=config.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
custom_provider=config.custom_provider,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
|
||||||
"""
|
|
||||||
Create an AgentConfig from a YAML configuration dictionary.
|
|
||||||
|
|
||||||
YAML configs now support the same prompt configuration fields as NewLLMConfig:
|
|
||||||
- system_instructions: Custom system instructions (empty string uses defaults)
|
|
||||||
- use_default_system_instructions: Whether to use default instructions
|
|
||||||
- citations_enabled: Whether citations are enabled
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_config: Configuration dictionary from YAML file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance
|
|
||||||
"""
|
|
||||||
# Lazy import to avoid pulling provider_capabilities (and its
|
|
||||||
# transitive litellm import) into module-init order.
|
|
||||||
from app.services.provider_capabilities import derive_supports_image_input
|
|
||||||
|
|
||||||
# Get system instructions from YAML, default to empty string
|
|
||||||
system_instructions = yaml_config.get("system_instructions", "")
|
|
||||||
|
|
||||||
provider = yaml_config.get("provider", "").upper()
|
|
||||||
model_name = yaml_config.get("model_name", "")
|
|
||||||
custom_provider = yaml_config.get("custom_provider")
|
|
||||||
litellm_params = yaml_config.get("litellm_params") or {}
|
|
||||||
base_model = (
|
|
||||||
litellm_params.get("base_model")
|
|
||||||
if isinstance(litellm_params, dict)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
|
||||||
# OpenRouter modalities. The YAML loader already populates this
|
|
||||||
# field, but this method is also called from
|
|
||||||
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
|
||||||
# so we re-derive here for safety. The bool() coercion preserves
|
|
||||||
# the loader's behaviour for explicit ``true`` / ``false``
|
|
||||||
# strings that PyYAML may surface.
|
|
||||||
if "supports_image_input" in yaml_config:
|
|
||||||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
|
||||||
else:
|
|
||||||
supports_image_input = derive_supports_image_input(
|
|
||||||
provider=provider,
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
custom_provider=custom_provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
provider=provider,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=yaml_config.get("api_key", ""),
|
|
||||||
api_base=yaml_config.get("api_base"),
|
|
||||||
custom_provider=custom_provider,
|
|
||||||
litellm_params=yaml_config.get("litellm_params"),
|
|
||||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
|
||||||
system_instructions=system_instructions if system_instructions else None,
|
|
||||||
use_default_system_instructions=yaml_config.get(
|
|
||||||
"use_default_system_instructions", True
|
|
||||||
),
|
|
||||||
citations_enabled=yaml_config.get("citations_enabled", True),
|
|
||||||
config_id=yaml_config.get("id"),
|
|
||||||
config_name=yaml_config.get("name"),
|
|
||||||
is_auto_mode=False,
|
|
||||||
billing_tier=yaml_config.get("billing_tier", "free"),
|
|
||||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
|
||||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
|
||||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
|
||||||
supports_image_input=supports_image_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
|
||||||
"""
|
|
||||||
Load a specific LLM config from global_llm_config.yaml.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config_id: The id of the config to load (default: -1)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLM config dict or None if not found
|
|
||||||
"""
|
|
||||||
# Get the config file path
|
|
||||||
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
|
||||||
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
|
||||||
|
|
||||||
# Fallback to example file if main config doesn't exist
|
|
||||||
if not config_file.exists():
|
|
||||||
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
|
||||||
if not config_file.exists():
|
|
||||||
print("Error: No global_llm_config.yaml or example file found")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(config_file, encoding="utf-8") as f:
|
|
||||||
data = yaml.safe_load(f)
|
|
||||||
configs = data.get("global_llm_configs", [])
|
|
||||||
for cfg in configs:
|
|
||||||
if isinstance(cfg, dict) and cfg.get("id") == llm_config_id:
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
print(f"Error: Global LLM config id {llm_config_id} not found")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading config: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
|
||||||
"""
|
|
||||||
Load a global LLM config by ID, checking in-memory configs first.
|
|
||||||
|
|
||||||
This handles both static YAML configs and dynamically injected configs
|
|
||||||
(e.g. OpenRouter integration models that only exist in memory).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config_id: The negative ID of the global config to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLM config dict or None if not found
|
|
||||||
"""
|
|
||||||
from app.config import config as app_config
|
|
||||||
|
|
||||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
|
||||||
if cfg.get("id") == llm_config_id:
|
|
||||||
return cfg
|
|
||||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
|
||||||
return load_llm_config_from_yaml(llm_config_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_new_llm_config_from_db(
|
|
||||||
session: AsyncSession,
|
|
||||||
config_id: int,
|
|
||||||
) -> "AgentConfig | None":
|
|
||||||
"""
|
|
||||||
Load a NewLLMConfig from the database by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: AsyncSession for database access
|
|
||||||
config_id: The ID of the NewLLMConfig to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance or None if not found
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from app.db import NewLLMConfig
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await session.execute(
|
|
||||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
|
||||||
)
|
|
||||||
config = result.scalars().first()
|
|
||||||
|
|
||||||
if not config:
|
|
||||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return AgentConfig.from_new_llm_config(config)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading NewLLMConfig from database: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def load_agent_llm_config_for_search_space(
|
|
||||||
session: AsyncSession,
|
|
||||||
search_space_id: int,
|
|
||||||
) -> "AgentConfig | None":
|
|
||||||
"""
|
|
||||||
Load the agent LLM configuration for a search space.
|
|
||||||
|
|
||||||
This loads the LLM config based on the search space's agent_llm_id setting:
|
|
||||||
- Positive ID: Load from NewLLMConfig database table
|
|
||||||
- Negative ID: Load from YAML global configs
|
|
||||||
- None: Falls back to first global config (id=-1)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: AsyncSession for database access
|
|
||||||
search_space_id: The search space ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance or None if not found
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from app.db import SearchSpace
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get the search space to check its agent_llm_id preference
|
|
||||||
result = await session.execute(
|
|
||||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
|
||||||
)
|
|
||||||
search_space = result.scalars().first()
|
|
||||||
|
|
||||||
if not search_space:
|
|
||||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Use agent_llm_id from search space, fallback to -1 (first global config)
|
|
||||||
config_id = (
|
|
||||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the config using the unified loader
|
|
||||||
return await load_agent_config(session, config_id, search_space_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def load_agent_config(
|
|
||||||
session: AsyncSession,
|
|
||||||
config_id: int,
|
|
||||||
search_space_id: int | None = None,
|
|
||||||
) -> "AgentConfig | None":
|
|
||||||
"""
|
|
||||||
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
|
||||||
|
|
||||||
This is the main entry point for loading configurations:
|
|
||||||
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
|
||||||
- Negative IDs: Load from YAML file (global configs)
|
|
||||||
- Positive IDs: Load from NewLLMConfig database table
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: AsyncSession for database access
|
|
||||||
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
|
||||||
search_space_id: Optional search space ID for context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentConfig instance or None if not found
|
|
||||||
"""
|
|
||||||
# Auto mode (ID 0) - use LiteLLM Router
|
|
||||||
if is_auto_mode(config_id):
|
|
||||||
if not LLMRouterService.is_initialized():
|
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
|
||||||
return None
|
|
||||||
return AgentConfig.from_auto_mode()
|
|
||||||
|
|
||||||
if config_id < 0:
|
|
||||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
|
||||||
from app.config import config as app_config
|
|
||||||
|
|
||||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
|
||||||
if cfg.get("id") == config_id:
|
|
||||||
return AgentConfig.from_yaml_config(cfg)
|
|
||||||
# Fallback to YAML file read for safety
|
|
||||||
yaml_config = load_llm_config_from_yaml(config_id)
|
|
||||||
if yaml_config:
|
|
||||||
return AgentConfig.from_yaml_config(yaml_config)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# Load from database (NewLLMConfig)
|
|
||||||
return await load_new_llm_config_from_db(session, config_id)
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|
||||||
"""
|
|
||||||
Create a ChatLiteLLM instance from a global LLM config dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_config: LLM configuration dictionary from YAML
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatLiteLLM instance or None on error
|
|
||||||
"""
|
|
||||||
# Build the model string
|
|
||||||
if llm_config.get("custom_provider"):
|
|
||||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
|
||||||
else:
|
|
||||||
provider = llm_config.get("provider", "").upper()
|
|
||||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
|
||||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
|
||||||
|
|
||||||
# Create ChatLiteLLM instance with streaming enabled
|
|
||||||
litellm_kwargs = {
|
|
||||||
"model": model_string,
|
|
||||||
"api_key": llm_config.get("api_key"),
|
|
||||||
"streaming": True, # Enable streaming for real-time token streaming
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional parameters
|
|
||||||
if llm_config.get("api_base"):
|
|
||||||
litellm_kwargs["api_base"] = llm_config["api_base"]
|
|
||||||
|
|
||||||
# Add any additional litellm parameters
|
|
||||||
if llm_config.get("litellm_params"):
|
|
||||||
litellm_kwargs.update(llm_config["litellm_params"])
|
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
|
||||||
_attach_model_profile(llm, model_string)
|
|
||||||
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
|
||||||
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
|
||||||
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
|
||||||
# in a structured form, so we set only the universal injection points.
|
|
||||||
apply_litellm_prompt_caching(llm)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_litellm_from_agent_config(
|
|
||||||
agent_config: AgentConfig,
|
|
||||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
|
||||||
"""
|
|
||||||
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
|
||||||
|
|
||||||
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
|
||||||
for automatic load balancing across available providers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_config: AgentConfig instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
|
||||||
"""
|
|
||||||
# Handle Auto mode - return ChatLiteLLMRouter
|
|
||||||
if agent_config.is_auto_mode:
|
|
||||||
if not LLMRouterService.is_initialized():
|
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
router_llm = get_auto_mode_llm()
|
|
||||||
if router_llm is not None:
|
|
||||||
# Universal cache_control_injection_points only — auto-mode
|
|
||||||
# fans out across providers, so OpenAI-only kwargs (e.g.
|
|
||||||
# ``prompt_cache_key``) are left off here. ``drop_params``
|
|
||||||
# would strip them at the provider boundary anyway, but
|
|
||||||
# there's no point setting them when we don't know the
|
|
||||||
# destination.
|
|
||||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
|
||||||
return router_llm
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Build the model string
|
|
||||||
if agent_config.custom_provider:
|
|
||||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
|
||||||
else:
|
|
||||||
provider_prefix = PROVIDER_MAP.get(
|
|
||||||
agent_config.provider, agent_config.provider.lower()
|
|
||||||
)
|
|
||||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
|
||||||
|
|
||||||
# Create ChatLiteLLM instance with streaming enabled
|
|
||||||
litellm_kwargs = {
|
|
||||||
"model": model_string,
|
|
||||||
"api_key": agent_config.api_key,
|
|
||||||
"streaming": True, # Enable streaming for real-time token streaming
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional parameters
|
|
||||||
if agent_config.api_base:
|
|
||||||
litellm_kwargs["api_base"] = agent_config.api_base
|
|
||||||
|
|
||||||
# Add any additional litellm parameters
|
|
||||||
if agent_config.litellm_params:
|
|
||||||
litellm_kwargs.update(agent_config.litellm_params)
|
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
|
||||||
_attach_model_profile(llm, model_string)
|
|
||||||
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
|
||||||
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
|
||||||
# Per-thread ``prompt_cache_key`` is layered on later in
|
|
||||||
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
|
||||||
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
|
||||||
return llm
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ https://github.com/BerriAI/litellm/issues/20485 — the litellm-side fix
|
||||||
in PR #15395 covers the litellm transformer but does not protect us
|
in PR #15395 covers the litellm transformer but does not protect us
|
||||||
when the OpenRouter SaaS itself does the redistribution.)
|
when the OpenRouter SaaS itself does the redistribution.)
|
||||||
|
|
||||||
A separate fix in :mod:`app.agents.new_chat.prompt_caching` (switching
|
A separate fix in :mod:`app.agents.shared.prompt_caching` (switching
|
||||||
the first injection point from ``role: system`` to ``index: 0``)
|
the first injection point from ``role: system`` to ``index: 0``)
|
||||||
neutralises the *primary* cause of the same 400 — multiple
|
neutralises the *primary* cause of the same 400 — multiple
|
||||||
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
``SystemMessage``\ s injected by ``before_agent`` middlewares
|
||||||
|
|
|
||||||
|
|
@ -1,241 +1,13 @@
|
||||||
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
"""Backward-compatible shim.
|
||||||
|
|
||||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
The LiteLLM prompt-caching helper now lives in the shared agent kernel at
|
||||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
``app.agents.shared.prompt_caching``. This module re-exports it so frozen
|
||||||
gate always failed) with LiteLLM's universal caching mechanism.
|
single-agent code (``chat_deepagent``) keeps working until that stack is
|
||||||
|
retired.
|
||||||
Coverage:
|
|
||||||
|
|
||||||
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
|
||||||
performs automatically when ``cache_control_injection_points`` is set):
|
|
||||||
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
|
||||||
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
|
||||||
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
|
||||||
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
|
||||||
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
|
||||||
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
|
||||||
|
|
||||||
We inject **two** breakpoints per request:
|
|
||||||
|
|
||||||
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
|
||||||
request (provider variant, citation rules, tool catalog, KB tree,
|
|
||||||
skills metadata). The langchain agent factory always prepends
|
|
||||||
``request.system_message`` at index 0 (see ``factory.py``
|
|
||||||
``_execute_model_async``), so this targets exactly the main system
|
|
||||||
prompt regardless of how many other ``SystemMessage``\ s the
|
|
||||||
``before_agent`` injectors (priority, tree, memory, file-intent,
|
|
||||||
anonymous-doc) have inserted into ``state["messages"]``. Using
|
|
||||||
``role: system`` here would apply ``cache_control`` to **every**
|
|
||||||
system-role message and trip Anthropic's hard cap of 4 cache
|
|
||||||
breakpoints per request once the conversation accumulates enough
|
|
||||||
injected system messages — which surfaces as the upstream 400
|
|
||||||
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
|
||||||
via OpenRouter→Anthropic.
|
|
||||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
|
||||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
|
||||||
N+1 still reads turn N's cache up to the shared prefix.
|
|
||||||
|
|
||||||
For OpenAI-family configs we additionally pass:
|
|
||||||
|
|
||||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
|
||||||
raises hit rate by sending requests with a shared prefix to the same
|
|
||||||
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
|
||||||
``azure/`` (added to LiteLLM's Azure transformer in
|
|
||||||
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
|
||||||
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
|
||||||
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
|
||||||
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
|
||||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
|
||||||
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
|
||||||
server-side support landed in Microsoft's docs on 2026-05-13 but
|
|
||||||
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
|
||||||
params list, so it gets silently dropped by ``litellm.drop_params``.
|
|
||||||
Azure's default in-memory retention (5-10 min, max 1 h) already
|
|
||||||
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
|
||||||
|
|
||||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
|
||||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
|
||||||
provider doesn't recognise is auto-stripped at the provider transformer
|
|
||||||
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
|
||||||
``prompt_cache_key`` etc.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
from app.agents.shared.prompt_caching import apply_litellm_prompt_caching
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
__all__ = ["apply_litellm_prompt_caching"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Two-breakpoint policy: head-of-request + latest message. See module
|
|
||||||
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
|
||||||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
|
||||||
#
|
|
||||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
|
||||||
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
|
||||||
# anonymous-doc) insert ``SystemMessage`` instances into
|
|
||||||
# ``state["messages"]`` that accumulate across turns. With
|
|
||||||
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
|
||||||
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
|
||||||
# always targets the langchain-prepended ``request.system_message``
|
|
||||||
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
|
||||||
# block), giving us exactly one stable cache breakpoint.
|
|
||||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
|
||||||
{"location": "message", "index": 0},
|
|
||||||
{"location": "message", "index": -1},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
|
||||||
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
|
||||||
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
|
||||||
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
|
||||||
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
|
||||||
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
|
||||||
# transformer ships ``prompt_cache_key`` in its supported params as of
|
|
||||||
# https://github.com/BerriAI/litellm/pull/20989.
|
|
||||||
#
|
|
||||||
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
|
||||||
# through litellm's ``openai`` prefix without implementing the OpenAI
|
|
||||||
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
|
||||||
# family from the litellm prefix alone.
|
|
||||||
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
|
||||||
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
|
||||||
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
|
||||||
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
|
||||||
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
|
||||||
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
|
||||||
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
|
||||||
{"OPENAI", "DEEPSEEK", "XAI"}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
|
||||||
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
|
||||||
|
|
||||||
Importing ``app.services.llm_router_service`` at module-load time would
|
|
||||||
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
|
||||||
Class-name comparison is sufficient since the class is defined in a
|
|
||||||
single place.
|
|
||||||
"""
|
|
||||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
|
|
||||||
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
|
|
||||||
|
|
||||||
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
|
|
||||||
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
|
|
||||||
providers return False because we can't statically know the
|
|
||||||
destination and the router fans out across mixed providers.
|
|
||||||
"""
|
|
||||||
if agent_config is None or not agent_config.provider:
|
|
||||||
return False
|
|
||||||
if agent_config.is_auto_mode:
|
|
||||||
return False
|
|
||||||
if agent_config.custom_provider:
|
|
||||||
return False
|
|
||||||
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_supports_prompt_cache_retention(
|
|
||||||
agent_config: AgentConfig | None,
|
|
||||||
) -> bool:
|
|
||||||
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
|
|
||||||
|
|
||||||
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
|
|
||||||
deployments are excluded until LiteLLM ships the param in its Azure
|
|
||||||
transformer (see module docstring).
|
|
||||||
"""
|
|
||||||
if agent_config is None or not agent_config.provider:
|
|
||||||
return False
|
|
||||||
if agent_config.is_auto_mode:
|
|
||||||
return False
|
|
||||||
if agent_config.custom_provider:
|
|
||||||
return False
|
|
||||||
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
|
|
||||||
|
|
||||||
|
|
||||||
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
|
||||||
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
|
||||||
|
|
||||||
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
|
||||||
model. Returns ``None`` if the LLM type doesn't expose a writable
|
|
||||||
``model_kwargs`` attribute (caller should treat as no-op).
|
|
||||||
"""
|
|
||||||
model_kwargs = getattr(llm, "model_kwargs", None)
|
|
||||||
if isinstance(model_kwargs, dict):
|
|
||||||
return model_kwargs
|
|
||||||
try:
|
|
||||||
llm.model_kwargs = {} # type: ignore[attr-defined]
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
refreshed = getattr(llm, "model_kwargs", None)
|
|
||||||
return refreshed if isinstance(refreshed, dict) else None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_litellm_prompt_caching(
|
|
||||||
llm: BaseChatModel,
|
|
||||||
*,
|
|
||||||
agent_config: AgentConfig | None = None,
|
|
||||||
thread_id: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
|
||||||
|
|
||||||
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
|
||||||
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
|
||||||
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
|
||||||
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
|
||||||
in our custom ``ChatLiteLLMRouter``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
|
||||||
agent_config: Optional ``AgentConfig`` driving provider-specific
|
|
||||||
behaviour. When omitted (or auto-mode), only the universal
|
|
||||||
``cache_control_injection_points`` are set.
|
|
||||||
thread_id: Optional thread id used to construct a per-thread
|
|
||||||
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
|
||||||
works without it (server-side automatic), but the key improves
|
|
||||||
backend routing affinity and therefore hit rate.
|
|
||||||
"""
|
|
||||||
model_kwargs = _get_or_init_model_kwargs(llm)
|
|
||||||
if model_kwargs is None:
|
|
||||||
logger.debug(
|
|
||||||
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
|
||||||
type(llm).__name__,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if "cache_control_injection_points" not in model_kwargs:
|
|
||||||
model_kwargs["cache_control_injection_points"] = [
|
|
||||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
|
||||||
]
|
|
||||||
|
|
||||||
# OpenAI-style extras only when we statically know the destination
|
|
||||||
# accepts them. Auto-mode router fans out across mixed providers so
|
|
||||||
# we can't safely set destination-specific kwargs there (drop_params
|
|
||||||
# would strip them but it's wasteful to set them in the first
|
|
||||||
# place).
|
|
||||||
if _is_router_llm(llm):
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
thread_id is not None
|
|
||||||
and "prompt_cache_key" not in model_kwargs
|
|
||||||
and _provider_supports_prompt_cache_key(agent_config)
|
|
||||||
):
|
|
||||||
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
|
||||||
|
|
||||||
if (
|
|
||||||
"prompt_cache_retention" not in model_kwargs
|
|
||||||
and _provider_supports_prompt_cache_retention(agent_config)
|
|
||||||
):
|
|
||||||
model_kwargs["prompt_cache_retention"] = "24h"
|
|
||||||
|
|
|
||||||
622
surfsense_backend/app/agents/shared/llm_config.py
Normal file
622
surfsense_backend/app/agents/shared/llm_config.py
Normal file
|
|
@ -0,0 +1,622 @@
|
||||||
|
"""
|
||||||
|
LLM configuration utilities for SurfSense agents.
|
||||||
|
|
||||||
|
This module provides functions for loading LLM configurations from:
|
||||||
|
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||||
|
2. YAML files (global configs with negative IDs)
|
||||||
|
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||||
|
|
||||||
|
It also provides utilities for creating ChatLiteLLM instances and
|
||||||
|
managing prompt configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
from litellm import get_model_info
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.shared.prompt_caching import apply_litellm_prompt_caching
|
||||||
|
from app.services.llm_router_service import (
|
||||||
|
AUTO_MODE_ID,
|
||||||
|
ChatLiteLLMRouter,
|
||||||
|
LLMRouterService,
|
||||||
|
_sanitize_content,
|
||||||
|
get_auto_mode_llm,
|
||||||
|
is_auto_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||||
|
"""Sanitize content on every message so it is safe for any provider.
|
||||||
|
|
||||||
|
Handles three cross-provider incompatibilities:
|
||||||
|
- List content with provider-specific blocks (e.g. ``thinking``)
|
||||||
|
- List content with bare strings or empty text blocks
|
||||||
|
- AI messages with empty content + tool calls: some providers (Bedrock)
|
||||||
|
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
||||||
|
reject the blank text. The OpenAI spec says ``content`` should be
|
||||||
|
``null`` when an assistant message only carries tool calls.
|
||||||
|
"""
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg.content, list):
|
||||||
|
msg.content = _sanitize_content(msg.content)
|
||||||
|
if (
|
||||||
|
isinstance(msg, AIMessage)
|
||||||
|
and (not msg.content or msg.content == "")
|
||||||
|
and getattr(msg, "tool_calls", None)
|
||||||
|
):
|
||||||
|
msg.content = None # type: ignore[assignment]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||||
|
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
||||||
|
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
||||||
|
in content arrays before forwarding to the underlying provider."""
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return super()._generate(
|
||||||
|
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
async for chunk in super()._astream(
|
||||||
|
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
# Provider mapping for LiteLLM model string construction.
|
||||||
|
#
|
||||||
|
# Single source of truth lives in
|
||||||
|
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||||
|
# runs during ``app.config`` class-body init) can resolve provider
|
||||||
|
# prefixes without dragging the agent / tools tree into module load
|
||||||
|
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||||
|
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||||
|
# tests) keep working unchanged.
|
||||||
|
from app.services.provider_capabilities import ( # noqa: E402
|
||||||
|
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||||
|
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
||||||
|
try:
|
||||||
|
info = get_model_info(model_string)
|
||||||
|
max_input_tokens = info.get("max_input_tokens")
|
||||||
|
if isinstance(max_input_tokens, int) and max_input_tokens > 0:
|
||||||
|
llm.profile = {
|
||||||
|
"max_input_tokens": max_input_tokens,
|
||||||
|
"max_input_tokens_upper": max_input_tokens,
|
||||||
|
"token_count_model": model_string,
|
||||||
|
"token_count_models": [model_string],
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
"""
|
||||||
|
Complete configuration for the SurfSense agent.
|
||||||
|
|
||||||
|
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||||
|
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# LLM Model Settings
|
||||||
|
provider: str
|
||||||
|
model_name: str
|
||||||
|
api_key: str
|
||||||
|
api_base: str | None = None
|
||||||
|
custom_provider: str | None = None
|
||||||
|
litellm_params: dict | None = None
|
||||||
|
|
||||||
|
# Prompt Configuration
|
||||||
|
system_instructions: str | None = None
|
||||||
|
use_default_system_instructions: bool = True
|
||||||
|
citations_enabled: bool = True
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
config_id: int | None = None
|
||||||
|
config_name: str | None = None
|
||||||
|
|
||||||
|
# Auto mode flag
|
||||||
|
is_auto_mode: bool = False
|
||||||
|
|
||||||
|
# Token quota and policy
|
||||||
|
billing_tier: str = "free"
|
||||||
|
is_premium: bool = False
|
||||||
|
anonymous_enabled: bool = False
|
||||||
|
quota_reserve_tokens: int | None = None
|
||||||
|
|
||||||
|
# Capability flag: best-effort True for the chat selector / catalog.
|
||||||
|
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||||
|
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||||
|
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||||
|
# is the conservative-allow stance — the streaming-task safety net
|
||||||
|
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||||
|
# actually blocks a request. Setting this to False here without an
|
||||||
|
# authoritative source would silently hide vision-capable models
|
||||||
|
# (the regression we're fixing).
|
||||||
|
supports_image_input: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_auto_mode(cls) -> "AgentConfig":
|
||||||
|
"""
|
||||||
|
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance configured for Auto mode
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
provider="AUTO",
|
||||||
|
model_name="auto",
|
||||||
|
api_key="", # Not needed for router
|
||||||
|
api_base=None,
|
||||||
|
custom_provider=None,
|
||||||
|
litellm_params=None,
|
||||||
|
system_instructions=None,
|
||||||
|
use_default_system_instructions=True,
|
||||||
|
citations_enabled=True,
|
||||||
|
config_id=AUTO_MODE_ID,
|
||||||
|
config_name="Auto (Fastest)",
|
||||||
|
is_auto_mode=True,
|
||||||
|
billing_tier="free",
|
||||||
|
is_premium=False,
|
||||||
|
anonymous_enabled=False,
|
||||||
|
quota_reserve_tokens=None,
|
||||||
|
# Auto routes across the configured pool, which usually
|
||||||
|
# contains at least one vision-capable deployment; the router
|
||||||
|
# will surface a 404 from a non-vision deployment as a normal
|
||||||
|
# ``allowed_fails`` event and fail over rather than blocking
|
||||||
|
# the request outright.
|
||||||
|
supports_image_input=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||||
|
"""
|
||||||
|
Create an AgentConfig from a NewLLMConfig database model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: NewLLMConfig database model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
|
provider_value = (
|
||||||
|
config.provider.value
|
||||||
|
if hasattr(config.provider, "value")
|
||||||
|
else str(config.provider)
|
||||||
|
)
|
||||||
|
litellm_params = config.litellm_params or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
provider=provider_value,
|
||||||
|
model_name=config.model_name,
|
||||||
|
api_key=config.api_key,
|
||||||
|
api_base=config.api_base,
|
||||||
|
custom_provider=config.custom_provider,
|
||||||
|
litellm_params=config.litellm_params,
|
||||||
|
system_instructions=config.system_instructions,
|
||||||
|
use_default_system_instructions=config.use_default_system_instructions,
|
||||||
|
citations_enabled=config.citations_enabled,
|
||||||
|
config_id=config.id,
|
||||||
|
config_name=config.name,
|
||||||
|
is_auto_mode=False,
|
||||||
|
billing_tier="free",
|
||||||
|
is_premium=False,
|
||||||
|
anonymous_enabled=False,
|
||||||
|
quota_reserve_tokens=None,
|
||||||
|
# BYOK rows have no operator-curated capability flag, so we
|
||||||
|
# ask LiteLLM (default-allow on unknown). The streaming
|
||||||
|
# safety net still blocks if the model is *explicitly*
|
||||||
|
# marked text-only.
|
||||||
|
supports_image_input=derive_supports_image_input(
|
||||||
|
provider=provider_value,
|
||||||
|
model_name=config.model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=config.custom_provider,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||||
|
"""
|
||||||
|
Create an AgentConfig from a YAML configuration dictionary.
|
||||||
|
|
||||||
|
YAML configs now support the same prompt configuration fields as NewLLMConfig:
|
||||||
|
- system_instructions: Custom system instructions (empty string uses defaults)
|
||||||
|
- use_default_system_instructions: Whether to use default instructions
|
||||||
|
- citations_enabled: Whether citations are enabled
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_config: Configuration dictionary from YAML file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid pulling provider_capabilities (and its
|
||||||
|
# transitive litellm import) into module-init order.
|
||||||
|
from app.services.provider_capabilities import derive_supports_image_input
|
||||||
|
|
||||||
|
# Get system instructions from YAML, default to empty string
|
||||||
|
system_instructions = yaml_config.get("system_instructions", "")
|
||||||
|
|
||||||
|
provider = yaml_config.get("provider", "").upper()
|
||||||
|
model_name = yaml_config.get("model_name", "")
|
||||||
|
custom_provider = yaml_config.get("custom_provider")
|
||||||
|
litellm_params = yaml_config.get("litellm_params") or {}
|
||||||
|
base_model = (
|
||||||
|
litellm_params.get("base_model")
|
||||||
|
if isinstance(litellm_params, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||||
|
# OpenRouter modalities. The YAML loader already populates this
|
||||||
|
# field, but this method is also called from
|
||||||
|
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||||
|
# so we re-derive here for safety. The bool() coercion preserves
|
||||||
|
# the loader's behaviour for explicit ``true`` / ``false``
|
||||||
|
# strings that PyYAML may surface.
|
||||||
|
if "supports_image_input" in yaml_config:
|
||||||
|
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||||
|
else:
|
||||||
|
supports_image_input = derive_supports_image_input(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
provider=provider,
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=yaml_config.get("api_key", ""),
|
||||||
|
api_base=yaml_config.get("api_base"),
|
||||||
|
custom_provider=custom_provider,
|
||||||
|
litellm_params=yaml_config.get("litellm_params"),
|
||||||
|
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||||
|
system_instructions=system_instructions if system_instructions else None,
|
||||||
|
use_default_system_instructions=yaml_config.get(
|
||||||
|
"use_default_system_instructions", True
|
||||||
|
),
|
||||||
|
citations_enabled=yaml_config.get("citations_enabled", True),
|
||||||
|
config_id=yaml_config.get("id"),
|
||||||
|
config_name=yaml_config.get("name"),
|
||||||
|
is_auto_mode=False,
|
||||||
|
billing_tier=yaml_config.get("billing_tier", "free"),
|
||||||
|
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||||
|
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||||
|
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||||
|
supports_image_input=supports_image_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
||||||
|
"""
|
||||||
|
Load a specific LLM config from global_llm_config.yaml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config_id: The id of the config to load (default: -1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM config dict or None if not found
|
||||||
|
"""
|
||||||
|
# Get the config file path
|
||||||
|
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
|
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
||||||
|
|
||||||
|
# Fallback to example file if main config doesn't exist
|
||||||
|
if not config_file.exists():
|
||||||
|
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
||||||
|
if not config_file.exists():
|
||||||
|
print("Error: No global_llm_config.yaml or example file found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_file, encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
configs = data.get("global_llm_configs", [])
|
||||||
|
for cfg in configs:
|
||||||
|
if isinstance(cfg, dict) and cfg.get("id") == llm_config_id:
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
print(f"Error: Global LLM config id {llm_config_id} not found")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading config: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||||
|
"""
|
||||||
|
Load a global LLM config by ID, checking in-memory configs first.
|
||||||
|
|
||||||
|
This handles both static YAML configs and dynamically injected configs
|
||||||
|
(e.g. OpenRouter integration models that only exist in memory).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config_id: The negative ID of the global config to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM config dict or None if not found
|
||||||
|
"""
|
||||||
|
from app.config import config as app_config
|
||||||
|
|
||||||
|
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||||
|
if cfg.get("id") == llm_config_id:
|
||||||
|
return cfg
|
||||||
|
# Fallback to YAML file read (covers edge cases like hot-reload)
|
||||||
|
return load_llm_config_from_yaml(llm_config_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_new_llm_config_from_db(
|
||||||
|
session: AsyncSession,
|
||||||
|
config_id: int,
|
||||||
|
) -> "AgentConfig | None":
|
||||||
|
"""
|
||||||
|
Load a NewLLMConfig from the database by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession for database access
|
||||||
|
config_id: The ID of the NewLLMConfig to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance or None if not found
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from app.db import NewLLMConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||||
|
)
|
||||||
|
config = result.scalars().first()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return AgentConfig.from_new_llm_config(config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading NewLLMConfig from database: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_agent_llm_config_for_search_space(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
) -> "AgentConfig | None":
|
||||||
|
"""
|
||||||
|
Load the agent LLM configuration for a search space.
|
||||||
|
|
||||||
|
This loads the LLM config based on the search space's agent_llm_id setting:
|
||||||
|
- Positive ID: Load from NewLLMConfig database table
|
||||||
|
- Negative ID: Load from YAML global configs
|
||||||
|
- None: Falls back to first global config (id=-1)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession for database access
|
||||||
|
search_space_id: The search space ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance or None if not found
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from app.db import SearchSpace
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the search space to check its agent_llm_id preference
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||||
|
)
|
||||||
|
search_space = result.scalars().first()
|
||||||
|
|
||||||
|
if not search_space:
|
||||||
|
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use agent_llm_id from search space, fallback to -1 (first global config)
|
||||||
|
config_id = (
|
||||||
|
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the config using the unified loader
|
||||||
|
return await load_agent_config(session, config_id, search_space_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_agent_config(
|
||||||
|
session: AsyncSession,
|
||||||
|
config_id: int,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
) -> "AgentConfig | None":
|
||||||
|
"""
|
||||||
|
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
||||||
|
|
||||||
|
This is the main entry point for loading configurations:
|
||||||
|
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
||||||
|
- Negative IDs: Load from YAML file (global configs)
|
||||||
|
- Positive IDs: Load from NewLLMConfig database table
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession for database access
|
||||||
|
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
||||||
|
search_space_id: Optional search space ID for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance or None if not found
|
||||||
|
"""
|
||||||
|
# Auto mode (ID 0) - use LiteLLM Router
|
||||||
|
if is_auto_mode(config_id):
|
||||||
|
if not LLMRouterService.is_initialized():
|
||||||
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
|
return None
|
||||||
|
return AgentConfig.from_auto_mode()
|
||||||
|
|
||||||
|
if config_id < 0:
|
||||||
|
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
||||||
|
from app.config import config as app_config
|
||||||
|
|
||||||
|
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||||
|
if cfg.get("id") == config_id:
|
||||||
|
return AgentConfig.from_yaml_config(cfg)
|
||||||
|
# Fallback to YAML file read for safety
|
||||||
|
yaml_config = load_llm_config_from_yaml(config_id)
|
||||||
|
if yaml_config:
|
||||||
|
return AgentConfig.from_yaml_config(yaml_config)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Load from database (NewLLMConfig)
|
||||||
|
return await load_new_llm_config_from_db(session, config_id)
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
|
"""
|
||||||
|
Create a ChatLiteLLM instance from a global LLM config dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_config: LLM configuration dictionary from YAML
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatLiteLLM instance or None on error
|
||||||
|
"""
|
||||||
|
# Build the model string
|
||||||
|
if llm_config.get("custom_provider"):
|
||||||
|
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||||
|
else:
|
||||||
|
provider = llm_config.get("provider", "").upper()
|
||||||
|
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||||
|
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||||
|
|
||||||
|
# Create ChatLiteLLM instance with streaming enabled
|
||||||
|
litellm_kwargs = {
|
||||||
|
"model": model_string,
|
||||||
|
"api_key": llm_config.get("api_key"),
|
||||||
|
"streaming": True, # Enable streaming for real-time token streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if llm_config.get("api_base"):
|
||||||
|
litellm_kwargs["api_base"] = llm_config["api_base"]
|
||||||
|
|
||||||
|
# Add any additional litellm parameters
|
||||||
|
if llm_config.get("litellm_params"):
|
||||||
|
litellm_kwargs.update(llm_config["litellm_params"])
|
||||||
|
|
||||||
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||||
|
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||||
|
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||||
|
# in a structured form, so we set only the universal injection points.
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_litellm_from_agent_config(
|
||||||
|
agent_config: AgentConfig,
|
||||||
|
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||||
|
"""
|
||||||
|
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
||||||
|
|
||||||
|
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
||||||
|
for automatic load balancing across available providers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_config: AgentConfig instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
||||||
|
"""
|
||||||
|
# Handle Auto mode - return ChatLiteLLMRouter
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
if not LLMRouterService.is_initialized():
|
||||||
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
router_llm = get_auto_mode_llm()
|
||||||
|
if router_llm is not None:
|
||||||
|
# Universal cache_control_injection_points only — auto-mode
|
||||||
|
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||||
|
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||||
|
# would strip them at the provider boundary anyway, but
|
||||||
|
# there's no point setting them when we don't know the
|
||||||
|
# destination.
|
||||||
|
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||||
|
return router_llm
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build the model string
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||||
|
else:
|
||||||
|
provider_prefix = PROVIDER_MAP.get(
|
||||||
|
agent_config.provider, agent_config.provider.lower()
|
||||||
|
)
|
||||||
|
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||||
|
|
||||||
|
# Create ChatLiteLLM instance with streaming enabled
|
||||||
|
litellm_kwargs = {
|
||||||
|
"model": model_string,
|
||||||
|
"api_key": agent_config.api_key,
|
||||||
|
"streaming": True, # Enable streaming for real-time token streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if agent_config.api_base:
|
||||||
|
litellm_kwargs["api_base"] = agent_config.api_base
|
||||||
|
|
||||||
|
# Add any additional litellm parameters
|
||||||
|
if agent_config.litellm_params:
|
||||||
|
litellm_kwargs.update(agent_config.litellm_params)
|
||||||
|
|
||||||
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||||
|
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||||
|
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||||
|
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||||
|
return llm
|
||||||
241
surfsense_backend/app/agents/shared/prompt_caching.py
Normal file
241
surfsense_backend/app/agents/shared/prompt_caching.py
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||||
|
|
||||||
|
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||||
|
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||||
|
gate always failed) with LiteLLM's universal caching mechanism.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||||
|
performs automatically when ``cache_control_injection_points`` is set):
|
||||||
|
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||||
|
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||||
|
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||||
|
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||||
|
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||||
|
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
We inject **two** breakpoints per request:
|
||||||
|
|
||||||
|
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
||||||
|
request (provider variant, citation rules, tool catalog, KB tree,
|
||||||
|
skills metadata). The langchain agent factory always prepends
|
||||||
|
``request.system_message`` at index 0 (see ``factory.py``
|
||||||
|
``_execute_model_async``), so this targets exactly the main system
|
||||||
|
prompt regardless of how many other ``SystemMessage``\ s the
|
||||||
|
``before_agent`` injectors (priority, tree, memory, file-intent,
|
||||||
|
anonymous-doc) have inserted into ``state["messages"]``. Using
|
||||||
|
``role: system`` here would apply ``cache_control`` to **every**
|
||||||
|
system-role message and trip Anthropic's hard cap of 4 cache
|
||||||
|
breakpoints per request once the conversation accumulates enough
|
||||||
|
injected system messages — which surfaces as the upstream 400
|
||||||
|
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
||||||
|
via OpenRouter→Anthropic.
|
||||||
|
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||||
|
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||||
|
N+1 still reads turn N's cache up to the shared prefix.
|
||||||
|
|
||||||
|
For OpenAI-family configs we additionally pass:
|
||||||
|
|
||||||
|
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||||
|
raises hit rate by sending requests with a shared prefix to the same
|
||||||
|
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
||||||
|
``azure/`` (added to LiteLLM's Azure transformer in
|
||||||
|
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
||||||
|
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
||||||
|
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
||||||
|
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
||||||
|
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||||
|
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
||||||
|
server-side support landed in Microsoft's docs on 2026-05-13 but
|
||||||
|
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
||||||
|
params list, so it gets silently dropped by ``litellm.drop_params``.
|
||||||
|
Azure's default in-memory retention (5-10 min, max 1 h) already
|
||||||
|
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
||||||
|
|
||||||
|
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||||
|
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||||
|
provider doesn't recognise is auto-stripped at the provider transformer
|
||||||
|
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||||
|
``prompt_cache_key`` etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Two-breakpoint policy: head-of-request + latest message. See module
|
||||||
|
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
||||||
|
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||||
|
#
|
||||||
|
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||||
|
# ``before_agent`` middlewares (priority, tree, memory, file-intent,
|
||||||
|
# anonymous-doc) insert ``SystemMessage`` instances into
|
||||||
|
# ``state["messages"]`` that accumulate across turns. With
|
||||||
|
# ``role: system`` the LiteLLM hook would tag *every* one of them with
|
||||||
|
# ``cache_control`` and overflow Anthropic's 4-block limit. ``index: 0``
|
||||||
|
# always targets the langchain-prepended ``request.system_message``
|
||||||
|
# (which our ``FlattenSystemMessageMiddleware`` reduces to a single text
|
||||||
|
# block), giving us exactly one stable cache breakpoint.
|
||||||
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
|
{"location": "message", "index": 0},
|
||||||
|
{"location": "message", "index": -1},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
||||||
|
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
||||||
|
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
||||||
|
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
||||||
|
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
||||||
|
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
||||||
|
# transformer ships ``prompt_cache_key`` in its supported params as of
|
||||||
|
# https://github.com/BerriAI/litellm/pull/20989.
|
||||||
|
#
|
||||||
|
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
||||||
|
# through litellm's ``openai`` prefix without implementing the OpenAI
|
||||||
|
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
||||||
|
# family from the litellm prefix alone.
|
||||||
|
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
||||||
|
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
||||||
|
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
||||||
|
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
||||||
|
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
||||||
|
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
||||||
|
{"OPENAI", "DEEPSEEK", "XAI"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
|
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||||
|
|
||||||
|
Importing ``app.services.llm_router_service`` at module-load time would
|
||||||
|
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||||
|
Class-name comparison is sufficient since the class is defined in a
|
||||||
|
single place.
|
||||||
|
"""
|
||||||
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
|
||||||
|
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
|
||||||
|
|
||||||
|
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
|
||||||
|
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
|
||||||
|
providers return False because we can't statically know the
|
||||||
|
destination and the router fans out across mixed providers.
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_supports_prompt_cache_retention(
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
|
||||||
|
deployments are excluded until LiteLLM ships the param in its Azure
|
||||||
|
transformer (see module docstring).
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||||
|
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||||
|
|
||||||
|
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||||
|
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||||
|
``model_kwargs`` attribute (caller should treat as no-op).
|
||||||
|
"""
|
||||||
|
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||||
|
if isinstance(model_kwargs, dict):
|
||||||
|
return model_kwargs
|
||||||
|
try:
|
||||||
|
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
refreshed = getattr(llm, "model_kwargs", None)
|
||||||
|
return refreshed if isinstance(refreshed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_litellm_prompt_caching(
|
||||||
|
llm: BaseChatModel,
|
||||||
|
*,
|
||||||
|
agent_config: AgentConfig | None = None,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||||
|
|
||||||
|
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||||
|
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||||
|
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||||
|
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||||
|
in our custom ``ChatLiteLLMRouter``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||||
|
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||||
|
behaviour. When omitted (or auto-mode), only the universal
|
||||||
|
``cache_control_injection_points`` are set.
|
||||||
|
thread_id: Optional thread id used to construct a per-thread
|
||||||
|
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||||
|
works without it (server-side automatic), but the key improves
|
||||||
|
backend routing affinity and therefore hit rate.
|
||||||
|
"""
|
||||||
|
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||||
|
if model_kwargs is None:
|
||||||
|
logger.debug(
|
||||||
|
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||||
|
type(llm).__name__,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if "cache_control_injection_points" not in model_kwargs:
|
||||||
|
model_kwargs["cache_control_injection_points"] = [
|
||||||
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
|
]
|
||||||
|
|
||||||
|
# OpenAI-style extras only when we statically know the destination
|
||||||
|
# accepts them. Auto-mode router fans out across mixed providers so
|
||||||
|
# we can't safely set destination-specific kwargs there (drop_params
|
||||||
|
# would strip them but it's wasteful to set them in the first
|
||||||
|
# place).
|
||||||
|
if _is_router_llm(llm):
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
thread_id is not None
|
||||||
|
and "prompt_cache_key" not in model_kwargs
|
||||||
|
and _provider_supports_prompt_cache_key(agent_config)
|
||||||
|
):
|
||||||
|
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||||
|
|
||||||
|
if (
|
||||||
|
"prompt_cache_retention" not in model_kwargs
|
||||||
|
and _provider_supports_prompt_cache_retention(agent_config)
|
||||||
|
):
|
||||||
|
model_kwargs["prompt_cache_retention"] = "24h"
|
||||||
|
|
@ -39,7 +39,7 @@ def _is_premium_global(kind: ModelKind, config_id: int) -> bool:
|
||||||
|
|
||||||
cfg: dict | None = None
|
cfg: dict | None = None
|
||||||
if kind == "llm":
|
if kind == "llm":
|
||||||
from app.agents.new_chat.llm_config import load_global_llm_config_by_id
|
from app.agents.shared.llm_config import load_global_llm_config_by_id
|
||||||
|
|
||||||
cfg = load_global_llm_config_by_id(config_id)
|
cfg = load_global_llm_config_by_id(config_id)
|
||||||
elif kind == "image":
|
elif kind == "image":
|
||||||
|
|
|
||||||
|
|
@ -236,7 +236,7 @@ async def stream_anonymous_chat(
|
||||||
detail="No-login mode is not enabled.",
|
detail="No-login mode is not enabled.",
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import (
|
from app.agents.shared.llm_config import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
create_chat_litellm_from_agent_config,
|
create_chat_litellm_from_agent_config,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,7 @@ async def validate_llm_config(
|
||||||
if litellm_params:
|
if litellm_params:
|
||||||
litellm_kwargs.update(litellm_params)
|
litellm_kwargs.update(litellm_params)
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
from app.agents.shared.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
|
|
@ -379,7 +379,7 @@ async def get_search_space_llm_instance(
|
||||||
if disable_streaming:
|
if disable_streaming:
|
||||||
litellm_kwargs["disable_streaming"] = True
|
litellm_kwargs["disable_streaming"] = True
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
from app.agents.shared.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
|
|
@ -458,7 +458,7 @@ async def get_search_space_llm_instance(
|
||||||
if disable_streaming:
|
if disable_streaming:
|
||||||
litellm_kwargs["disable_streaming"] = True
|
litellm_kwargs["disable_streaming"] = True
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
from app.agents.shared.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
|
|
@ -580,7 +580,7 @@ async def get_vision_llm(
|
||||||
if global_cfg.get("litellm_params"):
|
if global_cfg.get("litellm_params"):
|
||||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
from app.agents.shared.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
|
|
@ -634,7 +634,7 @@ async def get_vision_llm(
|
||||||
if vision_cfg.litellm_params:
|
if vision_cfg.litellm_params:
|
||||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import SanitizedChatLiteLLM
|
from app.agents.shared.llm_config import SanitizedChatLiteLLM
|
||||||
|
|
||||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
|
|
||||||
|
|
@ -679,7 +679,7 @@ def get_planner_llm() -> ChatLiteLLM | None:
|
||||||
Callers MUST fall back to their chat LLM when this returns ``None`` so
|
Callers MUST fall back to their chat LLM when this returns ``None`` so
|
||||||
deployments without a planner config keep working unchanged.
|
deployments without a planner config keep working unchanged.
|
||||||
"""
|
"""
|
||||||
from app.agents.new_chat.llm_config import create_chat_litellm_from_config
|
from app.agents.shared.llm_config import create_chat_litellm_from_config
|
||||||
|
|
||||||
planner_cfg = next(
|
planner_cfg = next(
|
||||||
(cfg for cfg in config.GLOBAL_LLM_CONFIGS if cfg.get("is_planner") is True),
|
(cfg for cfg in config.GLOBAL_LLM_CONFIGS if cfg.get("is_planner") is True),
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ logger = logging.getLogger(__name__)
|
||||||
#
|
#
|
||||||
# Owned here because ``app.services.provider_capabilities`` is the
|
# Owned here because ``app.services.provider_capabilities`` is the
|
||||||
# only edge that's safe to call from ``app.config``'s YAML loader at
|
# only edge that's safe to call from ``app.config``'s YAML loader at
|
||||||
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
|
# class-body init time. ``app.agents.shared.llm_config`` re-exports
|
||||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||||
# map there directly would re-introduce the
|
# map there directly would re-introduce the
|
||||||
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
|
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
from app.agents.shared.context import SurfSenseContextSchema
|
from app.agents.shared.context import SurfSenseContextSchema
|
||||||
from app.agents.shared.errors import BusyError
|
from app.agents.shared.errors import BusyError
|
||||||
from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
|
from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import (
|
from app.agents.shared.llm_config import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
create_chat_litellm_from_agent_config,
|
create_chat_litellm_from_agent_config,
|
||||||
create_chat_litellm_from_config,
|
create_chat_litellm_from_config,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from __future__ import annotations
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.agents.shared.filesystem_selection import FilesystemSelection
|
from app.agents.shared.filesystem_selection import FilesystemSelection
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ tells the user what to change.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
from app.observability import otel as ot
|
from app.observability import otel as ot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from app.prompts import TITLE_GENERATION_PROMPT
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
from app.services.token_tracking_service import TokenAccumulator
|
from app.services.token_tracking_service import TokenAccumulator
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import (
|
from app.agents.shared.llm_config import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
create_chat_litellm_from_agent_config,
|
create_chat_litellm_from_agent_config,
|
||||||
create_chat_litellm_from_config,
|
create_chat_litellm_from_config,
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
from app.db import shielded_async_session
|
from app.db import shielded_async_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -239,11 +239,11 @@ def _patch_llm_bindings() -> None:
|
||||||
|
|
||||||
chat_targets = [
|
chat_targets = [
|
||||||
(
|
(
|
||||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
|
"app.agents.shared.llm_config.create_chat_litellm_from_agent_config",
|
||||||
fake_create_chat_litellm_from_agent_config,
|
fake_create_chat_litellm_from_agent_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
|
"app.agents.shared.llm_config.create_chat_litellm_from_config",
|
||||||
fake_create_chat_litellm_from_config,
|
fake_create_chat_litellm_from_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -212,11 +212,11 @@ def _patch_llm_bindings() -> None:
|
||||||
|
|
||||||
chat_targets = [
|
chat_targets = [
|
||||||
(
|
(
|
||||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_agent_config",
|
"app.agents.shared.llm_config.create_chat_litellm_from_agent_config",
|
||||||
fake_create_chat_litellm_from_agent_config,
|
fake_create_chat_litellm_from_agent_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"app.agents.new_chat.llm_config.create_chat_litellm_from_config",
|
"app.agents.shared.llm_config.create_chat_litellm_from_config",
|
||||||
fake_create_chat_litellm_from_config,
|
fake_create_chat_litellm_from_config,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
r"""Tests for ``apply_litellm_prompt_caching`` in
|
r"""Tests for ``apply_litellm_prompt_caching`` in
|
||||||
:mod:`app.agents.new_chat.prompt_caching`.
|
:mod:`app.agents.shared.prompt_caching`.
|
||||||
|
|
||||||
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
|
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
|
||||||
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
|
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
|
||||||
|
|
@ -34,8 +34,8 @@ from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
from app.agents.shared.prompt_caching import apply_litellm_prompt_caching
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
|
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch):
|
||||||
-2: {"id": -2, "billing_tier": "free"},
|
-2: {"id": -2, "billing_tier": "free"},
|
||||||
}
|
}
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.agents.new_chat.llm_config.load_global_llm_config_by_id",
|
"app.agents.shared.llm_config.load_global_llm_config_by_id",
|
||||||
lambda cid: llm_configs.get(cid),
|
lambda cid: llm_configs.get(cid),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,7 @@ global_llm_configs:
|
||||||
|
|
||||||
|
|
||||||
def test_agent_config_from_yaml_explicit_overrides_resolver():
|
def test_agent_config_from_yaml_explicit_overrides_resolver():
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
|
|
||||||
cfg_text_only = AgentConfig.from_yaml_config(
|
cfg_text_only = AgentConfig.from_yaml_config(
|
||||||
{
|
{
|
||||||
|
|
@ -256,7 +256,7 @@ def test_agent_config_from_yaml_explicit_overrides_resolver():
|
||||||
def test_agent_config_from_yaml_unannotated_uses_resolver():
|
def test_agent_config_from_yaml_unannotated_uses_resolver():
|
||||||
"""Without an explicit YAML key, AgentConfig defers to the catalog
|
"""Without an explicit YAML key, AgentConfig defers to the catalog
|
||||||
resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
|
resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
|
|
||||||
cfg = AgentConfig.from_yaml_config(
|
cfg = AgentConfig.from_yaml_config(
|
||||||
{
|
{
|
||||||
|
|
@ -275,7 +275,7 @@ def test_agent_config_auto_mode_supports_image_input():
|
||||||
so users can keep their selection on Auto with a vision-capable
|
so users can keep their selection on Auto with a vision-capable
|
||||||
deployment somewhere in the pool. The router's own `allowed_fails`
|
deployment somewhere in the pool. The router's own `allowed_fails`
|
||||||
handles non-vision deployments via fallback."""
|
handles non-vision deployments via fallback."""
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.shared.llm_config import AgentConfig
|
||||||
|
|
||||||
auto = AgentConfig.from_auto_mode()
|
auto = AgentConfig.from_auto_mode()
|
||||||
assert auto.supports_image_input is True
|
assert auto.supports_image_input is True
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||||
return_value=cfg,
|
return_value=cfg,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
|
"app.agents.shared.llm_config.SanitizedChatLiteLLM",
|
||||||
new=FakeSanitized,
|
new=FakeSanitized,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue