""" LiteLLM Router Service for Load Balancing This module provides a singleton LiteLLM Router for automatic load balancing across multiple LLM deployments. It handles: - Rate limit management with automatic cooldowns - Automatic failover and retries - Usage-based routing to distribute load evenly The router is initialized from global LLM configs and provides both synchronous ChatLiteLLM-like interface and async methods. """ import copy import logging import re import time from typing import Any import litellm from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.exceptions import ContextOverflowError from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from litellm import Router from litellm.exceptions import ( BadRequestError as LiteLLMBadRequestError, ContextWindowExceededError, ) from app.utils.perf import get_perf_logger litellm.json_logs = False litellm.store_audit_logs = False logger = logging.getLogger(__name__) _CONTEXT_OVERFLOW_PATTERNS = re.compile( r"(input tokens exceed|context.{0,20}(length|window|limit)|" r"maximum context length|token.{0,20}(limit|exceed)|" r"too many tokens|reduce the length)", re.IGNORECASE, ) def _is_context_overflow_error(exc: LiteLLMBadRequestError) -> bool: """Check if a BadRequestError is actually a context window overflow.""" return bool(_CONTEXT_OVERFLOW_PATTERNS.search(str(exc))) # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 # Provider mapping for LiteLLM model string construction PROVIDER_MAP = { "OPENAI": "openai", "ANTHROPIC": "anthropic", "GROQ": "groq", "COHERE": "cohere", "GOOGLE": "gemini", "OLLAMA": "ollama_chat", "MISTRAL": "mistral", "AZURE_OPENAI": "azure", "OPENROUTER": "openrouter", "COMETAPI": "cometapi", "XAI": "xai", "BEDROCK": "bedrock", "AWS_BEDROCK": "bedrock", # Legacy support "VERTEX_AI": "vertex_ai", "TOGETHER_AI": "together_ai", "FIREWORKS_AI": "fireworks_ai", "REPLICATE": "replicate", "PERPLEXITY": "perplexity", "ANYSCALE": "anyscale", "DEEPINFRA": "deepinfra", "CEREBRAS": "cerebras", "SAMBANOVA": "sambanova", "AI21": "ai21", "CLOUDFLARE": "cloudflare", "DATABRICKS": "databricks", "DEEPSEEK": "openai", "ALIBABA_QWEN": "openai", "MOONSHOT": "openai", "ZHIPU": "openai", "GITHUB_MODELS": "github", "HUGGINGFACE": "huggingface", "MINIMAX": "openai", "CUSTOM": "custom", } class LLMRouterService: """ Singleton service for managing LiteLLM Router. The router provides automatic load balancing, failover, and rate limit handling across multiple LLM deployments. """ _instance = None _router: Router | None = None _model_list: list[dict] = [] _router_settings: dict = {} _initialized: bool = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @classmethod def get_instance(cls) -> "LLMRouterService": """Get the singleton instance of the router service.""" if cls._instance is None: cls._instance = cls() return cls._instance @classmethod def initialize( cls, global_configs: list[dict], router_settings: dict | None = None, ) -> None: """ Initialize the router with global LLM configurations. Args: global_configs: List of global LLM config dictionaries from YAML router_settings: Optional router settings (routing_strategy, num_retries, etc.) """ instance = cls.get_instance() if instance._initialized: logger.debug("LLM Router already initialized, skipping") return # Build model list from global configs model_list = [] for config in global_configs: deployment = cls._config_to_deployment(config) if deployment: model_list.append(deployment) if not model_list: logger.warning("No valid LLM configs found for router initialization") return instance._model_list = model_list instance._router_settings = router_settings or {} # Default router settings optimized for rate limit handling default_settings = { "routing_strategy": "usage-based-routing", # Best for rate limit management "num_retries": 3, "allowed_fails": 3, "cooldown_time": 60, # Cooldown for 60 seconds after failures "retry_after": 5, # Wait 5 seconds between retries } # Merge with provided settings final_settings = {**default_settings, **instance._router_settings} # Build a "auto-large" fallback group with deployments whose context # window exceeds the smallest deployment. This lets the router # automatically fall back to a bigger-context model when gpt-4o (128K) # hits ContextWindowExceededError. full_model_list, ctx_fallbacks = cls._build_context_fallback_groups(model_list) try: router_kwargs: dict[str, Any] = { "model_list": full_model_list, "routing_strategy": final_settings.get( "routing_strategy", "usage-based-routing" ), "num_retries": final_settings.get("num_retries", 3), "allowed_fails": final_settings.get("allowed_fails", 3), "cooldown_time": final_settings.get("cooldown_time", 60), "set_verbose": False, } if ctx_fallbacks: router_kwargs["context_window_fallbacks"] = ctx_fallbacks instance._router = Router(**router_kwargs) instance._initialized = True logger.info( "LLM Router initialized with %d deployments, " "strategy: %s, context_window_fallbacks: %s", len(model_list), final_settings.get("routing_strategy"), ctx_fallbacks or "none", ) except Exception as e: logger.error(f"Failed to initialize LLM Router: {e}") instance._router = None @classmethod def _build_context_fallback_groups( cls, model_list: list[dict] ) -> tuple[list[dict], list[dict[str, list[str]]] | None]: """Create an ``auto-large`` model group for context-window fallbacks. Uses ``litellm.get_model_info`` to discover the context window of each deployment. Deployments whose ``max_input_tokens`` exceeds the smallest window are duplicated into an ``auto-large`` group. The returned fallback config tells the Router: on ``ContextWindowExceededError`` for ``auto``, retry with ``auto-large``. Returns: (full_model_list, context_window_fallbacks) — ``full_model_list`` contains the original entries plus any ``auto-large`` duplicates. ``context_window_fallbacks`` is ``None`` when every deployment has the same context size (no useful fallback). """ from litellm import get_model_info ctx_map: dict[str, int] = {} for dep in model_list: params = dep.get("litellm_params", {}) base_model = params.get("base_model") or params.get("model", "") try: info = get_model_info(base_model) ctx = info.get("max_input_tokens") if isinstance(ctx, int) and ctx > 0: ctx_map[base_model] = ctx except Exception: continue if not ctx_map: return model_list, None min_ctx = min(ctx_map.values()) large_deployments: list[dict] = [] for dep in model_list: params = dep.get("litellm_params", {}) base_model = params.get("base_model") or params.get("model", "") if ctx_map.get(base_model, 0) > min_ctx: dup = {**dep, "model_name": "auto-large"} large_deployments.append(dup) if not large_deployments: return model_list, None logger.info( "Context-window fallback: %d large-context deployments " "(min_ctx=%d) added to 'auto-large' group", len(large_deployments), min_ctx, ) return model_list + large_deployments, [{"auto": ["auto-large"]}] @classmethod def _config_to_deployment(cls, config: dict) -> dict | None: """ Convert a global LLM config to a router deployment entry. Args: config: Global LLM config dictionary Returns: Router deployment dictionary or None if invalid """ try: # Skip if essential fields are missing if not config.get("model_name") or not config.get("api_key"): return None # Build model string if config.get("custom_provider"): model_string = f"{config['custom_provider']}/{config['model_name']}" else: provider = config.get("provider", "").upper() provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{provider_prefix}/{config['model_name']}" # Build litellm params litellm_params = { "model": model_string, "api_key": config.get("api_key"), } # Add optional api_base if config.get("api_base"): litellm_params["api_base"] = config["api_base"] # Add any additional litellm parameters if config.get("litellm_params"): litellm_params.update(config["litellm_params"]) # Extract rate limits if provided deployment = { "model_name": "auto", # All configs use same alias for unified routing "litellm_params": litellm_params, } # Add rate limits from config if available if config.get("rpm"): deployment["rpm"] = config["rpm"] if config.get("tpm"): deployment["tpm"] = config["tpm"] return deployment except Exception as e: logger.warning(f"Failed to convert config to deployment: {e}") return None @classmethod def get_router(cls) -> Router | None: """Get the initialized router instance.""" instance = cls.get_instance() return instance._router @classmethod def is_initialized(cls) -> bool: """Check if the router has been initialized.""" instance = cls.get_instance() return instance._initialized and instance._router is not None @classmethod def get_model_count(cls) -> int: """Get the number of models in the router.""" instance = cls.get_instance() return len(instance._model_list) _cached_context_profile: dict | None = None _cached_context_profile_computed: bool = False # Cached singleton instances keyed by (streaming,) to avoid re-creating on every call _router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {} def _get_cached_context_profile(router: Router) -> dict | None: """Compute and cache context profile across all router deployments. Called once on first ChatLiteLLMRouter creation; subsequent calls return the cached value. This avoids calling litellm.get_model_info() for every deployment on every request. Caches both ``max_input_tokens`` (minimum across deployments, used by SummarizationMiddleware) and ``max_input_tokens_upper`` (maximum across deployments, used for context-trimming so we can target the largest available model — the router's fallback logic handles smaller ones). """ global _cached_context_profile, _cached_context_profile_computed if _cached_context_profile_computed: return _cached_context_profile from litellm import get_model_info min_ctx: int | None = None max_ctx: int | None = None token_count_model: str | None = None ctx_pairs: list[tuple[int, str]] = [] for deployment in router.model_list: params = deployment.get("litellm_params", {}) base_model = params.get("base_model") or params.get("model", "") try: info = get_model_info(base_model) ctx = info.get("max_input_tokens") if isinstance(ctx, int) and ctx > 0: if min_ctx is None or ctx < min_ctx: min_ctx = ctx if max_ctx is None or ctx > max_ctx: max_ctx = ctx if token_count_model is None: token_count_model = base_model ctx_pairs.append((ctx, base_model)) except Exception: continue if min_ctx is not None: token_count_models: list[str] = [] if token_count_model: token_count_models.append(token_count_model) if ctx_pairs: ctx_pairs.sort(key=lambda x: x[0]) smallest_model = ctx_pairs[0][1] largest_model = ctx_pairs[-1][1] if smallest_model not in token_count_models: token_count_models.append(smallest_model) if largest_model not in token_count_models: token_count_models.append(largest_model) logger.info( "ChatLiteLLMRouter profile: max_input_tokens=%d, upper=%s, token_models=%s", min_ctx, max_ctx, token_count_models, ) _cached_context_profile = { "max_input_tokens": min_ctx, "max_input_tokens_upper": max_ctx, "token_count_model": token_count_model, "token_count_models": token_count_models, } else: _cached_context_profile = None _cached_context_profile_computed = True return _cached_context_profile class ChatLiteLLMRouter(BaseChatModel): """ A LangChain-compatible chat model that uses LiteLLM Router for load balancing. This wraps the LiteLLM Router to provide the same interface as ChatLiteLLM, making it a drop-in replacement for auto-mode routing. Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context window across all router deployments so that deepagents SummarizationMiddleware can use fraction-based triggers. **Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()`` directly — instances without bound tools are cached per streaming flag to avoid per-request re-initialization overhead and memory growth. """ # Use model_config for Pydantic v2 compatibility model_config = {"arbitrary_types_allowed": True} # Public attributes that Pydantic will manage model: str = "auto" streaming: bool = True # Bound tools and tool choice for tool calling _bound_tools: list[dict] | None = None _tool_choice: str | dict | None = None _router: Router | None = None def __init__( self, router: Router | None = None, bound_tools: list[dict] | None = None, tool_choice: str | dict | None = None, **kwargs, ): try: super().__init__(**kwargs) resolved_router = router or LLMRouterService.get_router() object.__setattr__(self, "_router", resolved_router) object.__setattr__(self, "_bound_tools", bound_tools) object.__setattr__(self, "_tool_choice", tool_choice) if not self._router: raise ValueError( "LLM Router not initialized. Call LLMRouterService.initialize() first." ) computed_profile = _get_cached_context_profile(self._router) if computed_profile is not None: object.__setattr__(self, "profile", computed_profile) logger.debug( "ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)", LLMRouterService.get_model_count(), self.streaming, bound_tools is not None, ) except Exception as e: logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}") raise # ----------------------------------------------------------------- # Context-aware trimming helpers # ----------------------------------------------------------------- def _get_token_count_model_names(self) -> list[str]: """Return concrete model names usable by ``litellm.token_counter``. The router uses ``"auto"`` as the model group name but tokenizers need concrete model identifiers. We keep multiple candidates and take the most conservative count across them. """ names: list[str] = [] profile = getattr(self, "profile", None) if isinstance(profile, dict): tcms = profile.get("token_count_models") if isinstance(tcms, list): for name in tcms: if isinstance(name, str) and name and name not in names: names.append(name) tcm = profile.get("token_count_model") if isinstance(tcm, str) and tcm and tcm not in names: names.append(tcm) if self._router and self._router.model_list: for dep in self._router.model_list: params = dep.get("litellm_params", {}) base = params.get("base_model") or params.get("model", "") if base and base not in names: names.append(base) if len(names) >= 3: break if not names: return ["gpt-4o"] return names def _count_tokens(self, messages: list[dict]) -> int | None: """Return conservative token count across candidate deployment models.""" from litellm import token_counter as _tc models = self._get_token_count_model_names() counts: list[int] = [] for model_name in models: try: counts.append(_tc(messages=messages, model=model_name)) except Exception: continue return max(counts) if counts else None def _get_max_input_tokens(self) -> int: """Return the max input tokens to use for context trimming. Prefers the *largest* context window across all deployments so we maximise usable context (the router's ``context_window_fallbacks`` handle routing to the right model). Falls back to the minimum profile value or a conservative default. """ profile = getattr(self, "profile", None) if isinstance(profile, dict): upper = profile.get("max_input_tokens_upper") if isinstance(upper, int) and upper > 0: return upper lower = profile.get("max_input_tokens") if isinstance(lower, int) and lower > 0: return lower return 128_000 def _trim_messages_to_fit_context( self, messages: list[dict], output_reserve_fraction: float = 0.10, ) -> list[dict]: """Trim message content via binary search to fit the model's context window. When the total token count exceeds the model's ``max_input_tokens``, this method identifies the largest messages (typically tool responses containing search results) and uses binary search on each to find the maximum content length that keeps the total within budget. Cutting prefers ```` XML boundaries so complete documents are preserved when possible. This is model-aware: it reads the context limit from ``litellm.get_model_info`` (cached in ``self.profile``) and counts tokens with ``litellm.token_counter``. """ max_input = self._get_max_input_tokens() output_reserve = min(int(max_input * output_reserve_fraction), 16_384) budget = max_input - output_reserve total_tokens = self._count_tokens(messages) if total_tokens is None: return messages if total_tokens <= budget: return messages perf = get_perf_logger() perf.warning( "[llm_router] context overflow detected: %d tokens > %d budget " "(max_input=%d, reserve=%d). Trimming messages.", total_tokens, budget, max_input, output_reserve, ) trimmed = copy.deepcopy(messages) # Per-message token counts for trimmable candidates. # Skip system messages to preserve agent instructions. msg_token_map: dict[int, int] = {} candidate_priority: dict[int, int] = {} for i, msg in enumerate(trimmed): if msg.get("role") == "system": continue role = msg.get("role") content = msg.get("content", "") if not isinstance(content, str) or len(content) < 500: continue # Prefer trimming tool/assistant outputs first. # User messages are only trimmed if they clearly contain injected # document context blobs. is_doc_blob = "" in content or "" in content if role in ("tool", "assistant"): candidate_priority[i] = 0 elif role == "user" and is_doc_blob: candidate_priority[i] = 1 else: continue token_count = self._count_tokens([msg]) if token_count is not None: msg_token_map[i] = token_count if not msg_token_map: perf.warning("[llm_router] no trimmable messages found, returning as-is") return trimmed # Trim largest messages first candidates = sorted( msg_token_map.items(), key=lambda x: (candidate_priority.get(x[0], 9), -x[1]), ) running_total = total_tokens trim_suffix = ( "\n\n" ) for idx, orig_msg_tokens in candidates: if running_total <= budget: break content = trimmed[idx]["content"] orig_len = len(content) # Binary search: find maximum content[:mid] that keeps total ≤ budget. lo, hi = 200, orig_len - 1 best = 200 while lo <= hi: mid = (lo + hi) // 2 trimmed[idx]["content"] = content[:mid] + trim_suffix new_msg_tokens = self._count_tokens([trimmed[idx]]) if new_msg_tokens is None: hi = mid - 1 continue projected_total = running_total - orig_msg_tokens + new_msg_tokens if projected_total <= budget: best = mid lo = mid + 1 else: hi = mid - 1 # Prefer cutting at a boundary for cleaner output last_doc_end = content[:best].rfind("") if last_doc_end > min(200, best // 4): best = last_doc_end + len("") trimmed[idx]["content"] = content[:best] + trim_suffix try: new_msg_tokens = self._count_tokens([trimmed[idx]]) if new_msg_tokens is None: continue running_total = running_total - orig_msg_tokens + new_msg_tokens except Exception: pass # Hard guarantee: if still over budget, replace remaining large # non-system messages with compact placeholders until we fit. if running_total > budget: fallback_indices: list[int] = [] for i, msg in enumerate(trimmed): if msg.get("role") == "system": continue content = msg.get("content") if isinstance(content, str) and len(content) > 0: fallback_indices.append(i) for idx in fallback_indices: if running_total <= budget: break role = trimmed[idx].get("role", "message") placeholder = ( f"[content omitted to fit model context window; role={role}]" ) old_tokens = self._count_tokens([trimmed[idx]]) or 0 trimmed[idx]["content"] = placeholder new_tokens = self._count_tokens([trimmed[idx]]) or 0 running_total = running_total - old_tokens + new_tokens if running_total > budget: perf.error( "[llm_router] unable to fit context even after aggressive trimming: " "tokens=%d budget=%d", running_total, budget, ) # Final safety net: clear oldest non-system contents. for idx in fallback_indices: if running_total <= budget: break old_tokens = self._count_tokens([trimmed[idx]]) or 0 trimmed[idx]["content"] = "" new_tokens = self._count_tokens([trimmed[idx]]) or 0 running_total = running_total - old_tokens + new_tokens perf.info( "[llm_router] messages trimmed: %d → %d tokens (budget=%d, max_input=%d)", total_tokens, running_total, budget, max_input, ) return trimmed # ----------------------------------------------------------------- @property def _llm_type(self) -> str: return "litellm-router" @property def _identifying_params(self) -> dict[str, Any]: return { "model": self.model, "model_count": LLMRouterService.get_model_count(), } def bind_tools( self, tools: list[Any], *, tool_choice: str | dict | None = None, **kwargs: Any, ) -> "ChatLiteLLMRouter": """ Bind tools to the model for function/tool calling. Args: tools: List of tools to bind (can be LangChain tools, Pydantic models, or dicts) tool_choice: Optional tool choice strategy ("auto", "required", "none", or specific tool) **kwargs: Additional arguments Returns: New ChatLiteLLMRouter instance with tools bound """ from langchain_core.utils.function_calling import convert_to_openai_tool # Convert tools to OpenAI format formatted_tools = [] for tool in tools: if isinstance(tool, dict): # Already in dict format formatted_tools.append(tool) else: # Convert using LangChain utility try: formatted_tools.append(convert_to_openai_tool(tool)) except Exception as e: logger.warning(f"Failed to convert tool {tool}: {e}") continue # Create a new instance with tools bound return ChatLiteLLMRouter( router=self._router, bound_tools=formatted_tools if formatted_tools else None, tool_choice=tool_choice, model=self.model, streaming=self.streaming, **kwargs, ) def _generate( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """ Generate a response using the router (synchronous). """ if not self._router: raise ValueError("Router not initialized") perf = get_perf_logger() t0 = time.perf_counter() msg_count = len(messages) # Convert LangChain messages to OpenAI format formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: call_kwargs["tool_choice"] = self._tool_choice try: response = self._router.completion( model=self.model, messages=formatted_messages, stop=stop, **call_kwargs, ) except ContextWindowExceededError as e: perf.warning( "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", msg_count, time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _generate CONTEXT_OVERFLOW msgs=%d in %.3fs", msg_count, time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise elapsed = time.perf_counter() - t0 perf.info( "[llm_router] _generate completed msgs=%d tools=%d in %.3fs", msg_count, len(self._bound_tools) if self._bound_tools else 0, elapsed, ) # Convert response to ChatResult with potential tool calls message = self._convert_response_to_message(response.choices[0].message) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) async def _agenerate( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """ Generate a response using the router (asynchronous). """ if not self._router: raise ValueError("Router not initialized") perf = get_perf_logger() t0 = time.perf_counter() msg_count = len(messages) # Convert LangChain messages to OpenAI format formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: call_kwargs["tool_choice"] = self._tool_choice try: response = await self._router.acompletion( model=self.model, messages=formatted_messages, stop=stop, **call_kwargs, ) except ContextWindowExceededError as e: perf.warning( "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", msg_count, time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _agenerate CONTEXT_OVERFLOW msgs=%d in %.3fs", msg_count, time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise elapsed = time.perf_counter() - t0 perf.info( "[llm_router] _agenerate completed msgs=%d tools=%d in %.3fs", msg_count, len(self._bound_tools) if self._bound_tools else 0, elapsed, ) # Convert response to ChatResult with potential tool calls message = self._convert_response_to_message(response.choices[0].message) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) def _stream( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ): """ Stream a response using the router (synchronous). """ if not self._router: raise ValueError("Router not initialized") formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: call_kwargs["tool_choice"] = self._tool_choice try: response = self._router.completion( model=self.model, messages=formatted_messages, stop=stop, stream=True, **call_kwargs, ) except ContextWindowExceededError as e: raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): raise ContextOverflowError(str(e)) from e raise # Yield chunks for chunk in response: if hasattr(chunk, "choices") and chunk.choices: delta = chunk.choices[0].delta chunk_msg = self._convert_delta_to_chunk(delta) if chunk_msg: yield ChatGenerationChunk(message=chunk_msg) async def _astream( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ): """ Stream a response using the router (asynchronous). """ if not self._router: raise ValueError("Router not initialized") perf = get_perf_logger() t0 = time.perf_counter() msg_count = len(messages) formatted_messages = self._convert_messages(messages) formatted_messages = self._trim_messages_to_fit_context(formatted_messages) # Add tools if bound call_kwargs = {**kwargs} if self._bound_tools: call_kwargs["tools"] = self._bound_tools if self._tool_choice is not None: call_kwargs["tool_choice"] = self._tool_choice try: response = await self._router.acompletion( model=self.model, messages=formatted_messages, stop=stop, stream=True, **call_kwargs, ) except ContextWindowExceededError as e: perf.warning( "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", msg_count, time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e except LiteLLMBadRequestError as e: if _is_context_overflow_error(e): perf.warning( "[llm_router] _astream CONTEXT_OVERFLOW msgs=%d in %.3fs", msg_count, time.perf_counter() - t0, ) raise ContextOverflowError(str(e)) from e raise t_first_chunk = time.perf_counter() perf.info( "[llm_router] _astream connection established msgs=%d in %.3fs", msg_count, t_first_chunk - t0, ) chunk_count = 0 first_chunk_logged = False async for chunk in response: if hasattr(chunk, "choices") and chunk.choices: delta = chunk.choices[0].delta chunk_msg = self._convert_delta_to_chunk(delta) if chunk_msg: chunk_count += 1 if not first_chunk_logged: perf.info( "[llm_router] _astream first chunk in %.3fs (total %.3fs from start)", time.perf_counter() - t_first_chunk, time.perf_counter() - t0, ) first_chunk_logged = True yield ChatGenerationChunk(message=chunk_msg) perf.info( "[llm_router] _astream completed chunks=%d total=%.3fs", chunk_count, time.perf_counter() - t0, ) def _convert_messages(self, messages: list[BaseMessage]) -> list[dict]: """Convert LangChain messages to OpenAI format.""" from langchain_core.messages import ( AIMessage as AIMsg, HumanMessage, SystemMessage, ToolMessage, ) result = [] for msg in messages: if isinstance(msg, SystemMessage): result.append({"role": "system", "content": msg.content}) elif isinstance(msg, HumanMessage): result.append({"role": "user", "content": msg.content}) elif isinstance(msg, AIMsg): ai_msg: dict[str, Any] = {"role": "assistant"} if msg.content: ai_msg["content"] = msg.content # Handle tool calls if hasattr(msg, "tool_calls") and msg.tool_calls: ai_msg["tool_calls"] = [ { "id": tc.get("id", ""), "type": "function", "function": { "name": tc.get("name", ""), "arguments": tc.get("args", "{}") if isinstance(tc.get("args"), str) else __import__("json").dumps(tc.get("args", {})), }, } for tc in msg.tool_calls ] result.append(ai_msg) elif isinstance(msg, ToolMessage): result.append( { "role": "tool", "tool_call_id": msg.tool_call_id, "content": msg.content if isinstance(msg.content, str) else __import__("json").dumps(msg.content), } ) else: # Fallback for other message types role = getattr(msg, "type", "user") if role == "human": role = "user" elif role == "ai": role = "assistant" result.append({"role": role, "content": msg.content}) return result def _convert_response_to_message(self, response_message: Any) -> AIMessage: """Convert a LiteLLM response message to a LangChain AIMessage.""" import json content = getattr(response_message, "content", None) or "" # Check for tool calls tool_calls = [] if hasattr(response_message, "tool_calls") and response_message.tool_calls: for tc in response_message.tool_calls: tool_call = { "id": tc.id if hasattr(tc, "id") else "", "name": tc.function.name if hasattr(tc, "function") else "", "args": {}, } # Parse arguments if hasattr(tc, "function") and hasattr(tc.function, "arguments"): try: tool_call["args"] = json.loads(tc.function.arguments) except json.JSONDecodeError: tool_call["args"] = tc.function.arguments tool_calls.append(tool_call) if tool_calls: return AIMessage(content=content, tool_calls=tool_calls) return AIMessage(content=content) def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None: """Convert a streaming delta to an AIMessageChunk.""" content = getattr(delta, "content", None) or "" # Check for tool calls in delta tool_call_chunks = [] if hasattr(delta, "tool_calls") and delta.tool_calls: for tc in delta.tool_calls: chunk = { "index": tc.index if hasattr(tc, "index") else 0, "id": tc.id if hasattr(tc, "id") else None, "name": tc.function.name if hasattr(tc, "function") and hasattr(tc.function, "name") else None, "args": tc.function.arguments if hasattr(tc, "function") and hasattr(tc.function, "arguments") else "", } tool_call_chunks.append(chunk) if content or tool_call_chunks: if tool_call_chunks: return AIMessageChunk( content=content, tool_call_chunks=tool_call_chunks ) return AIMessageChunk(content=content) return None def get_auto_mode_llm( *, streaming: bool = True, ) -> ChatLiteLLMRouter | None: """Return a cached ChatLiteLLMRouter for auto mode. Base (no tools) instances are cached per ``streaming`` flag so we avoid re-constructing them on every request. ``bind_tools()`` still returns a fresh instance because bound tools differ per agent. """ if not LLMRouterService.is_initialized(): logger.warning("LLM Router not initialized for auto mode") return None cached = _router_instance_cache.get(streaming) if cached is not None: return cached try: instance = ChatLiteLLMRouter(streaming=streaming) _router_instance_cache[streaming] = instance return instance except Exception as e: logger.error(f"Failed to create ChatLiteLLMRouter: {e}") return None def is_auto_mode(llm_config_id: int | None) -> bool: """ Check if the given LLM config ID represents Auto mode. Args: llm_config_id: The LLM config ID to check Returns: True if this is Auto mode, False otherwise """ return llm_config_id == AUTO_MODE_ID