mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 16:56:22 +02:00
Add full MiniMax provider support across the entire stack: Backend: - Add MINIMAX to LiteLLMProvider enum in db.py - Add MINIMAX mapping to all provider_map dicts in llm_service.py, llm_router_service.py, and llm_config.py - Add Alembic migration (rev 106) for PostgreSQL enum - Add MiniMax M2.5 example in global_llm_config.example.yaml Frontend: - Add MiniMax to LLM_PROVIDERS enum with apiBase - Add MiniMax-M2.5 and MiniMax-M2.5-highspeed to LLM_MODELS - Add MINIMAX to Zod validation schema - Add MiniMax SVG icon and wire up in provider-icons Docs: - Add MiniMax setup guide in chinese-llm-setup.md MiniMax uses an OpenAI-compatible API (https://api.minimax.io/v1) with models supporting up to 204K context window. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1173 lines
42 KiB
Python
1173 lines
42 KiB
Python
"""
|
|
LiteLLM Router Service for Load Balancing
|
|
|
|
This module provides a singleton LiteLLM Router for automatic load balancing
|
|
across multiple LLM deployments. It handles:
|
|
- Rate limit management with automatic cooldowns
|
|
- Automatic failover and retries
|
|
- Usage-based routing to distribute load evenly
|
|
|
|
The router is initialized from global LLM configs and provides both
|
|
synchronous ChatLiteLLM-like interface and async methods.
|
|
"""
|
|
|
|
import copy
|
|
import logging
|
|
import re
|
|
import time
|
|
from typing import Any
|
|
|
|
import litellm
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.exceptions import ContextOverflowError
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from litellm import Router
|
|
from litellm.exceptions import (
|
|
BadRequestError as LiteLLMBadRequestError,
|
|
ContextWindowExceededError,
|
|
)
|
|
|
|
from app.utils.perf import get_perf_logger
|
|
|
|
litellm.json_logs = False
|
|
litellm.store_audit_logs = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
|
|
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
|
|
r"maximum context length|token.{0,20}(limit|exceed)|"
|
|
r"too many tokens|reduce the length)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
|
|
"""Check if a BadRequestError is actually a context window overflow."""
|
|
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
|
|
|
|
|
|
# Special ID for Auto mode - uses router for load balancing
|
|
AUTO_MODE_ID = 0
|
|
|
|
# Provider mapping for LiteLLM model string construction
|
|
PROVIDER_MAP = {
|
|
"OPENAI": "openai",
|
|
"ANTHROPIC": "anthropic",
|
|
"GROQ": "groq",
|
|
"COHERE": "cohere",
|
|
"GOOGLE": "gemini",
|
|
"OLLAMA": "ollama_chat",
|
|
"MISTRAL": "mistral",
|
|
"AZURE_OPENAI": "azure",
|
|
"OPENROUTER": "openrouter",
|
|
"COMETAPI": "cometapi",
|
|
"XAI": "xai",
|
|
"BEDROCK": "bedrock",
|
|
"AWS_BEDROCK": "bedrock", # Legacy support
|
|
"VERTEX_AI": "vertex_ai",
|
|
"TOGETHER_AI": "together_ai",
|
|
"FIREWORKS_AI": "fireworks_ai",
|
|
"REPLICATE": "replicate",
|
|
"PERPLEXITY": "perplexity",
|
|
"ANYSCALE": "anyscale",
|
|
"DEEPINFRA": "deepinfra",
|
|
"CEREBRAS": "cerebras",
|
|
"SAMBANOVA": "sambanova",
|
|
"AI21": "ai21",
|
|
"CLOUDFLARE": "cloudflare",
|
|
"DATABRICKS": "databricks",
|
|
"DEEPSEEK": "openai",
|
|
"ALIBABA_QWEN": "openai",
|
|
"MOONSHOT": "openai",
|
|
"ZHIPU": "openai",
|
|
"GITHUB_MODELS": "github",
|
|
"HUGGINGFACE": "huggingface",
|
|
"MINIMAX": "openai",
|
|
"CUSTOM": "custom",
|
|
}
|
|
|
|
|
|
class LLMRouterService:
|
|
"""
|
|
Singleton service for managing LiteLLM Router.
|
|
|
|
The router provides automatic load balancing, failover, and rate limit
|
|
handling across multiple LLM deployments.
|
|
"""
|
|
|
|
_instance = None
|
|
_router: Router | None = None
|
|
_model_list: list[dict] = []
|
|
_router_settings: dict = {}
|
|
_initialized: bool = False
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> "LLMRouterService":
|
|
"""Get the singleton instance of the router service."""
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def initialize(
|
|
cls,
|
|
global_configs: list[dict],
|
|
router_settings: dict | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the router with global LLM configurations.
|
|
|
|
Args:
|
|
global_configs: List of global LLM config dictionaries from YAML
|
|
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
|
"""
|
|
instance = cls.get_instance()
|
|
|
|
if instance._initialized:
|
|
logger.debug("LLM Router already initialized, skipping")
|
|
return
|
|
|
|
# Build model list from global configs
|
|
model_list = []
|
|
for config in global_configs:
|
|
deployment = cls._config_to_deployment(config)
|
|
if deployment:
|
|
model_list.append(deployment)
|
|
|
|
if not model_list:
|
|
logger.warning("No valid LLM configs found for router initialization")
|
|
return
|
|
|
|
instance._model_list = model_list
|
|
instance._router_settings = router_settings or {}
|
|
|
|
# Default router settings optimized for rate limit handling
|
|
default_settings = {
|
|
"routing_strategy": "usage-based-routing", # Best for rate limit management
|
|
"num_retries": 3,
|
|
"allowed_fails": 3,
|
|
"cooldown_time": 60, # Cooldown for 60 seconds after failures
|
|
"retry_after": 5, # Wait 5 seconds between retries
|
|
}
|
|
|
|
# Merge with provided settings
|
|
final_settings = {**default_settings, **instance._router_settings}
|
|
|
|
# Build a "auto-large" fallback group with deployments whose context
|
|
# window exceeds the smallest deployment. This lets the router
|
|
# automatically fall back to a bigger-context model when gpt-4o (128K)
|
|
# hits ContextWindowExceededError.
|
|
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
|
|
|
|
try:
|
|
router_kwargs: dict[str, Any] = {
|
|
"model_list": full_model_list,
|
|
"routing_strategy": final_settings.get(
|
|
"routing_strategy", "usage-based-routing"
|
|
),
|
|
"num_retries": final_settings.get("num_retries", 3),
|
|
"allowed_fails": final_settings.get("allowed_fails", 3),
|
|
"cooldown_time": final_settings.get("cooldown_time", 60),
|
|
"set_verbose": False,
|
|
}
|
|
if ctx_fallbacks:
|
|
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
|
|
|
|
instance._router = Router(**router_kwargs)
|
|
instance._initialized = True
|
|
logger.info(
|
|
"LLM Router initialized with %d deployments, "
|
|
"strategy: %s, context_window_fallbacks: %s",
|
|
len(model_list),
|
|
final_settings.get("routing_strategy"),
|
|
ctx_fallbacks or "none",
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize LLM Router: {e}")
|
|
instance._router = None
|
|
|
|
@classmethod
|
|
def _build_context_fallback_groups(
|
|
cls, model_list: list[dict]
|
|
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
|
|
"""Create an ``auto-large`` model group for context-window fallbacks.
|
|
|
|
Uses ``litellm.get_model_info`` to discover the context window of each
|
|
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
|
|
window are duplicated into an ``auto-large`` group. The returned
|
|
fallback config tells the Router: on ``ContextWindowExceededError`` for
|
|
``auto``, retry with ``auto-large``.
|
|
|
|
Returns:
|
|
(full_model_list, context_window_fallbacks) — ``full_model_list``
|
|
contains the original entries plus any ``auto-large`` duplicates.
|
|
``context_window_fallbacks`` is ``None`` when every deployment has
|
|
the same context size (no useful fallback).
|
|
"""
|
|
from litellm import get_model_info
|
|
|
|
ctx_map: dict[str, int] = {}
|
|
for dep in model_list:
|
|
params = dep.get("litellm_params", {})
|
|
base_model = params.get("base_model") or params.get("model", "")
|
|
try:
|
|
info = get_model_info(base_model)
|
|
ctx = info.get("max_input_tokens")
|
|
if isinstance(ctx, int) and ctx > 0:
|
|
ctx_map[base_model] = ctx
|
|
except Exception:
|
|
continue
|
|
|
|
if not ctx_map:
|
|
return model_list, None
|
|
|
|
min_ctx = min(ctx_map.values())
|
|
|
|
large_deployments: list[dict] = []
|
|
for dep in model_list:
|
|
params = dep.get("litellm_params", {})
|
|
base_model = params.get("base_model") or params.get("model", "")
|
|
if ctx_map.get(base_model, 0) > min_ctx:
|
|
dup = {**dep, "model_name": "auto-large"}
|
|
large_deployments.append(dup)
|
|
|
|
if not large_deployments:
|
|
return model_list, None
|
|
|
|
logger.info(
|
|
"Context-window fallback: %d large-context deployments "
|
|
"(min_ctx=%d) added to 'auto-large' group",
|
|
len(large_deployments),
|
|
min_ctx,
|
|
)
|
|
return model_list + large_deployments, [{"auto": ["auto-large"]}]
|
|
|
|
@classmethod
|
|
def _config_to_deployment(cls, config: dict) -> dict | None:
|
|
"""
|
|
Convert a global LLM config to a router deployment entry.
|
|
|
|
Args:
|
|
config: Global LLM config dictionary
|
|
|
|
Returns:
|
|
Router deployment dictionary or None if invalid
|
|
"""
|
|
try:
|
|
# Skip if essential fields are missing
|
|
if not config.get("model_name") or not config.get("api_key"):
|
|
return None
|
|
|
|
# Build model string
|
|
if config.get("custom_provider"):
|
|
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
|
else:
|
|
provider = config.get("provider", "").upper()
|
|
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
|
model_string = f"{provider_prefix}/{config['model_name']}"
|
|
|
|
# Build litellm params
|
|
litellm_params = {
|
|
"model": model_string,
|
|
"api_key": config.get("api_key"),
|
|
}
|
|
|
|
# Add optional api_base
|
|
if config.get("api_base"):
|
|
litellm_params["api_base"] = config["api_base"]
|
|
|
|
# Add any additional litellm parameters
|
|
if config.get("litellm_params"):
|
|
litellm_params.update(config["litellm_params"])
|
|
|
|
# Extract rate limits if provided
|
|
deployment = {
|
|
"model_name": "auto", # All configs use same alias for unified routing
|
|
"litellm_params": litellm_params,
|
|
}
|
|
|
|
# Add rate limits from config if available
|
|
if config.get("rpm"):
|
|
deployment["rpm"] = config["rpm"]
|
|
if config.get("tpm"):
|
|
deployment["tpm"] = config["tpm"]
|
|
|
|
return deployment
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to convert config to deployment: {e}")
|
|
return None
|
|
|
|
@classmethod
|
|
def get_router(cls) -> Router | None:
|
|
"""Get the initialized router instance."""
|
|
instance = cls.get_instance()
|
|
return instance._router
|
|
|
|
@classmethod
|
|
def is_initialized(cls) -> bool:
|
|
"""Check if the router has been initialized."""
|
|
instance = cls.get_instance()
|
|
return instance._initialized and instance._router is not None
|
|
|
|
@classmethod
|
|
def get_model_count(cls) -> int:
|
|
"""Get the number of models in the router."""
|
|
instance = cls.get_instance()
|
|
return len(instance._model_list)
|
|
|
|
|
|
_cached_context_profile: dict | None = None
|
|
_cached_context_profile_computed: bool = False
|
|
|
|
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
|
|
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
|
|
|
|
|
|
def _get_cached_context_profile(router: Router) -> dict | None:
|
|
"""Compute and cache context profile across all router deployments.
|
|
|
|
Called once on first ChatLiteLLMRouter creation; subsequent calls return
|
|
the cached value. This avoids calling litellm.get_model_info() for every
|
|
deployment on every request.
|
|
|
|
Caches both ``max_input_tokens`` (minimum across deployments, used by
|
|
SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across
|
|
deployments, used for context-trimming so we can target the largest
|
|
available model — the router's fallback logic handles smaller ones).
|
|
"""
|
|
global _cached_context_profile, _cached_context_profile_computed
|
|
if _cached_context_profile_computed:
|
|
return _cached_context_profile
|
|
|
|
from litellm import get_model_info
|
|
|
|
min_ctx: int | None = None
|
|
max_ctx: int | None = None
|
|
token_count_model: str | None = None
|
|
ctx_pairs: list[tuple[int, str]] = []
|
|
for deployment in router.model_list:
|
|
params = deployment.get("litellm_params", {})
|
|
base_model = params.get("base_model") or params.get("model", "")
|
|
try:
|
|
info = get_model_info(base_model)
|
|
ctx = info.get("max_input_tokens")
|
|
if isinstance(ctx, int) and ctx > 0:
|
|
if min_ctx is None or ctx < min_ctx:
|
|
min_ctx = ctx
|
|
if max_ctx is None or ctx > max_ctx:
|
|
max_ctx = ctx
|
|
if token_count_model is None:
|
|
token_count_model = base_model
|
|
ctx_pairs.append((ctx, base_model))
|
|
except Exception:
|
|
continue
|
|
|
|
if min_ctx is not None:
|
|
token_count_models: list[str] = []
|
|
if token_count_model:
|
|
token_count_models.append(token_count_model)
|
|
if ctx_pairs:
|
|
ctx_pairs.sort(key=lambda x: x[0])
|
|
smallest_model = ctx_pairs[0][1]
|
|
largest_model = ctx_pairs[-1][1]
|
|
if smallest_model not in token_count_models:
|
|
token_count_models.append(smallest_model)
|
|
if largest_model not in token_count_models:
|
|
token_count_models.append(largest_model)
|
|
logger.info(
|
|
"ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s",
|
|
min_ctx,
|
|
max_ctx,
|
|
token_count_models,
|
|
)
|
|
_cached_context_profile = {
|
|
"max_input_tokens": min_ctx,
|
|
"max_input_tokens_upper": max_ctx,
|
|
"token_count_model": token_count_model,
|
|
"token_count_models": token_count_models,
|
|
}
|
|
else:
|
|
_cached_context_profile = None
|
|
|
|
_cached_context_profile_computed = True
|
|
return _cached_context_profile
|
|
|
|
|
|
class ChatLiteLLMRouter(BaseChatModel):
|
|
"""
|
|
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
|
|
|
|
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
|
|
making it a drop-in replacement for auto-mode routing.
|
|
|
|
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
|
|
window across all router deployments so that deepagents
|
|
SummarizationMiddleware can use fraction-based triggers.
|
|
|
|
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
|
|
directly — instances without bound tools are cached per streaming flag to
|
|
avoid per-request re-initialization overhead and memory growth.
|
|
"""
|
|
|
|
# Use model_config for Pydantic v2 compatibility
|
|
model_config = {"arbitrary_types_allowed": True}
|
|
|
|
# Public attributes that Pydantic will manage
|
|
model: str = "auto"
|
|
streaming: bool = True
|
|
|
|
# Bound tools and tool choice for tool calling
|
|
_bound_tools: list[dict] | None = None
|
|
_tool_choice: str | dict | None = None
|
|
_router: Router | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
router: Router | None = None,
|
|
bound_tools: list[dict] | None = None,
|
|
tool_choice: str | dict | None = None,
|
|
**kwargs,
|
|
):
|
|
try:
|
|
super().__init__(**kwargs)
|
|
resolved_router = router or LLMRouterService.get_router()
|
|
object.__setattr__(self, "_router", resolved_router)
|
|
object.__setattr__(self, "_bound_tools", bound_tools)
|
|
object.__setattr__(self, "_tool_choice", tool_choice)
|
|
if not self._router:
|
|
raise ValueError(
|
|
"LLM Router not initialized. Call LLMRouterService.initialize() first."
|
|
)
|
|
|
|
computed_profile = _get_cached_context_profile(self._router)
|
|
if computed_profile is not None:
|
|
object.__setattr__(self, "profile", computed_profile)
|
|
|
|
logger.debug(
|
|
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
|
|
LLMRouterService.get_model_count(),
|
|
self.streaming,
|
|
bound_tools is not None,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
|
|
raise
|
|
|
|
# -----------------------------------------------------------------
|
|
# Context-aware trimming helpers
|
|
# -----------------------------------------------------------------
|
|
|
|
def _get_token_count_model_names(self) -> list[str]:
|
|
"""Return concrete model names usable by ``litellm.token_counter``.
|
|
|
|
The router uses ``"auto"`` as the model group name but tokenizers need
|
|
concrete model identifiers. We keep multiple candidates and take the
|
|
most conservative count across them.
|
|
"""
|
|
names: list[str] = []
|
|
profile = getattr(self, "profile", None)
|
|
if isinstance(profile, dict):
|
|
tcms = profile.get("token_count_models")
|
|
if isinstance(tcms, list):
|
|
for name in tcms:
|
|
if isinstance(name, str) and name and name not in names:
|
|
names.append(name)
|
|
tcm = profile.get("token_count_model")
|
|
if isinstance(tcm, str) and tcm and tcm not in names:
|
|
names.append(tcm)
|
|
|
|
if self._router and self._router.model_list:
|
|
for dep in self._router.model_list:
|
|
params = dep.get("litellm_params", {})
|
|
base = params.get("base_model") or params.get("model", "")
|
|
if base and base not in names:
|
|
names.append(base)
|
|
if len(names) >= 3:
|
|
break
|
|
if not names:
|
|
return ["gpt-4o"]
|
|
return names
|
|
|
|
def _count_tokens(self, messages: list[dict]) -> int | None:
|
|
"""Return conservative token count across candidate deployment models."""
|
|
from litellm import token_counter as _tc
|
|
|
|
models = self._get_token_count_model_names()
|
|
counts: list[int] = []
|
|
for model_name in models:
|
|
try:
|
|
counts.append(_tc(messages=messages, model=model_name))
|
|
except Exception:
|
|
continue
|
|
return max(counts) if counts else None
|
|
|
|
def _get_max_input_tokens(self) -> int:
|
|
"""Return the max input tokens to use for context trimming.
|
|
|
|
Prefers the *largest* context window across all deployments so we
|
|
maximise usable context (the router's ``context_window_fallbacks``
|
|
handle routing to the right model). Falls back to the minimum
|
|
profile value or a conservative default.
|
|
"""
|
|
profile = getattr(self, "profile", None)
|
|
if isinstance(profile, dict):
|
|
upper = profile.get("max_input_tokens_upper")
|
|
if isinstance(upper, int) and upper > 0:
|
|
return upper
|
|
lower = profile.get("max_input_tokens")
|
|
if isinstance(lower, int) and lower > 0:
|
|
return lower
|
|
return 128_000
|
|
|
|
def _trim_messages_to_fit_context(
|
|
self,
|
|
messages: list[dict],
|
|
output_reserve_fraction: float = 0.10,
|
|
) -> list[dict]:
|
|
"""Trim message content via binary search to fit the model's context window.
|
|
|
|
When the total token count exceeds the model's ``max_input_tokens``,
|
|
this method identifies the largest messages (typically tool responses
|
|
containing search results) and uses binary search on each to find the
|
|
maximum content length that keeps the total within budget.
|
|
|
|
Cutting prefers ``</document>`` XML boundaries so complete documents
|
|
are preserved when possible.
|
|
|
|
This is model-aware: it reads the context limit from
|
|
``litellm.get_model_info`` (cached in ``self.profile``) and counts
|
|
tokens with ``litellm.token_counter``.
|
|
"""
|
|
max_input = self._get_max_input_tokens()
|
|
output_reserve = min(int(max_input * output_reserve_fraction), 16_384)
|
|
budget = max_input - output_reserve
|
|
|
|
total_tokens = self._count_tokens(messages)
|
|
if total_tokens is None:
|
|
return messages
|
|
|
|
if total_tokens <= budget:
|
|
return messages
|
|
|
|
perf = get_perf_logger()
|
|
perf.warning(
|
|
"[llm_router] context overflow detected: %d tokens > %d budget "
|
|
"(max_input=%d, reserve=%d). Trimming messages.",
|
|
total_tokens,
|
|
budget,
|
|
max_input,
|
|
output_reserve,
|
|
)
|
|
|
|
trimmed = copy.deepcopy(messages)
|
|
|
|
# Per-message token counts for trimmable candidates.
|
|
# Skip system messages to preserve agent instructions.
|
|
msg_token_map: dict[int, int] = {}
|
|
candidate_priority: dict[int, int] = {}
|
|
for i, msg in enumerate(trimmed):
|
|
if msg.get("role") == "system":
|
|
continue
|
|
role = msg.get("role")
|
|
content = msg.get("content", "")
|
|
if not isinstance(content, str) or len(content) < 500:
|
|
continue
|
|
# Prefer trimming tool/assistant outputs first.
|
|
# User messages are only trimmed if they clearly contain injected
|
|
# document context blobs.
|
|
is_doc_blob = "<document>" in content or "<mentioned_documents>" in content
|
|
if role in ("tool", "assistant"):
|
|
candidate_priority[i] = 0
|
|
elif role == "user" and is_doc_blob:
|
|
candidate_priority[i] = 1
|
|
else:
|
|
continue
|
|
token_count = self._count_tokens([msg])
|
|
if token_count is not None:
|
|
msg_token_map[i] = token_count
|
|
|
|
if not msg_token_map:
|
|
perf.warning("[llm_router] no trimmable messages found, returning as-is")
|
|
return trimmed
|
|
|
|
# Trim largest messages first
|
|
candidates = sorted(
|
|
msg_token_map.items(),
|
|
key=lambda x: (candidate_priority.get(x[0], 9), -x[1]),
|
|
)
|
|
running_total = total_tokens
|
|
|
|
trim_suffix = (
|
|
"\n\n<!-- Content trimmed to fit model context window. "
|
|
"Some documents were omitted. Refine your query or "
|
|
"reduce top_k for different results. -->"
|
|
)
|
|
|
|
for idx, orig_msg_tokens in candidates:
|
|
if running_total <= budget:
|
|
break
|
|
|
|
content = trimmed[idx]["content"]
|
|
orig_len = len(content)
|
|
|
|
# Binary search: find maximum content[:mid] that keeps total ≤ budget.
|
|
lo, hi = 200, orig_len - 1
|
|
best = 200
|
|
|
|
while lo <= hi:
|
|
mid = (lo + hi) // 2
|
|
trimmed[idx]["content"] = content[:mid] + trim_suffix
|
|
new_msg_tokens = self._count_tokens([trimmed[idx]])
|
|
if new_msg_tokens is None:
|
|
hi = mid - 1
|
|
continue
|
|
|
|
projected_total = running_total - orig_msg_tokens + new_msg_tokens
|
|
if projected_total <= budget:
|
|
best = mid
|
|
lo = mid + 1
|
|
else:
|
|
hi = mid - 1
|
|
|
|
# Prefer cutting at a </document> boundary for cleaner output
|
|
last_doc_end = content[:best].rfind("</document>")
|
|
if last_doc_end > min(200, best // 4):
|
|
best = last_doc_end + len("</document>")
|
|
|
|
trimmed[idx]["content"] = content[:best] + trim_suffix
|
|
|
|
try:
|
|
new_msg_tokens = self._count_tokens([trimmed[idx]])
|
|
if new_msg_tokens is None:
|
|
continue
|
|
running_total = running_total - orig_msg_tokens + new_msg_tokens
|
|
except Exception:
|
|
pass
|
|
|
|
# Hard guarantee: if still over budget, replace remaining large
|
|
# non-system messages with compact placeholders until we fit.
|
|
if running_total > budget:
|
|
fallback_indices: list[int] = []
|
|
for i, msg in enumerate(trimmed):
|
|
if msg.get("role") == "system":
|
|
continue
|
|
content = msg.get("content")
|
|
if isinstance(content, str) and len(content) > 0:
|
|
fallback_indices.append(i)
|
|
|
|
for idx in fallback_indices:
|
|
if running_total <= budget:
|
|
break
|
|
role = trimmed[idx].get("role", "message")
|
|
placeholder = (
|
|
f"[content omitted to fit model context window; role={role}]"
|
|
)
|
|
old_tokens = self._count_tokens([trimmed[idx]]) or 0
|
|
trimmed[idx]["content"] = placeholder
|
|
new_tokens = self._count_tokens([trimmed[idx]]) or 0
|
|
running_total = running_total - old_tokens + new_tokens
|
|
|
|
if running_total > budget:
|
|
perf.error(
|
|
"[llm_router] unable to fit context even after aggressive trimming: "
|
|
"tokens=%d budget=%d",
|
|
running_total,
|
|
budget,
|
|
)
|
|
# Final safety net: clear oldest non-system contents.
|
|
for idx in fallback_indices:
|
|
if running_total <= budget:
|
|
break
|
|
old_tokens = self._count_tokens([trimmed[idx]]) or 0
|
|
trimmed[idx]["content"] = ""
|
|
new_tokens = self._count_tokens([trimmed[idx]]) or 0
|
|
running_total = running_total - old_tokens + new_tokens
|
|
|
|
perf.info(
|
|
"[llm_router] messages trimmed: %d → %d tokens (budget=%d, max_input=%d)",
|
|
total_tokens,
|
|
running_total,
|
|
budget,
|
|
max_input,
|
|
)
|
|
|
|
return trimmed
|
|
|
|
# -----------------------------------------------------------------
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "litellm-router"
|
|
|
|
@property
|
|
def _identifying_params(self) -> dict[str, Any]:
|
|
return {
|
|
"model": self.model,
|
|
"model_count": LLMRouterService.get_model_count(),
|
|
}
|
|
|
|
def bind_tools(
|
|
self,
|
|
tools: list[Any],
|
|
*,
|
|
tool_choice: str | dict | None = None,
|
|
**kwargs: Any,
|
|
) -> "ChatLiteLLMRouter":
|
|
"""
|
|
Bind tools to the model for function/tool calling.
|
|
|
|
Args:
|
|
tools: List of tools to bind (can be LangChain tools, Pydantic models, or dicts)
|
|
tool_choice: Optional tool choice strategy ("auto", "required", "none", or specific tool)
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
New ChatLiteLLMRouter instance with tools bound
|
|
"""
|
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
|
|
# Convert tools to OpenAI format
|
|
formatted_tools = []
|
|
for tool in tools:
|
|
if isinstance(tool, dict):
|
|
# Already in dict format
|
|
formatted_tools.append(tool)
|
|
else:
|
|
# Convert using LangChain utility
|
|
try:
|
|
formatted_tools.append(convert_to_openai_tool(tool))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to convert tool {tool}: {e}")
|
|
continue
|
|
|
|
# Create a new instance with tools bound
|
|
return ChatLiteLLMRouter(
|
|
router=self._router,
|
|
bound_tools=formatted_tools if formatted_tools else None,
|
|
tool_choice=tool_choice,
|
|
model=self.model,
|
|
streaming=self.streaming,
|
|
**kwargs,
|
|
)
|
|
|
|
def _generate(
|
|
self,
|
|
messages: list[BaseMessage],
|
|
stop: list[str] | None = None,
|
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Generate a response using the router (synchronous).
|
|
"""
|
|
if not self._router:
|
|
raise ValueError("Router not initialized")
|
|
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
msg_count = len(messages)
|
|
|
|
# Convert LangChain messages to OpenAI format
|
|
formatted_messages = self._convert_messages(messages)
|
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
|
|
|
# Add tools if bound
|
|
call_kwargs = {**kwargs}
|
|
if self._bound_tools:
|
|
call_kwargs["tools"] = self._bound_tools
|
|
if self._tool_choice is not None:
|
|
call_kwargs["tool_choice"] = self._tool_choice
|
|
|
|
try:
|
|
response = self._router.completion(
|
|
model=self.model,
|
|
messages=formatted_messages,
|
|
stop=stop,
|
|
**call_kwargs,
|
|
)
|
|
except ContextWindowExceededError as e:
|
|
perf.warning(
|
|
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
|
msg_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
raise ContextOverflowError(str(e)) from e
|
|
except LiteLLMBadRequestError as e:
|
|
if _is_context_overflow_error(e):
|
|
perf.warning(
|
|
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
|
msg_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
raise ContextOverflowError(str(e)) from e
|
|
raise
|
|
|
|
elapsed = time.perf_counter() - t0
|
|
perf.info(
|
|
"[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
|
|
msg_count,
|
|
len(self._bound_tools) if self._bound_tools else 0,
|
|
elapsed,
|
|
)
|
|
|
|
# Convert response to ChatResult with potential tool calls
|
|
message = self._convert_response_to_message(response.choices[0].message)
|
|
generation = ChatGeneration(message=message)
|
|
|
|
return ChatResult(generations=[generation])
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: list[BaseMessage],
|
|
stop: list[str] | None = None,
|
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Generate a response using the router (asynchronous).
|
|
"""
|
|
if not self._router:
|
|
raise ValueError("Router not initialized")
|
|
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
msg_count = len(messages)
|
|
|
|
# Convert LangChain messages to OpenAI format
|
|
formatted_messages = self._convert_messages(messages)
|
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
|
|
|
# Add tools if bound
|
|
call_kwargs = {**kwargs}
|
|
if self._bound_tools:
|
|
call_kwargs["tools"] = self._bound_tools
|
|
if self._tool_choice is not None:
|
|
call_kwargs["tool_choice"] = self._tool_choice
|
|
|
|
try:
|
|
response = await self._router.acompletion(
|
|
model=self.model,
|
|
messages=formatted_messages,
|
|
stop=stop,
|
|
**call_kwargs,
|
|
)
|
|
except ContextWindowExceededError as e:
|
|
perf.warning(
|
|
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
|
msg_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
raise ContextOverflowError(str(e)) from e
|
|
except LiteLLMBadRequestError as e:
|
|
if _is_context_overflow_error(e):
|
|
perf.warning(
|
|
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
|
msg_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
raise ContextOverflowError(str(e)) from e
|
|
raise
|
|
|
|
elapsed = time.perf_counter() - t0
|
|
perf.info(
|
|
"[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
|
|
msg_count,
|
|
len(self._bound_tools) if self._bound_tools else 0,
|
|
elapsed,
|
|
)
|
|
|
|
# Convert response to ChatResult with potential tool calls
|
|
message = self._convert_response_to_message(response.choices[0].message)
|
|
generation = ChatGeneration(message=message)
|
|
|
|
return ChatResult(generations=[generation])
|
|
|
|
def _stream(
|
|
self,
|
|
messages: list[BaseMessage],
|
|
stop: list[str] | None = None,
|
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
**kwargs: Any,
|
|
):
|
|
"""
|
|
Stream a response using the router (synchronous).
|
|
"""
|
|
if not self._router:
|
|
raise ValueError("Router not initialized")
|
|
|
|
formatted_messages = self._convert_messages(messages)
|
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
|
|
|
# Add tools if bound
|
|
call_kwargs = {**kwargs}
|
|
if self._bound_tools:
|
|
call_kwargs["tools"] = self._bound_tools
|
|
if self._tool_choice is not None:
|
|
call_kwargs["tool_choice"] = self._tool_choice
|
|
|
|
try:
|
|
response = self._router.completion(
|
|
model=self.model,
|
|
messages=formatted_messages,
|
|
stop=stop,
|
|
stream=True,
|
|
**call_kwargs,
|
|
)
|
|
except ContextWindowExceededError as e:
|
|
raise ContextOverflowError(str(e)) from e
|
|
except LiteLLMBadRequestError as e:
|
|
if _is_context_overflow_error(e):
|
|
raise ContextOverflowError(str(e)) from e
|
|
raise
|
|
|
|
# Yield chunks
|
|
for chunk in response:
|
|
if hasattr(chunk, "choices") and chunk.choices:
|
|
delta = chunk.choices[0].delta
|
|
chunk_msg = self._convert_delta_to_chunk(delta)
|
|
if chunk_msg:
|
|
yield ChatGenerationChunk(message=chunk_msg)
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: list[BaseMessage],
|
|
stop: list[str] | None = None,
|
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
**kwargs: Any,
|
|
):
|
|
"""
|
|
Stream a response using the router (asynchronous).
|
|
"""
|
|
if not self._router:
|
|
raise ValueError("Router not initialized")
|
|
|
|
perf = get_perf_logger()
|
|
t0 = time.perf_counter()
|
|
msg_count = len(messages)
|
|
|
|
formatted_messages = self._convert_messages(messages)
|
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
|
|
|
# Add tools if bound
|
|
call_kwargs = {**kwargs}
|
|
if self._bound_tools:
|
|
call_kwargs["tools"] = self._bound_tools
|
|
if self._tool_choice is not None:
|
|
call_kwargs["tool_choice"] = self._tool_choice
|
|
|
|
try:
|
|
response = await self._router.acompletion(
|
|
model=self.model,
|
|
messages=formatted_messages,
|
|
stop=stop,
|
|
stream=True,
|
|
**call_kwargs,
|
|
)
|
|
except ContextWindowExceededError as e:
|
|
perf.warning(
|
|
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
|
msg_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
raise ContextOverflowError(str(e)) from e
|
|
except LiteLLMBadRequestError as e:
|
|
if _is_context_overflow_error(e):
|
|
perf.warning(
|
|
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
|
|
msg_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
raise ContextOverflowError(str(e)) from e
|
|
raise
|
|
|
|
t_first_chunk = time.perf_counter()
|
|
perf.info(
|
|
"[llm_router] _astream connection established msgs=%d in %.3fs",
|
|
msg_count,
|
|
t_first_chunk - t0,
|
|
)
|
|
|
|
chunk_count = 0
|
|
first_chunk_logged = False
|
|
async for chunk in response:
|
|
if hasattr(chunk, "choices") and chunk.choices:
|
|
delta = chunk.choices[0].delta
|
|
chunk_msg = self._convert_delta_to_chunk(delta)
|
|
if chunk_msg:
|
|
chunk_count += 1
|
|
if not first_chunk_logged:
|
|
perf.info(
|
|
"[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
|
|
time.perf_counter() - t_first_chunk,
|
|
time.perf_counter() - t0,
|
|
)
|
|
first_chunk_logged = True
|
|
yield ChatGenerationChunk(message=chunk_msg)
|
|
|
|
perf.info(
|
|
"[llm_router] _astream completed chunks=%d total=%.3fs",
|
|
chunk_count,
|
|
time.perf_counter() - t0,
|
|
)
|
|
|
|
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
|
|
"""Convert LangChain messages to OpenAI format."""
|
|
from langchain_core.messages import (
|
|
AIMessage as AIMsg,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
|
|
result = []
|
|
for msg in messages:
|
|
if isinstance(msg, SystemMessage):
|
|
result.append({"role": "system", "content": msg.content})
|
|
elif isinstance(msg, HumanMessage):
|
|
result.append({"role": "user", "content": msg.content})
|
|
elif isinstance(msg, AIMsg):
|
|
ai_msg: dict[str, Any] = {"role": "assistant"}
|
|
if msg.content:
|
|
ai_msg["content"] = msg.content
|
|
# Handle tool calls
|
|
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
|
ai_msg["tool_calls"] = [
|
|
{
|
|
"id": tc.get("id", ""),
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.get("name", ""),
|
|
"arguments": tc.get("args", "{}")
|
|
if isinstance(tc.get("args"), str)
|
|
else __import__("json").dumps(tc.get("args", {})),
|
|
},
|
|
}
|
|
for tc in msg.tool_calls
|
|
]
|
|
result.append(ai_msg)
|
|
elif isinstance(msg, ToolMessage):
|
|
result.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": msg.tool_call_id,
|
|
"content": msg.content
|
|
if isinstance(msg.content, str)
|
|
else __import__("json").dumps(msg.content),
|
|
}
|
|
)
|
|
else:
|
|
# Fallback for other message types
|
|
role = getattr(msg, "type", "user")
|
|
if role == "human":
|
|
role = "user"
|
|
elif role == "ai":
|
|
role = "assistant"
|
|
result.append({"role": role, "content": msg.content})
|
|
|
|
return result
|
|
|
|
def _convert_response_to_message(self, response_message: Any) -> AIMessage:
|
|
"""Convert a LiteLLM response message to a LangChain AIMessage."""
|
|
import json
|
|
|
|
content = getattr(response_message, "content", None) or ""
|
|
|
|
# Check for tool calls
|
|
tool_calls = []
|
|
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
|
|
for tc in response_message.tool_calls:
|
|
tool_call = {
|
|
"id": tc.id if hasattr(tc, "id") else "",
|
|
"name": tc.function.name if hasattr(tc, "function") else "",
|
|
"args": {},
|
|
}
|
|
# Parse arguments
|
|
if hasattr(tc, "function") and hasattr(tc.function, "arguments"):
|
|
try:
|
|
tool_call["args"] = json.loads(tc.function.arguments)
|
|
except json.JSONDecodeError:
|
|
tool_call["args"] = tc.function.arguments
|
|
tool_calls.append(tool_call)
|
|
|
|
if tool_calls:
|
|
return AIMessage(content=content, tool_calls=tool_calls)
|
|
return AIMessage(content=content)
|
|
|
|
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
|
|
"""Convert a streaming delta to an AIMessageChunk."""
|
|
|
|
content = getattr(delta, "content", None) or ""
|
|
|
|
# Check for tool calls in delta
|
|
tool_call_chunks = []
|
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
for tc in delta.tool_calls:
|
|
chunk = {
|
|
"index": tc.index if hasattr(tc, "index") else 0,
|
|
"id": tc.id if hasattr(tc, "id") else None,
|
|
"name": tc.function.name
|
|
if hasattr(tc, "function") and hasattr(tc.function, "name")
|
|
else None,
|
|
"args": tc.function.arguments
|
|
if hasattr(tc, "function") and hasattr(tc.function, "arguments")
|
|
else "",
|
|
}
|
|
tool_call_chunks.append(chunk)
|
|
|
|
if content or tool_call_chunks:
|
|
if tool_call_chunks:
|
|
return AIMessageChunk(
|
|
content=content, tool_call_chunks=tool_call_chunks
|
|
)
|
|
return AIMessageChunk(content=content)
|
|
|
|
return None
|
|
|
|
|
|
def get_auto_mode_llm(
|
|
*,
|
|
streaming: bool = True,
|
|
) -> ChatLiteLLMRouter | None:
|
|
"""Return a cached ChatLiteLLMRouter for auto mode.
|
|
|
|
Base (no tools) instances are cached per ``streaming`` flag so we
|
|
avoid re-constructing them on every request. ``bind_tools()`` still
|
|
returns a fresh instance because bound tools differ per agent.
|
|
"""
|
|
if not LLMRouterService.is_initialized():
|
|
logger.warning("LLM Router not initialized for auto mode")
|
|
return None
|
|
|
|
cached = _router_instance_cache.get(streaming)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
try:
|
|
instance = ChatLiteLLMRouter(streaming=streaming)
|
|
_router_instance_cache[streaming] = instance
|
|
return instance
|
|
except Exception as e:
|
|
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
|
return None
|
|
|
|
|
|
def is_auto_mode(llm_config_id: int | None) -> bool:
|
|
"""
|
|
Check if the given LLM config ID represents Auto mode.
|
|
|
|
Args:
|
|
llm_config_id: The LLM config ID to check
|
|
|
|
Returns:
|
|
True if this is Auto mode, False otherwise
|
|
"""
|
|
return llm_config_id == AUTO_MODE_ID
|