feat: Enhance LLM configuration and routing with model profile attachment

- Added `_attach_model_profile` function to attach model context metadata to `ChatLiteLLM`.
- Updated `create_chat_litellm_from_config` and `create_chat_litellm_from_agent_config` to utilize the new profile attachment.
- Improved context profile caching in `llm_router_service.py` to include both minimum and maximum input tokens, along with token model names for better context management.
- Introduced new methods for token counting and context trimming based on model profiles.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-10 18:18:59 -07:00
parent 5571e8aa53
commit eec4db4a3b
2 changed files with 310 additions and 7 deletions

View file

@ -11,6 +11,7 @@ The router is initialized from global LLM configs and provides both
synchronous ChatLiteLLM-like interface and async methods.
"""
import copy
import logging
import re
import time
@ -331,11 +332,16 @@ _router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
def _get_cached_context_profile(router: Router) -> dict | None:
"""Compute and cache the min context profile across all router deployments.
"""Compute and cache context profile across all router deployments.
Called once on first ChatLiteLLMRouter creation; subsequent calls return
the cached value. This avoids calling litellm.get_model_info() for every
deployment on every request.
Caches both ``max_input_tokens`` (minimum across deployments, used by
SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across
deployments, used for context-trimming so we can target the largest
available model the router's fallback logic handles smaller ones).
"""
global _cached_context_profile, _cached_context_profile_computed
if _cached_context_profile_computed:
@ -344,20 +350,50 @@ def _get_cached_context_profile(router: Router) -> dict | None:
from litellm import get_model_info
min_ctx: int | None = None
max_ctx: int | None = None
token_count_model: str | None = None
ctx_pairs: list[tuple[int, str]] = []
for deployment in router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
min_ctx = ctx
if isinstance(ctx, int) and ctx > 0:
if min_ctx is None or ctx < min_ctx:
min_ctx = ctx
if max_ctx is None or ctx > max_ctx:
max_ctx = ctx
if token_count_model is None:
token_count_model = base_model
ctx_pairs.append((ctx, base_model))
except Exception:
continue
if min_ctx is not None:
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
_cached_context_profile = {"max_input_tokens": min_ctx}
token_count_models: list[str] = []
if token_count_model:
token_count_models.append(token_count_model)
if ctx_pairs:
ctx_pairs.sort(key=lambda x: x[0])
smallest_model = ctx_pairs[0][1]
largest_model = ctx_pairs[-1][1]
if smallest_model not in token_count_models:
token_count_models.append(smallest_model)
if largest_model not in token_count_models:
token_count_models.append(largest_model)
logger.info(
"ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s",
min_ctx,
max_ctx,
token_count_models,
)
_cached_context_profile = {
"max_input_tokens": min_ctx,
"max_input_tokens_upper": max_ctx,
"token_count_model": token_count_model,
"token_count_models": token_count_models,
}
else:
_cached_context_profile = None
@ -425,6 +461,248 @@ class ChatLiteLLMRouter(BaseChatModel):
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise
# -----------------------------------------------------------------
# Context-aware trimming helpers
# -----------------------------------------------------------------
def _get_token_count_model_names(self) -> list[str]:
"""Return concrete model names usable by ``litellm.token_counter``.
The router uses ``"auto"`` as the model group name but tokenizers need
concrete model identifiers. We keep multiple candidates and take the
most conservative count across them.
"""
names: list[str] = []
profile = getattr(self, "profile", None)
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
for name in tcms:
if isinstance(name, str) and name and name not in names:
names.append(name)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in names:
names.append(tcm)
if self._router and self._router.model_list:
for dep in self._router.model_list:
params = dep.get("litellm_params", {})
base = params.get("base_model") or params.get("model", "")
if base and base not in names:
names.append(base)
if len(names) >= 3:
break
if not names:
return ["gpt-4o"]
return names
def _count_tokens(self, messages: list[dict]) -> int | None:
"""Return conservative token count across candidate deployment models."""
from litellm import token_counter as _tc
models = self._get_token_count_model_names()
counts: list[int] = []
for model_name in models:
try:
counts.append(_tc(messages=messages, model=model_name))
except Exception:
continue
return max(counts) if counts else None
def _get_max_input_tokens(self) -> int:
"""Return the max input tokens to use for context trimming.
Prefers the *largest* context window across all deployments so we
maximise usable context (the router's ``context_window_fallbacks``
handle routing to the right model). Falls back to the minimum
profile value or a conservative default.
"""
profile = getattr(self, "profile", None)
if isinstance(profile, dict):
upper = profile.get("max_input_tokens_upper")
if isinstance(upper, int) and upper > 0:
return upper
lower = profile.get("max_input_tokens")
if isinstance(lower, int) and lower > 0:
return lower
return 128_000
def _trim_messages_to_fit_context(
self,
messages: list[dict],
output_reserve_fraction: float = 0.10,
) -> list[dict]:
"""Trim message content via binary search to fit the model's context window.
When the total token count exceeds the model's ``max_input_tokens``,
this method identifies the largest messages (typically tool responses
containing search results) and uses binary search on each to find the
maximum content length that keeps the total within budget.
Cutting prefers ``</document>`` XML boundaries so complete documents
are preserved when possible.
This is model-aware: it reads the context limit from
``litellm.get_model_info`` (cached in ``self.profile``) and counts
tokens with ``litellm.token_counter``.
"""
max_input = self._get_max_input_tokens()
output_reserve = min(int(max_input * output_reserve_fraction), 16_384)
budget = max_input - output_reserve
total_tokens = self._count_tokens(messages)
if total_tokens is None:
return messages
if total_tokens <= budget:
return messages
perf = get_perf_logger()
perf.warning(
"[llm_router] context overflow detected: %d tokens > %d budget "
"(max_input=%d, reserve=%d). Trimming messages.",
total_tokens,
budget,
max_input,
output_reserve,
)
trimmed = copy.deepcopy(messages)
# Per-message token counts for trimmable candidates.
# Skip system messages to preserve agent instructions.
msg_token_map: dict[int, int] = {}
candidate_priority: dict[int, int] = {}
for i, msg in enumerate(trimmed):
if msg.get("role") == "system":
continue
role = msg.get("role")
content = msg.get("content", "")
if not isinstance(content, str) or len(content) < 500:
continue
# Prefer trimming tool/assistant outputs first.
# User messages are only trimmed if they clearly contain injected
# document context blobs.
is_doc_blob = "<document>" in content or "<mentioned_documents>" in content
if role in ("tool", "assistant"):
candidate_priority[i] = 0
elif role == "user" and is_doc_blob:
candidate_priority[i] = 1
else:
continue
token_count = self._count_tokens([msg])
if token_count is not None:
msg_token_map[i] = token_count
if not msg_token_map:
perf.warning("[llm_router] no trimmable messages found, returning as-is")
return trimmed
# Trim largest messages first
candidates = sorted(
msg_token_map.items(),
key=lambda x: (candidate_priority.get(x[0], 9), -x[1]),
)
running_total = total_tokens
trim_suffix = (
"\n\n<!-- Content trimmed to fit model context window. "
"Some documents were omitted. Refine your query or "
"reduce top_k for different results. -->"
)
for idx, orig_msg_tokens in candidates:
if running_total <= budget:
break
content = trimmed[idx]["content"]
orig_len = len(content)
# Binary search: find maximum content[:mid] that keeps total ≤ budget.
lo, hi = 200, orig_len - 1
best = 200
while lo <= hi:
mid = (lo + hi) // 2
trimmed[idx]["content"] = content[:mid] + trim_suffix
new_msg_tokens = self._count_tokens([trimmed[idx]])
if new_msg_tokens is None:
hi = mid - 1
continue
projected_total = running_total - orig_msg_tokens + new_msg_tokens
if projected_total <= budget:
best = mid
lo = mid + 1
else:
hi = mid - 1
# Prefer cutting at a </document> boundary for cleaner output
last_doc_end = content[:best].rfind("</document>")
if last_doc_end > min(200, best // 4):
best = last_doc_end + len("</document>")
trimmed[idx]["content"] = content[:best] + trim_suffix
try:
new_msg_tokens = self._count_tokens([trimmed[idx]])
if new_msg_tokens is None:
continue
running_total = running_total - orig_msg_tokens + new_msg_tokens
except Exception:
pass
# Hard guarantee: if still over budget, replace remaining large
# non-system messages with compact placeholders until we fit.
if running_total > budget:
fallback_indices: list[int] = []
for i, msg in enumerate(trimmed):
if msg.get("role") == "system":
continue
content = msg.get("content")
if isinstance(content, str) and len(content) > 0:
fallback_indices.append(i)
for idx in fallback_indices:
if running_total <= budget:
break
role = trimmed[idx].get("role", "message")
placeholder = (
f"[content omitted to fit model context window; role={role}]"
)
old_tokens = self._count_tokens([trimmed[idx]]) or 0
trimmed[idx]["content"] = placeholder
new_tokens = self._count_tokens([trimmed[idx]]) or 0
running_total = running_total - old_tokens + new_tokens
if running_total > budget:
perf.error(
"[llm_router] unable to fit context even after aggressive trimming: "
"tokens=%d budget=%d",
running_total,
budget,
)
# Final safety net: clear oldest non-system contents.
for idx in fallback_indices:
if running_total <= budget:
break
old_tokens = self._count_tokens([trimmed[idx]]) or 0
trimmed[idx]["content"] = ""
new_tokens = self._count_tokens([trimmed[idx]]) or 0
running_total = running_total - old_tokens + new_tokens
perf.info(
"[llm_router] messages trimmed: %d%d tokens (budget=%d, max_input=%d)",
total_tokens,
running_total,
budget,
max_input,
)
return trimmed
# -----------------------------------------------------------------
@property
def _llm_type(self) -> str:
return "litellm-router"
@ -499,6 +777,7 @@ class ChatLiteLLMRouter(BaseChatModel):
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
@ -564,6 +843,7 @@ class ChatLiteLLMRouter(BaseChatModel):
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
@ -624,6 +904,7 @@ class ChatLiteLLMRouter(BaseChatModel):
raise ValueError("Router not initialized")
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
@ -673,6 +954,7 @@ class ChatLiteLLMRouter(BaseChatModel):
msg_count = len(messages)
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}