mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 09:46:25 +02:00
feat: enhance document formatting and context management for LLM tools
- Introduced dynamic character budget calculation for document formatting based on model's context window. - Updated `format_documents_for_context` to respect character limits and improve output quality. - Added `max_input_tokens` parameter to various functions to facilitate context-aware processing. - Enhanced error handling for context overflow in LLM router service.
This commit is contained in:
parent
a4dc84d1ab
commit
1e4b8d3e89
4 changed files with 178 additions and 24 deletions
|
|
@ -12,6 +12,7 @@ synchronous ChatLiteLLM-like interface and async methods.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
|
|
@ -20,10 +21,26 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from litellm import Router
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
from litellm.exceptions import (
|
||||
BadRequestError as LiteLLMBadRequestError,
|
||||
ContextWindowExceededError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
|
||||
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
|
||||
r"maximum context length|token.{0,20}(limit|exceed)|"
|
||||
r"too many tokens|reduce the length)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
||||
"""Check if a BadRequestError is actually a context window overflow."""
|
||||
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
||||
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
|
|
@ -236,6 +253,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
|
||||
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
||||
making it a drop-in replacement for auto-mode routing.
|
||||
|
||||
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
||||
window across all router deployments so that deepagents
|
||||
SummarizationMiddleware can use fraction-based triggers.
|
||||
"""
|
||||
|
||||
# Use model_config for Pydantic v2 compatibility
|
||||
|
|
@ -267,7 +288,6 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
"""
|
||||
try:
|
||||
super().__init__(**kwargs)
|
||||
# Store router and tools as private attributes
|
||||
resolved_router = router or LLMRouterService.get_router()
|
||||
object.__setattr__(self, "_router", resolved_router)
|
||||
object.__setattr__(self, "_bound_tools", bound_tools)
|
||||
|
|
@ -276,6 +296,12 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
raise ValueError(
|
||||
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
||||
)
|
||||
|
||||
# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
|
||||
computed_profile = self._compute_min_context_profile()
|
||||
if computed_profile is not None:
|
||||
object.__setattr__(self, "profile", computed_profile)
|
||||
|
||||
logger.info(
|
||||
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
|
||||
)
|
||||
|
|
@ -283,6 +309,39 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
raise
|
||||
|
||||
def _compute_min_context_profile(self) -> dict | None:
|
||||
"""Derive a profile dict with max_input_tokens from router deployments.
|
||||
|
||||
Uses litellm.get_model_info to look up each deployment's context window
|
||||
and picks the *minimum* so that summarization triggers before ANY model
|
||||
in the pool overflows.
|
||||
"""
|
||||
from litellm import get_model_info
|
||||
|
||||
if not self._router:
|
||||
return None
|
||||
|
||||
min_ctx: int | None = None
|
||||
for deployment in self._router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if (
|
||||
isinstance(ctx, int)
|
||||
and ctx > 0
|
||||
and (min_ctx is None or ctx < min_ctx)
|
||||
):
|
||||
min_ctx = ctx
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
|
||||
return {"max_input_tokens": min_ctx}
|
||||
return None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-router"
|
||||
|
|
@ -370,6 +429,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
|
|
@ -409,6 +472,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Convert response to ChatResult with potential tool calls
|
||||
message = self._convert_response_to_message(response.choices[0].message)
|
||||
|
|
@ -448,6 +515,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Yield chunks
|
||||
for chunk in response:
|
||||
|
|
@ -489,6 +560,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
)
|
||||
except ContextWindowExceededError as e:
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
except LiteLLMBadRequestError as e:
|
||||
if _is_context_overflow_error(e):
|
||||
raise ContextOverflowError(str(e)) from e
|
||||
raise
|
||||
|
||||
# Yield chunks asynchronously
|
||||
async for chunk in response:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue