SurfSense/surfsense_backend/app/services/llm_router_service.py
PR Bot 760aa38225 feat: complete MiniMax LLM provider integration
Add full MiniMax provider support across the entire stack:

Backend:
- Add MINIMAX to LiteLLMProvider enum in db.py
- Add MINIMAX mapping to all provider_map dicts in llm_service.py,
  llm_router_service.py, and llm_config.py
- Add Alembic migration (rev 106) for PostgreSQL enum
- Add MiniMax M2.5 example in global_llm_config.example.yaml

Frontend:
- Add MiniMax to LLM_PROVIDERS enum with apiBase
- Add MiniMax-M2.5 and MiniMax-M2.5-highspeed to LLM_MODELS
- Add MINIMAX to Zod validation schema
- Add MiniMax SVG icon and wire up in provider-icons

Docs:
- Add MiniMax setup guide in chinese-llm-setup.md

MiniMax uses an OpenAI-compatible API (https://api.minimax.io/v1)
with models supporting up to 204K context window.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 07:27:47 +08:00

1173 lines
42 KiB
Python

"""
LiteLLM Router Service for Load Balancing
This module provides a singleton LiteLLM Router for automatic load balancing
across multiple LLM deployments. It handles:
- Rate limit management with automatic cooldowns
- Automatic failover and retries
- Usage-based routing to distribute load evenly
The router is initialized from global LLM configs and provides both
synchronous ChatLiteLLM-like interface and async methods.
"""
import copy
import logging
import re
import time
from typing import Any
import litellm
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from litellm import Router
from litellm.exceptions import (
BadRequestError as LiteLLMBadRequestError,
ContextWindowExceededError,
)
from app.utils.perf import get_perf_logger
litellm.json_logs = False
litellm.store_audit_logs = False
logger = logging.getLogger(__name__)
_CONTEXT_OVERFLOW_PATTERNS = re.compile(
r"(input tokens exceed|context.{0,20}(length|window|limit)|"
r"maximum context length|token.{0,20}(limit|exceed)|"
r"too many tokens|reduce the length)",
re.IGNORECASE,
)
def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool:
"""Check if a BadRequestError is actually a context window overflow."""
return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc)))
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock", # Legacy support
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
class LLMRouterService:
"""
Singleton service for managing LiteLLM Router.
The router provides automatic load balancing, failover, and rate limit
handling across multiple LLM deployments.
"""
_instance = None
_router: Router | None = None
_model_list: list[dict] = []
_router_settings: dict = {}
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls) -> "LLMRouterService":
"""Get the singleton instance of the router service."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def initialize(
cls,
global_configs: list[dict],
router_settings: dict | None = None,
) -> None:
"""
Initialize the router with global LLM configurations.
Args:
global_configs: List of global LLM config dictionaries from YAML
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
"""
instance = cls.get_instance()
if instance._initialized:
logger.debug("LLM Router already initialized, skipping")
return
# Build model list from global configs
model_list = []
for config in global_configs:
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
if not model_list:
logger.warning("No valid LLM configs found for router initialization")
return
instance._model_list = model_list
instance._router_settings = router_settings or {}
# Default router settings optimized for rate limit handling
default_settings = {
"routing_strategy": "usage-based-routing", # Best for rate limit management
"num_retries": 3,
"allowed_fails": 3,
"cooldown_time": 60, # Cooldown for 60 seconds after failures
"retry_after": 5, # Wait 5 seconds between retries
}
# Merge with provided settings
final_settings = {**default_settings, **instance._router_settings}
# Build a "auto-large" fallback group with deployments whose context
# window exceeds the smallest deployment. This lets the router
# automatically fall back to a bigger-context model when gpt-4o (128K)
# hits ContextWindowExceededError.
full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list)
try:
router_kwargs: dict[str, Any] = {
"model_list": full_model_list,
"routing_strategy": final_settings.get(
"routing_strategy", "usage-based-routing"
),
"num_retries": final_settings.get("num_retries", 3),
"allowed_fails": final_settings.get("allowed_fails", 3),
"cooldown_time": final_settings.get("cooldown_time", 60),
"set_verbose": False,
}
if ctx_fallbacks:
router_kwargs["context_window_fallbacks"] = ctx_fallbacks
instance._router = Router(**router_kwargs)
instance._initialized = True
logger.info(
"LLM Router initialized with %d deployments, "
"strategy: %s, context_window_fallbacks: %s",
len(model_list),
final_settings.get("routing_strategy"),
ctx_fallbacks or "none",
)
except Exception as e:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
@classmethod
def _build_context_fallback_groups(
cls, model_list: list[dict]
) -> tuple[list[dict], list[dict[str, list[str]]] | None]:
"""Create an ``auto-large`` model group for context-window fallbacks.
Uses ``litellm.get_model_info`` to discover the context window of each
deployment. Deployments whose ``max_input_tokens`` exceeds the smallest
window are duplicated into an ``auto-large`` group. The returned
fallback config tells the Router: on ``ContextWindowExceededError`` for
``auto``, retry with ``auto-large``.
Returns:
(full_model_list, context_window_fallbacks) — ``full_model_list``
contains the original entries plus any ``auto-large`` duplicates.
``context_window_fallbacks`` is ``None`` when every deployment has
the same context size (no useful fallback).
"""
from litellm import get_model_info
ctx_map: dict[str, int] = {}
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0:
ctx_map[base_model] = ctx
except Exception:
continue
if not ctx_map:
return model_list, None
min_ctx = min(ctx_map.values())
large_deployments: list[dict] = []
for dep in model_list:
params = dep.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
if ctx_map.get(base_model, 0) > min_ctx:
dup = {**dep, "model_name": "auto-large"}
large_deployments.append(dup)
if not large_deployments:
return model_list, None
logger.info(
"Context-window fallback: %d large-context deployments "
"(min_ctx=%d) added to 'auto-large' group",
len(large_deployments),
min_ctx,
)
return model_list + large_deployments, [{"auto": ["auto-large"]}]
@classmethod
def _config_to_deployment(cls, config: dict) -> dict | None:
"""
Convert a global LLM config to a router deployment entry.
Args:
config: Global LLM config dictionary
Returns:
Router deployment dictionary or None if invalid
"""
try:
# Skip if essential fields are missing
if not config.get("model_name") or not config.get("api_key"):
return None
# Build model string
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
else:
provider = config.get("provider", "").upper()
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params = {
"model": model_string,
"api_key": config.get("api_key"),
}
# Add optional api_base
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
# Add any additional litellm parameters
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
# Extract rate limits if provided
deployment = {
"model_name": "auto", # All configs use same alias for unified routing
"litellm_params": litellm_params,
}
# Add rate limits from config if available
if config.get("rpm"):
deployment["rpm"] = config["rpm"]
if config.get("tpm"):
deployment["tpm"] = config["tpm"]
return deployment
except Exception as e:
logger.warning(f"Failed to convert config to deployment: {e}")
return None
@classmethod
def get_router(cls) -> Router | None:
"""Get the initialized router instance."""
instance = cls.get_instance()
return instance._router
@classmethod
def is_initialized(cls) -> bool:
"""Check if the router has been initialized."""
instance = cls.get_instance()
return instance._initialized and instance._router is not None
@classmethod
def get_model_count(cls) -> int:
"""Get the number of models in the router."""
instance = cls.get_instance()
return len(instance._model_list)
_cached_context_profile: dict | None = None
_cached_context_profile_computed: bool = False
# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}
def _get_cached_context_profile(router: Router) -> dict | None:
"""Compute and cache context profile across all router deployments.
Called once on first ChatLiteLLMRouter creation; subsequent calls return
the cached value. This avoids calling litellm.get_model_info() for every
deployment on every request.
Caches both ``max_input_tokens`` (minimum across deployments, used by
SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across
deployments, used for context-trimming so we can target the largest
available model — the router's fallback logic handles smaller ones).
"""
global _cached_context_profile, _cached_context_profile_computed
if _cached_context_profile_computed:
return _cached_context_profile
from litellm import get_model_info
min_ctx: int | None = None
max_ctx: int | None = None
token_count_model: str | None = None
ctx_pairs: list[tuple[int, str]] = []
for deployment in router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0:
if min_ctx is None or ctx < min_ctx:
min_ctx = ctx
if max_ctx is None or ctx > max_ctx:
max_ctx = ctx
if token_count_model is None:
token_count_model = base_model
ctx_pairs.append((ctx, base_model))
except Exception:
continue
if min_ctx is not None:
token_count_models: list[str] = []
if token_count_model:
token_count_models.append(token_count_model)
if ctx_pairs:
ctx_pairs.sort(key=lambda x: x[0])
smallest_model = ctx_pairs[0][1]
largest_model = ctx_pairs[-1][1]
if smallest_model not in token_count_models:
token_count_models.append(smallest_model)
if largest_model not in token_count_models:
token_count_models.append(largest_model)
logger.info(
"ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s",
min_ctx,
max_ctx,
token_count_models,
)
_cached_context_profile = {
"max_input_tokens": min_ctx,
"max_input_tokens_upper": max_ctx,
"token_count_model": token_count_model,
"token_count_models": token_count_models,
}
else:
_cached_context_profile = None
_cached_context_profile_computed = True
return _cached_context_profile
class ChatLiteLLMRouter(BaseChatModel):
"""
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM,
making it a drop-in replacement for auto-mode routing.
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.
**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
directly — instances without bound tools are cached per streaming flag to
avoid per-request re-initialization overhead and memory growth.
"""
# Use model_config for Pydantic v2 compatibility
model_config = {"arbitrary_types_allowed": True}
# Public attributes that Pydantic will manage
model: str = "auto"
streaming: bool = True
# Bound tools and tool choice for tool calling
_bound_tools: list[dict] | None = None
_tool_choice: str | dict | None = None
_router: Router | None = None
def __init__(
self,
router: Router | None = None,
bound_tools: list[dict] | None = None,
tool_choice: str | dict | None = None,
**kwargs,
):
try:
super().__init__(**kwargs)
resolved_router = router or LLMRouterService.get_router()
object.__setattr__(self, "_router", resolved_router)
object.__setattr__(self, "_bound_tools", bound_tools)
object.__setattr__(self, "_tool_choice", tool_choice)
if not self._router:
raise ValueError(
"LLM Router not initialized. Call LLMRouterService.initialize() first."
)
computed_profile = _get_cached_context_profile(self._router)
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)
logger.debug(
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
LLMRouterService.get_model_count(),
self.streaming,
bound_tools is not None,
)
except Exception as e:
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise
# -----------------------------------------------------------------
# Context-aware trimming helpers
# -----------------------------------------------------------------
def _get_token_count_model_names(self) -> list[str]:
"""Return concrete model names usable by ``litellm.token_counter``.
The router uses ``"auto"`` as the model group name but tokenizers need
concrete model identifiers. We keep multiple candidates and take the
most conservative count across them.
"""
names: list[str] = []
profile = getattr(self, "profile", None)
if isinstance(profile, dict):
tcms = profile.get("token_count_models")
if isinstance(tcms, list):
for name in tcms:
if isinstance(name, str) and name and name not in names:
names.append(name)
tcm = profile.get("token_count_model")
if isinstance(tcm, str) and tcm and tcm not in names:
names.append(tcm)
if self._router and self._router.model_list:
for dep in self._router.model_list:
params = dep.get("litellm_params", {})
base = params.get("base_model") or params.get("model", "")
if base and base not in names:
names.append(base)
if len(names) >= 3:
break
if not names:
return ["gpt-4o"]
return names
def _count_tokens(self, messages: list[dict]) -> int | None:
"""Return conservative token count across candidate deployment models."""
from litellm import token_counter as _tc
models = self._get_token_count_model_names()
counts: list[int] = []
for model_name in models:
try:
counts.append(_tc(messages=messages, model=model_name))
except Exception:
continue
return max(counts) if counts else None
def _get_max_input_tokens(self) -> int:
"""Return the max input tokens to use for context trimming.
Prefers the *largest* context window across all deployments so we
maximise usable context (the router's ``context_window_fallbacks``
handle routing to the right model). Falls back to the minimum
profile value or a conservative default.
"""
profile = getattr(self, "profile", None)
if isinstance(profile, dict):
upper = profile.get("max_input_tokens_upper")
if isinstance(upper, int) and upper > 0:
return upper
lower = profile.get("max_input_tokens")
if isinstance(lower, int) and lower > 0:
return lower
return 128_000
def _trim_messages_to_fit_context(
self,
messages: list[dict],
output_reserve_fraction: float = 0.10,
) -> list[dict]:
"""Trim message content via binary search to fit the model's context window.
When the total token count exceeds the model's ``max_input_tokens``,
this method identifies the largest messages (typically tool responses
containing search results) and uses binary search on each to find the
maximum content length that keeps the total within budget.
Cutting prefers ``</document>`` XML boundaries so complete documents
are preserved when possible.
This is model-aware: it reads the context limit from
``litellm.get_model_info`` (cached in ``self.profile``) and counts
tokens with ``litellm.token_counter``.
"""
max_input = self._get_max_input_tokens()
output_reserve = min(int(max_input * output_reserve_fraction), 16_384)
budget = max_input - output_reserve
total_tokens = self._count_tokens(messages)
if total_tokens is None:
return messages
if total_tokens <= budget:
return messages
perf = get_perf_logger()
perf.warning(
"[llm_router] context overflow detected: %d tokens > %d budget "
"(max_input=%d, reserve=%d). Trimming messages.",
total_tokens,
budget,
max_input,
output_reserve,
)
trimmed = copy.deepcopy(messages)
# Per-message token counts for trimmable candidates.
# Skip system messages to preserve agent instructions.
msg_token_map: dict[int, int] = {}
candidate_priority: dict[int, int] = {}
for i, msg in enumerate(trimmed):
if msg.get("role") == "system":
continue
role = msg.get("role")
content = msg.get("content", "")
if not isinstance(content, str) or len(content) < 500:
continue
# Prefer trimming tool/assistant outputs first.
# User messages are only trimmed if they clearly contain injected
# document context blobs.
is_doc_blob = "<document>" in content or "<mentioned_documents>" in content
if role in ("tool", "assistant"):
candidate_priority[i] = 0
elif role == "user" and is_doc_blob:
candidate_priority[i] = 1
else:
continue
token_count = self._count_tokens([msg])
if token_count is not None:
msg_token_map[i] = token_count
if not msg_token_map:
perf.warning("[llm_router] no trimmable messages found, returning as-is")
return trimmed
# Trim largest messages first
candidates = sorted(
msg_token_map.items(),
key=lambda x: (candidate_priority.get(x[0], 9), -x[1]),
)
running_total = total_tokens
trim_suffix = (
"\n\n<!-- Content trimmed to fit model context window. "
"Some documents were omitted. Refine your query or "
"reduce top_k for different results. -->"
)
for idx, orig_msg_tokens in candidates:
if running_total <= budget:
break
content = trimmed[idx]["content"]
orig_len = len(content)
# Binary search: find maximum content[:mid] that keeps total ≤ budget.
lo, hi = 200, orig_len - 1
best = 200
while lo <= hi:
mid = (lo + hi) // 2
trimmed[idx]["content"] = content[:mid] + trim_suffix
new_msg_tokens = self._count_tokens([trimmed[idx]])
if new_msg_tokens is None:
hi = mid - 1
continue
projected_total = running_total - orig_msg_tokens + new_msg_tokens
if projected_total <= budget:
best = mid
lo = mid + 1
else:
hi = mid - 1
# Prefer cutting at a </document> boundary for cleaner output
last_doc_end = content[:best].rfind("</document>")
if last_doc_end > min(200, best // 4):
best = last_doc_end + len("</document>")
trimmed[idx]["content"] = content[:best] + trim_suffix
try:
new_msg_tokens = self._count_tokens([trimmed[idx]])
if new_msg_tokens is None:
continue
running_total = running_total - orig_msg_tokens + new_msg_tokens
except Exception:
pass
# Hard guarantee: if still over budget, replace remaining large
# non-system messages with compact placeholders until we fit.
if running_total > budget:
fallback_indices: list[int] = []
for i, msg in enumerate(trimmed):
if msg.get("role") == "system":
continue
content = msg.get("content")
if isinstance(content, str) and len(content) > 0:
fallback_indices.append(i)
for idx in fallback_indices:
if running_total <= budget:
break
role = trimmed[idx].get("role", "message")
placeholder = (
f"[content omitted to fit model context window; role={role}]"
)
old_tokens = self._count_tokens([trimmed[idx]]) or 0
trimmed[idx]["content"] = placeholder
new_tokens = self._count_tokens([trimmed[idx]]) or 0
running_total = running_total - old_tokens + new_tokens
if running_total > budget:
perf.error(
"[llm_router] unable to fit context even after aggressive trimming: "
"tokens=%d budget=%d",
running_total,
budget,
)
# Final safety net: clear oldest non-system contents.
for idx in fallback_indices:
if running_total <= budget:
break
old_tokens = self._count_tokens([trimmed[idx]]) or 0
trimmed[idx]["content"] = ""
new_tokens = self._count_tokens([trimmed[idx]]) or 0
running_total = running_total - old_tokens + new_tokens
perf.info(
"[llm_router] messages trimmed: %d%d tokens (budget=%d, max_input=%d)",
total_tokens,
running_total,
budget,
max_input,
)
return trimmed
# -----------------------------------------------------------------
@property
def _llm_type(self) -> str:
return "litellm-router"
@property
def _identifying_params(self) -> dict[str, Any]:
return {
"model": self.model,
"model_count": LLMRouterService.get_model_count(),
}
def bind_tools(
self,
tools: list[Any],
*,
tool_choice: str | dict | None = None,
**kwargs: Any,
) -> "ChatLiteLLMRouter":
"""
Bind tools to the model for function/tool calling.
Args:
tools: List of tools to bind (can be LangChain tools, Pydantic models, or dicts)
tool_choice: Optional tool choice strategy ("auto", "required", "none", or specific tool)
**kwargs: Additional arguments
Returns:
New ChatLiteLLMRouter instance with tools bound
"""
from langchain_core.utils.function_calling import convert_to_openai_tool
# Convert tools to OpenAI format
formatted_tools = []
for tool in tools:
if isinstance(tool, dict):
# Already in dict format
formatted_tools.append(tool)
else:
# Convert using LangChain utility
try:
formatted_tools.append(convert_to_openai_tool(tool))
except Exception as e:
logger.warning(f"Failed to convert tool {tool}: {e}")
continue
# Create a new instance with tools bound
return ChatLiteLLMRouter(
router=self._router,
bound_tools=formatted_tools if formatted_tools else None,
tool_choice=tool_choice,
model=self.model,
streaming=self.streaming,
**kwargs,
)
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""
Generate a response using the router (synchronous).
"""
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
try:
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
elapsed = time.perf_counter() - t0
perf.info(
"[llm_router] _generate completed msgs=%d tools=%d in %.3fs",
msg_count,
len(self._bound_tools) if self._bound_tools else 0,
elapsed,
)
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""
Generate a response using the router (asynchronous).
"""
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
# Convert LangChain messages to OpenAI format
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
try:
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
elapsed = time.perf_counter() - t0
perf.info(
"[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs",
msg_count,
len(self._bound_tools) if self._bound_tools else 0,
elapsed,
)
# Convert response to ChatResult with potential tool calls
message = self._convert_response_to_message(response.choices[0].message)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
):
"""
Stream a response using the router (synchronous).
"""
if not self._router:
raise ValueError("Router not initialized")
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
try:
response = self._router.completion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
except ContextWindowExceededError as e:
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
raise ContextOverflowError(str(e)) from e
raise
# Yield chunks
for chunk in response:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
chunk_msg = self._convert_delta_to_chunk(delta)
if chunk_msg:
yield ChatGenerationChunk(message=chunk_msg)
async def _astream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
):
"""
Stream a response using the router (asynchronous).
"""
if not self._router:
raise ValueError("Router not initialized")
perf = get_perf_logger()
t0 = time.perf_counter()
msg_count = len(messages)
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
call_kwargs["tool_choice"] = self._tool_choice
try:
response = await self._router.acompletion(
model=self.model,
messages=formatted_messages,
stop=stop,
stream=True,
**call_kwargs,
)
except ContextWindowExceededError as e:
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
except LiteLLMBadRequestError as e:
if _is_context_overflow_error(e):
perf.warning(
"[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs",
msg_count,
time.perf_counter() - t0,
)
raise ContextOverflowError(str(e)) from e
raise
t_first_chunk = time.perf_counter()
perf.info(
"[llm_router] _astream connection established msgs=%d in %.3fs",
msg_count,
t_first_chunk - t0,
)
chunk_count = 0
first_chunk_logged = False
async for chunk in response:
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
chunk_msg = self._convert_delta_to_chunk(delta)
if chunk_msg:
chunk_count += 1
if not first_chunk_logged:
perf.info(
"[llm_router] _astream first chunk in %.3fs (total %.3fs from start)",
time.perf_counter() - t_first_chunk,
time.perf_counter() - t0,
)
first_chunk_logged = True
yield ChatGenerationChunk(message=chunk_msg)
perf.info(
"[llm_router] _astream completed chunks=%d total=%.3fs",
chunk_count,
time.perf_counter() - t0,
)
def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]:
"""Convert LangChain messages to OpenAI format."""
from langchain_core.messages import (
AIMessage as AIMsg,
HumanMessage,
SystemMessage,
ToolMessage,
)
result = []
for msg in messages:
if isinstance(msg, SystemMessage):
result.append({"role": "system", "content": msg.content})
elif isinstance(msg, HumanMessage):
result.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMsg):
ai_msg: dict[str, Any] = {"role": "assistant"}
if msg.content:
ai_msg["content"] = msg.content
# Handle tool calls
if hasattr(msg, "tool_calls") and msg.tool_calls:
ai_msg["tool_calls"] = [
{
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": tc.get("name", ""),
"arguments": tc.get("args", "{}")
if isinstance(tc.get("args"), str)
else __import__("json").dumps(tc.get("args", {})),
},
}
for tc in msg.tool_calls
]
result.append(ai_msg)
elif isinstance(msg, ToolMessage):
result.append(
{
"role": "tool",
"tool_call_id": msg.tool_call_id,
"content": msg.content
if isinstance(msg.content, str)
else __import__("json").dumps(msg.content),
}
)
else:
# Fallback for other message types
role = getattr(msg, "type", "user")
if role == "human":
role = "user"
elif role == "ai":
role = "assistant"
result.append({"role": role, "content": msg.content})
return result
def _convert_response_to_message(self, response_message: Any) -> AIMessage:
"""Convert a LiteLLM response message to a LangChain AIMessage."""
import json
content = getattr(response_message, "content", None) or ""
# Check for tool calls
tool_calls = []
if hasattr(response_message, "tool_calls") and response_message.tool_calls:
for tc in response_message.tool_calls:
tool_call = {
"id": tc.id if hasattr(tc, "id") else "",
"name": tc.function.name if hasattr(tc, "function") else "",
"args": {},
}
# Parse arguments
if hasattr(tc, "function") and hasattr(tc.function, "arguments"):
try:
tool_call["args"] = json.loads(tc.function.arguments)
except json.JSONDecodeError:
tool_call["args"] = tc.function.arguments
tool_calls.append(tool_call)
if tool_calls:
return AIMessage(content=content, tool_calls=tool_calls)
return AIMessage(content=content)
def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
"""Convert a streaming delta to an AIMessageChunk."""
content = getattr(delta, "content", None) or ""
# Check for tool calls in delta
tool_call_chunks = []
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tc in delta.tool_calls:
chunk = {
"index": tc.index if hasattr(tc, "index") else 0,
"id": tc.id if hasattr(tc, "id") else None,
"name": tc.function.name
if hasattr(tc, "function") and hasattr(tc.function, "name")
else None,
"args": tc.function.arguments
if hasattr(tc, "function") and hasattr(tc.function, "arguments")
else "",
}
tool_call_chunks.append(chunk)
if content or tool_call_chunks:
if tool_call_chunks:
return AIMessageChunk(
content=content, tool_call_chunks=tool_call_chunks
)
return AIMessageChunk(content=content)
return None
def get_auto_mode_llm(
*,
streaming: bool = True,
) -> ChatLiteLLMRouter | None:
"""Return a cached ChatLiteLLMRouter for auto mode.
Base (no tools) instances are cached per ``streaming`` flag so we
avoid re-constructing them on every request. ``bind_tools()`` still
returns a fresh instance because bound tools differ per agent.
"""
if not LLMRouterService.is_initialized():
logger.warning("LLM Router not initialized for auto mode")
return None
cached = _router_instance_cache.get(streaming)
if cached is not None:
return cached
try:
instance = ChatLiteLLMRouter(streaming=streaming)
_router_instance_cache[streaming] = instance
return instance
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
def is_auto_mode(llm_config_id: int | None) -> bool:
"""
Check if the given LLM config ID represents Auto mode.
Args:
llm_config_id: The LLM config ID to check
Returns:
True if this is Auto mode, False otherwise
"""
return llm_config_id == AUTO_MODE_ID