diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 241c4f343..3843b1687 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -241,6 +241,15 @@ async def create_surfsense_deep_agent( # Build dependencies dict for the tools registry visibility = thread_visibility or ChatVisibility.PRIVATE + + # Extract the model's context window so tools can size their output. + _model_profile = getattr(llm, "profile", None) + _max_input_tokens: int | None = ( + _model_profile.get("max_input_tokens") + if isinstance(_model_profile, dict) + else None + ) + dependencies = { "search_space_id": search_space_id, "db_session": db_session, @@ -251,6 +260,7 @@ async def create_surfsense_deep_agent( "thread_visibility": visibility, "available_connectors": available_connectors, "available_document_types": available_document_types, + "max_input_tokens": _max_input_tokens, } # Disable Notion action tools if no Notion connector is configured diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index cf34e3e85..6989a1aa2 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -172,12 +172,52 @@ def _normalize_connectors( # ============================================================================= -def format_documents_for_context(documents: list[dict[str, Any]]) -> str: +# Fraction of the model's context window (in characters) that a single tool +# result is allowed to occupy. The remainder is reserved for system prompt, +# conversation history, and model output. With ~4 chars/token this gives a +# tool result ≈ 25 % of the context budget in tokens. +_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25 +_CHARS_PER_TOKEN = 4 + +# Hard-floor / ceiling so the budget is always sensible regardless of what +# the model reports. +_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens +_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens +_MAX_CHUNK_CHARS = 8_000 + + +def _compute_tool_output_budget(max_input_tokens: int | None) -> int: + """Derive a character budget from the model's context window. + + Uses ``litellm.get_model_info`` via the value already resolved by + ``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency + chain as ``max_input_tokens``. Falls back to a conservative default when + the value is unavailable. + """ + if max_input_tokens is None or max_input_tokens <= 0: + return _MIN_TOOL_OUTPUT_CHARS # conservative fallback + + budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION) + return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS)) + + +def format_documents_for_context( + documents: list[dict[str, Any]], + *, + max_chars: int = _MAX_TOOL_OUTPUT_CHARS, + max_chunk_chars: int = _MAX_CHUNK_CHARS, +) -> str: """ Format retrieved documents into a readable context string for the LLM. + Documents are added in order (highest relevance first) until the character + budget is reached. Individual chunks are capped at ``max_chunk_chars`` so + a single oversized chunk cannot monopolize the output. + Args: documents: List of document dictionaries from connector search + max_chars: Approximate character budget for the entire output. + max_chunk_chars: Per-chunk character cap (content is tail-truncated). Returns: Formatted string with document contents and metadata @@ -278,37 +318,57 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str: "BAIDU_SEARCH_API", } - # Render XML expected by citation instructions + # Render XML expected by citation instructions, respecting the char budget. parts: list[str] = [] - for g in grouped.values(): + total_chars = 0 + total_docs = len(grouped) + + for doc_idx, g in enumerate(grouped.values()): metadata_json = json.dumps(g["metadata"], ensure_ascii=False) is_live_search = g["document_type"] in live_search_connectors - parts.append("") - parts.append("") - parts.append(f" {g['document_id']}") - parts.append(f" {g['document_type']}") - parts.append(f" <![CDATA[{g['title']}]]>") - parts.append(f" ") - parts.append(f" ") - parts.append("") - parts.append("") - parts.append("") + doc_lines: list[str] = [ + "", + "", + f" {g['document_id']}", + f" {g['document_type']}", + f" <![CDATA[{g['title']}]]>", + f" ", + f" ", + "", + "", + "", + ] for ch in g["chunks"]: ch_content = ch["content"] - # For live search connectors, use the document URL as the chunk id - # so the LLM outputs [citation:https://...] which the frontend - # renders as a clickable link. + if max_chunk_chars and len(ch_content) > max_chunk_chars: + ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)" ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"] if ch_id is None: - parts.append(f" ") + doc_lines.append(f" ") else: - parts.append(f" ") + doc_lines.append( + f" " + ) - parts.append("") - parts.append("") - parts.append("") + doc_lines.extend(["", "", ""]) + + doc_xml = "\n".join(doc_lines) + doc_len = len(doc_xml) + + # Always include at least the first document; afterwards enforce budget. + if doc_idx > 0 and total_chars + doc_len > max_chars: + remaining = total_docs - doc_idx + parts.append( + f"" + ) + break + + parts.append(doc_xml) + total_chars += doc_len return "\n".join(parts).strip() @@ -328,6 +388,7 @@ async def search_knowledge_base_async( start_date: datetime | None = None, end_date: datetime | None = None, available_connectors: list[str] | None = None, + max_input_tokens: int | None = None, ) -> str: """ Search the user's knowledge base for relevant documents. @@ -345,6 +406,8 @@ async def search_knowledge_base_async( end_date: Optional end datetime (UTC) for filtering documents available_connectors: Optional list of connectors actually available in the search space. If provided, only these connectors will be searched. + max_input_tokens: Model context window size (tokens). Used to dynamically + size the output so it fits within the model's limits. Returns: Formatted string with search results @@ -488,7 +551,8 @@ async def search_knowledge_base_async( deduplicated.append(doc) - return format_documents_for_context(deduplicated) + output_budget = _compute_tool_output_budget(max_input_tokens) + return format_documents_for_context(deduplicated, max_chars=output_budget) def _build_connector_docstring(available_connectors: list[str] | None) -> str: @@ -552,6 +616,7 @@ def create_search_knowledge_base_tool( connector_service: ConnectorService, available_connectors: list[str] | None = None, available_document_types: list[str] | None = None, + max_input_tokens: int | None = None, ) -> StructuredTool: """ Factory function to create the search_knowledge_base tool with injected dependencies. @@ -564,6 +629,8 @@ def create_search_knowledge_base_tool( Used to dynamically generate the tool docstring. available_document_types: Optional list of document types that have data in the search space. Used to inform the LLM about what data exists. + max_input_tokens: Model context window (tokens) from litellm model info. + Used to dynamically size tool output. Returns: A configured StructuredTool instance @@ -634,6 +701,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type start_date=parsed_start, end_date=parsed_end, available_connectors=_available_connectors, + max_input_tokens=max_input_tokens, ) # Create StructuredTool with dynamic description diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 59efc2efb..f36f0de13 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -118,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ # Optional: dynamically discovered connectors/document types available_connectors=deps.get("available_connectors"), available_document_types=deps.get("available_document_types"), + max_input_tokens=deps.get("max_input_tokens"), ), requires=["search_space_id", "db_session", "connector_service"], # Note: available_connectors and available_document_types are optional diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 23fcad69d..2e517f0ba 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -12,6 +12,7 @@ synchronous ChatLiteLLM-like interface and async methods. """ import logging +import re from typing import Any from langchain_core.callbacks import CallbackManagerForLLMRun @@ -20,10 +21,26 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from litellm import Router -from litellm.exceptions import ContextWindowExceededError +from litellm.exceptions import ( + BadRequestError as LiteLLMBadRequestError, + ContextWindowExceededError, +) logger = logging.getLogger(__name__) +_CONTEXT_OVERFLOW_PATTERNS = re.compile( + r"(input tokens exceed|context.{0,20}(length|window|limit)|" + r"maximum context length|token.{0,20}(limit|exceed)|" + r"too many tokens|reduce the length)", + re.IGNORECASE, +) + + +def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool: + """Check if a BadRequestError is actually a context window overflow.""" + return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc))) + + # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 @@ -236,6 +253,10 @@ class ChatLiteLLMRouter(BaseChatModel): This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM, making it a drop-in replacement for auto-mode routing. + + Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context + window across all router deployments so that deepagents + SummarizationMiddleware can use fraction-based triggers. """ # Use model_config for Pydantic v2 compatibility @@ -267,7 +288,6 @@ class ChatLiteLLMRouter(BaseChatModel): """ try: super().__init__(**kwargs) - # Store router and tools as private attributes resolved_router = router or LLMRouterService.get_router() object.__setattr__(self, "_router", resolved_router) object.__setattr__(self, "_bound_tools", bound_tools) @@ -276,6 +296,12 @@ class ChatLiteLLMRouter(BaseChatModel): raise ValueError( "LLM Router not initialized. Call LLMRouterService.initialize() first." ) + + # Set profile so deepagents SummarizationMiddleware gets fraction-based triggers + computed_profile = self._compute_min_context_profile() + if computed_profile is not None: + object.__setattr__(self, "profile", computed_profile) + logger.info( f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models" ) @@ -283,6 +309,39 @@ class ChatLiteLLMRouter(BaseChatModel): logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}") raise + def _compute_min_context_profile(self) -> dict | None: + """Derive a profile dict with max_input_tokens from router deployments. + + Uses litellm.get_model_info to look up each deployment's context window + and picks the *minimum* so that summarization triggers before ANY model + in the pool overflows. + """ + from litellm import get_model_info + + if not self._router: + return None + + min_ctx: int | None = None + for deployment in self._router.model_list: + params = deployment.get("litellm_params", {}) + base_model = params.get("base_model") or params.get("model", "") + try: + info = get_model_info(base_model) + ctx = info.get("max_input_tokens") + if ( + isinstance(ctx, int) + and ctx > 0 + and (min_ctx is None or ctx < min_ctx) + ): + min_ctx = ctx + except Exception: + continue + + if min_ctx is not None: + logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}") + return {"max_input_tokens": min_ctx} + return None + @property def _llm_type(self) -> str: return "litellm-router" @@ -370,6 +429,10 @@ class ChatLiteLLMRouter(BaseChatModel): ) except ContextWindowExceededError as e: raise ContextOverflowError(str(e)) from e + except LiteLLMBadRequestError as e: + if _is_context_overflow_error(e): + raise ContextOverflowError(str(e)) from e + raise # Convert response to ChatResult with potential tool calls message = self._convert_response_to_message(response.choices[0].message) @@ -409,6 +472,10 @@ class ChatLiteLLMRouter(BaseChatModel): ) except ContextWindowExceededError as e: raise ContextOverflowError(str(e)) from e + except LiteLLMBadRequestError as e: + if _is_context_overflow_error(e): + raise ContextOverflowError(str(e)) from e + raise # Convert response to ChatResult with potential tool calls message = self._convert_response_to_message(response.choices[0].message) @@ -448,6 +515,10 @@ class ChatLiteLLMRouter(BaseChatModel): ) except ContextWindowExceededError as e: raise ContextOverflowError(str(e)) from e + except LiteLLMBadRequestError as e: + if _is_context_overflow_error(e): + raise ContextOverflowError(str(e)) from e + raise # Yield chunks for chunk in response: @@ -489,6 +560,10 @@ class ChatLiteLLMRouter(BaseChatModel): ) except ContextWindowExceededError as e: raise ContextOverflowError(str(e)) from e + except LiteLLMBadRequestError as e: + if _is_context_overflow_error(e): + raise ContextOverflowError(str(e)) from e + raise # Yield chunks asynchronously async for chunk in response: