mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 01:06:23 +02:00
feat: Enhance LLM configuration and routing with model profile attachment
- Added `_attach_model_profile` function to attach model context metadata to `ChatLiteLLM`. - Updated `create_chat_litellm_from_config` and `create_chat_litellm_from_agent_config` to utilize the new profile attachment. - Improved context profile caching in `llm_router_service.py` to include both minimum and maximum input tokens, along with token model names for better context management. - Introduced new methods for token counting and context trimming based on model profiles.
This commit is contained in:
parent
5571e8aa53
commit
eec4db4a3b
2 changed files with 310 additions and 7 deletions
|
|
@ -11,6 +11,7 @@ The router is initialized from global LLM configs and provides both
|
|||
synchronous ChatLiteLLM-like interface and async methods.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
|
|
@ -331,11 +332,16 @@ _router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
|
|||
|
||||
|
||||
def _get_cached_context_profile(router: Router) -> dict | None:
|
||||
"""Compute and cache the min context profile across all router deployments.
|
||||
"""Compute and cache context profile across all router deployments.
|
||||
|
||||
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
||||
the cached value. This avoids calling litellm.get_model_info() for every
|
||||
deployment on every request.
|
||||
|
||||
Caches both ``max_input_tokens`` (minimum across deployments, used by
|
||||
SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across
|
||||
deployments, used for context-trimming so we can target the largest
|
||||
available model — the router's fallback logic handles smaller ones).
|
||||
"""
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
if _cached_context_profile_computed:
|
||||
|
|
@ -344,20 +350,50 @@ def _get_cached_context_profile(router: Router) -> dict | None:
|
|||
from litellm import get_model_info
|
||||
|
||||
min_ctx: int | None = None
|
||||
max_ctx: int | None = None
|
||||
token_count_model: str | None = None
|
||||
ctx_pairs: list[tuple[int, str]] = []
|
||||
for deployment in router.model_list:
|
||||
params = deployment.get("litellm_params", {})
|
||||
base_model = params.get("base_model") or params.get("model", "")
|
||||
try:
|
||||
info = get_model_info(base_model)
|
||||
ctx = info.get("max_input_tokens")
|
||||
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
|
||||
min_ctx = ctx
|
||||
if isinstance(ctx, int) and ctx > 0:
|
||||
if min_ctx is None or ctx < min_ctx:
|
||||
min_ctx = ctx
|
||||
if max_ctx is None or ctx > max_ctx:
|
||||
max_ctx = ctx
|
||||
if token_count_model is None:
|
||||
token_count_model = base_model
|
||||
ctx_pairs.append((ctx, base_model))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if min_ctx is not None:
|
||||
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
|
||||
_cached_context_profile = {"max_input_tokens": min_ctx}
|
||||
token_count_models: list[str] = []
|
||||
if token_count_model:
|
||||
token_count_models.append(token_count_model)
|
||||
if ctx_pairs:
|
||||
ctx_pairs.sort(key=lambda x: x[0])
|
||||
smallest_model = ctx_pairs[0][1]
|
||||
largest_model = ctx_pairs[-1][1]
|
||||
if smallest_model not in token_count_models:
|
||||
token_count_models.append(smallest_model)
|
||||
if largest_model not in token_count_models:
|
||||
token_count_models.append(largest_model)
|
||||
logger.info(
|
||||
"ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s",
|
||||
min_ctx,
|
||||
max_ctx,
|
||||
token_count_models,
|
||||
)
|
||||
_cached_context_profile = {
|
||||
"max_input_tokens": min_ctx,
|
||||
"max_input_tokens_upper": max_ctx,
|
||||
"token_count_model": token_count_model,
|
||||
"token_count_models": token_count_models,
|
||||
}
|
||||
else:
|
||||
_cached_context_profile = None
|
||||
|
||||
|
|
@ -425,6 +461,248 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||
raise
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Context-aware trimming helpers
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _get_token_count_model_names(self) -> list[str]:
|
||||
"""Return concrete model names usable by ``litellm.token_counter``.
|
||||
|
||||
The router uses ``"auto"`` as the model group name but tokenizers need
|
||||
concrete model identifiers. We keep multiple candidates and take the
|
||||
most conservative count across them.
|
||||
"""
|
||||
names: list[str] = []
|
||||
profile = getattr(self, "profile", None)
|
||||
if isinstance(profile, dict):
|
||||
tcms = profile.get("token_count_models")
|
||||
if isinstance(tcms, list):
|
||||
for name in tcms:
|
||||
if isinstance(name, str) and name and name not in names:
|
||||
names.append(name)
|
||||
tcm = profile.get("token_count_model")
|
||||
if isinstance(tcm, str) and tcm and tcm not in names:
|
||||
names.append(tcm)
|
||||
|
||||
if self._router and self._router.model_list:
|
||||
for dep in self._router.model_list:
|
||||
params = dep.get("litellm_params", {})
|
||||
base = params.get("base_model") or params.get("model", "")
|
||||
if base and base not in names:
|
||||
names.append(base)
|
||||
if len(names) >= 3:
|
||||
break
|
||||
if not names:
|
||||
return ["gpt-4o"]
|
||||
return names
|
||||
|
||||
def _count_tokens(self, messages: list[dict]) -> int | None:
|
||||
"""Return conservative token count across candidate deployment models."""
|
||||
from litellm import token_counter as _tc
|
||||
|
||||
models = self._get_token_count_model_names()
|
||||
counts: list[int] = []
|
||||
for model_name in models:
|
||||
try:
|
||||
counts.append(_tc(messages=messages, model=model_name))
|
||||
except Exception:
|
||||
continue
|
||||
return max(counts) if counts else None
|
||||
|
||||
def _get_max_input_tokens(self) -> int:
|
||||
"""Return the max input tokens to use for context trimming.
|
||||
|
||||
Prefers the *largest* context window across all deployments so we
|
||||
maximise usable context (the router's ``context_window_fallbacks``
|
||||
handle routing to the right model). Falls back to the minimum
|
||||
profile value or a conservative default.
|
||||
"""
|
||||
profile = getattr(self, "profile", None)
|
||||
if isinstance(profile, dict):
|
||||
upper = profile.get("max_input_tokens_upper")
|
||||
if isinstance(upper, int) and upper > 0:
|
||||
return upper
|
||||
lower = profile.get("max_input_tokens")
|
||||
if isinstance(lower, int) and lower > 0:
|
||||
return lower
|
||||
return 128_000
|
||||
|
||||
def _trim_messages_to_fit_context(
|
||||
self,
|
||||
messages: list[dict],
|
||||
output_reserve_fraction: float = 0.10,
|
||||
) -> list[dict]:
|
||||
"""Trim message content via binary search to fit the model's context window.
|
||||
|
||||
When the total token count exceeds the model's ``max_input_tokens``,
|
||||
this method identifies the largest messages (typically tool responses
|
||||
containing search results) and uses binary search on each to find the
|
||||
maximum content length that keeps the total within budget.
|
||||
|
||||
Cutting prefers ``</document>`` XML boundaries so complete documents
|
||||
are preserved when possible.
|
||||
|
||||
This is model-aware: it reads the context limit from
|
||||
``litellm.get_model_info`` (cached in ``self.profile``) and counts
|
||||
tokens with ``litellm.token_counter``.
|
||||
"""
|
||||
max_input = self._get_max_input_tokens()
|
||||
output_reserve = min(int(max_input * output_reserve_fraction), 16_384)
|
||||
budget = max_input - output_reserve
|
||||
|
||||
total_tokens = self._count_tokens(messages)
|
||||
if total_tokens is None:
|
||||
return messages
|
||||
|
||||
if total_tokens <= budget:
|
||||
return messages
|
||||
|
||||
perf = get_perf_logger()
|
||||
perf.warning(
|
||||
"[llm_router] context overflow detected: %d tokens > %d budget "
|
||||
"(max_input=%d, reserve=%d). Trimming messages.",
|
||||
total_tokens,
|
||||
budget,
|
||||
max_input,
|
||||
output_reserve,
|
||||
)
|
||||
|
||||
trimmed = copy.deepcopy(messages)
|
||||
|
||||
# Per-message token counts for trimmable candidates.
|
||||
# Skip system messages to preserve agent instructions.
|
||||
msg_token_map: dict[int, int] = {}
|
||||
candidate_priority: dict[int, int] = {}
|
||||
for i, msg in enumerate(trimmed):
|
||||
if msg.get("role") == "system":
|
||||
continue
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if not isinstance(content, str) or len(content) < 500:
|
||||
continue
|
||||
# Prefer trimming tool/assistant outputs first.
|
||||
# User messages are only trimmed if they clearly contain injected
|
||||
# document context blobs.
|
||||
is_doc_blob = "<document>" in content or "<mentioned_documents>" in content
|
||||
if role in ("tool", "assistant"):
|
||||
candidate_priority[i] = 0
|
||||
elif role == "user" and is_doc_blob:
|
||||
candidate_priority[i] = 1
|
||||
else:
|
||||
continue
|
||||
token_count = self._count_tokens([msg])
|
||||
if token_count is not None:
|
||||
msg_token_map[i] = token_count
|
||||
|
||||
if not msg_token_map:
|
||||
perf.warning("[llm_router] no trimmable messages found, returning as-is")
|
||||
return trimmed
|
||||
|
||||
# Trim largest messages first
|
||||
candidates = sorted(
|
||||
msg_token_map.items(),
|
||||
key=lambda x: (candidate_priority.get(x[0], 9), -x[1]),
|
||||
)
|
||||
running_total = total_tokens
|
||||
|
||||
trim_suffix = (
|
||||
"\n\n<!-- Content trimmed to fit model context window. "
|
||||
"Some documents were omitted. Refine your query or "
|
||||
"reduce top_k for different results. -->"
|
||||
)
|
||||
|
||||
for idx, orig_msg_tokens in candidates:
|
||||
if running_total <= budget:
|
||||
break
|
||||
|
||||
content = trimmed[idx]["content"]
|
||||
orig_len = len(content)
|
||||
|
||||
# Binary search: find maximum content[:mid] that keeps total ≤ budget.
|
||||
lo, hi = 200, orig_len - 1
|
||||
best = 200
|
||||
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
trimmed[idx]["content"] = content[:mid] + trim_suffix
|
||||
new_msg_tokens = self._count_tokens([trimmed[idx]])
|
||||
if new_msg_tokens is None:
|
||||
hi = mid - 1
|
||||
continue
|
||||
|
||||
projected_total = running_total - orig_msg_tokens + new_msg_tokens
|
||||
if projected_total <= budget:
|
||||
best = mid
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
|
||||
# Prefer cutting at a </document> boundary for cleaner output
|
||||
last_doc_end = content[:best].rfind("</document>")
|
||||
if last_doc_end > min(200, best // 4):
|
||||
best = last_doc_end + len("</document>")
|
||||
|
||||
trimmed[idx]["content"] = content[:best] + trim_suffix
|
||||
|
||||
try:
|
||||
new_msg_tokens = self._count_tokens([trimmed[idx]])
|
||||
if new_msg_tokens is None:
|
||||
continue
|
||||
running_total = running_total - orig_msg_tokens + new_msg_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Hard guarantee: if still over budget, replace remaining large
|
||||
# non-system messages with compact placeholders until we fit.
|
||||
if running_total > budget:
|
||||
fallback_indices: list[int] = []
|
||||
for i, msg in enumerate(trimmed):
|
||||
if msg.get("role") == "system":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str) and len(content) > 0:
|
||||
fallback_indices.append(i)
|
||||
|
||||
for idx in fallback_indices:
|
||||
if running_total <= budget:
|
||||
break
|
||||
role = trimmed[idx].get("role", "message")
|
||||
placeholder = (
|
||||
f"[content omitted to fit model context window; role={role}]"
|
||||
)
|
||||
old_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||
trimmed[idx]["content"] = placeholder
|
||||
new_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||
running_total = running_total - old_tokens + new_tokens
|
||||
|
||||
if running_total > budget:
|
||||
perf.error(
|
||||
"[llm_router] unable to fit context even after aggressive trimming: "
|
||||
"tokens=%d budget=%d",
|
||||
running_total,
|
||||
budget,
|
||||
)
|
||||
# Final safety net: clear oldest non-system contents.
|
||||
for idx in fallback_indices:
|
||||
if running_total <= budget:
|
||||
break
|
||||
old_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||
trimmed[idx]["content"] = ""
|
||||
new_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||
running_total = running_total - old_tokens + new_tokens
|
||||
|
||||
perf.info(
|
||||
"[llm_router] messages trimmed: %d → %d tokens (budget=%d, max_input=%d)",
|
||||
total_tokens,
|
||||
running_total,
|
||||
budget,
|
||||
max_input,
|
||||
)
|
||||
|
||||
return trimmed
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-router"
|
||||
|
|
@ -499,6 +777,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
|
|
@ -564,6 +843,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
|
||||
# Convert LangChain messages to OpenAI format
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
|
|
@ -624,6 +904,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
raise ValueError("Router not initialized")
|
||||
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
|
|
@ -673,6 +954,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
|||
msg_count = len(messages)
|
||||
|
||||
formatted_messages = self._convert_messages(messages)
|
||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||
|
||||
# Add tools if bound
|
||||
call_kwargs = {**kwargs}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue