feat(chat): route models by provider capabilities

This commit is contained in:
Anish Sarkar 2026-06-11 18:22:23 +05:30
parent 8f20a32571
commit c28c4f5785
18 changed files with 429 additions and 319 deletions

View file

@ -40,7 +40,7 @@ def check_image_input_capability(
else None
)
if not is_known_text_only_chat_model(
provider=agent_config.provider,
litellm_provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,

View file

@ -80,7 +80,6 @@ async def _generate_title(
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by
@ -125,26 +124,12 @@ async def _generate_title(
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
provider_value = agent_config.provider if agent_config is not None else None
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=title_api_base,
api_base=getattr(llm, "api_base", None),
)
usage_info = None

View file

@ -1,8 +1,8 @@
"""Load an LLM + AgentConfig bundle for a given config id.
Handles both code paths uniformly:
- ``config_id >= 0`` database-backed ``NewLLMConfig`` row (per-user/per-space).
- ``config_id < 0`` YAML-defined global LLM config (built-in defaults).
- ``config_id > 0`` database-backed model-connection ``Model`` row.
- ``config_id < 0`` virtual global model materialized from YAML/OpenRouter.
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
``None``. The caller emits the friendly SSE error frame.
@ -12,15 +12,72 @@ from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.agents.chat.runtime.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
create_chat_litellm_from_config,
load_agent_config,
load_global_llm_config_by_id,
SanitizedChatLiteLLM,
)
from app.config import config
from app.db import Model, SearchSpace
from app.services.model_resolver import to_litellm
def _agent_config_from_resolved(
*,
config_id: int,
config_name: str | None,
provider: str,
model_name: str,
api_key: str | None,
api_base: str | None,
litellm_params: dict | None,
supports_image_input: bool,
billing_tier: str = "free",
) -> AgentConfig:
return AgentConfig(
provider=provider,
model_name=model_name,
api_key=api_key or "",
api_base=api_base,
custom_provider=None,
litellm_params=litellm_params,
config_id=config_id,
config_name=config_name,
is_auto_mode=False,
billing_tier=billing_tier,
is_premium=billing_tier == "premium",
supports_image_input=supports_image_input,
)
async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None:
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
return result.scalars().first()
async def _load_db_model(
session: AsyncSession,
*,
model_id: int,
search_space: SearchSpace,
) -> Model | None:
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id, Model.enabled.is_(True))
)
model = result.scalars().first()
if not model or not model.connection or not model.connection.enabled:
return None
conn = model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
return None
if conn.user_id is not None and conn.user_id != search_space.user_id:
return None
return model
async def load_llm_bundle(
@ -29,29 +86,67 @@ async def load_llm_bundle(
config_id: int,
search_space_id: int,
) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
search_space = await _load_search_space(session, search_space_id)
if not search_space:
return None, None, f"Search space {search_space_id} not found"
if config_id > 0:
model = await _load_db_model(
session,
model_id=config_id,
search_space=search_space,
)
if not loaded_agent_config:
if not model or not (model.capabilities or {}).get("chat"):
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
f"Failed to load chat model with id {config_id}",
)
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=model.display_name or model.model_id,
provider=model.connection.litellm_provider or "",
model_name=model.model_id,
api_key=model.connection.api_key,
api_base=model.connection.base_url,
litellm_params=(model.connection.extra or {}).get("litellm_params"),
supports_image_input=bool((model.capabilities or {}).get("vision")),
billing_tier="free",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
return None, None, f"Failed to load global chat model with id {config_id}"
global_connection = next(
(
c
for c in config.GLOBAL_CONNECTIONS
if c.get("id") == global_model.get("connection_id")
),
None,
)
if not global_connection:
return None, None, f"Failed to load global connection for model {config_id}"
model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"])
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=global_model.get("display_name") or global_model.get("model_id"),
provider=global_connection.get("litellm_provider") or "",
model_name=global_model["model_id"],
api_key=global_connection.get("api_key"),
api_base=global_connection.get("base_url"),
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
)
return (
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None,
)