mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): consolidate chat runtime infra under chat/runtime
Move the lower-level runtime/infra modules out of multi_agent_chat/shared/ (they were never used by subagents, so they failed the shared-by-all-siblings rule) and unify them with the already-relocated checkpointer: agents/runtime/ -> agents/chat/runtime/ mac/shared/errors.py -> chat/runtime/errors.py mac/shared/llm_config.py -> chat/runtime/llm_config.py mac/shared/prompt_caching.py -> chat/runtime/prompt_caching.py mac/shared/mention_resolver.py -> chat/runtime/mention_resolver.py mac/shared/path_resolver.py -> chat/runtime/path_resolver.py These sit below the agent packages: the boundary + agent factory + shared middleware depend on them, and they import no agent code (acyclic).
This commit is contained in:
parent
7d866a2279
commit
f2a61bc0ef
52 changed files with 97 additions and 87 deletions
16
surfsense_backend/app/agents/chat/runtime/__init__.py
Normal file
16
surfsense_backend/app/agents/chat/runtime/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Lower-level runtime infrastructure for the chat agents.
|
||||
|
||||
Modules here are the foundation layer used to *run* chat agents: wired by the
|
||||
boundary (routes/tasks) and/or imported by the agent factory + shared
|
||||
middleware, but never part of any single agent's domain logic. Because they sit
|
||||
below the agent packages, both the boundary and the agents may depend on them
|
||||
(forward dependency), while they never import agent code.
|
||||
|
||||
Contents:
|
||||
- ``checkpointer`` LangGraph Postgres checkpoint saver (boundary lifespan)
|
||||
- ``llm_config`` LLM provider/model configuration resolution
|
||||
- ``prompt_caching`` LiteLLM prompt-caching configuration
|
||||
- ``errors`` agent-runtime error contracts (raised by MW, caught at boundary)
|
||||
- ``path_resolver`` filesystem path resolution helpers
|
||||
- ``mention_resolver`` @-mention resolution helpers
|
||||
"""
|
||||
144
surfsense_backend/app/agents/chat/runtime/checkpointer.py
Normal file
144
surfsense_backend/app/agents/chat/runtime/checkpointer.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
PostgreSQL-based checkpointer for LangGraph agents.
|
||||
|
||||
This module provides a persistent checkpointer using AsyncPostgresSaver
|
||||
that stores conversation state in the PostgreSQL database.
|
||||
|
||||
Uses a connection pool (psycopg_pool.AsyncConnectionPool) to handle
|
||||
connection lifecycle, health checks, and automatic reconnection,
|
||||
preventing 'the connection is closed' errors in long-running deployments.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance (initialized lazily)
|
||||
_checkpointer: AsyncPostgresSaver | None = None
|
||||
_connection_pool: AsyncConnectionPool | None = None
|
||||
_checkpointer_initialized: bool = False
|
||||
|
||||
|
||||
def get_postgres_connection_string() -> str:
|
||||
"""
|
||||
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
|
||||
|
||||
The DATABASE_URL is typically in format:
|
||||
postgresql+asyncpg://user:pass@host:port/dbname
|
||||
|
||||
We need to convert it to:
|
||||
postgresql://user:pass@host:port/dbname
|
||||
"""
|
||||
db_url = config.DATABASE_URL
|
||||
|
||||
# Handle asyncpg driver prefix
|
||||
if db_url.startswith("postgresql+asyncpg://"):
|
||||
return db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
# Handle other async prefixes
|
||||
if "+asyncpg" in db_url:
|
||||
return db_url.replace("+asyncpg", "")
|
||||
|
||||
return db_url
|
||||
|
||||
|
||||
async def _create_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Create a new AsyncPostgresSaver backed by a connection pool.
|
||||
|
||||
The connection pool automatically handles:
|
||||
- Connection health checks before use
|
||||
- Reconnection when connections die (idle timeout, DB restart, etc.)
|
||||
- Connection lifecycle management (max_lifetime, max_idle)
|
||||
"""
|
||||
global _connection_pool
|
||||
|
||||
conn_string = get_postgres_connection_string()
|
||||
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
# Connections are recycled after 30 minutes to avoid stale connections
|
||||
max_lifetime=1800,
|
||||
# Idle connections are closed after 5 minutes
|
||||
max_idle=300,
|
||||
open=False,
|
||||
# Connection kwargs required by AsyncPostgresSaver:
|
||||
# - autocommit: required for .setup() to commit checkpoint tables
|
||||
# - prepare_threshold: disable prepared statements for compatibility
|
||||
# - row_factory: checkpointer accesses rows as dicts (row["column"])
|
||||
kwargs={
|
||||
"autocommit": True,
|
||||
"prepare_threshold": 0,
|
||||
"row_factory": dict_row,
|
||||
},
|
||||
)
|
||||
await _connection_pool.open(wait=True)
|
||||
|
||||
checkpointer = AsyncPostgresSaver(conn=_connection_pool)
|
||||
logger.info("[Checkpointer] Created AsyncPostgresSaver with connection pool")
|
||||
return checkpointer
|
||||
|
||||
|
||||
async def get_checkpointer() -> AsyncPostgresSaver:
|
||||
"""
|
||||
Get or create the global AsyncPostgresSaver instance.
|
||||
|
||||
This function:
|
||||
1. Creates the checkpointer with a connection pool if it doesn't exist
|
||||
2. Sets up the required database tables on first call
|
||||
3. Returns the cached instance on subsequent calls
|
||||
|
||||
The underlying connection pool handles reconnection automatically,
|
||||
so a stale/closed connection will not cause OperationalError.
|
||||
|
||||
Returns:
|
||||
AsyncPostgresSaver: The configured checkpointer instance
|
||||
"""
|
||||
global _checkpointer, _checkpointer_initialized
|
||||
|
||||
if _checkpointer is None:
|
||||
_checkpointer = await _create_checkpointer()
|
||||
_checkpointer_initialized = False
|
||||
|
||||
# Setup tables on first call (idempotent)
|
||||
if not _checkpointer_initialized:
|
||||
await _checkpointer.setup()
|
||||
_checkpointer_initialized = True
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
async def setup_checkpointer_tables() -> None:
|
||||
"""
|
||||
Explicitly setup the checkpointer tables.
|
||||
|
||||
This can be called during application startup to ensure
|
||||
tables exist before any agent calls.
|
||||
"""
|
||||
await get_checkpointer()
|
||||
logger.info("[Checkpointer] PostgreSQL checkpoint tables ready")
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""
|
||||
Close the checkpointer connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer, _connection_pool, _checkpointer_initialized
|
||||
|
||||
if _connection_pool is not None:
|
||||
await _connection_pool.close()
|
||||
logger.info("[Checkpointer] PostgreSQL connection pool closed")
|
||||
|
||||
_checkpointer = None
|
||||
_connection_pool = None
|
||||
_checkpointer_initialized = False
|
||||
95
surfsense_backend/app/agents/chat/runtime/errors.py
Normal file
95
surfsense_backend/app/agents/chat/runtime/errors.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Typed error taxonomy for the SurfSense agent stack.
|
||||
|
||||
Used by:
|
||||
- :class:`RetryAfterMiddleware` — its ``retry_on`` callable consults
|
||||
the error code to decide whether a retry is appropriate.
|
||||
- :class:`PermissionMiddleware` — emits ``code="permission_denied"``
|
||||
errors when a deny rule trips.
|
||||
- All tools — return :class:`StreamingError` payloads in
|
||||
``ToolMessage.additional_kwargs["error"]`` so the model and the
|
||||
retry/permission layers share a contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ErrorCode = Literal[
|
||||
"rate_limit",
|
||||
"auth",
|
||||
"tool_validation",
|
||||
"tool_runtime",
|
||||
"context_overflow",
|
||||
"provider",
|
||||
"permission_denied",
|
||||
"doom_loop",
|
||||
"busy",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
"""Structured error payload attached to ``ToolMessage.additional_kwargs["error"]``.
|
||||
|
||||
Tools and middleware emit this so retry, permission, and routing
|
||||
layers can decide what to do without parsing free-form strings.
|
||||
"""
|
||||
|
||||
code: ErrorCode
|
||||
retryable: bool = False
|
||||
suggestion: str | None = None
|
||||
correlation_id: str | None = None
|
||||
detail: str | None = Field(
|
||||
default=None,
|
||||
description="Free-form additional context. Not surfaced to the model.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class RejectedError(Exception):
|
||||
"""Raised when the user rejects a permission ask without feedback.
|
||||
|
||||
Caught by :class:`PermissionMiddleware`; the agent stops the current
|
||||
tool fan-out and surfaces a user-facing rejection.
|
||||
"""
|
||||
|
||||
def __init__(self, *, tool: str | None = None, pattern: str | None = None) -> None:
|
||||
super().__init__(f"Permission rejected for tool {tool!r}, pattern {pattern!r}")
|
||||
self.tool = tool
|
||||
self.pattern = pattern
|
||||
|
||||
|
||||
class CorrectedError(Exception):
|
||||
"""Raised when the user rejects a permission ask *with* feedback.
|
||||
|
||||
The :class:`PermissionMiddleware` translates the feedback into a
|
||||
synthetic ``ToolMessage`` so the model sees the user's correction
|
||||
and can retry the request differently.
|
||||
"""
|
||||
|
||||
def __init__(self, feedback: str, *, tool: str | None = None) -> None:
|
||||
super().__init__(feedback)
|
||||
self.feedback = feedback
|
||||
self.tool = tool
|
||||
|
||||
|
||||
class BusyError(Exception):
|
||||
"""Raised when a second prompt arrives while the same thread is mid-stream."""
|
||||
|
||||
def __init__(self, request_id: str | None = None) -> None:
|
||||
super().__init__("Thread is busy with another request")
|
||||
self.request_id = request_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BusyError",
|
||||
"CorrectedError",
|
||||
"ErrorCode",
|
||||
"RejectedError",
|
||||
"StreamingError",
|
||||
]
|
||||
624
surfsense_backend/app/agents/chat/runtime/llm_config.py
Normal file
624
surfsense_backend/app/agents/chat/runtime/llm_config.py
Normal file
|
|
@ -0,0 +1,624 @@
|
|||
"""
|
||||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.runtime.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
)
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Sanitize content on every message so it is safe for any provider.
|
||||
|
||||
Handles three cross-provider incompatibilities:
|
||||
- List content with provider-specific blocks (e.g. ``thinking``)
|
||||
- List content with bare strings or empty text blocks
|
||||
- AI messages with empty content + tool calls: some providers (Bedrock)
|
||||
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
||||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
||||
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
||||
in content arrays before forwarding to the underlying provider."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return super()._generate(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
async for chunk in super()._astream(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
#
|
||||
# Single source of truth lives in
|
||||
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||
# runs during ``app.config`` class-body init) can resolve provider
|
||||
# prefixes without dragging the agent / tools tree into module load
|
||||
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||
# tests) keep working unchanged.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
||||
try:
|
||||
info = get_model_info(model_string)
|
||||
max_input_tokens = info.get("max_input_tokens")
|
||||
if isinstance(max_input_tokens, int) and max_input_tokens > 0:
|
||||
llm.profile = {
|
||||
"max_input_tokens": max_input_tokens,
|
||||
"max_input_tokens_upper": max_input_tokens,
|
||||
"token_count_model": model_string,
|
||||
"token_count_models": [model_string],
|
||||
}
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""
|
||||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
provider: str
|
||||
model_name: str
|
||||
api_key: str
|
||||
api_base: str | None = None
|
||||
custom_provider: str | None = None
|
||||
litellm_params: dict | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str | None = None
|
||||
use_default_system_instructions: bool = True
|
||||
citations_enabled: bool = True
|
||||
|
||||
# Metadata
|
||||
config_id: int | None = None
|
||||
config_name: str | None = None
|
||||
|
||||
# Auto mode flag
|
||||
is_auto_mode: bool = False
|
||||
|
||||
# Token quota and policy
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
anonymous_enabled: bool = False
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
# Capability flag: best-effort True for the chat selector / catalog.
|
||||
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||
# is the conservative-allow stance — the streaming-task safety net
|
||||
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||
# actually blocks a request. Setting this to False here without an
|
||||
# authoritative source would silently hide vision-capable models
|
||||
# (the regression we're fixing).
|
||||
supports_image_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_auto_mode(cls) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
Returns:
|
||||
AgentConfig instance configured for Auto mode
|
||||
"""
|
||||
return cls(
|
||||
provider="AUTO",
|
||||
model_name="auto",
|
||||
api_key="", # Not needed for router
|
||||
api_base=None,
|
||||
custom_provider=None,
|
||||
litellm_params=None,
|
||||
system_instructions=None,
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Fastest)",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# contains at least one vision-capable deployment; the router
|
||||
# will surface a 404 from a non-vision deployment as a normal
|
||||
# ``allowed_fails`` event and fail over rather than blocking
|
||||
# the request outright.
|
||||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig from a NewLLMConfig database model.
|
||||
|
||||
Args:
|
||||
config: NewLLMConfig database model instance
|
||||
|
||||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
custom_provider=config.custom_provider,
|
||||
litellm_params=config.litellm_params,
|
||||
system_instructions=config.system_instructions,
|
||||
use_default_system_instructions=config.use_default_system_instructions,
|
||||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no operator-curated capability flag, so we
|
||||
# ask LiteLLM (default-allow on unknown). The streaming
|
||||
# safety net still blocks if the model is *explicitly*
|
||||
# marked text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig from a YAML configuration dictionary.
|
||||
|
||||
YAML configs now support the same prompt configuration fields as NewLLMConfig:
|
||||
- system_instructions: Custom system instructions (empty string uses defaults)
|
||||
- use_default_system_instructions: Whether to use default instructions
|
||||
- citations_enabled: Whether citations are enabled
|
||||
|
||||
Args:
|
||||
yaml_config: Configuration dictionary from YAML file
|
||||
|
||||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
# Get system instructions from YAML, default to empty string
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||
# OpenRouter modalities. The YAML loader already populates this
|
||||
# field, but this method is also called from
|
||||
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||
# so we re-derive here for safety. The bool() coercion preserves
|
||||
# the loader's behaviour for explicit ``true`` / ``false``
|
||||
# strings that PyYAML may surface.
|
||||
if "supports_image_input" in yaml_config:
|
||||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||
else:
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=yaml_config.get("api_key", ""),
|
||||
api_base=yaml_config.get("api_base"),
|
||||
custom_provider=custom_provider,
|
||||
litellm_params=yaml_config.get("litellm_params"),
|
||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||
system_instructions=system_instructions if system_instructions else None,
|
||||
use_default_system_instructions=yaml_config.get(
|
||||
"use_default_system_instructions", True
|
||||
),
|
||||
citations_enabled=yaml_config.get("citations_enabled", True),
|
||||
config_id=yaml_config.get("id"),
|
||||
config_name=yaml_config.get("name"),
|
||||
is_auto_mode=False,
|
||||
billing_tier=yaml_config.get("billing_tier", "free"),
|
||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
||||
"""
|
||||
Load a specific LLM config from global_llm_config.yaml.
|
||||
|
||||
Args:
|
||||
llm_config_id: The id of the config to load (default: -1)
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
# Get the config file path
|
||||
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
||||
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
# Fallback to example file if main config doesn't exist
|
||||
if not config_file.exists():
|
||||
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
||||
if not config_file.exists():
|
||||
print("Error: No global_llm_config.yaml or example file found")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_llm_configs", [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict) and cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
|
||||
print(f"Error: Global LLM config id {llm_config_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error loading config: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Load a global LLM config by ID, checking in-memory configs first.
|
||||
|
||||
This handles both static YAML configs and dynamically injected configs
|
||||
(e.g. OpenRouter integration models that only exist in memory).
|
||||
|
||||
Args:
|
||||
llm_config_id: The negative ID of the global config to load
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
||||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load a NewLLMConfig from the database by ID.
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
config_id: The ID of the NewLLMConfig to load
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from app.db import NewLLMConfig
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||
return None
|
||||
|
||||
return AgentConfig.from_new_llm_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error loading NewLLMConfig from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_llm_config_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load the agent LLM configuration for a search space.
|
||||
|
||||
This loads the LLM config based on the search space's agent_llm_id setting:
|
||||
- Positive ID: Load from NewLLMConfig database table
|
||||
- Negative ID: Load from YAML global configs
|
||||
- None: Falls back to first global config (id=-1)
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from app.db import SearchSpace
|
||||
|
||||
try:
|
||||
# Get the search space to check its agent_llm_id preference
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||
return None
|
||||
|
||||
# Use agent_llm_id from search space, fallback to -1 (first global config)
|
||||
config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Load the config using the unified loader
|
||||
return await load_agent_config(session, config_id, search_space_id)
|
||||
except Exception as e:
|
||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_config(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
||||
|
||||
This is the main entry point for loading configurations:
|
||||
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
||||
- Negative IDs: Load from YAML file (global configs)
|
||||
- Positive IDs: Load from NewLLMConfig database table
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
||||
search_space_id: Optional search space ID for context
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Auto mode (ID 0) - use LiteLLM Router
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
# Fallback to YAML file read for safety
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
return None
|
||||
else:
|
||||
# Load from database (NewLLMConfig)
|
||||
return await load_new_llm_config_from_db(session, config_id)
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Create a ChatLiteLLM instance from a global LLM config dictionary.
|
||||
|
||||
Args:
|
||||
llm_config: LLM configuration dictionary from YAML
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None on error
|
||||
"""
|
||||
# Build the model string
|
||||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
|
||||
# Create ChatLiteLLM instance with streaming enabled
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.get("api_key"),
|
||||
"streaming": True, # Enable streaming for real-time token streaming
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.get("api_base"):
|
||||
litellm_kwargs["api_base"] = llm_config["api_base"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.get("litellm_params"):
|
||||
litellm_kwargs.update(llm_config["litellm_params"])
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||
# in a structured form, so we set only the universal injection points.
|
||||
apply_litellm_prompt_caching(llm)
|
||||
return llm
|
||||
|
||||
|
||||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
||||
|
||||
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
||||
for automatic load balancing across available providers.
|
||||
|
||||
Args:
|
||||
agent_config: AgentConfig instance
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
||||
"""
|
||||
# Handle Auto mode - return ChatLiteLLMRouter
|
||||
if agent_config.is_auto_mode:
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal cache_control_injection_points only — auto-mode
|
||||
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||
# would strip them at the provider boundary anyway, but
|
||||
# there's no point setting them when we don't know the
|
||||
# destination.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
# Build the model string
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(
|
||||
agent_config.provider, agent_config.provider.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance with streaming enabled
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": agent_config.api_key,
|
||||
"streaming": True, # Enable streaming for real-time token streaming
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if agent_config.api_base:
|
||||
litellm_kwargs["api_base"] = agent_config.api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if agent_config.litellm_params:
|
||||
litellm_kwargs.update(agent_config.litellm_params)
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||
return llm
|
||||
277
surfsense_backend/app/agents/chat/runtime/mention_resolver.py
Normal file
277
surfsense_backend/app/agents/chat/runtime/mention_resolver.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
"""Resolve @-mention chips to canonical virtual paths and substitute the
|
||||
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
|
||||
the agent sees.
|
||||
|
||||
The frontend's mention seam is a single discriminated-union list of
|
||||
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
|
||||
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
|
||||
reaches the backend stream task we have three needs that this module
|
||||
centralises:
|
||||
|
||||
1. Map each chip to its canonical virtual path
|
||||
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
|
||||
folders) so the agent sees concrete filesystem locations instead of
|
||||
ambiguous ``@``-titles.
|
||||
2. Substitute ``@title`` tokens in the user-typed text with backtick-
|
||||
wrapped paths so the path becomes part of the ``HumanMessage`` body
|
||||
the LLM consumes — without rewriting the persisted user message
|
||||
text (which keeps ``@title`` so chip rendering on reload is
|
||||
unchanged).
|
||||
3. Surface the resolved id sets (docs + folders) to the priority
|
||||
middleware so it can render ``[USER-MENTIONED]`` priority entries
|
||||
without re-doing path resolution.
|
||||
|
||||
This is intentionally one module — see the architectural note in
|
||||
``mention-paths-and-folders`` plan: previously the doc-resolution lived
|
||||
inline in ``stream_new_chat`` and the folder mention had no resolution
|
||||
at all. Centralising both behind a single ``resolve_mentions`` call
|
||||
turns a leaky multi-field seam into a single deeper interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.runtime.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, Folder
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedMention:
|
||||
"""Canonical view of a single @-mention chip.
|
||||
|
||||
``virtual_path`` is the path the agent will see (no trailing slash
|
||||
for documents, trailing ``/`` for folders to match the convention
|
||||
used by ``KnowledgeTreeMiddleware``).
|
||||
"""
|
||||
|
||||
kind: str # "doc" | "folder"
|
||||
id: int
|
||||
title: str
|
||||
virtual_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedMentionSet:
|
||||
"""Aggregate result of resolving a turn's mention chips.
|
||||
|
||||
``token_to_path`` maps ``@title`` (the literal token the user typed
|
||||
and the editor emitted) to the canonical virtual path for that
|
||||
chip. It is produced longest-token-first so substitution mirrors
|
||||
``parseMentionSegments`` on the frontend (a longer title like
|
||||
``@Project Roadmap`` is never shadowed by a shorter prefix
|
||||
``@Project``).
|
||||
|
||||
``mentioned_document_ids`` is an ordered, deduped list consumed by
|
||||
the priority middleware downstream — see
|
||||
``KnowledgePriorityMiddleware._compute_priority_paths``.
|
||||
"""
|
||||
|
||||
mentions: list[ResolvedMention] = field(default_factory=list)
|
||||
token_to_path: list[tuple[str, str]] = field(default_factory=list)
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
|
||||
|
||||
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
|
||||
"""Return ``/documents/Folder/Sub/`` for a folder id.
|
||||
|
||||
Falls back to the documents root when the folder is missing from
|
||||
the index (deleted or in a different search space). Trailing slash
|
||||
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
|
||||
the agent's ``ls`` can dispatch on it as a directory.
|
||||
"""
|
||||
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
return f"{base}/" if not base.endswith("/") else base
|
||||
|
||||
|
||||
async def resolve_mentions(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
) -> ResolvedMentionSet:
|
||||
"""Resolve every @-mention chip on a turn into virtual paths.
|
||||
|
||||
The function takes both the ``mentioned_documents`` discriminated
|
||||
list (chip metadata used for substitution + persistence) and the
|
||||
parallel id arrays (``mentioned_document_ids``,
|
||||
``mentioned_folder_ids``) for two reasons:
|
||||
|
||||
* Legacy clients that haven't migrated to the unified chip list
|
||||
still send the id arrays — we treat the union as authoritative.
|
||||
* The id arrays are the canonical input to
|
||||
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
|
||||
returning the deduped, validated lists lets the route forward
|
||||
them unchanged.
|
||||
|
||||
Resolution is best-effort: a chip whose id no longer exists (e.g.
|
||||
document was deleted between mention and submit) is silently
|
||||
dropped. The agent still sees the user's original text, just
|
||||
without a backtick-path substitution for that chip.
|
||||
"""
|
||||
chip_doc_ids: list[int] = []
|
||||
chip_folder_ids: list[int] = []
|
||||
chip_titles_by_id: dict[tuple[str, int], str] = {}
|
||||
if mentioned_documents:
|
||||
for chip in mentioned_documents:
|
||||
kind = chip.kind
|
||||
if kind == "folder":
|
||||
chip_folder_ids.append(chip.id)
|
||||
elif kind == "doc":
|
||||
chip_doc_ids.append(chip.id)
|
||||
chip_titles_by_id[(kind, chip.id)] = chip.title
|
||||
|
||||
doc_id_pool: list[int] = list(
|
||||
dict.fromkeys(
|
||||
[
|
||||
*(mentioned_document_ids or []),
|
||||
*chip_doc_ids,
|
||||
]
|
||||
)
|
||||
)
|
||||
folder_id_pool: list[int] = list(
|
||||
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
|
||||
)
|
||||
|
||||
if not doc_id_pool and not folder_id_pool:
|
||||
return ResolvedMentionSet()
|
||||
|
||||
index = await build_path_index(session, search_space_id)
|
||||
|
||||
doc_rows: dict[int, Document] = {}
|
||||
if doc_id_pool:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
doc_rows[row.id] = row
|
||||
|
||||
folder_rows: dict[int, Folder] = {}
|
||||
if folder_id_pool:
|
||||
result = await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.id.in_(folder_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
folder_rows[row.id] = row
|
||||
|
||||
resolved: list[ResolvedMention] = []
|
||||
accepted_doc_ids: list[int] = []
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
for doc_id in doc_id_pool:
|
||||
row = doc_rows.get(doc_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping doc id=%s (not found in space=%s)",
|
||||
doc_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_doc_ids.append(row.id)
|
||||
|
||||
for folder_id in folder_id_pool:
|
||||
row = folder_rows.get(folder_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping folder id=%s (not found in space=%s)",
|
||||
folder_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
|
||||
path = _folder_virtual_path(row.id, index.folder_paths)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_folder_ids.append(row.id)
|
||||
|
||||
token_to_path: list[tuple[str, str]] = []
|
||||
seen_tokens: set[str] = set()
|
||||
for mention in resolved:
|
||||
if not mention.title:
|
||||
continue
|
||||
token = f"@{mention.title}"
|
||||
if token in seen_tokens:
|
||||
continue
|
||||
seen_tokens.add(token)
|
||||
token_to_path.append((token, mention.virtual_path))
|
||||
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
|
||||
|
||||
return ResolvedMentionSet(
|
||||
mentions=resolved,
|
||||
token_to_path=token_to_path,
|
||||
mentioned_document_ids=accepted_doc_ids,
|
||||
mentioned_folder_ids=accepted_folder_ids,
|
||||
)
|
||||
|
||||
|
||||
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
|
||||
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
|
||||
|
||||
Mirrors ``parseMentionSegments`` on the frontend: longest token
|
||||
first, single forward pass, no regex (titles can contain regex
|
||||
metacharacters). The substitution is idempotent for already-
|
||||
substituted text because the backtick-wrapped path no longer
|
||||
starts with ``@``.
|
||||
|
||||
Empty / no-op cases short-circuit so callers can pass this through
|
||||
unconditionally without paying for a scan.
|
||||
"""
|
||||
if not text or not token_to_path:
|
||||
return text
|
||||
|
||||
out: list[str] = []
|
||||
i = 0
|
||||
n = len(text)
|
||||
while i < n:
|
||||
matched: tuple[str, str] | None = None
|
||||
for token, path in token_to_path:
|
||||
if text.startswith(token, i):
|
||||
matched = (token, path)
|
||||
break
|
||||
if matched is None:
|
||||
out.append(text[i])
|
||||
i += 1
|
||||
continue
|
||||
token, path = matched
|
||||
out.append(f"`{path}`")
|
||||
i += len(token)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ResolvedMention",
|
||||
"ResolvedMentionSet",
|
||||
"resolve_mentions",
|
||||
"substitute_in_text",
|
||||
]
|
||||
351
surfsense_backend/app/agents/chat/runtime/path_resolver.py
Normal file
351
surfsense_backend/app/agents/chat/runtime/path_resolver.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
|
||||
|
||||
This module is the single source of truth for mapping ``Document`` rows to
|
||||
virtual paths under ``/documents/`` and back. It is used by:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
|
||||
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
|
||||
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
|
||||
|
||||
Centralising the logic ensures that title-collision suffixes, folder paths,
|
||||
and ``unique_identifier_hash`` lookups never drift between renders and
|
||||
commits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType, Folder
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
DOCUMENTS_ROOT = "/documents"
|
||||
"""Root virtual folder for all KB documents."""
|
||||
|
||||
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
|
||||
_WHITESPACE_RUN = re.compile(r"\s+")
|
||||
|
||||
|
||||
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
name = fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
if not name.lower().endswith(".xml"):
|
||||
name = f"{name}.xml"
|
||||
return name
|
||||
|
||||
|
||||
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
|
||||
"""Sanitize a single folder name into a path-safe segment."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
return fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
return name
|
||||
|
||||
|
||||
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
|
||||
if doc_id is None:
|
||||
return filename
|
||||
if not filename.lower().endswith(".xml"):
|
||||
return f"{filename} ({doc_id}).xml"
|
||||
stem = filename[:-4]
|
||||
return f"{stem} ({doc_id}).xml"
|
||||
|
||||
|
||||
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
|
||||
|
||||
|
||||
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
|
||||
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
|
||||
|
||||
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
|
||||
"""
|
||||
match = _SUFFIX_PATTERN.search(filename)
|
||||
if match:
|
||||
doc_id = int(match.group(1))
|
||||
stem = filename[: match.start()]
|
||||
return stem, doc_id
|
||||
if filename.lower().endswith(".xml"):
|
||||
return filename[:-4], None
|
||||
return filename, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathIndex:
|
||||
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
|
||||
|
||||
Built once per call site so collision handling is deterministic and so
|
||||
we don't perform N folder lookups per render.
|
||||
"""
|
||||
|
||||
folder_paths: dict[int, str] = field(default_factory=dict)
|
||||
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
|
||||
|
||||
occupants: dict[str, int] = field(default_factory=dict)
|
||||
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
|
||||
|
||||
|
||||
async def _build_folder_paths(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> dict[int, str]:
|
||||
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
|
||||
result = await session.execute(
|
||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||
cache: dict[int, str] = {}
|
||||
|
||||
def resolve(folder_id: int) -> str:
|
||||
if folder_id in cache:
|
||||
return cache[folder_id]
|
||||
parts: list[str] = []
|
||||
cursor: int | None = folder_id
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
entry = by_id[cursor]
|
||||
parts.append(safe_folder_segment(str(entry["name"])))
|
||||
cursor = entry["parent_id"]
|
||||
parts.reverse()
|
||||
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||
cache[folder_id] = path
|
||||
return path
|
||||
|
||||
for folder_id in by_id:
|
||||
resolve(folder_id)
|
||||
return cache
|
||||
|
||||
|
||||
async def build_path_index(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
populate_occupants: bool = True,
|
||||
) -> PathIndex:
|
||||
"""Build a :class:`PathIndex` for a search space.
|
||||
|
||||
``populate_occupants`` controls whether the occupancy map is pre-seeded
|
||||
from existing ``Document`` rows. Most callers want this so that
|
||||
:func:`doc_to_virtual_path` can detect collisions across the whole space;
|
||||
the persistence middleware sets this to ``False`` when it is iterating to
|
||||
decide where to place fresh documents.
|
||||
"""
|
||||
folder_paths = await _build_folder_paths(session, search_space_id)
|
||||
occupants: dict[str, int] = {}
|
||||
if populate_occupants:
|
||||
rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
for row in rows.all():
|
||||
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(row.title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
if path in occupants and occupants[path] != row.id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
|
||||
occupants[path] = row.id
|
||||
return PathIndex(folder_paths=folder_paths, occupants=occupants)
|
||||
|
||||
|
||||
def doc_to_virtual_path(
|
||||
*,
|
||||
doc_id: int | None,
|
||||
title: str,
|
||||
folder_id: int | None,
|
||||
index: PathIndex,
|
||||
) -> str:
|
||||
"""Return the canonical virtual path for a document.
|
||||
|
||||
Mutates ``index.occupants`` so subsequent calls see this assignment and
|
||||
deterministically pick a different suffix for the next colliding doc.
|
||||
"""
|
||||
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
occupant = index.occupants.get(path)
|
||||
if occupant is not None and occupant != doc_id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
|
||||
if doc_id is not None:
|
||||
index.occupants[path] = doc_id
|
||||
return path
|
||||
|
||||
|
||||
async def virtual_path_to_doc(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
) -> Document | None:
|
||||
"""Resolve a virtual path back to a ``Document`` row.
|
||||
|
||||
Resolution order:
|
||||
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
|
||||
by SurfSense itself — every NOTE write goes through this hash).
|
||||
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
|
||||
try a direct id lookup constrained to the search space.
|
||||
3. Title-from-basename + folder-resolution lookup as a last resort.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return None
|
||||
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_hash,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
return None
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return None
|
||||
basename = parts[-1]
|
||||
folder_parts = parts[:-1]
|
||||
|
||||
stem, suffix_doc_id = parse_doc_id_suffix(basename)
|
||||
if suffix_doc_id is not None:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id == suffix_doc_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
folder_id = await _resolve_folder_id(
|
||||
session, search_space_id=search_space_id, folder_parts=folder_parts
|
||||
)
|
||||
title_candidates: list[str] = []
|
||||
raw_title = stem
|
||||
title_candidates.append(raw_title)
|
||||
if raw_title.endswith(".xml"):
|
||||
title_candidates.append(raw_title[:-4])
|
||||
|
||||
for candidate in dict.fromkeys(title_candidates):
|
||||
if not candidate:
|
||||
continue
|
||||
query = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.title == candidate,
|
||||
)
|
||||
if folder_id is None:
|
||||
query = query.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
query = query.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalars().first()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
# Fallback: title-as-string lookup misses when the real DB title contains
|
||||
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
|
||||
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
|
||||
# The workspace tree shows the lossy filename, so the agent passes that
|
||||
# filename back here. Scan all documents in the resolved folder and match
|
||||
# by ``safe_filename(title)`` to recover the original document.
|
||||
folder_scan = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
if folder_id is None:
|
||||
folder_scan = folder_scan.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
folder_scan = folder_scan.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(folder_scan)
|
||||
for candidate_doc in result.scalars().all():
|
||||
encoded = safe_filename(str(candidate_doc.title or "untitled"))
|
||||
if encoded == basename:
|
||||
return candidate_doc
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_folder_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
folder_parts: list[str],
|
||||
) -> int | None:
|
||||
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
|
||||
if not folder_parts:
|
||||
return None
|
||||
parent_id: int | None = None
|
||||
for raw in folder_parts:
|
||||
name = safe_folder_segment(raw)
|
||||
query = select(Folder.id).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.name == name,
|
||||
)
|
||||
if parent_id is None:
|
||||
query = query.where(Folder.parent_id.is_(None))
|
||||
else:
|
||||
query = query.where(Folder.parent_id == parent_id)
|
||||
result = await session.execute(query)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
return None
|
||||
parent_id = row[0]
|
||||
return parent_id
|
||||
|
||||
|
||||
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
|
||||
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
|
||||
|
||||
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
|
||||
disambiguation suffix stripped.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return [], ""
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||
if not rel:
|
||||
return [], ""
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return [], ""
|
||||
folder_parts = parts[:-1]
|
||||
basename = parts[-1]
|
||||
stem, _ = parse_doc_id_suffix(basename)
|
||||
title = stem
|
||||
if title.endswith(".xml"):
|
||||
title = title[:-4]
|
||||
return folder_parts, title
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DOCUMENTS_ROOT",
|
||||
"PathIndex",
|
||||
"build_path_index",
|
||||
"doc_to_virtual_path",
|
||||
"parse_doc_id_suffix",
|
||||
"parse_documents_path",
|
||||
"safe_filename",
|
||||
"safe_folder_segment",
|
||||
"virtual_path_to_doc",
|
||||
]
|
||||
239
surfsense_backend/app/agents/chat/runtime/prompt_caching.py
Normal file
239
surfsense_backend/app/agents/chat/runtime/prompt_caching.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||
|
||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||
gate always failed) with LiteLLM's universal caching mechanism.
|
||||
|
||||
Coverage:
|
||||
|
||||
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||
performs automatically when ``cache_control_injection_points`` is set):
|
||||
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||
|
||||
We inject **two** breakpoints per request:
|
||||
|
||||
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
||||
request (provider variant, citation rules, tool catalog, KB tree,
|
||||
skills metadata). The langchain agent factory always prepends
|
||||
``request.system_message`` at index 0 (see ``factory.py``
|
||||
``_execute_model_async``), so this targets exactly the main system
|
||||
prompt regardless of how many other ``SystemMessage``\ s the
|
||||
``before_agent`` injectors (priority, tree, memory, file-intent,
|
||||
anonymous-doc) have inserted into ``state["messages"]``. Using
|
||||
``role: system`` here would apply ``cache_control`` to **every**
|
||||
system-role message and trip Anthropic's hard cap of 4 cache
|
||||
breakpoints per request once the conversation accumulates enough
|
||||
injected system messages — which surfaces as the upstream 400
|
||||
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
||||
via OpenRouter→Anthropic.
|
||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||
N+1 still reads turn N's cache up to the shared prefix.
|
||||
|
||||
For OpenAI-family configs we additionally pass:
|
||||
|
||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||
raises hit rate by sending requests with a shared prefix to the same
|
||||
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
||||
``azure/`` (added to LiteLLM's Azure transformer in
|
||||
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
||||
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
||||
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
||||
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
||||
server-side support landed in Microsoft's docs on 2026-05-13 but
|
||||
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
||||
params list, so it gets silently dropped by ``litellm.drop_params``.
|
||||
Azure's default in-memory retention (5-10 min, max 1 h) already
|
||||
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
||||
|
||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||
provider doesn't recognise is auto-stripped at the provider transformer
|
||||
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||
``prompt_cache_key`` etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.agents.chat.runtime.llm_config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Two-breakpoint policy: head-of-request + latest message. See module
|
||||
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
||||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||
#
|
||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
|
||||
# insert ``SystemMessage`` instances into ``state["messages"]`` that
|
||||
# accumulate across turns. With ``role: system`` the LiteLLM hook would
|
||||
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
|
||||
# 4-block limit. ``index: 0`` always targets the langchain-prepended
|
||||
# ``request.system_message``, giving us exactly one stable cache breakpoint.
|
||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||
{"location": "message", "index": 0},
|
||||
{"location": "message", "index": -1},
|
||||
)
|
||||
|
||||
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
||||
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
||||
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
||||
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
||||
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
||||
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
||||
# transformer ships ``prompt_cache_key`` in its supported params as of
|
||||
# https://github.com/BerriAI/litellm/pull/20989.
|
||||
#
|
||||
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
||||
# through litellm's ``openai`` prefix without implementing the OpenAI
|
||||
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
||||
# family from the litellm prefix alone.
|
||||
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
||||
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
||||
)
|
||||
|
||||
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
||||
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
||||
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
||||
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
||||
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
||||
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
||||
{"OPENAI", "DEEPSEEK", "XAI"}
|
||||
)
|
||||
|
||||
|
||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||
|
||||
Importing ``app.services.llm_router_service`` at module-load time would
|
||||
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||
Class-name comparison is sufficient since the class is defined in a
|
||||
single place.
|
||||
"""
|
||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||
|
||||
|
||||
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
|
||||
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
|
||||
|
||||
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
|
||||
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
|
||||
providers return False because we can't statically know the
|
||||
destination and the router fans out across mixed providers.
|
||||
"""
|
||||
if agent_config is None or not agent_config.provider:
|
||||
return False
|
||||
if agent_config.is_auto_mode:
|
||||
return False
|
||||
if agent_config.custom_provider:
|
||||
return False
|
||||
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
|
||||
|
||||
|
||||
def _provider_supports_prompt_cache_retention(
|
||||
agent_config: AgentConfig | None,
|
||||
) -> bool:
|
||||
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
|
||||
|
||||
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
|
||||
deployments are excluded until LiteLLM ships the param in its Azure
|
||||
transformer (see module docstring).
|
||||
"""
|
||||
if agent_config is None or not agent_config.provider:
|
||||
return False
|
||||
if agent_config.is_auto_mode:
|
||||
return False
|
||||
if agent_config.custom_provider:
|
||||
return False
|
||||
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
|
||||
|
||||
|
||||
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||
|
||||
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||
``model_kwargs`` attribute (caller should treat as no-op).
|
||||
"""
|
||||
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||
if isinstance(model_kwargs, dict):
|
||||
return model_kwargs
|
||||
try:
|
||||
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return None
|
||||
refreshed = getattr(llm, "model_kwargs", None)
|
||||
return refreshed if isinstance(refreshed, dict) else None
|
||||
|
||||
|
||||
def apply_litellm_prompt_caching(
|
||||
llm: BaseChatModel,
|
||||
*,
|
||||
agent_config: AgentConfig | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> None:
|
||||
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||
|
||||
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||
in our custom ``ChatLiteLLMRouter``.
|
||||
|
||||
Args:
|
||||
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||
behaviour. When omitted (or auto-mode), only the universal
|
||||
``cache_control_injection_points`` are set.
|
||||
thread_id: Optional thread id used to construct a per-thread
|
||||
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||
works without it (server-side automatic), but the key improves
|
||||
backend routing affinity and therefore hit rate.
|
||||
"""
|
||||
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||
if model_kwargs is None:
|
||||
logger.debug(
|
||||
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||
type(llm).__name__,
|
||||
)
|
||||
return
|
||||
|
||||
if "cache_control_injection_points" not in model_kwargs:
|
||||
model_kwargs["cache_control_injection_points"] = [
|
||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||
]
|
||||
|
||||
# OpenAI-style extras only when we statically know the destination
|
||||
# accepts them. Auto-mode router fans out across mixed providers so
|
||||
# we can't safely set destination-specific kwargs there (drop_params
|
||||
# would strip them but it's wasteful to set them in the first
|
||||
# place).
|
||||
if _is_router_llm(llm):
|
||||
return
|
||||
|
||||
if (
|
||||
thread_id is not None
|
||||
and "prompt_cache_key" not in model_kwargs
|
||||
and _provider_supports_prompt_cache_key(agent_config)
|
||||
):
|
||||
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||
|
||||
if (
|
||||
"prompt_cache_retention" not in model_kwargs
|
||||
and _provider_supports_prompt_cache_retention(agent_config)
|
||||
):
|
||||
model_kwargs["prompt_cache_retention"] = "24h"
|
||||
Loading…
Add table
Add a link
Reference in a new issue