mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-01 20:03:30 +02:00
feat: enhance document formatting and context management for LLM tools
- Introduced dynamic character budget calculation for document formatting based on model's context window. - Updated `format_documents_for_context` to respect character limits and improve output quality. - Added `max_input_tokens` parameter to various functions to facilitate context-aware processing. - Enhanced error handling for context overflow in LLM router service.
This commit is contained in:
parent
a4dc84d1ab
commit
1e4b8d3e89
4 changed files with 178 additions and 24 deletions
|
|
@ -241,6 +241,15 @@ async def create_surfsense_deep_agent(
|
||||||
|
|
||||||
# Build dependencies dict for the tools registry
|
# Build dependencies dict for the tools registry
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
||||||
|
# Extract the model's context window so tools can size their output.
|
||||||
|
_model_profile = getattr(llm, "profile", None)
|
||||||
|
_max_input_tokens: int | None = (
|
||||||
|
_model_profile.get("max_input_tokens")
|
||||||
|
if isinstance(_model_profile, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
dependencies = {
|
dependencies = {
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
"db_session": db_session,
|
"db_session": db_session,
|
||||||
|
|
@ -251,6 +260,7 @@ async def create_surfsense_deep_agent(
|
||||||
"thread_visibility": visibility,
|
"thread_visibility": visibility,
|
||||||
"available_connectors": available_connectors,
|
"available_connectors": available_connectors,
|
||||||
"available_document_types": available_document_types,
|
"available_document_types": available_document_types,
|
||||||
|
"max_input_tokens": _max_input_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Disable Notion action tools if no Notion connector is configured
|
# Disable Notion action tools if no Notion connector is configured
|
||||||
|
|
|
||||||
|
|
@ -172,12 +172,52 @@ def _normalize_connectors(
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
# Fraction of the model's context window (in characters) that a single tool
|
||||||
|
# result is allowed to occupy. The remainder is reserved for system prompt,
|
||||||
|
# conversation history, and model output. With ~4 chars/token this gives a
|
||||||
|
# tool result ≈ 25 % of the context budget in tokens.
|
||||||
|
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
|
||||||
|
_CHARS_PER_TOKEN = 4
|
||||||
|
|
||||||
|
# Hard-floor / ceiling so the budget is always sensible regardless of what
|
||||||
|
# the model reports.
|
||||||
|
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
|
||||||
|
_MAX_TOOL_OUTPUT_CHARS = 400_000 # ~100K tokens
|
||||||
|
_MAX_CHUNK_CHARS = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
|
||||||
|
"""Derive a character budget from the model's context window.
|
||||||
|
|
||||||
|
Uses ``litellm.get_model_info`` via the value already resolved by
|
||||||
|
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
|
||||||
|
chain as ``max_input_tokens``. Falls back to a conservative default when
|
||||||
|
the value is unavailable.
|
||||||
|
"""
|
||||||
|
if max_input_tokens is None or max_input_tokens <= 0:
|
||||||
|
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
|
||||||
|
|
||||||
|
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
|
||||||
|
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
|
||||||
|
|
||||||
|
|
||||||
|
def format_documents_for_context(
|
||||||
|
documents: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
|
||||||
|
max_chunk_chars: int = _MAX_CHUNK_CHARS,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format retrieved documents into a readable context string for the LLM.
|
Format retrieved documents into a readable context string for the LLM.
|
||||||
|
|
||||||
|
Documents are added in order (highest relevance first) until the character
|
||||||
|
budget is reached. Individual chunks are capped at ``max_chunk_chars`` so
|
||||||
|
a single oversized chunk cannot monopolize the output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
documents: List of document dictionaries from connector search
|
documents: List of document dictionaries from connector search
|
||||||
|
max_chars: Approximate character budget for the entire output.
|
||||||
|
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted string with document contents and metadata
|
Formatted string with document contents and metadata
|
||||||
|
|
@ -278,37 +318,57 @@ def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
|
||||||
"BAIDU_SEARCH_API",
|
"BAIDU_SEARCH_API",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Render XML expected by citation instructions
|
# Render XML expected by citation instructions, respecting the char budget.
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
for g in grouped.values():
|
total_chars = 0
|
||||||
|
total_docs = len(grouped)
|
||||||
|
|
||||||
|
for doc_idx, g in enumerate(grouped.values()):
|
||||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||||
is_live_search = g["document_type"] in live_search_connectors
|
is_live_search = g["document_type"] in live_search_connectors
|
||||||
|
|
||||||
parts.append("<document>")
|
doc_lines: list[str] = [
|
||||||
parts.append("<document_metadata>")
|
"<document>",
|
||||||
parts.append(f" <document_id>{g['document_id']}</document_id>")
|
"<document_metadata>",
|
||||||
parts.append(f" <document_type>{g['document_type']}</document_type>")
|
f" <document_id>{g['document_id']}</document_id>",
|
||||||
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
|
f" <document_type>{g['document_type']}</document_type>",
|
||||||
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
|
f" <title><![CDATA[{g['title']}]]></title>",
|
||||||
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
|
f" <url><![CDATA[{g['url']}]]></url>",
|
||||||
parts.append("</document_metadata>")
|
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||||
parts.append("")
|
"</document_metadata>",
|
||||||
parts.append("<document_content>")
|
"",
|
||||||
|
"<document_content>",
|
||||||
|
]
|
||||||
|
|
||||||
for ch in g["chunks"]:
|
for ch in g["chunks"]:
|
||||||
ch_content = ch["content"]
|
ch_content = ch["content"]
|
||||||
# For live search connectors, use the document URL as the chunk id
|
if max_chunk_chars and len(ch_content) > max_chunk_chars:
|
||||||
# so the LLM outputs [citation:https://...] which the frontend
|
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
|
||||||
# renders as a clickable link.
|
|
||||||
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
|
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
|
||||||
if ch_id is None:
|
if ch_id is None:
|
||||||
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
|
||||||
else:
|
else:
|
||||||
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
|
doc_lines.append(
|
||||||
|
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
|
||||||
|
)
|
||||||
|
|
||||||
parts.append("</document_content>")
|
doc_lines.extend(["</document_content>", "</document>", ""])
|
||||||
parts.append("</document>")
|
|
||||||
parts.append("")
|
doc_xml = "\n".join(doc_lines)
|
||||||
|
doc_len = len(doc_xml)
|
||||||
|
|
||||||
|
# Always include at least the first document; afterwards enforce budget.
|
||||||
|
if doc_idx > 0 and total_chars + doc_len > max_chars:
|
||||||
|
remaining = total_docs - doc_idx
|
||||||
|
parts.append(
|
||||||
|
f"<!-- Output truncated: {remaining} more document(s) omitted "
|
||||||
|
f"(budget {max_chars} chars). Refine your query or reduce top_k "
|
||||||
|
f"to retrieve different results. -->"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
parts.append(doc_xml)
|
||||||
|
total_chars += doc_len
|
||||||
|
|
||||||
return "\n".join(parts).strip()
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
@ -328,6 +388,7 @@ async def search_knowledge_base_async(
|
||||||
start_date: datetime | None = None,
|
start_date: datetime | None = None,
|
||||||
end_date: datetime | None = None,
|
end_date: datetime | None = None,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
|
max_input_tokens: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Search the user's knowledge base for relevant documents.
|
Search the user's knowledge base for relevant documents.
|
||||||
|
|
@ -345,6 +406,8 @@ async def search_knowledge_base_async(
|
||||||
end_date: Optional end datetime (UTC) for filtering documents
|
end_date: Optional end datetime (UTC) for filtering documents
|
||||||
available_connectors: Optional list of connectors actually available in the search space.
|
available_connectors: Optional list of connectors actually available in the search space.
|
||||||
If provided, only these connectors will be searched.
|
If provided, only these connectors will be searched.
|
||||||
|
max_input_tokens: Model context window size (tokens). Used to dynamically
|
||||||
|
size the output so it fits within the model's limits.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted string with search results
|
Formatted string with search results
|
||||||
|
|
@ -488,7 +551,8 @@ async def search_knowledge_base_async(
|
||||||
|
|
||||||
deduplicated.append(doc)
|
deduplicated.append(doc)
|
||||||
|
|
||||||
return format_documents_for_context(deduplicated)
|
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||||
|
return format_documents_for_context(deduplicated, max_chars=output_budget)
|
||||||
|
|
||||||
|
|
||||||
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
|
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
|
||||||
|
|
@ -552,6 +616,7 @@ def create_search_knowledge_base_tool(
|
||||||
connector_service: ConnectorService,
|
connector_service: ConnectorService,
|
||||||
available_connectors: list[str] | None = None,
|
available_connectors: list[str] | None = None,
|
||||||
available_document_types: list[str] | None = None,
|
available_document_types: list[str] | None = None,
|
||||||
|
max_input_tokens: int | None = None,
|
||||||
) -> StructuredTool:
|
) -> StructuredTool:
|
||||||
"""
|
"""
|
||||||
Factory function to create the search_knowledge_base tool with injected dependencies.
|
Factory function to create the search_knowledge_base tool with injected dependencies.
|
||||||
|
|
@ -564,6 +629,8 @@ def create_search_knowledge_base_tool(
|
||||||
Used to dynamically generate the tool docstring.
|
Used to dynamically generate the tool docstring.
|
||||||
available_document_types: Optional list of document types that have data in the search space.
|
available_document_types: Optional list of document types that have data in the search space.
|
||||||
Used to inform the LLM about what data exists.
|
Used to inform the LLM about what data exists.
|
||||||
|
max_input_tokens: Model context window (tokens) from litellm model info.
|
||||||
|
Used to dynamically size tool output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A configured StructuredTool instance
|
A configured StructuredTool instance
|
||||||
|
|
@ -634,6 +701,7 @@ NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type
|
||||||
start_date=parsed_start,
|
start_date=parsed_start,
|
||||||
end_date=parsed_end,
|
end_date=parsed_end,
|
||||||
available_connectors=_available_connectors,
|
available_connectors=_available_connectors,
|
||||||
|
max_input_tokens=max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create StructuredTool with dynamic description
|
# Create StructuredTool with dynamic description
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
# Optional: dynamically discovered connectors/document types
|
# Optional: dynamically discovered connectors/document types
|
||||||
available_connectors=deps.get("available_connectors"),
|
available_connectors=deps.get("available_connectors"),
|
||||||
available_document_types=deps.get("available_document_types"),
|
available_document_types=deps.get("available_document_types"),
|
||||||
|
max_input_tokens=deps.get("max_input_tokens"),
|
||||||
),
|
),
|
||||||
requires=["search_space_id", "db_session", "connector_service"],
|
requires=["search_space_id", "db_session", "connector_service"],
|
||||||
# Note: available_connectors and available_document_types are optional
|
# Note: available_connectors and available_document_types are optional
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ synchronous ChatLiteLLM-like interface and async methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
|
@ -20,10 +21,26 @@ from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.exceptions import ContextWindowExceededError
|
from litellm.exceptions import (
|
||||||
|
BadRequestError as LiteLLMBadRequestError,
|
||||||
|
ContextWindowExceededError,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
|
||||||
|
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
|
||||||
|
r"maximum context length|token.{0,20}(limit|exceed)|"
|
||||||
|
r"too many tokens|reduce the length)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
||||||
|
"""Check if a BadRequestError is actually a context window overflow."""
|
||||||
|
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||||
|
|
||||||
|
|
||||||
# Special ID for Auto mode - uses router for load balancing
|
# Special ID for Auto mode - uses router for load balancing
|
||||||
AUTO_MODE_ID = 0
|
AUTO_MODE_ID = 0
|
||||||
|
|
||||||
|
|
@ -236,6 +253,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
|
|
||||||
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
||||||
making it a drop-in replacement for auto-mode routing.
|
making it a drop-in replacement for auto-mode routing.
|
||||||
|
|
||||||
|
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
||||||
|
window across all router deployments so that deepagents
|
||||||
|
SummarizationMiddleware can use fraction-based triggers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Use model_config for Pydantic v2 compatibility
|
# Use model_config for Pydantic v2 compatibility
|
||||||
|
|
@ -267,7 +288,6 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
# Store router and tools as private attributes
|
|
||||||
resolved_router = router or LLMRouterService.get_router()
|
resolved_router = router or LLMRouterService.get_router()
|
||||||
object.__setattr__(self, "_router", resolved_router)
|
object.__setattr__(self, "_router", resolved_router)
|
||||||
object.__setattr__(self, "_bound_tools", bound_tools)
|
object.__setattr__(self, "_bound_tools", bound_tools)
|
||||||
|
|
@ -276,6 +296,12 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
|
||||||
|
computed_profile = self._compute_min_context_profile()
|
||||||
|
if computed_profile is not None:
|
||||||
|
object.__setattr__(self, "profile", computed_profile)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||||
)
|
)
|
||||||
|
|
@ -283,6 +309,39 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _compute_min_context_profile(self) -> dict | None:
|
||||||
|
"""Derive a profile dict with max_input_tokens from router deployments.
|
||||||
|
|
||||||
|
Uses litellm.get_model_info to look up each deployment's context window
|
||||||
|
and picks the *minimum* so that summarization triggers before ANY model
|
||||||
|
in the pool overflows.
|
||||||
|
"""
|
||||||
|
from litellm import get_model_info
|
||||||
|
|
||||||
|
if not self._router:
|
||||||
|
return None
|
||||||
|
|
||||||
|
min_ctx: int | None = None
|
||||||
|
for deployment in self._router.model_list:
|
||||||
|
params = deployment.get("litellm_params", {})
|
||||||
|
base_model = params.get("base_model") or params.get("model", "")
|
||||||
|
try:
|
||||||
|
info = get_model_info(base_model)
|
||||||
|
ctx = info.get("max_input_tokens")
|
||||||
|
if (
|
||||||
|
isinstance(ctx, int)
|
||||||
|
and ctx > 0
|
||||||
|
and (min_ctx is None or ctx < min_ctx)
|
||||||
|
):
|
||||||
|
min_ctx = ctx
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if min_ctx is not None:
|
||||||
|
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
|
||||||
|
return {"max_input_tokens": min_ctx}
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "litellm-router"
|
return "litellm-router"
|
||||||
|
|
@ -370,6 +429,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
)
|
)
|
||||||
except ContextWindowExceededError as e:
|
except ContextWindowExceededError as e:
|
||||||
raise ContextOverflowError(str(e)) from e
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
except LiteLLMBadRequestError as e:
|
||||||
|
if _is_context_overflow_error(e):
|
||||||
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
raise
|
||||||
|
|
||||||
# Convert response to ChatResult with potential tool calls
|
# Convert response to ChatResult with potential tool calls
|
||||||
message = self._convert_response_to_message(response.choices[0].message)
|
message = self._convert_response_to_message(response.choices[0].message)
|
||||||
|
|
@ -409,6 +472,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
)
|
)
|
||||||
except ContextWindowExceededError as e:
|
except ContextWindowExceededError as e:
|
||||||
raise ContextOverflowError(str(e)) from e
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
except LiteLLMBadRequestError as e:
|
||||||
|
if _is_context_overflow_error(e):
|
||||||
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
raise
|
||||||
|
|
||||||
# Convert response to ChatResult with potential tool calls
|
# Convert response to ChatResult with potential tool calls
|
||||||
message = self._convert_response_to_message(response.choices[0].message)
|
message = self._convert_response_to_message(response.choices[0].message)
|
||||||
|
|
@ -448,6 +515,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
)
|
)
|
||||||
except ContextWindowExceededError as e:
|
except ContextWindowExceededError as e:
|
||||||
raise ContextOverflowError(str(e)) from e
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
except LiteLLMBadRequestError as e:
|
||||||
|
if _is_context_overflow_error(e):
|
||||||
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
raise
|
||||||
|
|
||||||
# Yield chunks
|
# Yield chunks
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
|
|
@ -489,6 +560,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
)
|
)
|
||||||
except ContextWindowExceededError as e:
|
except ContextWindowExceededError as e:
|
||||||
raise ContextOverflowError(str(e)) from e
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
except LiteLLMBadRequestError as e:
|
||||||
|
if _is_context_overflow_error(e):
|
||||||
|
raise ContextOverflowError(str(e)) from e
|
||||||
|
raise
|
||||||
|
|
||||||
# Yield chunks asynchronously
|
# Yield chunks asynchronously
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue