mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-28 02:23:53 +02:00
feat: Enhance LLM configuration and routing with model profile attachment
- Added `_attach_model_profile` function to attach model context metadata to `ChatLiteLLM`. - Updated `create_chat_litellm_from_config` and `create_chat_litellm_from_agent_config` to utilize the new profile attachment. - Improved context profile caching in `llm_router_service.py` to include both minimum and maximum input tokens, along with token model names for better context management. - Introduced new methods for token counting and context trimming based on model profiles.
This commit is contained in:
parent
5571e8aa53
commit
eec4db4a3b
2 changed files with 310 additions and 7 deletions
|
|
@ -15,6 +15,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_litellm import ChatLiteLLM
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
from litellm import get_model_info
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
@ -62,6 +63,22 @@ PROVIDER_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||||
|
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
||||||
|
try:
|
||||||
|
info = get_model_info(model_string)
|
||||||
|
max_input_tokens = info.get("max_input_tokens")
|
||||||
|
if isinstance(max_input_tokens, int) and max_input_tokens > 0:
|
||||||
|
llm.profile = {
|
||||||
|
"max_input_tokens": max_input_tokens,
|
||||||
|
"max_input_tokens_upper": max_input_tokens,
|
||||||
|
"token_count_model": model_string,
|
||||||
|
"token_count_models": [model_string],
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AgentConfig:
|
class AgentConfig:
|
||||||
"""
|
"""
|
||||||
|
|
@ -366,7 +383,9 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
if llm_config.get("litellm_params"):
|
if llm_config.get("litellm_params"):
|
||||||
litellm_kwargs.update(llm_config["litellm_params"])
|
litellm_kwargs.update(llm_config["litellm_params"])
|
||||||
|
|
||||||
return ChatLiteLLM(**litellm_kwargs)
|
llm = ChatLiteLLM(**litellm_kwargs)
|
||||||
|
_attach_model_profile(llm, model_string)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
def create_chat_litellm_from_agent_config(
|
def create_chat_litellm_from_agent_config(
|
||||||
|
|
@ -419,4 +438,6 @@ def create_chat_litellm_from_agent_config(
|
||||||
if agent_config.litellm_params:
|
if agent_config.litellm_params:
|
||||||
litellm_kwargs.update(agent_config.litellm_params)
|
litellm_kwargs.update(agent_config.litellm_params)
|
||||||
|
|
||||||
return ChatLiteLLM(**litellm_kwargs)
|
llm = ChatLiteLLM(**litellm_kwargs)
|
||||||
|
_attach_model_profile(llm, model_string)
|
||||||
|
return llm
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ The router is initialized from global LLM configs and provides both
|
||||||
synchronous ChatLiteLLM-like interface and async methods.
|
synchronous ChatLiteLLM-like interface and async methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
@ -331,11 +332,16 @@ _router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_context_profile(router: Router) -> dict | None:
|
def _get_cached_context_profile(router: Router) -> dict | None:
|
||||||
"""Compute and cache the min context profile across all router deployments.
|
"""Compute and cache context profile across all router deployments.
|
||||||
|
|
||||||
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
||||||
the cached value. This avoids calling litellm.get_model_info() for every
|
the cached value. This avoids calling litellm.get_model_info() for every
|
||||||
deployment on every request.
|
deployment on every request.
|
||||||
|
|
||||||
|
Caches both ``max_input_tokens`` (minimum across deployments, used by
|
||||||
|
SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across
|
||||||
|
deployments, used for context-trimming so we can target the largest
|
||||||
|
available model — the router's fallback logic handles smaller ones).
|
||||||
"""
|
"""
|
||||||
global _cached_context_profile, _cached_context_profile_computed
|
global _cached_context_profile, _cached_context_profile_computed
|
||||||
if _cached_context_profile_computed:
|
if _cached_context_profile_computed:
|
||||||
|
|
@ -344,20 +350,50 @@ def _get_cached_context_profile(router: Router) -> dict | None:
|
||||||
from litellm import get_model_info
|
from litellm import get_model_info
|
||||||
|
|
||||||
min_ctx: int | None = None
|
min_ctx: int | None = None
|
||||||
|
max_ctx: int | None = None
|
||||||
|
token_count_model: str | None = None
|
||||||
|
ctx_pairs: list[tuple[int, str]] = []
|
||||||
for deployment in router.model_list:
|
for deployment in router.model_list:
|
||||||
params = deployment.get("litellm_params", {})
|
params = deployment.get("litellm_params", {})
|
||||||
base_model = params.get("base_model") or params.get("model", "")
|
base_model = params.get("base_model") or params.get("model", "")
|
||||||
try:
|
try:
|
||||||
info = get_model_info(base_model)
|
info = get_model_info(base_model)
|
||||||
ctx = info.get("max_input_tokens")
|
ctx = info.get("max_input_tokens")
|
||||||
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
|
if isinstance(ctx, int) and ctx > 0:
|
||||||
min_ctx = ctx
|
if min_ctx is None or ctx < min_ctx:
|
||||||
|
min_ctx = ctx
|
||||||
|
if max_ctx is None or ctx > max_ctx:
|
||||||
|
max_ctx = ctx
|
||||||
|
if token_count_model is None:
|
||||||
|
token_count_model = base_model
|
||||||
|
ctx_pairs.append((ctx, base_model))
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if min_ctx is not None:
|
if min_ctx is not None:
|
||||||
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
|
token_count_models: list[str] = []
|
||||||
_cached_context_profile = {"max_input_tokens": min_ctx}
|
if token_count_model:
|
||||||
|
token_count_models.append(token_count_model)
|
||||||
|
if ctx_pairs:
|
||||||
|
ctx_pairs.sort(key=lambda x: x[0])
|
||||||
|
smallest_model = ctx_pairs[0][1]
|
||||||
|
largest_model = ctx_pairs[-1][1]
|
||||||
|
if smallest_model not in token_count_models:
|
||||||
|
token_count_models.append(smallest_model)
|
||||||
|
if largest_model not in token_count_models:
|
||||||
|
token_count_models.append(largest_model)
|
||||||
|
logger.info(
|
||||||
|
"ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s",
|
||||||
|
min_ctx,
|
||||||
|
max_ctx,
|
||||||
|
token_count_models,
|
||||||
|
)
|
||||||
|
_cached_context_profile = {
|
||||||
|
"max_input_tokens": min_ctx,
|
||||||
|
"max_input_tokens_upper": max_ctx,
|
||||||
|
"token_count_model": token_count_model,
|
||||||
|
"token_count_models": token_count_models,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
_cached_context_profile = None
|
_cached_context_profile = None
|
||||||
|
|
||||||
|
|
@ -425,6 +461,248 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
# Context-aware trimming helpers
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_token_count_model_names(self) -> list[str]:
|
||||||
|
"""Return concrete model names usable by ``litellm.token_counter``.
|
||||||
|
|
||||||
|
The router uses ``"auto"`` as the model group name but tokenizers need
|
||||||
|
concrete model identifiers. We keep multiple candidates and take the
|
||||||
|
most conservative count across them.
|
||||||
|
"""
|
||||||
|
names: list[str] = []
|
||||||
|
profile = getattr(self, "profile", None)
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
tcms = profile.get("token_count_models")
|
||||||
|
if isinstance(tcms, list):
|
||||||
|
for name in tcms:
|
||||||
|
if isinstance(name, str) and name and name not in names:
|
||||||
|
names.append(name)
|
||||||
|
tcm = profile.get("token_count_model")
|
||||||
|
if isinstance(tcm, str) and tcm and tcm not in names:
|
||||||
|
names.append(tcm)
|
||||||
|
|
||||||
|
if self._router and self._router.model_list:
|
||||||
|
for dep in self._router.model_list:
|
||||||
|
params = dep.get("litellm_params", {})
|
||||||
|
base = params.get("base_model") or params.get("model", "")
|
||||||
|
if base and base not in names:
|
||||||
|
names.append(base)
|
||||||
|
if len(names) >= 3:
|
||||||
|
break
|
||||||
|
if not names:
|
||||||
|
return ["gpt-4o"]
|
||||||
|
return names
|
||||||
|
|
||||||
|
def _count_tokens(self, messages: list[dict]) -> int | None:
|
||||||
|
"""Return conservative token count across candidate deployment models."""
|
||||||
|
from litellm import token_counter as _tc
|
||||||
|
|
||||||
|
models = self._get_token_count_model_names()
|
||||||
|
counts: list[int] = []
|
||||||
|
for model_name in models:
|
||||||
|
try:
|
||||||
|
counts.append(_tc(messages=messages, model=model_name))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return max(counts) if counts else None
|
||||||
|
|
||||||
|
def _get_max_input_tokens(self) -> int:
|
||||||
|
"""Return the max input tokens to use for context trimming.
|
||||||
|
|
||||||
|
Prefers the *largest* context window across all deployments so we
|
||||||
|
maximise usable context (the router's ``context_window_fallbacks``
|
||||||
|
handle routing to the right model). Falls back to the minimum
|
||||||
|
profile value or a conservative default.
|
||||||
|
"""
|
||||||
|
profile = getattr(self, "profile", None)
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
upper = profile.get("max_input_tokens_upper")
|
||||||
|
if isinstance(upper, int) and upper > 0:
|
||||||
|
return upper
|
||||||
|
lower = profile.get("max_input_tokens")
|
||||||
|
if isinstance(lower, int) and lower > 0:
|
||||||
|
return lower
|
||||||
|
return 128_000
|
||||||
|
|
||||||
|
def _trim_messages_to_fit_context(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
output_reserve_fraction: float = 0.10,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Trim message content via binary search to fit the model's context window.
|
||||||
|
|
||||||
|
When the total token count exceeds the model's ``max_input_tokens``,
|
||||||
|
this method identifies the largest messages (typically tool responses
|
||||||
|
containing search results) and uses binary search on each to find the
|
||||||
|
maximum content length that keeps the total within budget.
|
||||||
|
|
||||||
|
Cutting prefers ``</document>`` XML boundaries so complete documents
|
||||||
|
are preserved when possible.
|
||||||
|
|
||||||
|
This is model-aware: it reads the context limit from
|
||||||
|
``litellm.get_model_info`` (cached in ``self.profile``) and counts
|
||||||
|
tokens with ``litellm.token_counter``.
|
||||||
|
"""
|
||||||
|
max_input = self._get_max_input_tokens()
|
||||||
|
output_reserve = min(int(max_input * output_reserve_fraction), 16_384)
|
||||||
|
budget = max_input - output_reserve
|
||||||
|
|
||||||
|
total_tokens = self._count_tokens(messages)
|
||||||
|
if total_tokens is None:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
if total_tokens <= budget:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
perf = get_perf_logger()
|
||||||
|
perf.warning(
|
||||||
|
"[llm_router] context overflow detected: %d tokens > %d budget "
|
||||||
|
"(max_input=%d, reserve=%d). Trimming messages.",
|
||||||
|
total_tokens,
|
||||||
|
budget,
|
||||||
|
max_input,
|
||||||
|
output_reserve,
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = copy.deepcopy(messages)
|
||||||
|
|
||||||
|
# Per-message token counts for trimmable candidates.
|
||||||
|
# Skip system messages to preserve agent instructions.
|
||||||
|
msg_token_map: dict[int, int] = {}
|
||||||
|
candidate_priority: dict[int, int] = {}
|
||||||
|
for i, msg in enumerate(trimmed):
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
continue
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if not isinstance(content, str) or len(content) < 500:
|
||||||
|
continue
|
||||||
|
# Prefer trimming tool/assistant outputs first.
|
||||||
|
# User messages are only trimmed if they clearly contain injected
|
||||||
|
# document context blobs.
|
||||||
|
is_doc_blob = "<document>" in content or "<mentioned_documents>" in content
|
||||||
|
if role in ("tool", "assistant"):
|
||||||
|
candidate_priority[i] = 0
|
||||||
|
elif role == "user" and is_doc_blob:
|
||||||
|
candidate_priority[i] = 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
token_count = self._count_tokens([msg])
|
||||||
|
if token_count is not None:
|
||||||
|
msg_token_map[i] = token_count
|
||||||
|
|
||||||
|
if not msg_token_map:
|
||||||
|
perf.warning("[llm_router] no trimmable messages found, returning as-is")
|
||||||
|
return trimmed
|
||||||
|
|
||||||
|
# Trim largest messages first
|
||||||
|
candidates = sorted(
|
||||||
|
msg_token_map.items(),
|
||||||
|
key=lambda x: (candidate_priority.get(x[0], 9), -x[1]),
|
||||||
|
)
|
||||||
|
running_total = total_tokens
|
||||||
|
|
||||||
|
trim_suffix = (
|
||||||
|
"\n\n<!-- Content trimmed to fit model context window. "
|
||||||
|
"Some documents were omitted. Refine your query or "
|
||||||
|
"reduce top_k for different results. -->"
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, orig_msg_tokens in candidates:
|
||||||
|
if running_total <= budget:
|
||||||
|
break
|
||||||
|
|
||||||
|
content = trimmed[idx]["content"]
|
||||||
|
orig_len = len(content)
|
||||||
|
|
||||||
|
# Binary search: find maximum content[:mid] that keeps total ≤ budget.
|
||||||
|
lo, hi = 200, orig_len - 1
|
||||||
|
best = 200
|
||||||
|
|
||||||
|
while lo <= hi:
|
||||||
|
mid = (lo + hi) // 2
|
||||||
|
trimmed[idx]["content"] = content[:mid] + trim_suffix
|
||||||
|
new_msg_tokens = self._count_tokens([trimmed[idx]])
|
||||||
|
if new_msg_tokens is None:
|
||||||
|
hi = mid - 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
projected_total = running_total - orig_msg_tokens + new_msg_tokens
|
||||||
|
if projected_total <= budget:
|
||||||
|
best = mid
|
||||||
|
lo = mid + 1
|
||||||
|
else:
|
||||||
|
hi = mid - 1
|
||||||
|
|
||||||
|
# Prefer cutting at a </document> boundary for cleaner output
|
||||||
|
last_doc_end = content[:best].rfind("</document>")
|
||||||
|
if last_doc_end > min(200, best // 4):
|
||||||
|
best = last_doc_end + len("</document>")
|
||||||
|
|
||||||
|
trimmed[idx]["content"] = content[:best] + trim_suffix
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_msg_tokens = self._count_tokens([trimmed[idx]])
|
||||||
|
if new_msg_tokens is None:
|
||||||
|
continue
|
||||||
|
running_total = running_total - orig_msg_tokens + new_msg_tokens
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Hard guarantee: if still over budget, replace remaining large
|
||||||
|
# non-system messages with compact placeholders until we fit.
|
||||||
|
if running_total > budget:
|
||||||
|
fallback_indices: list[int] = []
|
||||||
|
for i, msg in enumerate(trimmed):
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
continue
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str) and len(content) > 0:
|
||||||
|
fallback_indices.append(i)
|
||||||
|
|
||||||
|
for idx in fallback_indices:
|
||||||
|
if running_total <= budget:
|
||||||
|
break
|
||||||
|
role = trimmed[idx].get("role", "message")
|
||||||
|
placeholder = (
|
||||||
|
f"[content omitted to fit model context window; role={role}]"
|
||||||
|
)
|
||||||
|
old_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||||
|
trimmed[idx]["content"] = placeholder
|
||||||
|
new_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||||
|
running_total = running_total - old_tokens + new_tokens
|
||||||
|
|
||||||
|
if running_total > budget:
|
||||||
|
perf.error(
|
||||||
|
"[llm_router] unable to fit context even after aggressive trimming: "
|
||||||
|
"tokens=%d budget=%d",
|
||||||
|
running_total,
|
||||||
|
budget,
|
||||||
|
)
|
||||||
|
# Final safety net: clear oldest non-system contents.
|
||||||
|
for idx in fallback_indices:
|
||||||
|
if running_total <= budget:
|
||||||
|
break
|
||||||
|
old_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||||
|
trimmed[idx]["content"] = ""
|
||||||
|
new_tokens = self._count_tokens([trimmed[idx]]) or 0
|
||||||
|
running_total = running_total - old_tokens + new_tokens
|
||||||
|
|
||||||
|
perf.info(
|
||||||
|
"[llm_router] messages trimmed: %d → %d tokens (budget=%d, max_input=%d)",
|
||||||
|
total_tokens,
|
||||||
|
running_total,
|
||||||
|
budget,
|
||||||
|
max_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
return trimmed
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "litellm-router"
|
return "litellm-router"
|
||||||
|
|
@ -499,6 +777,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
|
|
||||||
# Convert LangChain messages to OpenAI format
|
# Convert LangChain messages to OpenAI format
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Add tools if bound
|
||||||
call_kwargs = {**kwargs}
|
call_kwargs = {**kwargs}
|
||||||
|
|
@ -564,6 +843,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
|
|
||||||
# Convert LangChain messages to OpenAI format
|
# Convert LangChain messages to OpenAI format
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Add tools if bound
|
||||||
call_kwargs = {**kwargs}
|
call_kwargs = {**kwargs}
|
||||||
|
|
@ -624,6 +904,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
raise ValueError("Router not initialized")
|
raise ValueError("Router not initialized")
|
||||||
|
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Add tools if bound
|
||||||
call_kwargs = {**kwargs}
|
call_kwargs = {**kwargs}
|
||||||
|
|
@ -673,6 +954,7 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
msg_count = len(messages)
|
msg_count = len(messages)
|
||||||
|
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Add tools if bound
|
||||||
call_kwargs = {**kwargs}
|
call_kwargs = {**kwargs}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue