feat: enhance document formatting and context management for LLM tools

- Introduced dynamic character budget calculation for document formatting based on model's context window.
- Updated `format_documents_for_context` to respect character limits and improve output quality.
- Added `max_input_tokens` parameter to various functions to facilitate context-aware processing.
- Enhanced error handling for context overflow in LLM router service.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-26 20:47:19 -08:00
parent a4dc84d1ab
commit 1e4b8d3e89
4 changed files with 178 additions and 24 deletions

View file

@ -241,6 +241,15 @@ async def create_surfsense_deep_agent(
# Build dependencies dict for the tools registry # Build dependencies dict for the tools registry
visibility = thread_visibility or ChatVisibility.PRIVATE visibility = thread_visibility or ChatVisibility.PRIVATE
# Extract the model's context window so tools can size their output.
_model_profile = getattr(llm, "profile", None)
_max_input_tokens: int | None = (
_model_profile.get("max_input_tokens")
if isinstance(_model_profile, dict)
else None
)
dependencies = { dependencies = {
"search_space_id": search_space_id, "search_space_id": search_space_id,
"db_session": db_session, "db_session": db_session,
@ -251,6 +260,7 @@ async def create_surfsense_deep_agent(
"thread_visibility": visibility, "thread_visibility": visibility,
"available_connectors": available_connectors, "available_connectors": available_connectors,
"available_document_types": available_document_types, "available_document_types": available_document_types,
"max_input_tokens": _max_input_tokens,
} }
# Disable Notion action tools if no Notion connector is configured # Disable Notion action tools if no Notion connector is configured

View file

@ -172,12 +172,52 @@ def _normalize_connectors(
# ============================================================================= # =============================================================================
def format_documents_for_context(documents: list[dict[str, Any]]) -> str: # Fraction of the model's context window (in characters) that a single tool
# result is allowed to occupy. The remainder is reserved for system prompt,
# conversation history, and model output. With ~4 chars/token this gives a
# tool result ≈ 25 % of the context budget in tokens.
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
_CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
_MAX_CHUNK_CHARS = 8_000
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window.
Uses ``litellm.get_model_info`` via the value already resolved by
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
chain as ``max_input_tokens``. Falls back to a conservative default when
the value is unavailable.
"""
if max_input_tokens is None or max_input_tokens <= 0:
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
def format_documents_for_context(
documents: list[dict[str, Any]],
*,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS,
) -> str:
""" """
Format retrieved documents into a readable context string for the LLM. Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
a single oversized chunk cannot monopolize the output.
Args: Args:
documents: List of document dictionaries from connector search documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
Returns: Returns:
Formatted string with document contents and metadata Formatted string with document contents and metadata
@ -278,37 +318,57 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
"BAIDU_SEARCH_API", "BAIDU_SEARCH_API",
} }
# Render XML expected by citation instructions # Render XML expected by citation instructions, respecting the char budget.
parts: list[str] = [] parts: list[str] = []
for g in grouped.values(): total_chars = 0
total_docs = len(grouped)
for doc_idx, g in enumerate(grouped.values()):
metadata_json = json.dumps(g["metadata"], ensure_ascii=False) metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
is_live_search = g["document_type"] in live_search_connectors is_live_search = g["document_type"] in live_search_connectors
parts.append("<document>") doc_lines: list[str] = [
parts.append("<document_metadata>") "<document>",
parts.append(f" <document_id>{g['document_id']}</document_id>") "<document_metadata>",
parts.append(f" <document_type>{g['document_type']}</document_type>") f" <document_id>{g['document_id']}</document_id>",
parts.append(f" <title><![CDATA[{g['title']}]]></title>") f" <document_type>{g['document_type']}</document_type>",
parts.append(f" <url><![CDATA[{g['url']}]]></url>") f" <title><![CDATA[{g['title']}]]></title>",
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>") f" <url><![CDATA[{g['url']}]]></url>",
parts.append("</document_metadata>") f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
parts.append("") "</document_metadata>",
parts.append("<document_content>") "",
"<document_content>",
]
for ch in g["chunks"]: for ch in g["chunks"]:
ch_content = ch["content"] ch_content = ch["content"]
# For live search connectors, use the document URL as the chunk id if max_chunk_chars and len(ch_content) > max_chunk_chars:
# so the LLM outputs [citation:https://...] which the frontend ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
# renders as a clickable link.
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"] ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
if ch_id is None: if ch_id is None:
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>") doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
else: else:
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>") doc_lines.append(
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
)
parts.append("</document_content>") doc_lines.extend(["</document_content>", "</document>", ""])
parts.append("</document>")
parts.append("") doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml)
# Always include at least the first document; afterwards enforce budget.
if doc_idx > 0 and total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx
parts.append(
f"<!-- Output truncated: {remaining} more document(s) omitted "
f"(budget {max_chars} chars). Refine your query or reduce top_k "
f"to retrieve different results. -->"
)
break
parts.append(doc_xml)
total_chars += doc_len
return "\n".join(parts).strip() return "\n".join(parts).strip()
@ -328,6 +388,7 @@ async def search_knowledge_base_async(
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str: ) -> str:
""" """
Search the user's knowledge base for relevant documents. Search the user's knowledge base for relevant documents.
@ -345,6 +406,8 @@ async def search_knowledge_base_async(
end_date: Optional end datetime (UTC) for filtering documents end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space. available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched. If provided, only these connectors will be searched.
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns: Returns:
Formatted string with search results Formatted string with search results
@ -488,7 +551,8 @@ async def search_knowledge_base_async(
deduplicated.append(doc) deduplicated.append(doc)
return format_documents_for_context(deduplicated) output_budget = _compute_tool_output_budget(max_input_tokens)
return format_documents_for_context(deduplicated, max_chars=output_budget)
def _build_connector_docstring(available_connectors: list[str] | None) -> str: def _build_connector_docstring(available_connectors: list[str] | None) -> str:
@ -552,6 +616,7 @@ def create_search_knowledge_base_tool(
connector_service: ConnectorService, connector_service: ConnectorService,
available_connectors: list[str] | None = None, available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None, available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> StructuredTool: ) -> StructuredTool:
""" """
Factory function to create the search_knowledge_base tool with injected dependencies. Factory function to create the search_knowledge_base tool with injected dependencies.
@ -564,6 +629,8 @@ def create_search_knowledge_base_tool(
Used to dynamically generate the tool docstring. Used to dynamically generate the tool docstring.
available_document_types: Optional list of document types that have data in the search space. available_document_types: Optional list of document types that have data in the search space.
Used to inform the LLM about what data exists. Used to inform the LLM about what data exists.
max_input_tokens: Model context window (tokens) from litellm model info.
Used to dynamically size tool output.
Returns: Returns:
A configured StructuredTool instance A configured StructuredTool instance
@ -634,6 +701,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
start_date=parsed_start, start_date=parsed_start,
end_date=parsed_end, end_date=parsed_end,
available_connectors=_available_connectors, available_connectors=_available_connectors,
max_input_tokens=max_input_tokens,
) )
# Create StructuredTool with dynamic description # Create StructuredTool with dynamic description

View file

@ -118,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
# Optional: dynamically discovered connectors/document types # Optional: dynamically discovered connectors/document types
available_connectors=deps.get("available_connectors"), available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"), available_document_types=deps.get("available_document_types"),
max_input_tokens=deps.get("max_input_tokens"),
), ),
requires=["search_space_id", "db_session", "connector_service"], requires=["search_space_id", "db_session", "connector_service"],
# Note: available_connectors and available_document_types are optional # Note: available_connectors and available_document_types are optional

View file

@ -12,6 +12,7 @@ synchronous ChatLiteLLM-like interface and async methods.
""" """
import logging import logging
import re
from typing import Any from typing import Any
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
@ -20,10 +21,26 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from litellm import Router from litellm import Router
from litellm.exceptions import ContextWindowExceededError from litellm.exceptions import (
BadRequestError as LiteLLMBadRequestError,
ContextWindowExceededError,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
r"maximum context length|token.{0,20}(limit|exceed)|"
r"too many tokens|reduce the length)",
re.IGNORECASE,
)
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
"""Check if a BadRequestError is actually a context window overflow."""
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
# Special ID for Auto mode - uses router for load balancing # Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0 AUTO_MODE_ID = 0
@ -236,6 +253,10 @@ class ChatLiteLLMRouter(BaseChatModel):
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM, This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
making it a drop-in replacement for auto-mode routing. making it a drop-in replacement for auto-mode routing.
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.
""" """
# Use model_config for Pydantic v2 compatibility # Use model_config for Pydantic v2 compatibility
@ -267,7 +288,6 @@ class ChatLiteLLMRouter(BaseChatModel):
""" """
try: try:
super().__init__(**kwargs) super().__init__(**kwargs)
# Store router and tools as private attributes
resolved_router = router or LLMRouterService.get_router() resolved_router = router or LLMRouterService.get_router()
object.__setattr__(self, "_router", resolved_router) object.__setattr__(self, "_router", resolved_router)
object.__setattr__(self, "_bound_tools", bound_tools) object.__setattr__(self, "_bound_tools", bound_tools)
@ -276,6 +296,12 @@ class ChatLiteLLMRouter(BaseChatModel):
raise ValueError( raise ValueError(
"LLM Router not initialized. Call LLMRouterService.initialize() first." "LLM Router not initialized. Call LLMRouterService.initialize() first."
) )
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
computed_profile = self._compute_min_context_profile()
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)
logger.info( logger.info(
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models" f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
) )
@ -283,6 +309,39 @@ class ChatLiteLLMRouter(BaseChatModel):
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}") logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise raise
def _compute_min_context_profile(self) -> dict | None:
"""Derive a profile dict with max_input_tokens from router deployments.
Uses litellm.get_model_info to look up each deployment's context window
and picks the *minimum* so that summarization triggers before ANY model
in the pool overflows.
"""
from litellm import get_model_info
if not self._router:
return None
min_ctx: int | None = None
for deployment in self._router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if (
isinstance(ctx, int)
and ctx > 0
and (min_ctx is None or ctx < min_ctx)
):
min_ctx = ctx
except Exception:
continue
if min_ctx is not None:
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
return {"max_input_tokens": min_ctx}
return None
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "litellm-router" return "litellm-router"
@ -370,6 +429,10 @@ class ChatLiteLLMRouter(BaseChatModel):
) )
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Convert response to ChatResult with potential tool calls # Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message) message = self._convert_response_to_message(response.choices[0].message)
@ -409,6 +472,10 @@ class ChatLiteLLMRouter(BaseChatModel):
) )
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Convert response to ChatResult with potential tool calls # Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message) message = self._convert_response_to_message(response.choices[0].message)
@ -448,6 +515,10 @@ class ChatLiteLLMRouter(BaseChatModel):
) )
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks # Yield chunks
for chunk in response: for chunk in response:
@ -489,6 +560,10 @@ class ChatLiteLLMRouter(BaseChatModel):
) )
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks asynchronously # Yield chunks asynchronously
async for chunk in response: async for chunk in response: